Last active
November 18, 2024 13:49
-
-
Save wizyoung/5330ad501e73a97dfe2f0088decdb1ca to your computer and use it in GitHub Desktop.
chunked_lce.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
sys.path.append('.') | |
import argparse | |
import csv | |
import json | |
import os | |
import time | |
from typing import Any, Callable, Dict, List, Optional, Union | |
from collections import OrderedDict | |
from dataclasses import asdict, dataclass | |
from importlib.metadata import version | |
from itertools import zip_longest | |
import numpy as np | |
import torch | |
import triton | |
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss | |
from chunked_lce import CompiledFusedLinearCrossEntropyLoss | |
LIGER_KERNEL_VERSION = version("liger-kernel") | |
QUANTILES = [0.5, 0.2, 0.8] | |
@dataclass | |
class SingleBenchmarkRunInput: | |
x: Union[int, float] | |
kernel_provider: str | |
kernel_operation_mode: Optional[str] = "" | |
extra_benchmark_config: Optional[Dict[str, Any]] = None | |
@dataclass | |
class SingleBenchmarkRunOutput: | |
# 20th percentile | |
y_20: float | |
# 50th percentile (median) | |
y_50: float | |
# 80th percentile | |
y_80: float | |
@dataclass | |
class BenchmarkData: | |
""" | |
BenchmarkData is a dataclass to store the benchmark data for a a completed benchmark | |
run on all x-values for a given kernel/kernel operation mode/metric/extra_benchmark_config | |
""" | |
kernel_name: str | |
kernel_provider: str | |
metric_name: str | |
metric_unit: str | |
gpu_name: str | |
x_name: str | |
x_label: str | |
x_values: List[float] | |
y_values_50: List[float] | |
y_values_20: List[float] | |
y_values_80: List[float] | |
timestamp: str | |
kernel_operation_mode: Optional[str] = None | |
extra_benchmark_config_str: Optional[str] = None | |
liger_version: str = LIGER_KERNEL_VERSION | |
@dataclass | |
class BenchmarkDataCSVRow: | |
# The ordering of field names here will be the order of columns in the CSV | |
kernel_name: str | |
kernel_provider: str | |
kernel_operation_mode: Union[str, None] | |
metric_name: str | |
metric_unit: str | |
x_name: str | |
x_label: str | |
x_value: float | |
y_value_50: float | |
y_value_20: float | |
y_value_80: float | |
extra_benchmark_config_str: Union[str, None] | |
gpu_name: str | |
timestamp: str | |
liger_version: str | |
def _test_memory( | |
func: Callable, | |
_iter: int = 10, | |
quantiles: Optional[List[float]] = None, | |
return_mode="mean", | |
) -> float: | |
assert return_mode in ["min", "max", "mean", "median"] | |
total_mem = [] | |
for _ in range(_iter): | |
torch.cuda.memory.reset_peak_memory_stats() | |
func() | |
# Convert to MB | |
mem = torch.cuda.max_memory_allocated() / 2**20 | |
total_mem.append(mem) | |
total_mem = torch.tensor(total_mem, dtype=torch.float) | |
if quantiles is not None: | |
quantiles_data = torch.quantile( | |
total_mem, torch.tensor(quantiles, dtype=torch.float) | |
).tolist() | |
if len(quantiles_data) == 1: | |
quantiles_data = quantiles_data[0] | |
return quantiles_data | |
return getattr(torch, return_mode)(total_mem).item() | |
def get_current_file_directory() -> str: | |
""" | |
Returns the directory path of the current Python file. | |
""" | |
# Get the absolute path of the current file | |
current_file_path = os.path.abspath(__file__) | |
# Get the directory path of the current file | |
return os.path.dirname(current_file_path) | |
def sleep(seconds): | |
def decorator(function): | |
def wrapper(*args, **kwargs): | |
time.sleep(seconds) | |
return function(*args, **kwargs) | |
return wrapper | |
return decorator | |
def _print_benchmarking_banner(metric_name: str, kernel_name: str): | |
print("**************************************") | |
print(f" BENCHMARKING {metric_name.upper()} for {kernel_name.upper()}") | |
print("**************************************") | |
def get_formatted_time(): | |
return time.strftime("%Y-%m-%d %H:%M:%S") | |
def get_gpu_name(): | |
""" | |
Returns the current GPU name, formatted to serve as a directory name | |
""" | |
if torch.cuda.is_available(): | |
gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()) | |
return gpu_name | |
else: | |
raise Exception("Benchmarks can only be run on GPU.") | |
def update_benchmark_data_csv( | |
benchmark_data_list: List[BenchmarkData], | |
filename: str = "all_benchmark_data.csv", | |
overwrite: bool = True, | |
): | |
""" | |
Update the CSV file with the new benchmark data. If the file does not exist, create it. | |
If an entry already exists for the benchmark, then overwrite it if `overwrite` is True. | |
""" | |
def create_unique_key(row): | |
# This unique key is used to determine if a benchmark run already exists in the CSV | |
# If the key is the same, then the benchmark run already exists and will optionally | |
# be overwritten. Otherwise, it is considered a new benchmark run and appended. | |
return ( | |
row["kernel_name"], | |
row["kernel_provider"], | |
row["kernel_operation_mode"] if row["kernel_operation_mode"] else "", | |
row["metric_name"], | |
row["x_name"], | |
str(row["x_value"]), | |
( | |
row["extra_benchmark_config_str"] | |
if row["extra_benchmark_config_str"] | |
else "" | |
), | |
row["gpu_name"], | |
) | |
fieldnames = BenchmarkDataCSVRow.__annotations__.keys() | |
# Make filename path relative to current file | |
# filename_abs_path = os.path.join('raw_data', filename) | |
filename_abs_path = filename | |
file_exists = os.path.isfile(filename_abs_path) | |
# Read existing data into a list of dicts | |
existing_data = [] | |
if file_exists: | |
with open(filename_abs_path, mode="r") as file: | |
reader = csv.DictReader(file) | |
for row in reader: | |
existing_data.append(row) | |
existing_data_dict = OrderedDict( | |
(create_unique_key(row), row) for row in existing_data | |
) | |
for benchmark_data in benchmark_data_list: | |
benchmark_data_dict = asdict(benchmark_data) | |
x_values = benchmark_data_dict.pop("x_values") | |
y_values_50 = benchmark_data_dict.pop("y_values_50") | |
y_values_20 = benchmark_data_dict.pop("y_values_20") | |
y_values_80 = benchmark_data_dict.pop("y_values_80") | |
# Need to convert benchmark_data into multiple rows based on x_values and y_values | |
for x_value, y_value_50, y_value_20, y_value_80 in zip_longest( | |
x_values, y_values_50, y_values_20, y_values_80 | |
): | |
row = BenchmarkDataCSVRow( | |
x_value=x_value, | |
y_value_50=y_value_50, | |
y_value_20=y_value_20, | |
y_value_80=y_value_80, | |
**benchmark_data_dict, | |
) | |
row_dict = asdict(row) | |
row_key = create_unique_key(row_dict) | |
if row_key in existing_data_dict: | |
if overwrite: | |
# If overwriting, update the row | |
existing_data_dict[row_key] = row_dict | |
else: | |
# If not overwriting, skip this row | |
pass | |
else: | |
existing_data_dict[row_key] = row_dict | |
with open(filename_abs_path, mode="w", newline="") as file: | |
writer = csv.DictWriter(file, fieldnames=fieldnames) | |
writer.writeheader() | |
for row in existing_data_dict.values(): | |
writer.writerow(row) | |
class CustomEncoder(json.JSONEncoder): | |
def default(self, obj): | |
if isinstance(obj, torch.dtype): | |
return str(obj) | |
return super().default(self, obj) | |
def print_benchmark_data(benchmark_data_list: List[BenchmarkData]) -> str: | |
print("********** Benchmark Data **********") | |
formatted_list = [obj.__dict__ for obj in benchmark_data_list] | |
print(json.dumps(formatted_list, indent=2)) | |
def run_benchmarks( | |
bench_test_fn: Callable, | |
kernel_name: str, | |
metric_name: str, | |
metric_unit: str, | |
x_name: str, | |
x_label: str, | |
x_values: List[Union[float, int]], | |
kernel_providers: List[str], | |
kernel_operation_modes: Optional[List[str]] = [None], | |
extra_benchmark_configs: Optional[List[Dict[str, Any]]] = None, | |
overwrite: bool = False, | |
): | |
""" | |
Run benchmarks given a bench_test_fn that takes in a SingleBenchmarkRunInput as input and | |
saves data to the CSV file. | |
Args: | |
- bench_test_fn: The benchmark test function to run. This function should take in a | |
SingleBenchmarkRunInput as input and return a SingleBenchmarkRunOutput. | |
- kernel_name: The name of the kernel being benchmarked (e.g. "swiglu") | |
- metric_name: The name of the metric being benchmarked (e.g. "speed" or "memory") | |
- metric_unit: The unit of the metric being benchmarked (e.g. "ms" or "MB") | |
- x_name: The name of the x-axis (e.g. "T" for sequence length) | |
- x_label: The label of the x-axis (e.g. "sequence length") | |
- x_values: The list of x-values to run the benchmark on (e.g. [2**i for i in range(10, 14)]) | |
- kernel_providers: The list of kernel providers to run the benchmark on (e.g. ["liger", "huggingface"]) | |
- kernel_operation_modes: The list of kernel operation modes to run the benchmark on (e.g. ["full", "backward"]) | |
- extra_benchmark_configs: The list of extra benchmark configurations to run the benchmark on. | |
- overwrite: Whether to overwrite the existing benchmark data entry if it already exists. | |
""" | |
assert len(kernel_operation_modes) >= 1 | |
assert len(kernel_providers) >= 1 | |
_print_benchmarking_banner(metric_name=metric_name, kernel_name=kernel_name) | |
gpu_name = get_gpu_name() | |
benchmark_data_list = [] | |
for extra_benchmark_config in extra_benchmark_configs: | |
for kernel_operation_mode in kernel_operation_modes: | |
for kernel_provider in kernel_providers: | |
y_values_50 = [] | |
y_values_20 = [] | |
y_values_80 = [] | |
for x in x_values: | |
print(f'{extra_benchmark_config} - {kernel_operation_mode} - {kernel_provider} - {x}') | |
single_benchmark_run_input = SingleBenchmarkRunInput( | |
x=x, | |
kernel_provider=kernel_provider, | |
kernel_operation_mode=kernel_operation_mode, | |
extra_benchmark_config=extra_benchmark_config, | |
) | |
benchmark_result: SingleBenchmarkRunOutput = bench_test_fn( | |
single_benchmark_run_input | |
) | |
y_values_50.append(benchmark_result.y_50) | |
y_values_20.append(benchmark_result.y_20) | |
y_values_80.append(benchmark_result.y_80) | |
benchmark_run_data = BenchmarkData( | |
kernel_name=kernel_name, | |
kernel_operation_mode=kernel_operation_mode, | |
kernel_provider=kernel_provider, | |
metric_name=metric_name, | |
metric_unit=metric_unit, | |
gpu_name=gpu_name, | |
x_name=x_name, | |
x_label=x_label, | |
x_values=x_values, | |
y_values_50=y_values_50, | |
y_values_20=y_values_20, | |
y_values_80=y_values_80, | |
extra_benchmark_config_str=json.dumps( | |
extra_benchmark_config, cls=CustomEncoder | |
), | |
timestamp=get_formatted_time(), | |
liger_version=LIGER_KERNEL_VERSION, | |
) | |
benchmark_data_list.append(benchmark_run_data) | |
print_benchmark_data(benchmark_data_list) | |
update_benchmark_data_csv( | |
benchmark_data_list=benchmark_data_list, overwrite=overwrite | |
) | |
def parse_benchmark_script_args(): | |
parser = argparse.ArgumentParser(description="Benchmarking script for Liger-Kernel") | |
# Add an optional --overwrite flag | |
parser.add_argument( | |
"--overwrite", | |
action="store_true", | |
help="Flag to overwrite existing benchmark data with current run.", | |
) | |
args = parser.parse_args() | |
return args | |
class TorchLMHeadCE(torch.nn.Module): | |
"""Ground truth implementation of the linear fused with torch based cross entropy loss. | |
:param H: hidden size | |
:param V: vocab size | |
:param ignore_index: index to ignore | |
:param reduction: reduction method | |
""" | |
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | |
super().__init__() | |
self.lin = torch.nn.Linear( | |
in_features=H, out_features=V, bias=False, dtype=dtype | |
) | |
self.ce_loss = torch.nn.CrossEntropyLoss( | |
ignore_index=ignore_index, reduction="mean" | |
) | |
def forward(self, x, y): | |
logits = self.lin(x) | |
return self.ce_loss(logits, y) | |
class CompiledTorchLMHeadCE(torch.nn.Module): | |
def __init__( | |
self, | |
H: int, | |
V: int, | |
dtype: torch.dtype, | |
bias: bool = False, | |
ignore_index: int = -100, | |
): | |
super().__init__() | |
self.lin = torch.nn.Linear( | |
in_features=H, out_features=V, bias=bias, dtype=dtype | |
) | |
self.ce_loss = CompiledFusedLinearCrossEntropyLoss() | |
def forward(self, x, y, softcap_value=None): | |
return self.ce_loss(self.lin.weight, x, y, self.lin.bias, softcap_value) | |
class LigerLMHeadCE(torch.nn.Module): | |
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | |
super().__init__() | |
self.lin = torch.nn.Linear( | |
in_features=H, out_features=V, bias=False, dtype=dtype | |
) | |
self.ce_loss = LigerFusedLinearCrossEntropyLoss( | |
ignore_index=ignore_index, reduction="mean" | |
) | |
def forward(self, x, y): | |
return self.ce_loss(self.lin.weight, x, y) | |
############################################################################# | |
# Test the memory consumption of the linear fused cross entropy loss | |
############################################################################# | |
def bench_memory_fused_linear_cross_entropy( | |
input: SingleBenchmarkRunInput, | |
) -> SingleBenchmarkRunOutput: | |
BT = input.x | |
H = input.extra_benchmark_config["H"] | |
V = input.extra_benchmark_config["V"] | |
dtype = input.extra_benchmark_config["dtype"] | |
provider = input.kernel_provider | |
device = "cuda" | |
torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) | |
compiled_torch_lm_head_ce = CompiledTorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) | |
liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) | |
_input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device) | |
target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1) | |
def fwd(): | |
if provider == "liger": | |
return liger_lm_head_ce(_input, target) | |
elif provider == "huggingface": | |
return torch_lm_head_ce(_input, target) | |
elif provider == "compile": | |
return compiled_torch_lm_head_ce(_input, target) | |
def full(): | |
y = fwd() | |
y.backward() | |
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) | |
return SingleBenchmarkRunOutput( | |
y_20=mem_20, | |
y_50=mem_50, | |
y_80=mem_80, | |
) | |
# ############################################################################# | |
# # Test the speed of the fused linear cross entropy loss | |
# ############################################################################# | |
def bench_speed_fused_linear_cross_entropy( | |
input: SingleBenchmarkRunInput, | |
) -> SingleBenchmarkRunOutput: | |
BT = input.x | |
H = input.extra_benchmark_config["H"] | |
V = input.extra_benchmark_config["V"] | |
dtype = input.extra_benchmark_config["dtype"] | |
provider = input.kernel_provider | |
mode = input.kernel_operation_mode | |
device = "cuda" | |
torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) | |
compiled_torch_lm_head_ce = CompiledTorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) | |
liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) | |
_input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device) | |
target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1) | |
def fwd(): | |
if provider == "liger": | |
return liger_lm_head_ce(_input, target) | |
elif provider == "huggingface": | |
return torch_lm_head_ce(_input, target) | |
elif provider == "compile": | |
return compiled_torch_lm_head_ce(_input, target) | |
if mode == "forward": | |
ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
fwd, | |
rep=100, | |
quantiles=QUANTILES, | |
) | |
elif mode == "backward": | |
y = fwd() | |
ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
lambda: y.backward(retain_graph=True), | |
grad_to_none=[_input], | |
rep=100, | |
quantiles=QUANTILES, | |
) | |
elif mode == "full": | |
def full(): | |
y = fwd() | |
y.backward() | |
ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
full, | |
rep=100, | |
quantiles=QUANTILES, | |
) | |
return SingleBenchmarkRunOutput( | |
y_20=ms_20, | |
y_50=ms_50, | |
y_80=ms_80, | |
) | |
if __name__ == "__main__": | |
args = parse_benchmark_script_args() | |
# fast | |
noisy_points = [2**i for i in range(12, 16)] | |
# point_sample_num = 10 + 2 | |
noisy_points = [2**12] + (np.random.randint(2**12, 2**15 + 1, 10) + np.random.randint(0, 1001, 10)).tolist() + [2**15] | |
# noisy_points.sort() | |
common_configs = { | |
"kernel_name": "FLCE", | |
"x_name": "BT", | |
"x_label": "B x T", | |
"x_values": noisy_points, | |
"kernel_providers": ["compile", "liger", "huggingface"], | |
"extra_benchmark_configs": [ | |
{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16} # llama3 config | |
], | |
"overwrite": args.overwrite, | |
} | |
run_benchmarks( | |
bench_test_fn=bench_speed_fused_linear_cross_entropy, | |
kernel_operation_modes=["forward", "full"], | |
metric_name="speed", | |
metric_unit="ms", | |
**common_configs | |
) | |
run_benchmarks( | |
bench_test_fn=bench_memory_fused_linear_cross_entropy, | |
kernel_operation_modes=["full"], | |
metric_name="memory", | |
metric_unit="MB", | |
**common_configs | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# usage: after running benchmark_flce.py, run: | |
# python benchmarks_visualizer.py --kernel-name FLCE --metric-name memory --kernel-operation-mode full | |
# python benchmarks_visualizer.py --kernel-name FLCE --metric-name speed --kernel-operation-mode full | |
import json | |
import os | |
from argparse import ArgumentParser | |
from dataclasses import dataclass | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import seaborn as sns | |
DATA_PATH = "all_benchmark_data.csv" | |
VISUALIZATIONS_PATH = "visualizations/" | |
@dataclass | |
class VisualizationsConfig: | |
""" | |
Configuration for the visualizations script. | |
Args: | |
kernel_name (str): Kernel name to benchmark. (Will run `scripts/benchmark_{kernel_name}.py`) | |
metric_name (str): Metric name to visualize (speed/memory) | |
kernel_operation_mode (str): Kernel operation mode to visualize (forward/backward/full). Defaults to "full" | |
display (bool): Display the visualization. Defaults to False | |
overwrite (bool): Overwrite existing visualization, if none exist this flag has no effect as ones are always created and saved. Defaults to False | |
""" | |
kernel_name: str | |
metric_name: str | |
kernel_operation_mode: str = "full" | |
display: bool = False | |
overwrite: bool = False | |
def parse_args() -> VisualizationsConfig: | |
"""Parse command line arguments into a configuration object. | |
Returns: | |
VisualizationsConfig: Configuration object for the visualizations script. | |
""" | |
parser = ArgumentParser() | |
parser.add_argument( | |
"--kernel-name", type=str, required=True, help="Kernel name to benchmark" | |
) | |
parser.add_argument( | |
"--metric-name", | |
type=str, | |
required=True, | |
help="Metric name to visualize (speed/memory)", | |
) | |
parser.add_argument( | |
"--kernel-operation-mode", | |
type=str, | |
required=True, | |
help="Kernel operation mode to visualize (forward/backward/full)", | |
) | |
parser.add_argument( | |
"--display", action="store_true", help="Display the visualization" | |
) | |
parser.add_argument( | |
"--overwrite", | |
action="store_true", | |
help="Overwrite existing visualization, if none exist this flag has no effect as one are always created", | |
) | |
args = parser.parse_args() | |
return VisualizationsConfig(**dict(args._get_kwargs())) | |
def load_data(config: VisualizationsConfig) -> pd.DataFrame: | |
"""Loads the benchmark data from the CSV file and filters it based on the configuration. | |
Args: | |
config (VisualizationsConfig): Configuration object for the visualizations script. | |
Raises: | |
ValueError: If no data is found for the given filters. | |
Returns: | |
pd.DataFrame: Filtered benchmark dataframe. | |
""" | |
df = pd.read_csv(DATA_PATH) | |
df["extra_benchmark_config"] = df["extra_benchmark_config_str"].apply(json.loads) | |
filtered_df = df[ | |
(df["kernel_name"] == config.kernel_name) | |
& (df["metric_name"] == config.metric_name) | |
& (df["kernel_operation_mode"] == config.kernel_operation_mode) | |
# Use this to filter by extra benchmark configuration property | |
# & (data['extra_benchmark_config'].apply(lambda x: x.get('H') == 4096)) | |
# FIXME: maybe add a way to filter using some configuration, except of hardcoding it | |
] | |
if filtered_df.empty: | |
raise ValueError("No data found for the given filters") | |
return filtered_df | |
def plot_data(df: pd.DataFrame, config: VisualizationsConfig): | |
"""Plots the benchmark data, saving the result if needed. | |
Args: | |
df (pd.DataFrame): Filtered benchmark dataframe. | |
config (VisualizationsConfig): Configuration object for the visualizations script. | |
""" | |
xlabel = df["x_label"].iloc[0] | |
ylabel = f"{config.metric_name} ({df['metric_unit'].iloc[0]})" | |
# Sort by "kernel_provider" to ensure consistent color assignment | |
df = df.sort_values(by="kernel_provider") | |
plt.figure(figsize=(10, 6)) | |
sns.set(style="whitegrid") | |
ax = sns.lineplot( | |
data=df, | |
x="x_value", | |
y="y_value_50", | |
hue="kernel_provider", | |
marker="o", | |
palette="tab10", | |
errorbar=("ci", None), | |
) | |
# Seaborn can't plot pre-computed error bars, so we need to do it manually | |
lines = ax.get_lines() | |
colors = [line.get_color() for line in lines] | |
for (_, group_data), color in zip(df.groupby("kernel_provider"), colors): | |
# for i, row in group_data.iterrows(): | |
y_error_lower = group_data["y_value_50"] - group_data["y_value_20"] | |
y_error_upper = group_data["y_value_80"] - group_data["y_value_50"] | |
y_error = [y_error_lower, y_error_upper] | |
plt.errorbar( | |
group_data["x_value"], | |
group_data["y_value_50"], | |
yerr=y_error, | |
fmt="o", | |
color=color, | |
capsize=5, | |
) | |
plt.legend(title="Kernel Provider") | |
plt.xlabel(xlabel) | |
plt.ylabel(ylabel) | |
plt.tight_layout() | |
out_path = os.path.join( | |
VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png" | |
) | |
if config.display: | |
plt.show() | |
if config.overwrite or not os.path.exists( | |
out_path | |
): # Save the plot if it doesn't exist or if we want to overwrite it | |
os.makedirs(VISUALIZATIONS_PATH, exist_ok=True) | |
plt.savefig(out_path) | |
plt.close() | |
def main(): | |
config = parse_args() | |
df = load_data(config) | |
plot_data(df, config) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# inspired by https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899 | |
# discussion: https://github.com/linkedin/Liger-Kernel/issues/227 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import CrossEntropyLoss | |
def cdiv(x: int, y: int): | |
return (x + y - 1) // y | |
def next_power_of_2(n: int): | |
"""Return the smallest power of 2 greater than or equal to n""" | |
n -= 1 | |
n |= n >> 1 | |
n |= n >> 2 | |
n |= n >> 4 | |
n |= n >> 8 | |
n |= n >> 16 | |
n |= n >> 32 | |
n += 1 | |
return n | |
class ChunkedCE(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, _input, weight, target, bias=None, softcap_value=None, chunk_size=None): | |
BT, H = _input.shape | |
if chunk_size is None: | |
chunk_size = 1024 | |
if chunk_size == 'auto': | |
V = weight.shape[0] | |
inc_factor = cdiv(V, H) # (V + H - 1) // H | |
chunk_size = next_power_of_2(cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor | |
chunks = cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size | |
total_n_non_ignore = (target != -100).sum() | |
def compute_loss(input_chunk, weight, bias, target): | |
# if bias is not None: | |
# logits = torch.addmm(bias, input_chunk, weight.t()) | |
# else: | |
# logits = torch.matmul(input_chunk, weight.t()) | |
# more memory efficient when bias is set | |
logits = F.linear(input_chunk, weight, bias) | |
if softcap_value is not None: | |
logits = torch.tanh(logits / softcap_value) * softcap_value | |
logits = logits.float() | |
loss = F.cross_entropy(logits, target) | |
return loss | |
grad_weight = torch.zeros_like(weight) | |
grad_input = torch.zeros_like(_input) | |
grad_bias = torch.zeros_like(bias) if bias is not None else None | |
loss_acc = torch.zeros((), device=_input.device) | |
@torch.compile(dynamic=True, options={"shape_padding": True}) | |
def accumulate_chunk(input_chunk, target_chunk): | |
if bias is not None: | |
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), chunk_loss = torch.func.grad_and_value(compute_loss, argnums=(0, 1, 2))(input_chunk, weight, bias, target_chunk) | |
else: | |
(chunk_grad_input, chunk_grad_weight), chunk_loss = torch.func.grad_and_value(compute_loss, argnums=(0, 1))(input_chunk, weight, None, target_chunk) | |
chunk_grad_bias = None | |
n_non_ignore = (target_chunk != -100).sum().item() | |
grad_weight.add_(chunk_grad_weight * n_non_ignore) | |
if grad_bias is not None: grad_bias.add_(chunk_grad_bias * n_non_ignore) | |
loss_acc.add_(chunk_loss * n_non_ignore) | |
return chunk_grad_input * n_non_ignore | |
accu_len = 0 | |
for chunk_id in range(chunks): | |
start_idx = chunk_id * chunk_size | |
end_idx = min((chunk_id + 1) * chunk_size, BT) | |
input_chunk = _input[start_idx: end_idx] | |
target_chunk = target[start_idx: end_idx] | |
grad_input[accu_len: accu_len + input_chunk.shape[0]] = accumulate_chunk(input_chunk, target_chunk) | |
accu_len += input_chunk.shape[0] | |
ctx.save_for_backward( | |
grad_input / total_n_non_ignore, | |
grad_weight / total_n_non_ignore, | |
grad_bias / total_n_non_ignore if grad_bias is not None else None, | |
) | |
return loss_acc / total_n_non_ignore | |
@staticmethod | |
def backward(ctx, grad_output): | |
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors | |
return (grad_input, grad_weight, None, grad_bias, None, None) | |
class CompiledFusedLinearCrossEntropyLoss(CrossEntropyLoss): | |
def __init__(self, *args, **kwargs): | |
super(CompiledFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs) | |
def forward(self, lin_weight, _input, target, bias=None, softcap_value=None, chunk_size=None): | |
return ChunkedCE.apply( | |
_input, lin_weight, target, bias, softcap_value, chunk_size | |
) | |
if __name__ == "__main__": | |
torch.set_default_device('cuda') | |
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction | |
B, T, D, V = 4, 2048, 4096, 128256 | |
B, T, D, V = 4, 4096, 3584, 256000 | |
model = nn.Linear(D, V, bias=True).to(torch.bfloat16) | |
x = torch.randn(B, T, D, requires_grad=True, dtype=torch.bfloat16) | |
label = torch.randint(0, V, (B, T)).to(torch.int64) | |
def f(m, x, label): | |
out = F.cross_entropy(m(x).view(-1, V), label.view(-1)) | |
out.backward() | |
return out | |
def chunked_f(m, x, label, chunk_size=None): | |
out = ChunkedCE.apply(x.view(-1, D), m.weight, label.view(-1), m.bias, None, chunk_size) | |
out.backward() | |
return out | |
def ligerf(m, x, label): | |
out = LigerFusedLinearCrossEntropyFunction.apply(x.view(-1, D), m.weight,label.view(-1), model.bias) | |
out.backward() | |
return out | |
def bench(f, name=None, iters=100, warmup=5, display=True, profile=False, profile_mem=False): | |
from triton.testing import do_bench | |
for _ in range(warmup): | |
f() | |
if profile_mem: | |
torch.cuda.memory._record_memory_history() | |
f() | |
torch.cuda.memory._dump_snapshot(f"{name if name is not None else 'memory'}.pickle") | |
if profile: | |
with torch.profiler.profile() as prof: | |
f() | |
prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json") | |
torch.cuda.reset_peak_memory_stats() | |
torch.cuda.reset_peak_memory_stats() | |
ms_per_iter = do_bench(lambda: f()) | |
if name is None: | |
res = ms_per_iter | |
else: | |
res= f"{name}: {ms_per_iter:.3f}ms" | |
if display: | |
print(res) | |
print("Peak mem: ", torch.cuda.max_memory_allocated()/1e9) | |
print() | |
return res | |
opt_f = torch.compile(f) | |
# bench(lambda: f(model, x, label), name='eager lce (non-chunked)') | |
# bench(lambda: opt_f(model, x, label), name='compile lce (non-chunked)') | |
bench(lambda: ligerf(model, x, label), name='liger lce') | |
bench(lambda: chunked_f(model, x, label, chunk_size=1024), name='compile lce (chunk 1024)') | |
# bench(lambda: chunked_f(model, x, label, chunk_size='auto'), name='compile lce (chunk auto)') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment