Skip to content

Instantly share code, notes, and snippets.

@wizyoung
Last active November 18, 2024 13:49
Show Gist options
  • Save wizyoung/5330ad501e73a97dfe2f0088decdb1ca to your computer and use it in GitHub Desktop.
Save wizyoung/5330ad501e73a97dfe2f0088decdb1ca to your computer and use it in GitHub Desktop.
chunked_lce.py
import sys
sys.path.append('.')
import argparse
import csv
import json
import os
import time
from typing import Any, Callable, Dict, List, Optional, Union
from collections import OrderedDict
from dataclasses import asdict, dataclass
from importlib.metadata import version
from itertools import zip_longest
import numpy as np
import torch
import triton
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
from chunked_lce import CompiledFusedLinearCrossEntropyLoss
LIGER_KERNEL_VERSION = version("liger-kernel")
QUANTILES = [0.5, 0.2, 0.8]
@dataclass
class SingleBenchmarkRunInput:
x: Union[int, float]
kernel_provider: str
kernel_operation_mode: Optional[str] = ""
extra_benchmark_config: Optional[Dict[str, Any]] = None
@dataclass
class SingleBenchmarkRunOutput:
# 20th percentile
y_20: float
# 50th percentile (median)
y_50: float
# 80th percentile
y_80: float
@dataclass
class BenchmarkData:
"""
BenchmarkData is a dataclass to store the benchmark data for a a completed benchmark
run on all x-values for a given kernel/kernel operation mode/metric/extra_benchmark_config
"""
kernel_name: str
kernel_provider: str
metric_name: str
metric_unit: str
gpu_name: str
x_name: str
x_label: str
x_values: List[float]
y_values_50: List[float]
y_values_20: List[float]
y_values_80: List[float]
timestamp: str
kernel_operation_mode: Optional[str] = None
extra_benchmark_config_str: Optional[str] = None
liger_version: str = LIGER_KERNEL_VERSION
@dataclass
class BenchmarkDataCSVRow:
# The ordering of field names here will be the order of columns in the CSV
kernel_name: str
kernel_provider: str
kernel_operation_mode: Union[str, None]
metric_name: str
metric_unit: str
x_name: str
x_label: str
x_value: float
y_value_50: float
y_value_20: float
y_value_80: float
extra_benchmark_config_str: Union[str, None]
gpu_name: str
timestamp: str
liger_version: str
def _test_memory(
func: Callable,
_iter: int = 10,
quantiles: Optional[List[float]] = None,
return_mode="mean",
) -> float:
assert return_mode in ["min", "max", "mean", "median"]
total_mem = []
for _ in range(_iter):
torch.cuda.memory.reset_peak_memory_stats()
func()
# Convert to MB
mem = torch.cuda.max_memory_allocated() / 2**20
total_mem.append(mem)
total_mem = torch.tensor(total_mem, dtype=torch.float)
if quantiles is not None:
quantiles_data = torch.quantile(
total_mem, torch.tensor(quantiles, dtype=torch.float)
).tolist()
if len(quantiles_data) == 1:
quantiles_data = quantiles_data[0]
return quantiles_data
return getattr(torch, return_mode)(total_mem).item()
def get_current_file_directory() -> str:
"""
Returns the directory path of the current Python file.
"""
# Get the absolute path of the current file
current_file_path = os.path.abspath(__file__)
# Get the directory path of the current file
return os.path.dirname(current_file_path)
def sleep(seconds):
def decorator(function):
def wrapper(*args, **kwargs):
time.sleep(seconds)
return function(*args, **kwargs)
return wrapper
return decorator
def _print_benchmarking_banner(metric_name: str, kernel_name: str):
print("**************************************")
print(f" BENCHMARKING {metric_name.upper()} for {kernel_name.upper()}")
print("**************************************")
def get_formatted_time():
return time.strftime("%Y-%m-%d %H:%M:%S")
def get_gpu_name():
"""
Returns the current GPU name, formatted to serve as a directory name
"""
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(torch.cuda.current_device())
return gpu_name
else:
raise Exception("Benchmarks can only be run on GPU.")
def update_benchmark_data_csv(
benchmark_data_list: List[BenchmarkData],
filename: str = "all_benchmark_data.csv",
overwrite: bool = True,
):
"""
Update the CSV file with the new benchmark data. If the file does not exist, create it.
If an entry already exists for the benchmark, then overwrite it if `overwrite` is True.
"""
def create_unique_key(row):
# This unique key is used to determine if a benchmark run already exists in the CSV
# If the key is the same, then the benchmark run already exists and will optionally
# be overwritten. Otherwise, it is considered a new benchmark run and appended.
return (
row["kernel_name"],
row["kernel_provider"],
row["kernel_operation_mode"] if row["kernel_operation_mode"] else "",
row["metric_name"],
row["x_name"],
str(row["x_value"]),
(
row["extra_benchmark_config_str"]
if row["extra_benchmark_config_str"]
else ""
),
row["gpu_name"],
)
fieldnames = BenchmarkDataCSVRow.__annotations__.keys()
# Make filename path relative to current file
# filename_abs_path = os.path.join('raw_data', filename)
filename_abs_path = filename
file_exists = os.path.isfile(filename_abs_path)
# Read existing data into a list of dicts
existing_data = []
if file_exists:
with open(filename_abs_path, mode="r") as file:
reader = csv.DictReader(file)
for row in reader:
existing_data.append(row)
existing_data_dict = OrderedDict(
(create_unique_key(row), row) for row in existing_data
)
for benchmark_data in benchmark_data_list:
benchmark_data_dict = asdict(benchmark_data)
x_values = benchmark_data_dict.pop("x_values")
y_values_50 = benchmark_data_dict.pop("y_values_50")
y_values_20 = benchmark_data_dict.pop("y_values_20")
y_values_80 = benchmark_data_dict.pop("y_values_80")
# Need to convert benchmark_data into multiple rows based on x_values and y_values
for x_value, y_value_50, y_value_20, y_value_80 in zip_longest(
x_values, y_values_50, y_values_20, y_values_80
):
row = BenchmarkDataCSVRow(
x_value=x_value,
y_value_50=y_value_50,
y_value_20=y_value_20,
y_value_80=y_value_80,
**benchmark_data_dict,
)
row_dict = asdict(row)
row_key = create_unique_key(row_dict)
if row_key in existing_data_dict:
if overwrite:
# If overwriting, update the row
existing_data_dict[row_key] = row_dict
else:
# If not overwriting, skip this row
pass
else:
existing_data_dict[row_key] = row_dict
with open(filename_abs_path, mode="w", newline="") as file:
writer = csv.DictWriter(file, fieldnames=fieldnames)
writer.writeheader()
for row in existing_data_dict.values():
writer.writerow(row)
class CustomEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, torch.dtype):
return str(obj)
return super().default(self, obj)
def print_benchmark_data(benchmark_data_list: List[BenchmarkData]) -> str:
print("********** Benchmark Data **********")
formatted_list = [obj.__dict__ for obj in benchmark_data_list]
print(json.dumps(formatted_list, indent=2))
def run_benchmarks(
bench_test_fn: Callable,
kernel_name: str,
metric_name: str,
metric_unit: str,
x_name: str,
x_label: str,
x_values: List[Union[float, int]],
kernel_providers: List[str],
kernel_operation_modes: Optional[List[str]] = [None],
extra_benchmark_configs: Optional[List[Dict[str, Any]]] = None,
overwrite: bool = False,
):
"""
Run benchmarks given a bench_test_fn that takes in a SingleBenchmarkRunInput as input and
saves data to the CSV file.
Args:
- bench_test_fn: The benchmark test function to run. This function should take in a
SingleBenchmarkRunInput as input and return a SingleBenchmarkRunOutput.
- kernel_name: The name of the kernel being benchmarked (e.g. "swiglu")
- metric_name: The name of the metric being benchmarked (e.g. "speed" or "memory")
- metric_unit: The unit of the metric being benchmarked (e.g. "ms" or "MB")
- x_name: The name of the x-axis (e.g. "T" for sequence length)
- x_label: The label of the x-axis (e.g. "sequence length")
- x_values: The list of x-values to run the benchmark on (e.g. [2**i for i in range(10, 14)])
- kernel_providers: The list of kernel providers to run the benchmark on (e.g. ["liger", "huggingface"])
- kernel_operation_modes: The list of kernel operation modes to run the benchmark on (e.g. ["full", "backward"])
- extra_benchmark_configs: The list of extra benchmark configurations to run the benchmark on.
- overwrite: Whether to overwrite the existing benchmark data entry if it already exists.
"""
assert len(kernel_operation_modes) >= 1
assert len(kernel_providers) >= 1
_print_benchmarking_banner(metric_name=metric_name, kernel_name=kernel_name)
gpu_name = get_gpu_name()
benchmark_data_list = []
for extra_benchmark_config in extra_benchmark_configs:
for kernel_operation_mode in kernel_operation_modes:
for kernel_provider in kernel_providers:
y_values_50 = []
y_values_20 = []
y_values_80 = []
for x in x_values:
print(f'{extra_benchmark_config} - {kernel_operation_mode} - {kernel_provider} - {x}')
single_benchmark_run_input = SingleBenchmarkRunInput(
x=x,
kernel_provider=kernel_provider,
kernel_operation_mode=kernel_operation_mode,
extra_benchmark_config=extra_benchmark_config,
)
benchmark_result: SingleBenchmarkRunOutput = bench_test_fn(
single_benchmark_run_input
)
y_values_50.append(benchmark_result.y_50)
y_values_20.append(benchmark_result.y_20)
y_values_80.append(benchmark_result.y_80)
benchmark_run_data = BenchmarkData(
kernel_name=kernel_name,
kernel_operation_mode=kernel_operation_mode,
kernel_provider=kernel_provider,
metric_name=metric_name,
metric_unit=metric_unit,
gpu_name=gpu_name,
x_name=x_name,
x_label=x_label,
x_values=x_values,
y_values_50=y_values_50,
y_values_20=y_values_20,
y_values_80=y_values_80,
extra_benchmark_config_str=json.dumps(
extra_benchmark_config, cls=CustomEncoder
),
timestamp=get_formatted_time(),
liger_version=LIGER_KERNEL_VERSION,
)
benchmark_data_list.append(benchmark_run_data)
print_benchmark_data(benchmark_data_list)
update_benchmark_data_csv(
benchmark_data_list=benchmark_data_list, overwrite=overwrite
)
def parse_benchmark_script_args():
parser = argparse.ArgumentParser(description="Benchmarking script for Liger-Kernel")
# Add an optional --overwrite flag
parser.add_argument(
"--overwrite",
action="store_true",
help="Flag to overwrite existing benchmark data with current run.",
)
args = parser.parse_args()
return args
class TorchLMHeadCE(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based cross entropy loss.
:param H: hidden size
:param V: vocab size
:param ignore_index: index to ignore
:param reduction: reduction method
"""
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.ce_loss = torch.nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction="mean"
)
def forward(self, x, y):
logits = self.lin(x)
return self.ce_loss(logits, y)
class CompiledTorchLMHeadCE(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ignore_index: int = -100,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.ce_loss = CompiledFusedLinearCrossEntropyLoss()
def forward(self, x, y, softcap_value=None):
return self.ce_loss(self.lin.weight, x, y, self.lin.bias, softcap_value)
class LigerLMHeadCE(torch.nn.Module):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.ce_loss = LigerFusedLinearCrossEntropyLoss(
ignore_index=ignore_index, reduction="mean"
)
def forward(self, x, y):
return self.ce_loss(self.lin.weight, x, y)
#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################
def bench_memory_fused_linear_cross_entropy(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
device = "cuda"
torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
compiled_torch_lm_head_ce = CompiledTorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
_input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)
def fwd():
if provider == "liger":
return liger_lm_head_ce(_input, target)
elif provider == "huggingface":
return torch_lm_head_ce(_input, target)
elif provider == "compile":
return compiled_torch_lm_head_ce(_input, target)
def full():
y = fwd()
y.backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
# #############################################################################
# # Test the speed of the fused linear cross entropy loss
# #############################################################################
def bench_speed_fused_linear_cross_entropy(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
device = "cuda"
torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
compiled_torch_lm_head_ce = CompiledTorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
_input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)
def fwd():
if provider == "liger":
return liger_lm_head_ce(_input, target)
elif provider == "huggingface":
return torch_lm_head_ce(_input, target)
elif provider == "compile":
return compiled_torch_lm_head_ce(_input, target)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
# fast
noisy_points = [2**i for i in range(12, 16)]
# point_sample_num = 10 + 2
noisy_points = [2**12] + (np.random.randint(2**12, 2**15 + 1, 10) + np.random.randint(0, 1001, 10)).tolist() + [2**15]
# noisy_points.sort()
common_configs = {
"kernel_name": "FLCE",
"x_name": "BT",
"x_label": "B x T",
"x_values": noisy_points,
"kernel_providers": ["compile", "liger", "huggingface"],
"extra_benchmark_configs": [
{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16} # llama3 config
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_fused_linear_cross_entropy,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_cross_entropy,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs
)
# usage: after running benchmark_flce.py, run:
# python benchmarks_visualizer.py --kernel-name FLCE --metric-name memory --kernel-operation-mode full
# python benchmarks_visualizer.py --kernel-name FLCE --metric-name speed --kernel-operation-mode full
import json
import os
from argparse import ArgumentParser
from dataclasses import dataclass
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
DATA_PATH = "all_benchmark_data.csv"
VISUALIZATIONS_PATH = "visualizations/"
@dataclass
class VisualizationsConfig:
"""
Configuration for the visualizations script.
Args:
kernel_name (str): Kernel name to benchmark. (Will run `scripts/benchmark_{kernel_name}.py`)
metric_name (str): Metric name to visualize (speed/memory)
kernel_operation_mode (str): Kernel operation mode to visualize (forward/backward/full). Defaults to "full"
display (bool): Display the visualization. Defaults to False
overwrite (bool): Overwrite existing visualization, if none exist this flag has no effect as ones are always created and saved. Defaults to False
"""
kernel_name: str
metric_name: str
kernel_operation_mode: str = "full"
display: bool = False
overwrite: bool = False
def parse_args() -> VisualizationsConfig:
"""Parse command line arguments into a configuration object.
Returns:
VisualizationsConfig: Configuration object for the visualizations script.
"""
parser = ArgumentParser()
parser.add_argument(
"--kernel-name", type=str, required=True, help="Kernel name to benchmark"
)
parser.add_argument(
"--metric-name",
type=str,
required=True,
help="Metric name to visualize (speed/memory)",
)
parser.add_argument(
"--kernel-operation-mode",
type=str,
required=True,
help="Kernel operation mode to visualize (forward/backward/full)",
)
parser.add_argument(
"--display", action="store_true", help="Display the visualization"
)
parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite existing visualization, if none exist this flag has no effect as one are always created",
)
args = parser.parse_args()
return VisualizationsConfig(**dict(args._get_kwargs()))
def load_data(config: VisualizationsConfig) -> pd.DataFrame:
"""Loads the benchmark data from the CSV file and filters it based on the configuration.
Args:
config (VisualizationsConfig): Configuration object for the visualizations script.
Raises:
ValueError: If no data is found for the given filters.
Returns:
pd.DataFrame: Filtered benchmark dataframe.
"""
df = pd.read_csv(DATA_PATH)
df["extra_benchmark_config"] = df["extra_benchmark_config_str"].apply(json.loads)
filtered_df = df[
(df["kernel_name"] == config.kernel_name)
& (df["metric_name"] == config.metric_name)
& (df["kernel_operation_mode"] == config.kernel_operation_mode)
# Use this to filter by extra benchmark configuration property
# & (data['extra_benchmark_config'].apply(lambda x: x.get('H') == 4096))
# FIXME: maybe add a way to filter using some configuration, except of hardcoding it
]
if filtered_df.empty:
raise ValueError("No data found for the given filters")
return filtered_df
def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
"""Plots the benchmark data, saving the result if needed.
Args:
df (pd.DataFrame): Filtered benchmark dataframe.
config (VisualizationsConfig): Configuration object for the visualizations script.
"""
xlabel = df["x_label"].iloc[0]
ylabel = f"{config.metric_name} ({df['metric_unit'].iloc[0]})"
# Sort by "kernel_provider" to ensure consistent color assignment
df = df.sort_values(by="kernel_provider")
plt.figure(figsize=(10, 6))
sns.set(style="whitegrid")
ax = sns.lineplot(
data=df,
x="x_value",
y="y_value_50",
hue="kernel_provider",
marker="o",
palette="tab10",
errorbar=("ci", None),
)
# Seaborn can't plot pre-computed error bars, so we need to do it manually
lines = ax.get_lines()
colors = [line.get_color() for line in lines]
for (_, group_data), color in zip(df.groupby("kernel_provider"), colors):
# for i, row in group_data.iterrows():
y_error_lower = group_data["y_value_50"] - group_data["y_value_20"]
y_error_upper = group_data["y_value_80"] - group_data["y_value_50"]
y_error = [y_error_lower, y_error_upper]
plt.errorbar(
group_data["x_value"],
group_data["y_value_50"],
yerr=y_error,
fmt="o",
color=color,
capsize=5,
)
plt.legend(title="Kernel Provider")
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.tight_layout()
out_path = os.path.join(
VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png"
)
if config.display:
plt.show()
if config.overwrite or not os.path.exists(
out_path
): # Save the plot if it doesn't exist or if we want to overwrite it
os.makedirs(VISUALIZATIONS_PATH, exist_ok=True)
plt.savefig(out_path)
plt.close()
def main():
config = parse_args()
df = load_data(config)
plot_data(df, config)
if __name__ == "__main__":
main()
# inspired by https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899
# discussion: https://github.com/linkedin/Liger-Kernel/issues/227
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
def cdiv(x: int, y: int):
return (x + y - 1) // y
def next_power_of_2(n: int):
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n += 1
return n
class ChunkedCE(torch.autograd.Function):
@staticmethod
def forward(ctx, _input, weight, target, bias=None, softcap_value=None, chunk_size=None):
BT, H = _input.shape
if chunk_size is None:
chunk_size = 1024
if chunk_size == 'auto':
V = weight.shape[0]
inc_factor = cdiv(V, H) # (V + H - 1) // H
chunk_size = next_power_of_2(cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
chunks = cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
total_n_non_ignore = (target != -100).sum()
def compute_loss(input_chunk, weight, bias, target):
# if bias is not None:
# logits = torch.addmm(bias, input_chunk, weight.t())
# else:
# logits = torch.matmul(input_chunk, weight.t())
# more memory efficient when bias is set
logits = F.linear(input_chunk, weight, bias)
if softcap_value is not None:
logits = torch.tanh(logits / softcap_value) * softcap_value
logits = logits.float()
loss = F.cross_entropy(logits, target)
return loss
grad_weight = torch.zeros_like(weight)
grad_input = torch.zeros_like(_input)
grad_bias = torch.zeros_like(bias) if bias is not None else None
loss_acc = torch.zeros((), device=_input.device)
@torch.compile(dynamic=True, options={"shape_padding": True})
def accumulate_chunk(input_chunk, target_chunk):
if bias is not None:
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), chunk_loss = torch.func.grad_and_value(compute_loss, argnums=(0, 1, 2))(input_chunk, weight, bias, target_chunk)
else:
(chunk_grad_input, chunk_grad_weight), chunk_loss = torch.func.grad_and_value(compute_loss, argnums=(0, 1))(input_chunk, weight, None, target_chunk)
chunk_grad_bias = None
n_non_ignore = (target_chunk != -100).sum().item()
grad_weight.add_(chunk_grad_weight * n_non_ignore)
if grad_bias is not None: grad_bias.add_(chunk_grad_bias * n_non_ignore)
loss_acc.add_(chunk_loss * n_non_ignore)
return chunk_grad_input * n_non_ignore
accu_len = 0
for chunk_id in range(chunks):
start_idx = chunk_id * chunk_size
end_idx = min((chunk_id + 1) * chunk_size, BT)
input_chunk = _input[start_idx: end_idx]
target_chunk = target[start_idx: end_idx]
grad_input[accu_len: accu_len + input_chunk.shape[0]] = accumulate_chunk(input_chunk, target_chunk)
accu_len += input_chunk.shape[0]
ctx.save_for_backward(
grad_input / total_n_non_ignore,
grad_weight / total_n_non_ignore,
grad_bias / total_n_non_ignore if grad_bias is not None else None,
)
return loss_acc / total_n_non_ignore
@staticmethod
def backward(ctx, grad_output):
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
return (grad_input, grad_weight, None, grad_bias, None, None)
class CompiledFusedLinearCrossEntropyLoss(CrossEntropyLoss):
def __init__(self, *args, **kwargs):
super(CompiledFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs)
def forward(self, lin_weight, _input, target, bias=None, softcap_value=None, chunk_size=None):
return ChunkedCE.apply(
_input, lin_weight, target, bias, softcap_value, chunk_size
)
if __name__ == "__main__":
torch.set_default_device('cuda')
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
B, T, D, V = 4, 2048, 4096, 128256
B, T, D, V = 4, 4096, 3584, 256000
model = nn.Linear(D, V, bias=True).to(torch.bfloat16)
x = torch.randn(B, T, D, requires_grad=True, dtype=torch.bfloat16)
label = torch.randint(0, V, (B, T)).to(torch.int64)
def f(m, x, label):
out = F.cross_entropy(m(x).view(-1, V), label.view(-1))
out.backward()
return out
def chunked_f(m, x, label, chunk_size=None):
out = ChunkedCE.apply(x.view(-1, D), m.weight, label.view(-1), m.bias, None, chunk_size)
out.backward()
return out
def ligerf(m, x, label):
out = LigerFusedLinearCrossEntropyFunction.apply(x.view(-1, D), m.weight,label.view(-1), model.bias)
out.backward()
return out
def bench(f, name=None, iters=100, warmup=5, display=True, profile=False, profile_mem=False):
from triton.testing import do_bench
for _ in range(warmup):
f()
if profile_mem:
torch.cuda.memory._record_memory_history()
f()
torch.cuda.memory._dump_snapshot(f"{name if name is not None else 'memory'}.pickle")
if profile:
with torch.profiler.profile() as prof:
f()
prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_peak_memory_stats()
ms_per_iter = do_bench(lambda: f())
if name is None:
res = ms_per_iter
else:
res= f"{name}: {ms_per_iter:.3f}ms"
if display:
print(res)
print("Peak mem: ", torch.cuda.max_memory_allocated()/1e9)
print()
return res
opt_f = torch.compile(f)
# bench(lambda: f(model, x, label), name='eager lce (non-chunked)')
# bench(lambda: opt_f(model, x, label), name='compile lce (non-chunked)')
bench(lambda: ligerf(model, x, label), name='liger lce')
bench(lambda: chunked_f(model, x, label, chunk_size=1024), name='compile lce (chunk 1024)')
# bench(lambda: chunked_f(model, x, label, chunk_size='auto'), name='compile lce (chunk auto)')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment