Created
December 15, 2024 09:47
-
-
Save kashif/29e19d624aca5556b225f5c5692ce770 to your computer and use it in GitHub Desktop.
liger code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
├── benchmark | |
├── __init__.py | |
├── benchmarks_visualizer.py | |
└── scripts | |
│ ├── __init__.py | |
│ ├── benchmark_cpo_loss.py | |
│ ├── benchmark_cross_entropy.py | |
│ ├── benchmark_dpo_loss.py | |
│ ├── benchmark_embedding.py | |
│ ├── benchmark_fused_linear_cross_entropy.py | |
│ ├── benchmark_fused_linear_jsd.py | |
│ ├── benchmark_geglu.py | |
│ ├── benchmark_group_norm.py | |
│ ├── benchmark_jsd.py | |
│ ├── benchmark_kl_div.py | |
│ ├── benchmark_layer_norm.py | |
│ ├── benchmark_orpo_loss.py | |
│ ├── benchmark_qwen2vl_mrope.py | |
│ ├── benchmark_rms_norm.py | |
│ ├── benchmark_rope.py | |
│ ├── benchmark_simpo_loss.py | |
│ ├── benchmark_swiglu.py | |
│ └── utils.py | |
├── dev | |
└── modal | |
│ ├── tests.py | |
│ └── tests_bwd.py | |
├── examples | |
├── alignment | |
│ └── run_orpo.py | |
├── huggingface | |
│ ├── callback.py | |
│ ├── launch_on_modal.py | |
│ ├── training.py | |
│ └── training_multimodal.py | |
├── lightning | |
│ └── training.py | |
└── medusa | |
│ ├── callback.py | |
│ ├── medusa_util.py | |
│ └── train.py | |
├── setup.py | |
├── src | |
└── liger_kernel | |
│ ├── __init__.py | |
│ ├── chunked_loss | |
│ ├── __init__.py | |
│ ├── cpo_loss.py | |
│ ├── dpo_loss.py | |
│ ├── functional.py | |
│ ├── fused_linear_distillation.py | |
│ ├── fused_linear_preference.py | |
│ ├── orpo_loss.py | |
│ └── simpo_loss.py | |
│ ├── env_report.py | |
│ ├── ops | |
│ ├── __init__.py | |
│ ├── cross_entropy.py | |
│ ├── experimental | |
│ │ ├── embedding.py | |
│ │ └── mm_int8int2.py | |
│ ├── fused_linear_cross_entropy.py | |
│ ├── fused_linear_jsd.py | |
│ ├── geglu.py | |
│ ├── group_norm.py | |
│ ├── jsd.py | |
│ ├── kl_div.py | |
│ ├── layer_norm.py | |
│ ├── qwen2vl_mrope.py | |
│ ├── rms_norm.py | |
│ ├── rope.py | |
│ ├── swiglu.py | |
│ └── utils.py | |
│ ├── transformers | |
│ ├── __init__.py | |
│ ├── auto_model.py | |
│ ├── cross_entropy.py | |
│ ├── experimental | |
│ │ └── embedding.py | |
│ ├── functional.py | |
│ ├── fused_linear_cross_entropy.py | |
│ ├── fused_linear_jsd.py | |
│ ├── geglu.py | |
│ ├── group_norm.py | |
│ ├── jsd.py | |
│ ├── kl_div.py | |
│ ├── layer_norm.py | |
│ ├── model | |
│ │ ├── __init__.py | |
│ │ ├── gemma.py | |
│ │ ├── gemma2.py | |
│ │ ├── llama.py | |
│ │ ├── mistral.py | |
│ │ ├── mixtral.py | |
│ │ ├── mllama.py | |
│ │ ├── phi3.py | |
│ │ ├── qwen2.py | |
│ │ └── qwen2_vl.py | |
│ ├── monkey_patch.py | |
│ ├── qwen2vl_mrope.py | |
│ ├── rms_norm.py | |
│ ├── rope.py | |
│ ├── swiglu.py | |
│ ├── trainer | |
│ │ ├── __init__.py | |
│ │ └── orpo_trainer.py | |
│ └── trainer_integration.py | |
│ ├── triton | |
│ ├── __init__.py | |
│ └── monkey_patch.py | |
│ └── utils.py | |
└── test | |
├── __init__.py | |
├── chunked_loss | |
├── __init__.py | |
├── test_cpo_loss.py | |
├── test_dpo_loss.py | |
├── test_orpo_loss.py | |
└── test_simpo_loss.py | |
├── conftest.py | |
├── convergence | |
├── __init__.py | |
├── test_mini_models.py | |
├── test_mini_models_multimodal.py | |
└── test_mini_models_with_logits.py | |
├── resources | |
└── scripts | |
│ └── generate_tokenized_dataset.py | |
├── transformers | |
├── test_auto_model.py | |
├── test_cross_entropy.py | |
├── test_embedding.py | |
├── test_fused_linear_cross_entropy.py | |
├── test_fused_linear_jsd.py | |
├── test_geglu.py | |
├── test_group_norm.py | |
├── test_jsd.py | |
├── test_kl_div.py | |
├── test_layer_norm.py | |
├── test_mm_int8int2.py | |
├── test_monkey_patch.py | |
├── test_qwen2vl_mrope.py | |
├── test_rms_norm.py | |
├── test_rope.py | |
├── test_swiglu.py | |
├── test_trainer_integration.py | |
└── test_transformers.py | |
├── triton | |
└── test_triton_monkey_patch.py | |
└── utils.py | |
/benchmark/__init__.py: | |
-------------------------------------------------------------------------------- | |
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/benchmark/__init__.py | |
-------------------------------------------------------------------------------- | |
/benchmark/benchmarks_visualizer.py: | |
-------------------------------------------------------------------------------- | |
1 | import json | |
2 | import os | |
3 | from argparse import ArgumentParser | |
4 | from dataclasses import dataclass | |
5 | | |
6 | import matplotlib.pyplot as plt | |
7 | import pandas as pd | |
8 | import seaborn as sns | |
9 | | |
10 | DATA_PATH = "data/all_benchmark_data.csv" | |
11 | VISUALIZATIONS_PATH = "visualizations/" | |
12 | | |
13 | | |
14 | @dataclass | |
15 | class VisualizationsConfig: | |
16 | """ | |
17 | Configuration for the visualizations script. | |
18 | | |
19 | Args: | |
20 | kernel_name (str): Kernel name to benchmark. (Will run `scripts/benchmark_{kernel_name}.py`) | |
21 | metric_name (str): Metric name to visualize (speed/memory) | |
22 | kernel_operation_mode (str): Kernel operation mode to visualize (forward/backward/full). Defaults to "full" | |
23 | display (bool): Display the visualization. Defaults to False | |
24 | overwrite (bool): Overwrite existing visualization, if none exist this flag has no effect as ones are always created and saved. Defaults to False | |
25 | | |
26 | """ | |
27 | | |
28 | kernel_name: str | |
29 | metric_name: str | |
30 | kernel_operation_mode: str = "full" | |
31 | display: bool = False | |
32 | overwrite: bool = False | |
33 | | |
34 | | |
35 | def parse_args() -> VisualizationsConfig: | |
36 | """Parse command line arguments into a configuration object. | |
37 | | |
38 | Returns: | |
39 | VisualizationsConfig: Configuration object for the visualizations script. | |
40 | """ | |
41 | parser = ArgumentParser() | |
42 | parser.add_argument( | |
43 | "--kernel-name", type=str, required=True, help="Kernel name to benchmark" | |
44 | ) | |
45 | parser.add_argument( | |
46 | "--metric-name", | |
47 | type=str, | |
48 | required=True, | |
49 | help="Metric name to visualize (speed/memory)", | |
50 | ) | |
51 | parser.add_argument( | |
52 | "--kernel-operation-mode", | |
53 | type=str, | |
54 | required=True, | |
55 | help="Kernel operation mode to visualize (forward/backward/full)", | |
56 | ) | |
57 | parser.add_argument( | |
58 | "--display", action="store_true", help="Display the visualization" | |
59 | ) | |
60 | parser.add_argument( | |
61 | "--overwrite", | |
62 | action="store_true", | |
63 | help="Overwrite existing visualization, if none exist this flag has no effect as one are always created", | |
64 | ) | |
65 | | |
66 | args = parser.parse_args() | |
67 | | |
68 | return VisualizationsConfig(**dict(args._get_kwargs())) | |
69 | | |
70 | | |
71 | def load_data(config: VisualizationsConfig) -> pd.DataFrame: | |
72 | """Loads the benchmark data from the CSV file and filters it based on the configuration. | |
73 | | |
74 | Args: | |
75 | config (VisualizationsConfig): Configuration object for the visualizations script. | |
76 | | |
77 | Raises: | |
78 | ValueError: If no data is found for the given filters. | |
79 | | |
80 | Returns: | |
81 | pd.DataFrame: Filtered benchmark dataframe. | |
82 | """ | |
83 | df = pd.read_csv(DATA_PATH) | |
84 | df["extra_benchmark_config"] = df["extra_benchmark_config_str"].apply(json.loads) | |
85 | | |
86 | filtered_df = df[ | |
87 | (df["kernel_name"] == config.kernel_name) | |
88 | & (df["metric_name"] == config.metric_name) | |
89 | & (df["kernel_operation_mode"] == config.kernel_operation_mode) | |
90 | # Use this to filter by extra benchmark configuration property | |
91 | # & (data['extra_benchmark_config'].apply(lambda x: x.get('H') == 4096)) | |
92 | # FIXME: maybe add a way to filter using some configuration, except of hardcoding it | |
93 | ] | |
94 | | |
95 | if filtered_df.empty: | |
96 | raise ValueError("No data found for the given filters") | |
97 | | |
98 | return filtered_df | |
99 | | |
100 | | |
101 | def plot_data(df: pd.DataFrame, config: VisualizationsConfig): | |
102 | """Plots the benchmark data, saving the result if needed. | |
103 | | |
104 | Args: | |
105 | df (pd.DataFrame): Filtered benchmark dataframe. | |
106 | config (VisualizationsConfig): Configuration object for the visualizations script. | |
107 | """ | |
108 | xlabel = df["x_label"].iloc[0] | |
109 | ylabel = f"{config.metric_name} ({df['metric_unit'].iloc[0]})" | |
110 | # Sort by "kernel_provider" to ensure consistent color assignment | |
111 | df = df.sort_values(by="kernel_provider") | |
112 | | |
113 | plt.figure(figsize=(10, 6)) | |
114 | sns.set(style="whitegrid") | |
115 | ax = sns.lineplot( | |
116 | data=df, | |
117 | x="x_value", | |
118 | y="y_value_50", | |
119 | hue="kernel_provider", | |
120 | marker="o", | |
121 | palette="tab10", | |
122 | errorbar=("ci", None), | |
123 | ) | |
124 | | |
125 | # Seaborn can't plot pre-computed error bars, so we need to do it manually | |
126 | lines = ax.get_lines() | |
127 | colors = [line.get_color() for line in lines] | |
128 | | |
129 | for (_, group_data), color in zip(df.groupby("kernel_provider"), colors): | |
130 | # for i, row in group_data.iterrows(): | |
131 | y_error_lower = group_data["y_value_50"] - group_data["y_value_20"] | |
132 | y_error_upper = group_data["y_value_80"] - group_data["y_value_50"] | |
133 | y_error = [y_error_lower, y_error_upper] | |
134 | | |
135 | plt.errorbar( | |
136 | group_data["x_value"], | |
137 | group_data["y_value_50"], | |
138 | yerr=y_error, | |
139 | fmt="o", | |
140 | color=color, | |
141 | capsize=5, | |
142 | ) | |
143 | plt.legend(title="Kernel Provider") | |
144 | plt.xlabel(xlabel) | |
145 | plt.ylabel(ylabel) | |
146 | plt.tight_layout() | |
147 | | |
148 | out_path = os.path.join( | |
149 | VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png" | |
150 | ) | |
151 | | |
152 | if config.display: | |
153 | plt.show() | |
154 | if config.overwrite or not os.path.exists( | |
155 | out_path | |
156 | ): # Save the plot if it doesn't exist or if we want to overwrite it | |
157 | os.makedirs(VISUALIZATIONS_PATH, exist_ok=True) | |
158 | plt.savefig(out_path) | |
159 | plt.close() | |
160 | | |
161 | | |
162 | def main(): | |
163 | config = parse_args() | |
164 | df = load_data(config) | |
165 | plot_data(df, config) | |
166 | | |
167 | | |
168 | if __name__ == "__main__": | |
169 | main() | |
170 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/__init__.py: | |
-------------------------------------------------------------------------------- | |
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/benchmark/scripts/__init__.py | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_cpo_loss.py: | |
-------------------------------------------------------------------------------- | |
1 | import os | |
2 | import sys | |
3 | | |
4 | import torch | |
5 | import triton | |
6 | from utils import ( | |
7 | QUANTILES, | |
8 | SingleBenchmarkRunInput, | |
9 | SingleBenchmarkRunOutput, | |
10 | _test_memory, | |
11 | parse_benchmark_script_args, | |
12 | run_benchmarks, | |
13 | ) | |
14 | | |
15 | from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction | |
16 | from liger_kernel.utils import infer_device | |
17 | | |
18 | device = infer_device() | |
19 | | |
20 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) | |
21 | | |
22 | | |
23 | class TorchLMHeadCPO(torch.nn.Module): | |
24 | """Ground truth implementation of the linear fused with torch based cross entropy loss. | |
25 | | |
26 | :param H: hidden size | |
27 | :param V: vocab size | |
28 | :param ignore_index: index to ignore | |
29 | :param reduction: reduction method | |
30 | """ | |
31 | | |
32 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | |
33 | from test.chunked_loss.test_cpo_loss import HFCPOLoss | |
34 | | |
35 | super().__init__() | |
36 | self.lin = torch.nn.Linear( | |
37 | in_features=H, out_features=V, bias=False, dtype=dtype | |
38 | ) | |
39 | self.cpo_loss = HFCPOLoss().get_batch_loss_metrics | |
40 | | |
41 | def forward(self, x, y): | |
42 | return self.cpo_loss(x, self.lin.weight, y) | |
43 | | |
44 | | |
45 | class LigerLMHeadCPO(torch.nn.Module): | |
46 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | |
47 | super().__init__() | |
48 | self.lin = torch.nn.Linear( | |
49 | in_features=H, out_features=V, bias=False, dtype=dtype | |
50 | ) | |
51 | self.cpo_loss = LigerFusedLinearCPOFunction.apply | |
52 | | |
53 | def forward(self, x, y): | |
54 | return self.cpo_loss(x, self.lin.weight, y) | |
55 | | |
56 | | |
57 | ############################################################################# | |
58 | # Test the memory consumption of the linear fused cross entropy loss | |
59 | ############################################################################# | |
60 | | |
61 | | |
62 | def bench_memory_fused_linear_cpo_loss( | |
63 | input: SingleBenchmarkRunInput, | |
64 | ) -> SingleBenchmarkRunOutput: | |
65 | B = input.x | |
66 | T = input.extra_benchmark_config["T"] | |
67 | H = input.extra_benchmark_config["H"] | |
68 | V = input.extra_benchmark_config["V"] | |
69 | dtype = input.extra_benchmark_config["dtype"] | |
70 | provider = input.kernel_provider | |
71 | | |
72 | torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) | |
73 | liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device) | |
74 | | |
75 | _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) | |
76 | target = torch.randint(V, (B, T), dtype=torch.long, device=device) | |
77 | | |
78 | def fwd(): | |
79 | if provider == "liger": | |
80 | return liger_lm_head_cpo(_input, target) | |
81 | elif provider == "huggingface": | |
82 | return torch_lm_head_cpo(_input, target) | |
83 | | |
84 | def full(): | |
85 | y = fwd() | |
86 | y.backward() | |
87 | | |
88 | mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) | |
89 | return SingleBenchmarkRunOutput( | |
90 | y_20=mem_20, | |
91 | y_50=mem_50, | |
92 | y_80=mem_80, | |
93 | ) | |
94 | | |
95 | | |
96 | # ############################################################################# | |
97 | # # Test the speed of the fused linear cross entropy loss | |
98 | # ############################################################################# | |
99 | | |
100 | | |
101 | def bench_speed_fused_linear_cpo_loss( | |
102 | input: SingleBenchmarkRunInput, | |
103 | ) -> SingleBenchmarkRunOutput: | |
104 | B = input.x | |
105 | T = input.extra_benchmark_config["T"] | |
106 | H = input.extra_benchmark_config["H"] | |
107 | V = input.extra_benchmark_config["V"] | |
108 | dtype = input.extra_benchmark_config["dtype"] | |
109 | provider = input.kernel_provider | |
110 | mode = input.kernel_operation_mode | |
111 | | |
112 | torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) | |
113 | liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device) | |
114 | | |
115 | _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) | |
116 | target = torch.randint(V, (B, T), dtype=torch.long, device=device) | |
117 | | |
118 | def fwd(): | |
119 | if provider == "liger": | |
120 | return liger_lm_head_cpo(_input, target) | |
121 | elif provider == "huggingface": | |
122 | return torch_lm_head_cpo(_input, target) | |
123 | | |
124 | if mode == "forward": | |
125 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
126 | fwd, | |
127 | rep=100, | |
128 | quantiles=QUANTILES, | |
129 | ) | |
130 | elif mode == "backward": | |
131 | y = fwd() | |
132 | | |
133 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
134 | lambda: y.backward(retain_graph=True), | |
135 | grad_to_none=[_input], | |
136 | rep=100, | |
137 | quantiles=QUANTILES, | |
138 | ) | |
139 | elif mode == "full": | |
140 | | |
141 | def full(): | |
142 | y = fwd() | |
143 | y.backward() | |
144 | | |
145 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
146 | full, | |
147 | rep=100, | |
148 | quantiles=QUANTILES, | |
149 | ) | |
150 | return SingleBenchmarkRunOutput( | |
151 | y_20=ms_20, | |
152 | y_50=ms_50, | |
153 | y_80=ms_80, | |
154 | ) | |
155 | | |
156 | | |
157 | if __name__ == "__main__": | |
158 | args = parse_benchmark_script_args() | |
159 | | |
160 | common_configs = { | |
161 | "kernel_name": "fused_linear_cpo_loss", | |
162 | "x_name": "B", | |
163 | "x_label": "B", | |
164 | "x_values": [2**i for i in range(1, 5)], | |
165 | "kernel_providers": ["liger", "huggingface"], | |
166 | "extra_benchmark_configs": [ | |
167 | { | |
168 | "T": 1024, | |
169 | "H": 4096, | |
170 | "V": 128256, | |
171 | "mode": "forward", | |
172 | "dtype": torch.bfloat16, | |
173 | } | |
174 | ], | |
175 | "overwrite": args.overwrite, | |
176 | } | |
177 | | |
178 | run_benchmarks( | |
179 | bench_test_fn=bench_speed_fused_linear_cpo_loss, | |
180 | kernel_operation_modes=["forward", "full"], | |
181 | metric_name="speed", | |
182 | metric_unit="ms", | |
183 | **common_configs | |
184 | ) | |
185 | run_benchmarks( | |
186 | bench_test_fn=bench_memory_fused_linear_cpo_loss, | |
187 | kernel_operation_modes=["full"], | |
188 | metric_name="memory", | |
189 | metric_unit="MB", | |
190 | **common_configs | |
191 | ) | |
192 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_cross_entropy.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import triton | |
3 | from torch.nn import CrossEntropyLoss | |
4 | from utils import ( | |
5 | QUANTILES, | |
6 | SingleBenchmarkRunInput, | |
7 | SingleBenchmarkRunOutput, | |
8 | _test_memory, | |
9 | parse_benchmark_script_args, | |
10 | run_benchmarks, | |
11 | ) | |
12 | | |
13 | from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss | |
14 | from liger_kernel.utils import infer_device | |
15 | | |
16 | device = infer_device() | |
17 | | |
18 | | |
19 | def bench_memory_cross_entropy( | |
20 | input: SingleBenchmarkRunInput, | |
21 | ) -> SingleBenchmarkRunOutput: | |
22 | torch_ce = CrossEntropyLoss() | |
23 | liger_ce = LigerCrossEntropyLoss() | |
24 | | |
25 | V = input.x | |
26 | provider = input.kernel_provider | |
27 | B = input.extra_benchmark_config["B"] | |
28 | T = input.extra_benchmark_config["T"] | |
29 | | |
30 | _input = torch.randn(B * T, V, requires_grad=True, device=device) | |
31 | target = torch.randint(V, (B * T, 1), device=device).squeeze(1) | |
32 | | |
33 | def fwd(): | |
34 | if provider == "liger": | |
35 | return liger_ce(_input, target) | |
36 | else: | |
37 | return torch_ce(_input, target) | |
38 | | |
39 | def full(): | |
40 | y = fwd() | |
41 | y.backward() | |
42 | | |
43 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) | |
44 | return SingleBenchmarkRunOutput( | |
45 | y_20=mem_20, | |
46 | y_50=mem_50, | |
47 | y_80=mem_80, | |
48 | ) | |
49 | | |
50 | | |
51 | def bench_speed_cross_entropy( | |
52 | input: SingleBenchmarkRunInput, | |
53 | ) -> SingleBenchmarkRunOutput: | |
54 | torch_ce = CrossEntropyLoss() | |
55 | liger_ce = LigerCrossEntropyLoss() | |
56 | | |
57 | V = input.x | |
58 | provider = input.kernel_provider | |
59 | mode = input.kernel_operation_mode | |
60 | B = input.extra_benchmark_config["B"] | |
61 | T = input.extra_benchmark_config["T"] | |
62 | | |
63 | _input = torch.randn(B * T, V, requires_grad=True, device=device) | |
64 | target = torch.randint(V, (B * T, 1), device=device).squeeze(1) | |
65 | | |
66 | def fwd(): | |
67 | if provider == "liger": | |
68 | return liger_ce(_input, target) | |
69 | else: | |
70 | return torch_ce(_input, target) | |
71 | | |
72 | if mode == "forward": | |
73 | ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) | |
74 | elif mode == "backward": | |
75 | y = fwd() | |
76 | | |
77 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
78 | lambda: y.backward(retain_graph=True), | |
79 | grad_to_none=[_input], | |
80 | rep=100, | |
81 | quantiles=QUANTILES, | |
82 | ) | |
83 | elif mode == "full": | |
84 | | |
85 | def full(): | |
86 | y = fwd() | |
87 | y.backward() | |
88 | | |
89 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
90 | full, rep=100, quantiles=QUANTILES | |
91 | ) | |
92 | | |
93 | return SingleBenchmarkRunOutput( | |
94 | y_20=ms_20, | |
95 | y_50=ms_50, | |
96 | y_80=ms_80, | |
97 | ) | |
98 | | |
99 | | |
100 | if __name__ == "__main__": | |
101 | args = parse_benchmark_script_args() | |
102 | | |
103 | common_configs = { | |
104 | "kernel_name": "cross_entropy", | |
105 | "x_name": "V", | |
106 | "x_label": "vocab size", | |
107 | "x_values": [2**i for i in range(12, 18)], | |
108 | "kernel_providers": ["liger", "huggingface"], | |
109 | "extra_benchmark_configs": [{"B": 8, "T": 2048}], | |
110 | "overwrite": args.overwrite, | |
111 | } | |
112 | | |
113 | run_benchmarks( | |
114 | bench_test_fn=bench_speed_cross_entropy, | |
115 | kernel_operation_modes=["forward", "full"], | |
116 | metric_name="speed", | |
117 | metric_unit="ms", | |
118 | **common_configs | |
119 | ) | |
120 | run_benchmarks( | |
121 | bench_test_fn=bench_memory_cross_entropy, | |
122 | kernel_operation_modes=["full"], | |
123 | metric_name="memory", | |
124 | metric_unit="MB", | |
125 | **common_configs | |
126 | ) | |
127 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_dpo_loss.py: | |
-------------------------------------------------------------------------------- | |
1 | from test.chunked_loss.test_dpo_loss import HF_DPO_Loss | |
2 | | |
3 | import torch | |
4 | import triton | |
5 | from utils import ( | |
6 | QUANTILES, | |
7 | SingleBenchmarkRunInput, | |
8 | SingleBenchmarkRunOutput, | |
9 | _test_memory, | |
10 | parse_benchmark_script_args, | |
11 | run_benchmarks, | |
12 | ) | |
13 | | |
14 | from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction | |
15 | from liger_kernel.utils import infer_device | |
16 | | |
17 | device = infer_device() | |
18 | | |
19 | | |
20 | class TorchDPOLoss(torch.nn.Module): | |
21 | def __init__( | |
22 | self, | |
23 | H: int, | |
24 | V: int, | |
25 | dtype: torch.dtype, | |
26 | beta: float = 0.1, | |
27 | ignore_index: int = -100, | |
28 | bias: bool = False, | |
29 | ): | |
30 | super().__init__() | |
31 | self.lin = torch.nn.Linear( | |
32 | in_features=H, out_features=V, bias=bias, dtype=dtype | |
33 | ) | |
34 | self.dpo_loss = HF_DPO_Loss(beta=beta, ignore_index=ignore_index) | |
35 | | |
36 | def forward(self, x, target): | |
37 | return self.dpo_loss.get_batch_loss_metrics( | |
38 | x, | |
39 | self.lin.weight, | |
40 | target, | |
41 | self.lin.bias if hasattr(self.lin, "bias") else None, | |
42 | ) | |
43 | | |
44 | | |
45 | class LigerDPOLoss(torch.nn.Module): | |
46 | def __init__( | |
47 | self, | |
48 | H: int, | |
49 | V: int, | |
50 | dtype: torch.dtype, | |
51 | beta: float = 0.1, | |
52 | ignore_index: int = -100, | |
53 | bias: bool = False, | |
54 | ): | |
55 | super().__init__() | |
56 | self.lin = torch.nn.Linear( | |
57 | in_features=H, out_features=V, bias=bias, dtype=dtype | |
58 | ) | |
59 | self.beta = beta | |
60 | self.ignore_index = ignore_index | |
61 | | |
62 | def forward(self, x, target): | |
63 | return LigerFusedLinearDPOFunction.apply( | |
64 | x, | |
65 | self.lin.weight, | |
66 | target, | |
67 | self.lin.bias if hasattr(self.lin, "bias") else None, | |
68 | self.ignore_index, | |
69 | self.beta, | |
70 | True, | |
71 | ) | |
72 | | |
73 | | |
74 | def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
75 | B = input.x | |
76 | T = input.extra_benchmark_config["T"] | |
77 | H = input.extra_benchmark_config["H"] | |
78 | V = input.extra_benchmark_config["V"] | |
79 | dtype = input.extra_benchmark_config["dtype"] | |
80 | bias = input.extra_benchmark_config["bias"] | |
81 | beta = input.extra_benchmark_config["beta"] | |
82 | ignore_index = input.extra_benchmark_config["ignore_index"] | |
83 | provider = input.kernel_provider | |
84 | | |
85 | torch_dpo_loss = TorchDPOLoss( | |
86 | H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias | |
87 | ).to(device) | |
88 | liger_dpo_loss = LigerDPOLoss( | |
89 | H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias | |
90 | ).to(device) | |
91 | | |
92 | # Input shape: [B, T, H] | |
93 | _input = torch.randn(B, T, H, device=device, dtype=dtype) | |
94 | # Target shape: [B, T] | |
95 | target = torch.randint(V, (B, T), dtype=torch.long, device=device) | |
96 | | |
97 | # Add ignore_index tokens to simulate padding | |
98 | num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() | |
99 | indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] | |
100 | target.view(-1)[indices_to_assign] = ignore_index | |
101 | | |
102 | def fwd(): | |
103 | if provider == "liger": | |
104 | return liger_dpo_loss(_input, target) | |
105 | elif provider == "huggingface": | |
106 | return torch_dpo_loss(_input, target) | |
107 | | |
108 | def full(): | |
109 | y = fwd() | |
110 | y.backward() | |
111 | | |
112 | mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) | |
113 | return SingleBenchmarkRunOutput( | |
114 | y_20=mem_20, | |
115 | y_50=mem_50, | |
116 | y_80=mem_80, | |
117 | ) | |
118 | | |
119 | | |
120 | def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
121 | B = input.x | |
122 | T = input.extra_benchmark_config["T"] | |
123 | H = input.extra_benchmark_config["H"] | |
124 | V = input.extra_benchmark_config["V"] | |
125 | dtype = input.extra_benchmark_config["dtype"] | |
126 | bias = input.extra_benchmark_config["bias"] | |
127 | beta = input.extra_benchmark_config["beta"] | |
128 | ignore_index = input.extra_benchmark_config["ignore_index"] | |
129 | provider = input.kernel_provider | |
130 | mode = input.kernel_operation_mode | |
131 | | |
132 | torch_dpo_loss = TorchDPOLoss( | |
133 | H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias | |
134 | ).to(device) | |
135 | liger_dpo_loss = LigerDPOLoss( | |
136 | H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias | |
137 | ).to(device) | |
138 | | |
139 | # Input shape: [B, T, H] | |
140 | _input = torch.randn(B, T, H, device=device, dtype=dtype) | |
141 | | |
142 | # Target shape: [B, T] | |
143 | target = torch.randint(V, (B, T), device=device, dtype=torch.long) | |
144 | | |
145 | # Add ignore_index tokens | |
146 | num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() | |
147 | indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] | |
148 | target.view(-1)[indices_to_assign] = ignore_index | |
149 | | |
150 | def fwd(): | |
151 | if provider == "liger": | |
152 | return liger_dpo_loss(_input, target) | |
153 | elif provider == "huggingface": | |
154 | return torch_dpo_loss(_input, target) | |
155 | | |
156 | if mode == "forward": | |
157 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
158 | fwd, | |
159 | rep=100, | |
160 | quantiles=QUANTILES, | |
161 | ) | |
162 | elif mode == "backward": | |
163 | y = fwd() | |
164 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
165 | lambda: y.backward(retain_graph=True), | |
166 | grad_to_none=[_input], | |
167 | rep=100, | |
168 | quantiles=QUANTILES, | |
169 | ) | |
170 | elif mode == "full": | |
171 | | |
172 | def full(): | |
173 | y = fwd() | |
174 | y.backward() | |
175 | | |
176 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
177 | full, | |
178 | rep=100, | |
179 | quantiles=QUANTILES, | |
180 | ) | |
181 | | |
182 | return SingleBenchmarkRunOutput( | |
183 | y_20=ms_20, | |
184 | y_50=ms_50, | |
185 | y_80=ms_80, | |
186 | ) | |
187 | | |
188 | | |
189 | if __name__ == "__main__": | |
190 | args = parse_benchmark_script_args() | |
191 | | |
192 | common_configs = { | |
193 | "kernel_name": "dpo_loss", | |
194 | "x_name": "B", | |
195 | "x_label": "Batch Size (B)", | |
196 | "x_values": [2**i for i in range(1, 6)], | |
197 | "kernel_providers": ["liger", "huggingface"], | |
198 | "extra_benchmark_configs": [ | |
199 | { | |
200 | "T": 512, | |
201 | "H": 1024, | |
202 | "V": 128256, | |
203 | "mode": "forward", | |
204 | "dtype": torch.bfloat16, | |
205 | "bias": True, | |
206 | "beta": 0.1, | |
207 | "ignore_index": 42, | |
208 | } | |
209 | ], | |
210 | "overwrite": args.overwrite, | |
211 | } | |
212 | | |
213 | run_benchmarks( | |
214 | bench_test_fn=bench_speed_dpo_loss, | |
215 | kernel_operation_modes=["forward", "full"], | |
216 | metric_name="speed", | |
217 | metric_unit="ms", | |
218 | **common_configs | |
219 | ) | |
220 | | |
221 | run_benchmarks( | |
222 | bench_test_fn=bench_memory_dpo_loss, | |
223 | kernel_operation_modes=["full"], | |
224 | metric_name="memory", | |
225 | metric_unit="MB", | |
226 | **common_configs | |
227 | ) | |
228 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_embedding.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import triton | |
3 | from torch.nn import Embedding | |
4 | from utils import ( | |
5 | QUANTILES, | |
6 | SingleBenchmarkRunInput, | |
7 | SingleBenchmarkRunOutput, | |
8 | _test_memory, | |
9 | parse_benchmark_script_args, | |
10 | run_benchmarks, | |
11 | ) | |
12 | | |
13 | from liger_kernel.transformers.experimental.embedding import LigerEmbedding | |
14 | from liger_kernel.utils import infer_device | |
15 | | |
16 | device = infer_device() | |
17 | | |
18 | # NOTE: For torch compile, we will just use default inductor settings. No further customization | |
19 | # is needed. | |
20 | | |
21 | | |
22 | def bench_speed_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
23 | V = input.x | |
24 | provider = input.kernel_provider | |
25 | mode = input.kernel_operation_mode | |
26 | | |
27 | B = input.extra_benchmark_config["B"] | |
28 | T = input.extra_benchmark_config["T"] | |
29 | D = input.extra_benchmark_config["D"] | |
30 | dtype = input.extra_benchmark_config["dtype"] | |
31 | | |
32 | torch_emb = Embedding(V, D).to(device).to(dtype) | |
33 | liger_emb = LigerEmbedding(V, D).to(device).to(dtype) | |
34 | torch_compile_emb = torch.compile(torch_emb) | |
35 | | |
36 | input_ids = torch.randint(0, V, (B, T), device=device) | |
37 | | |
38 | def fwd(): | |
39 | if provider == "liger": | |
40 | return liger_emb(input_ids) | |
41 | elif provider == "torch_compile": | |
42 | return torch_compile_emb(input_ids) | |
43 | else: | |
44 | return torch_emb(input_ids) | |
45 | | |
46 | def full(): | |
47 | output = fwd() | |
48 | output.backward(torch.randn_like(output)) | |
49 | | |
50 | if mode == "forward": | |
51 | ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) | |
52 | elif mode == "full": | |
53 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
54 | full, quantiles=QUANTILES, rep=100 | |
55 | ) | |
56 | return SingleBenchmarkRunOutput( | |
57 | y_20=ms_20, | |
58 | y_50=ms_50, | |
59 | y_80=ms_80, | |
60 | ) | |
61 | | |
62 | | |
63 | def bench_memory_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
64 | V = input.x | |
65 | provider = input.kernel_provider | |
66 | | |
67 | B = input.extra_benchmark_config["B"] | |
68 | T = input.extra_benchmark_config["T"] | |
69 | D = input.extra_benchmark_config["D"] | |
70 | dtype = input.extra_benchmark_config["dtype"] | |
71 | | |
72 | torch_emb = Embedding(V, D).to(device).to(dtype) | |
73 | liger_emb = LigerEmbedding(V, D).to(device).to(dtype) | |
74 | torch_compile_emb = torch.compile(torch_emb) | |
75 | | |
76 | input_ids = torch.randint(0, V, (B, T), device=device) | |
77 | | |
78 | def fwd(): | |
79 | if provider == "liger": | |
80 | return liger_emb(input_ids) | |
81 | elif provider == "torch_compile": | |
82 | return torch_compile_emb(input_ids) | |
83 | else: | |
84 | return torch_emb(input_ids) | |
85 | | |
86 | def full(): | |
87 | output = fwd() | |
88 | output.backward(torch.randn_like(output)) | |
89 | | |
90 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) | |
91 | return SingleBenchmarkRunOutput( | |
92 | y_20=mem_20, | |
93 | y_50=mem_50, | |
94 | y_80=mem_80, | |
95 | ) | |
96 | | |
97 | | |
98 | if __name__ == "__main__": | |
99 | args = parse_benchmark_script_args() | |
100 | | |
101 | common_configs = { | |
102 | "kernel_name": "embedding", | |
103 | "x_name": "V", | |
104 | "x_label": "embedding dimension", | |
105 | "x_values": [2**i for i in range(10, 18)], | |
106 | "kernel_providers": ["liger", "huggingface", "torch_compile"], | |
107 | "extra_benchmark_configs": [ | |
108 | # BERT | |
109 | {"B": 32, "T": 512, "D": 768, "dtype": torch.float32}, | |
110 | # Llama | |
111 | {"B": 8, "T": 2048, "D": 4096, "dtype": torch.float32}, | |
112 | ], | |
113 | "overwrite": args.overwrite, | |
114 | } | |
115 | | |
116 | run_benchmarks( | |
117 | bench_test_fn=bench_speed_embedding, | |
118 | kernel_operation_modes=["forward", "full"], | |
119 | metric_name="speed", | |
120 | metric_unit="ms", | |
121 | **common_configs | |
122 | ) | |
123 | run_benchmarks( | |
124 | bench_test_fn=bench_memory_embedding, | |
125 | kernel_operation_modes=["full"], | |
126 | metric_name="memory", | |
127 | metric_unit="MB", | |
128 | **common_configs | |
129 | ) | |
130 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_fused_linear_cross_entropy.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import triton | |
3 | from utils import ( | |
4 | QUANTILES, | |
5 | SingleBenchmarkRunInput, | |
6 | SingleBenchmarkRunOutput, | |
7 | _test_memory, | |
8 | parse_benchmark_script_args, | |
9 | run_benchmarks, | |
10 | ) | |
11 | | |
12 | from liger_kernel.transformers.fused_linear_cross_entropy import ( | |
13 | LigerFusedLinearCrossEntropyLoss, | |
14 | ) | |
15 | from liger_kernel.utils import infer_device | |
16 | | |
17 | device = infer_device() | |
18 | | |
19 | | |
20 | class TorchLMHeadCE(torch.nn.Module): | |
21 | """Ground truth implementation of the linear fused with torch based cross entropy loss. | |
22 | | |
23 | :param H: hidden size | |
24 | :param V: vocab size | |
25 | :param ignore_index: index to ignore | |
26 | :param reduction: reduction method | |
27 | """ | |
28 | | |
29 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | |
30 | super().__init__() | |
31 | self.lin = torch.nn.Linear( | |
32 | in_features=H, out_features=V, bias=False, dtype=dtype | |
33 | ) | |
34 | self.ce_loss = torch.nn.CrossEntropyLoss( | |
35 | ignore_index=ignore_index, reduction="mean" | |
36 | ) | |
37 | | |
38 | def forward(self, x, y): | |
39 | logits = self.lin(x) | |
40 | return self.ce_loss(logits, y) | |
41 | | |
42 | | |
43 | class LigerLMHeadCE(torch.nn.Module): | |
44 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | |
45 | super().__init__() | |
46 | self.lin = torch.nn.Linear( | |
47 | in_features=H, out_features=V, bias=False, dtype=dtype | |
48 | ) | |
49 | self.ce_loss = LigerFusedLinearCrossEntropyLoss( | |
50 | ignore_index=ignore_index, reduction="mean" | |
51 | ) | |
52 | | |
53 | def forward(self, x, y): | |
54 | return self.ce_loss(self.lin.weight, x, y) | |
55 | | |
56 | | |
57 | ############################################################################# | |
58 | # Test the memory consumption of the linear fused cross entropy loss | |
59 | ############################################################################# | |
60 | | |
61 | | |
62 | def bench_memory_fused_linear_cross_entropy( | |
63 | input: SingleBenchmarkRunInput, | |
64 | ) -> SingleBenchmarkRunOutput: | |
65 | BT = input.x | |
66 | H = input.extra_benchmark_config["H"] | |
67 | V = input.extra_benchmark_config["V"] | |
68 | dtype = input.extra_benchmark_config["dtype"] | |
69 | provider = input.kernel_provider | |
70 | | |
71 | torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) | |
72 | liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) | |
73 | | |
74 | _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device) | |
75 | target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1) | |
76 | | |
77 | def fwd(): | |
78 | if provider == "liger": | |
79 | return liger_lm_head_ce(_input, target) | |
80 | elif provider == "huggingface": | |
81 | return torch_lm_head_ce(_input, target) | |
82 | | |
83 | def full(): | |
84 | y = fwd() | |
85 | y.backward() | |
86 | | |
87 | mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) | |
88 | return SingleBenchmarkRunOutput( | |
89 | y_20=mem_20, | |
90 | y_50=mem_50, | |
91 | y_80=mem_80, | |
92 | ) | |
93 | | |
94 | | |
95 | # ############################################################################# | |
96 | # # Test the speed of the fused linear cross entropy loss | |
97 | # ############################################################################# | |
98 | | |
99 | | |
100 | def bench_speed_fused_linear_cross_entropy( | |
101 | input: SingleBenchmarkRunInput, | |
102 | ) -> SingleBenchmarkRunOutput: | |
103 | BT = input.x | |
104 | H = input.extra_benchmark_config["H"] | |
105 | V = input.extra_benchmark_config["V"] | |
106 | dtype = input.extra_benchmark_config["dtype"] | |
107 | provider = input.kernel_provider | |
108 | mode = input.kernel_operation_mode | |
109 | | |
110 | torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) | |
111 | liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) | |
112 | | |
113 | _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device) | |
114 | target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1) | |
115 | | |
116 | def fwd(): | |
117 | if provider == "liger": | |
118 | return liger_lm_head_ce(_input, target) | |
119 | elif provider == "huggingface": | |
120 | return torch_lm_head_ce(_input, target) | |
121 | | |
122 | if mode == "forward": | |
123 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
124 | fwd, | |
125 | rep=100, | |
126 | quantiles=QUANTILES, | |
127 | ) | |
128 | elif mode == "backward": | |
129 | y = fwd() | |
130 | | |
131 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
132 | lambda: y.backward(retain_graph=True), | |
133 | grad_to_none=[_input], | |
134 | rep=100, | |
135 | quantiles=QUANTILES, | |
136 | ) | |
137 | elif mode == "full": | |
138 | | |
139 | def full(): | |
140 | y = fwd() | |
141 | y.backward() | |
142 | | |
143 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
144 | full, | |
145 | rep=100, | |
146 | quantiles=QUANTILES, | |
147 | ) | |
148 | return SingleBenchmarkRunOutput( | |
149 | y_20=ms_20, | |
150 | y_50=ms_50, | |
151 | y_80=ms_80, | |
152 | ) | |
153 | | |
154 | | |
155 | if __name__ == "__main__": | |
156 | args = parse_benchmark_script_args() | |
157 | | |
158 | common_configs = { | |
159 | "kernel_name": "fused_linear_cross_entropy", | |
160 | "x_name": "BT", | |
161 | "x_label": "B x T", | |
162 | "x_values": [2**i for i in range(12, 16)], | |
163 | "kernel_providers": ["liger", "huggingface"], | |
164 | "extra_benchmark_configs": [ | |
165 | {"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16} | |
166 | ], | |
167 | "overwrite": args.overwrite, | |
168 | } | |
169 | | |
170 | run_benchmarks( | |
171 | bench_test_fn=bench_speed_fused_linear_cross_entropy, | |
172 | kernel_operation_modes=["forward", "full"], | |
173 | metric_name="speed", | |
174 | metric_unit="ms", | |
175 | **common_configs | |
176 | ) | |
177 | run_benchmarks( | |
178 | bench_test_fn=bench_memory_fused_linear_cross_entropy, | |
179 | kernel_operation_modes=["full"], | |
180 | metric_name="memory", | |
181 | metric_unit="MB", | |
182 | **common_configs | |
183 | ) | |
184 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_geglu.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import triton | |
3 | from transformers.models.llama.configuration_llama import LlamaConfig | |
4 | from transformers.models.llama.modeling_llama import LlamaMLP | |
5 | from utils import ( | |
6 | QUANTILES, | |
7 | SingleBenchmarkRunInput, | |
8 | SingleBenchmarkRunOutput, | |
9 | _test_memory, | |
10 | parse_benchmark_script_args, | |
11 | run_benchmarks, | |
12 | ) | |
13 | | |
14 | from liger_kernel.transformers.geglu import LigerGEGLUMLP | |
15 | from liger_kernel.utils import infer_device | |
16 | | |
17 | device = infer_device() | |
18 | | |
19 | | |
20 | def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
21 | seq_len = input.x | |
22 | bsz = input.extra_benchmark_config["bsz"] | |
23 | hidden_size = input.extra_benchmark_config["hidden_size"] | |
24 | intermediate_size = input.extra_benchmark_config["intermediate_size"] | |
25 | hidden_act = input.extra_benchmark_config["hidden_act"] | |
26 | dtype = input.extra_benchmark_config["dtype"] | |
27 | provider = input.kernel_provider | |
28 | mode = input.kernel_operation_mode | |
29 | | |
30 | llama_config = LlamaConfig( | |
31 | hidden_size=hidden_size, | |
32 | intermediate_size=intermediate_size, | |
33 | hidden_act=hidden_act, | |
34 | ) | |
35 | | |
36 | x_shape = (bsz, seq_len, hidden_size) | |
37 | | |
38 | # initialize input | |
39 | x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) | |
40 | | |
41 | if provider == "liger": | |
42 | layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype) | |
43 | elif provider == "huggingface": | |
44 | layer = LlamaMLP(config=llama_config).to(device).to(dtype) | |
45 | else: | |
46 | raise ValueError(f"Invalid provider: {provider} for GEGLU") | |
47 | | |
48 | def fwd(): | |
49 | return layer(x) | |
50 | | |
51 | if mode == "forward": | |
52 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
53 | fwd, | |
54 | grad_to_none=[x], | |
55 | rep=10, | |
56 | quantiles=QUANTILES, | |
57 | ) | |
58 | elif mode == "backward": | |
59 | do = torch.randn_like(x) | |
60 | y = fwd() | |
61 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
62 | lambda: y.backward(do, retain_graph=True), | |
63 | grad_to_none=[x], | |
64 | rep=10, | |
65 | quantiles=QUANTILES, | |
66 | ) | |
67 | else: | |
68 | | |
69 | def full(): | |
70 | y = fwd() | |
71 | y.backward(torch.randn_like(y), retain_graph=True) | |
72 | | |
73 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
74 | full, | |
75 | grad_to_none=[x], | |
76 | rep=10, | |
77 | quantiles=QUANTILES, | |
78 | ) | |
79 | | |
80 | return SingleBenchmarkRunOutput( | |
81 | y_20=ms_20, | |
82 | y_50=ms_50, | |
83 | y_80=ms_80, | |
84 | ) | |
85 | | |
86 | | |
87 | def bench_memory_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
88 | seq_len = input.x | |
89 | bsz = input.extra_benchmark_config["bsz"] | |
90 | hidden_size = input.extra_benchmark_config["hidden_size"] | |
91 | intermediate_size = input.extra_benchmark_config["intermediate_size"] | |
92 | hidden_act = input.extra_benchmark_config["hidden_act"] | |
93 | dtype = input.extra_benchmark_config["dtype"] | |
94 | provider = input.kernel_provider | |
95 | mode = input.kernel_operation_mode | |
96 | | |
97 | llama_config = LlamaConfig( | |
98 | hidden_size=hidden_size, | |
99 | intermediate_size=intermediate_size, | |
100 | hidden_act=hidden_act, | |
101 | ) | |
102 | | |
103 | x_shape = (bsz, seq_len, hidden_size) | |
104 | # initialize input | |
105 | x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) | |
106 | | |
107 | if provider == "liger": | |
108 | layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype) | |
109 | elif provider == "huggingface": | |
110 | layer = LlamaMLP(config=llama_config).to(device).to(dtype) | |
111 | else: | |
112 | raise ValueError(f"Invalid provider: {provider} for GEGLU") | |
113 | | |
114 | def fwd(): | |
115 | return layer(x) | |
116 | | |
117 | def full(): | |
118 | y = fwd() | |
119 | y.backward(torch.randn_like(y), retain_graph=True) | |
120 | | |
121 | if mode == "forward": | |
122 | mem_50, mem_20, mem_80 = _test_memory( | |
123 | fwd, | |
124 | quantiles=QUANTILES, | |
125 | ) | |
126 | elif mode == "backward": | |
127 | do = torch.randn_like(x) | |
128 | y = fwd() | |
129 | mem_50, mem_20, mem_80 = _test_memory( | |
130 | lambda: y.backward(do, retain_graph=True), | |
131 | quantiles=QUANTILES, | |
132 | ) | |
133 | else: | |
134 | mem_50, mem_20, mem_80 = _test_memory( | |
135 | full, | |
136 | quantiles=QUANTILES, | |
137 | ) | |
138 | | |
139 | return SingleBenchmarkRunOutput( | |
140 | y_20=mem_20, | |
141 | y_50=mem_50, | |
142 | y_80=mem_80, | |
143 | ) | |
144 | | |
145 | | |
146 | if __name__ == "__main__": | |
147 | args = parse_benchmark_script_args() | |
148 | | |
149 | common_configs = { | |
150 | "kernel_name": "geglu", | |
151 | "x_name": "T", | |
152 | "x_label": "sequence length", | |
153 | "x_values": [2**i for i in range(10, 14)], | |
154 | "kernel_providers": ["liger", "huggingface"], | |
155 | "extra_benchmark_configs": [ | |
156 | { | |
157 | "bsz": 8, | |
158 | "hidden_size": 4096, | |
159 | "intermediate_size": 11008, | |
160 | "hidden_act": "gelu_pytorch_tanh", | |
161 | "dtype": torch.bfloat16, | |
162 | } | |
163 | ], | |
164 | "overwrite": args.overwrite, | |
165 | } | |
166 | | |
167 | run_benchmarks( | |
168 | bench_test_fn=bench_speed_geglu, | |
169 | kernel_operation_modes=["full", "forward", "backward"], | |
170 | metric_name="speed", | |
171 | metric_unit="ms", | |
172 | **common_configs, | |
173 | ) | |
174 | run_benchmarks( | |
175 | bench_test_fn=bench_memory_geglu, | |
176 | kernel_operation_modes=["full", "forward", "backward"], | |
177 | metric_name="memory", | |
178 | metric_unit="MB", | |
179 | **common_configs, | |
180 | ) | |
181 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_group_norm.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import triton | |
3 | from utils import ( | |
4 | QUANTILES, | |
5 | SingleBenchmarkRunInput, | |
6 | SingleBenchmarkRunOutput, | |
7 | _test_memory, | |
8 | parse_benchmark_script_args, | |
9 | run_benchmarks, | |
10 | ) | |
11 | | |
12 | from liger_kernel.transformers.group_norm import LigerGroupNorm | |
13 | from liger_kernel.utils import infer_device | |
14 | | |
15 | device = infer_device() | |
16 | | |
17 | | |
18 | def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
19 | C = input.x | |
20 | provider = input.kernel_provider | |
21 | mode = input.kernel_operation_mode | |
22 | extra_benchmark_config = input.extra_benchmark_config | |
23 | M = extra_benchmark_config["M"] | |
24 | H = extra_benchmark_config["H"] | |
25 | channels_per_group = extra_benchmark_config["channels_per_group"] | |
26 | eps = extra_benchmark_config["eps"] | |
27 | dtype = extra_benchmark_config["dtype"] | |
28 | | |
29 | x_shape = (M, C, H) | |
30 | triton_ln = LigerGroupNorm( | |
31 | num_channels=C, num_groups=C // channels_per_group, eps=eps | |
32 | ).to(device) | |
33 | torch_ln = torch.nn.GroupNorm( | |
34 | num_groups=C // channels_per_group, num_channels=C, eps=eps | |
35 | ).to(device) | |
36 | | |
37 | x = torch.randn(x_shape, dtype=dtype, device=device) | |
38 | dy = torch.randn_like(x) | |
39 | x.requires_grad_(True) | |
40 | | |
41 | def y_fwd(): | |
42 | if provider == "liger": | |
43 | return triton_ln(x) | |
44 | if provider == "huggingface": | |
45 | return torch_ln(x) | |
46 | | |
47 | if mode == "forward": | |
48 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
49 | y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500 | |
50 | ) | |
51 | elif mode == "backward": | |
52 | y = y_fwd() | |
53 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
54 | lambda: y.backward(dy, retain_graph=True), | |
55 | quantiles=QUANTILES, | |
56 | grad_to_none=[x], | |
57 | rep=500, | |
58 | ) | |
59 | elif mode == "full": | |
60 | | |
61 | def full(): | |
62 | y = y_fwd() | |
63 | y.backward(dy, retain_graph=True) | |
64 | | |
65 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
66 | full, quantiles=QUANTILES, grad_to_none=[x], rep=500 | |
67 | ) | |
68 | | |
69 | return SingleBenchmarkRunOutput( | |
70 | y_20=ms_20, | |
71 | y_50=ms_50, | |
72 | y_80=ms_80, | |
73 | ) | |
74 | | |
75 | | |
76 | def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
77 | C = input.x | |
78 | provider = input.kernel_provider | |
79 | extra_benchmark_config = input.extra_benchmark_config | |
80 | M = extra_benchmark_config["M"] | |
81 | H = extra_benchmark_config["H"] | |
82 | channels_per_group = extra_benchmark_config["channels_per_group"] | |
83 | eps = extra_benchmark_config["eps"] | |
84 | dtype = extra_benchmark_config["dtype"] | |
85 | | |
86 | x_shape = (M, C, H) | |
87 | triton_ln = LigerGroupNorm( | |
88 | num_channels=C, num_groups=C // channels_per_group, eps=eps | |
89 | ).to(device) | |
90 | torch_ln = torch.nn.GroupNorm( | |
91 | num_groups=C // channels_per_group, num_channels=C, eps=eps | |
92 | ).to(device) | |
93 | | |
94 | x = torch.randn(x_shape, dtype=dtype, device=device) | |
95 | dy = torch.randn_like(x) | |
96 | x.requires_grad_(True) | |
97 | | |
98 | def y_fwd(): | |
99 | if provider == "liger": | |
100 | return triton_ln(x) | |
101 | if provider == "huggingface": | |
102 | return torch_ln(x) | |
103 | | |
104 | def full(): | |
105 | y = y_fwd() | |
106 | y.backward(dy, retain_graph=True) | |
107 | | |
108 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) | |
109 | return SingleBenchmarkRunOutput( | |
110 | y_20=mem_20, | |
111 | y_50=mem_50, | |
112 | y_80=mem_80, | |
113 | ) | |
114 | | |
115 | | |
116 | if __name__ == "__main__": | |
117 | args = parse_benchmark_script_args() | |
118 | | |
119 | common_configs = { | |
120 | "kernel_name": "group_norm", | |
121 | "x_name": "C", | |
122 | "x_label": "num_channels", | |
123 | "x_values": [2**i for i in range(5, 12)], | |
124 | "kernel_providers": ["liger", "huggingface"], | |
125 | "extra_benchmark_configs": [ | |
126 | { | |
127 | "M": 128, | |
128 | "H": 512, | |
129 | "channels_per_group": 4, | |
130 | "dtype": torch.float32, | |
131 | "eps": 1e-6, | |
132 | } | |
133 | ], | |
134 | "overwrite": args.overwrite, | |
135 | } | |
136 | | |
137 | run_benchmarks( | |
138 | bench_test_fn=bench_speed_group_norm, | |
139 | kernel_operation_modes=["forward", "full", "backward"], | |
140 | metric_name="speed", | |
141 | metric_unit="ms", | |
142 | **common_configs | |
143 | ) | |
144 | run_benchmarks( | |
145 | bench_test_fn=bench_memory_group_norm, | |
146 | kernel_operation_modes=["full", "forward", "backward"], | |
147 | metric_name="memory", | |
148 | metric_unit="MB", | |
149 | **common_configs | |
150 | ) | |
151 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_jsd.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import triton | |
3 | from utils import ( | |
4 | QUANTILES, | |
5 | SingleBenchmarkRunInput, | |
6 | SingleBenchmarkRunOutput, | |
7 | _test_memory, | |
8 | parse_benchmark_script_args, | |
9 | run_benchmarks, | |
10 | ) | |
11 | | |
12 | from liger_kernel.transformers.jsd import LigerJSD | |
13 | from liger_kernel.utils import infer_device | |
14 | | |
15 | device = infer_device() | |
16 | | |
17 | | |
18 | class TorchJSD(torch.nn.Module): | |
19 | def __init__( | |
20 | self, | |
21 | beta: float = 0.5, | |
22 | ignore_index: int = -100, | |
23 | dtype: torch.dtype = torch.float, | |
24 | ): | |
25 | super(TorchJSD, self).__init__() | |
26 | self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True) | |
27 | self.beta = beta | |
28 | self.ignore_index = ignore_index | |
29 | self.dtype = dtype | |
30 | | |
31 | def forward( | |
32 | self, | |
33 | log_q: torch.Tensor, # input | |
34 | log_p: torch.Tensor, # target | |
35 | label=None, | |
36 | ): | |
37 | log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) | |
38 | log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1)) | |
39 | m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta) | |
40 | loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + ( | |
41 | 1 - self.beta | |
42 | ) * self.kl(torch.log(m), log_q).sum(dim=-1) | |
43 | | |
44 | if label is not None: | |
45 | loss = torch.where(label != self.ignore_index, loss, 0.0) | |
46 | n_non_ignore = (label != self.ignore_index).sum().item() | |
47 | if n_non_ignore == 0: | |
48 | loss = 0.0 | |
49 | else: | |
50 | loss = (loss / n_non_ignore).sum() | |
51 | else: | |
52 | loss = (loss / log_q.shape[0]).sum() | |
53 | return loss.to(self.dtype) | |
54 | | |
55 | | |
56 | def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
57 | V = input.x | |
58 | B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] | |
59 | torch_jsd = TorchJSD() | |
60 | liger_jsd = LigerJSD() | |
61 | | |
62 | _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( | |
63 | dim=-1 | |
64 | ) | |
65 | target = torch.randn(B * T, V, device=device).log_softmax(dim=-1) | |
66 | | |
67 | def fwd(): | |
68 | if input.kernel_provider == "liger": | |
69 | return liger_jsd(_input, target) | |
70 | else: | |
71 | return torch_jsd(_input, target) | |
72 | | |
73 | if input.kernel_operation_mode == "forward": | |
74 | ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) | |
75 | elif input.kernel_operation_mode == "backward": | |
76 | y = fwd() | |
77 | | |
78 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
79 | lambda: y.backward(retain_graph=True), | |
80 | quantiles=QUANTILES, | |
81 | grad_to_none=[_input], | |
82 | rep=100, | |
83 | ) | |
84 | elif input.kernel_operation_mode == "full": | |
85 | | |
86 | def full(): | |
87 | y = fwd() | |
88 | y.backward(retain_graph=True) | |
89 | | |
90 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
91 | full, quantiles=QUANTILES, rep=100 | |
92 | ) | |
93 | return SingleBenchmarkRunOutput( | |
94 | y_20=ms_20, | |
95 | y_50=ms_50, | |
96 | y_80=ms_80, | |
97 | ) | |
98 | | |
99 | | |
100 | def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
101 | torch_jsd = TorchJSD() | |
102 | liger_jsd = LigerJSD() | |
103 | | |
104 | V = input.x | |
105 | B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] | |
106 | | |
107 | _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( | |
108 | dim=-1 | |
109 | ) | |
110 | target = torch.randn(B * T, V, device=device).log_softmax(dim=-1) | |
111 | | |
112 | def fwd(): | |
113 | if input.kernel_provider == "liger": | |
114 | return liger_jsd(_input, target) | |
115 | else: | |
116 | return torch_jsd(_input, target) | |
117 | | |
118 | def full(): | |
119 | y = fwd() | |
120 | y.backward(retain_graph=True) | |
121 | | |
122 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) | |
123 | | |
124 | return SingleBenchmarkRunOutput( | |
125 | y_20=mem_20, | |
126 | y_50=mem_50, | |
127 | y_80=mem_80, | |
128 | ) | |
129 | | |
130 | | |
131 | if __name__ == "__main__": | |
132 | args = parse_benchmark_script_args() | |
133 | common_args = { | |
134 | "kernel_name": "jsd", | |
135 | "x_name": "V", | |
136 | "x_label": "vocab size", | |
137 | "x_values": [2**i for i in range(12, 18)], | |
138 | "kernel_providers": ["liger", "torch"], | |
139 | "extra_benchmark_configs": [{"B": 4, "T": 2048}], | |
140 | "overwrite": args.overwrite, | |
141 | } | |
142 | | |
143 | run_benchmarks( | |
144 | bench_test_fn=bench_memory_jsd, | |
145 | kernel_operation_modes=["full"], | |
146 | metric_name="memory", | |
147 | metric_unit="MB", | |
148 | **common_args, | |
149 | ) | |
150 | | |
151 | run_benchmarks( | |
152 | bench_test_fn=bench_speed_jsd, | |
153 | kernel_operation_modes=["forward", "full"], | |
154 | metric_name="speed", | |
155 | metric_unit="ms", | |
156 | **common_args, | |
157 | ) | |
158 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_kl_div.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import torch.nn as nn | |
3 | import triton | |
4 | from utils import ( | |
5 | QUANTILES, | |
6 | SingleBenchmarkRunInput, | |
7 | SingleBenchmarkRunOutput, | |
8 | _test_memory, | |
9 | parse_benchmark_script_args, | |
10 | run_benchmarks, | |
11 | ) | |
12 | | |
13 | from liger_kernel.transformers.kl_div import LigerKLDIVLoss | |
14 | from liger_kernel.utils import infer_device | |
15 | | |
16 | device = infer_device() | |
17 | | |
18 | S, E = 12, 18 | |
19 | | |
20 | | |
21 | def bench_speed_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
22 | reduction = "batchmean" | |
23 | V = input.x | |
24 | B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] | |
25 | torch_kl_div = nn.KLDivLoss(reduction=reduction) | |
26 | liger_kl_div = LigerKLDIVLoss(reduction=reduction) | |
27 | | |
28 | _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( | |
29 | dim=-1 | |
30 | ) | |
31 | target = torch.randn(B * T, V, device=device).softmax(dim=-1) | |
32 | | |
33 | def fwd(): | |
34 | if input.kernel_provider == "liger": | |
35 | return liger_kl_div(_input, target) | |
36 | else: | |
37 | return torch_kl_div(_input, target) | |
38 | | |
39 | if input.kernel_operation_mode == "forward": | |
40 | ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) | |
41 | elif input.kernel_operation_mode == "backward": | |
42 | y = fwd() | |
43 | | |
44 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
45 | lambda: y.backward(retain_graph=True), | |
46 | quantiles=QUANTILES, | |
47 | grad_to_none=[_input], | |
48 | rep=100, | |
49 | ) | |
50 | elif input.kernel_operation_mode == "full": | |
51 | | |
52 | def full(): | |
53 | y = fwd() | |
54 | y.backward(retain_graph=True) | |
55 | | |
56 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
57 | full, quantiles=QUANTILES, rep=100 | |
58 | ) | |
59 | return SingleBenchmarkRunOutput( | |
60 | y_20=ms_20, | |
61 | y_50=ms_50, | |
62 | y_80=ms_80, | |
63 | ) | |
64 | | |
65 | | |
66 | def bench_memory_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
67 | reduction = "batchmean" | |
68 | torch_kl_div = nn.KLDivLoss(reduction=reduction) | |
69 | liger_kl_div = LigerKLDIVLoss(reduction=reduction) | |
70 | | |
71 | V = input.x | |
72 | B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] | |
73 | | |
74 | _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( | |
75 | dim=-1 | |
76 | ) | |
77 | target = torch.randn(B * T, V, device=device).softmax(dim=-1) | |
78 | | |
79 | def fwd(): | |
80 | if input.kernel_provider == "liger": | |
81 | return liger_kl_div(_input, target) | |
82 | else: | |
83 | return torch_kl_div(_input, target) | |
84 | | |
85 | def full(): | |
86 | y = fwd() | |
87 | y.backward(retain_graph=True) | |
88 | | |
89 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) | |
90 | | |
91 | return SingleBenchmarkRunOutput( | |
92 | y_20=mem_20, | |
93 | y_50=mem_50, | |
94 | y_80=mem_80, | |
95 | ) | |
96 | | |
97 | | |
98 | if __name__ == "__main__": | |
99 | args = parse_benchmark_script_args() | |
100 | common_args = { | |
101 | "kernel_name": "kl_div", | |
102 | "x_name": "V", | |
103 | "x_label": "vocab size", | |
104 | "x_values": [2**i for i in range(12, 18)], | |
105 | "kernel_providers": ["liger", "torch"], | |
106 | "extra_benchmark_configs": [{"B": 8, "T": 512}], | |
107 | "overwrite": args.overwrite, | |
108 | } | |
109 | | |
110 | run_benchmarks( | |
111 | bench_test_fn=bench_memory_kldiv, | |
112 | kernel_operation_modes=["full"], | |
113 | metric_name="memory", | |
114 | metric_unit="MB", | |
115 | **common_args, | |
116 | ) | |
117 | | |
118 | run_benchmarks( | |
119 | bench_test_fn=bench_speed_kldiv, | |
120 | kernel_operation_modes=["forward", "full"], | |
121 | metric_name="speed", | |
122 | metric_unit="ms", | |
123 | **common_args, | |
124 | ) | |
125 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_layer_norm.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import triton | |
3 | from utils import ( | |
4 | QUANTILES, | |
5 | SingleBenchmarkRunInput, | |
6 | SingleBenchmarkRunOutput, | |
7 | _test_memory, | |
8 | parse_benchmark_script_args, | |
9 | run_benchmarks, | |
10 | ) | |
11 | | |
12 | from liger_kernel.transformers.layer_norm import LigerLayerNorm | |
13 | from liger_kernel.utils import infer_device | |
14 | | |
15 | device = infer_device() | |
16 | | |
17 | | |
18 | def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
19 | N = input.x | |
20 | provider = input.kernel_provider | |
21 | mode = input.kernel_operation_mode | |
22 | extra_benchmark_config = input.extra_benchmark_config | |
23 | M = extra_benchmark_config["M"] | |
24 | eps = extra_benchmark_config["eps"] | |
25 | dtype = extra_benchmark_config["dtype"] | |
26 | | |
27 | x_shape = (M, N) | |
28 | triton_ln = LigerLayerNorm(hidden_size=N).to(device) | |
29 | torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device) | |
30 | | |
31 | x = torch.randn(x_shape, dtype=dtype, device=device) | |
32 | dy = torch.randn_like(x) | |
33 | x.requires_grad_(True) | |
34 | | |
35 | def y_fwd(): | |
36 | if provider == "liger": | |
37 | return triton_ln(x) | |
38 | if provider == "huggingface": | |
39 | return torch_ln(x) | |
40 | | |
41 | if mode == "forward": | |
42 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
43 | y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500 | |
44 | ) | |
45 | elif mode == "backward": | |
46 | y = y_fwd() | |
47 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
48 | lambda: y.backward(dy, retain_graph=True), | |
49 | quantiles=QUANTILES, | |
50 | grad_to_none=[x], | |
51 | rep=500, | |
52 | ) | |
53 | elif mode == "full": | |
54 | | |
55 | def full(): | |
56 | y = y_fwd() | |
57 | y.backward(dy, retain_graph=True) | |
58 | | |
59 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
60 | full, quantiles=QUANTILES, grad_to_none=[x], rep=500 | |
61 | ) | |
62 | | |
63 | return SingleBenchmarkRunOutput( | |
64 | y_20=ms_20, | |
65 | y_50=ms_50, | |
66 | y_80=ms_80, | |
67 | ) | |
68 | | |
69 | | |
70 | def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
71 | N = input.x | |
72 | provider = input.kernel_provider | |
73 | dtype = input.extra_benchmark_config["dtype"] | |
74 | M = input.extra_benchmark_config["M"] | |
75 | eps = input.extra_benchmark_config["eps"] | |
76 | | |
77 | x_shape = (M, N) | |
78 | | |
79 | triton_ln = LigerLayerNorm(hidden_size=N).to(device) | |
80 | torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device) | |
81 | | |
82 | x = torch.randn(x_shape, dtype=dtype, device=device) | |
83 | dy = torch.randn_like(x) | |
84 | x.requires_grad_(True) | |
85 | | |
86 | def y_fwd(): | |
87 | if provider == "liger": | |
88 | return triton_ln(x) | |
89 | if provider == "huggingface": | |
90 | return torch_ln(x) | |
91 | | |
92 | def full(): | |
93 | y = y_fwd() | |
94 | y.backward(dy, retain_graph=True) | |
95 | | |
96 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) | |
97 | return SingleBenchmarkRunOutput( | |
98 | y_20=mem_20, | |
99 | y_50=mem_50, | |
100 | y_80=mem_80, | |
101 | ) | |
102 | | |
103 | | |
104 | if __name__ == "__main__": | |
105 | args = parse_benchmark_script_args() | |
106 | | |
107 | common_configs = { | |
108 | "kernel_name": "layer_norm", | |
109 | "x_name": "N", | |
110 | "x_label": "hidden size", | |
111 | "x_values": [2**i for i in range(10, 15)], | |
112 | "kernel_providers": ["liger", "huggingface"], | |
113 | "extra_benchmark_configs": [{"M": 4096, "dtype": torch.float32, "eps": 1e-6}], | |
114 | "overwrite": args.overwrite, | |
115 | } | |
116 | | |
117 | run_benchmarks( | |
118 | bench_test_fn=bench_speed_layer_norm, | |
119 | kernel_operation_modes=["forward", "full"], | |
120 | metric_name="speed", | |
121 | metric_unit="ms", | |
122 | **common_configs | |
123 | ) | |
124 | run_benchmarks( | |
125 | bench_test_fn=bench_memory_layer_norm, | |
126 | kernel_operation_modes=["full"], | |
127 | metric_name="memory", | |
128 | metric_unit="MB", | |
129 | **common_configs | |
130 | ) | |
131 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_orpo_loss.py: | |
-------------------------------------------------------------------------------- | |
1 | import os | |
2 | import sys | |
3 | | |
4 | import torch | |
5 | import triton | |
6 | from utils import ( | |
7 | QUANTILES, | |
8 | SingleBenchmarkRunInput, | |
9 | SingleBenchmarkRunOutput, | |
10 | _test_memory, | |
11 | parse_benchmark_script_args, | |
12 | run_benchmarks, | |
13 | ) | |
14 | | |
15 | from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction | |
16 | from liger_kernel.utils import infer_device | |
17 | | |
18 | device = infer_device() | |
19 | | |
20 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) | |
21 | | |
22 | | |
23 | class TorchLMHeadORPO(torch.nn.Module): | |
24 | """Ground truth implementation of the linear fused with torch based cross entropy loss. | |
25 | | |
26 | :param H: hidden size | |
27 | :param V: vocab size | |
28 | :param ignore_index: index to ignore | |
29 | :param reduction: reduction method | |
30 | """ | |
31 | | |
32 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | |
33 | from test.chunked_loss.test_orpo_loss import HF_ORPO_Loss | |
34 | | |
35 | super().__init__() | |
36 | self.lin = torch.nn.Linear( | |
37 | in_features=H, out_features=V, bias=False, dtype=dtype | |
38 | ) | |
39 | self.orpo_loss = HF_ORPO_Loss().get_batch_loss_metrics | |
40 | | |
41 | def forward(self, x, y): | |
42 | return self.orpo_loss(x, self.lin.weight, y) | |
43 | | |
44 | | |
45 | class LigerLMHeadORPO(torch.nn.Module): | |
46 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | |
47 | super().__init__() | |
48 | self.lin = torch.nn.Linear( | |
49 | in_features=H, out_features=V, bias=False, dtype=dtype | |
50 | ) | |
51 | self.orpo_loss = LigerFusedLinearORPOFunction.apply | |
52 | | |
53 | def forward(self, x, y): | |
54 | return self.orpo_loss(x, self.lin.weight, y) | |
55 | | |
56 | | |
57 | ############################################################################# | |
58 | # Test the memory consumption of the linear fused cross entropy loss | |
59 | ############################################################################# | |
60 | | |
61 | | |
62 | def bench_memory_fused_linear_orpo_loss( | |
63 | input: SingleBenchmarkRunInput, | |
64 | ) -> SingleBenchmarkRunOutput: | |
65 | B = input.x | |
66 | T = input.extra_benchmark_config["T"] | |
67 | H = input.extra_benchmark_config["H"] | |
68 | V = input.extra_benchmark_config["V"] | |
69 | dtype = input.extra_benchmark_config["dtype"] | |
70 | provider = input.kernel_provider | |
71 | | |
72 | torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device) | |
73 | liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device) | |
74 | | |
75 | _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) | |
76 | target = torch.randint(V, (B, T), dtype=torch.long, device=device) | |
77 | | |
78 | def fwd(): | |
79 | if provider == "liger": | |
80 | return liger_lm_head_orpo(_input, target) | |
81 | elif provider == "huggingface": | |
82 | return torch_lm_head_orpo(_input, target) | |
83 | | |
84 | def full(): | |
85 | y = fwd() | |
86 | y.backward() | |
87 | | |
88 | mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) | |
89 | return SingleBenchmarkRunOutput( | |
90 | y_20=mem_20, | |
91 | y_50=mem_50, | |
92 | y_80=mem_80, | |
93 | ) | |
94 | | |
95 | | |
96 | # ############################################################################# | |
97 | # # Test the speed of the fused linear cross entropy loss | |
98 | # ############################################################################# | |
99 | | |
100 | | |
101 | def bench_speed_fused_linear_orpo_loss( | |
102 | input: SingleBenchmarkRunInput, | |
103 | ) -> SingleBenchmarkRunOutput: | |
104 | B = input.x | |
105 | T = input.extra_benchmark_config["T"] | |
106 | H = input.extra_benchmark_config["H"] | |
107 | V = input.extra_benchmark_config["V"] | |
108 | dtype = input.extra_benchmark_config["dtype"] | |
109 | provider = input.kernel_provider | |
110 | mode = input.kernel_operation_mode | |
111 | | |
112 | torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device) | |
113 | liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device) | |
114 | | |
115 | _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) | |
116 | target = torch.randint(V, (B, T), dtype=torch.long, device=device) | |
117 | | |
118 | def fwd(): | |
119 | if provider == "liger": | |
120 | return liger_lm_head_orpo(_input, target) | |
121 | elif provider == "huggingface": | |
122 | return torch_lm_head_orpo(_input, target) | |
123 | | |
124 | if mode == "forward": | |
125 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
126 | fwd, | |
127 | rep=100, | |
128 | quantiles=QUANTILES, | |
129 | ) | |
130 | elif mode == "backward": | |
131 | y = fwd() | |
132 | | |
133 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
134 | lambda: y.backward(retain_graph=True), | |
135 | grad_to_none=[_input], | |
136 | rep=100, | |
137 | quantiles=QUANTILES, | |
138 | ) | |
139 | elif mode == "full": | |
140 | | |
141 | def full(): | |
142 | y = fwd() | |
143 | y.backward() | |
144 | | |
145 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
146 | full, | |
147 | rep=100, | |
148 | quantiles=QUANTILES, | |
149 | ) | |
150 | return SingleBenchmarkRunOutput( | |
151 | y_20=ms_20, | |
152 | y_50=ms_50, | |
153 | y_80=ms_80, | |
154 | ) | |
155 | | |
156 | | |
157 | if __name__ == "__main__": | |
158 | args = parse_benchmark_script_args() | |
159 | | |
160 | common_configs = { | |
161 | "kernel_name": "fused_linear_orpo_loss", | |
162 | "x_name": "B", | |
163 | "x_label": "B", | |
164 | "x_values": [2**i for i in range(1, 5)], | |
165 | "kernel_providers": ["liger", "huggingface"], | |
166 | "extra_benchmark_configs": [ | |
167 | { | |
168 | "T": 1024, | |
169 | "H": 4096, | |
170 | "V": 128256, | |
171 | "mode": "forward", | |
172 | "dtype": torch.bfloat16, | |
173 | } | |
174 | ], | |
175 | "overwrite": args.overwrite, | |
176 | } | |
177 | | |
178 | run_benchmarks( | |
179 | bench_test_fn=bench_speed_fused_linear_orpo_loss, | |
180 | kernel_operation_modes=["forward", "full"], | |
181 | metric_name="speed", | |
182 | metric_unit="ms", | |
183 | **common_configs | |
184 | ) | |
185 | run_benchmarks( | |
186 | bench_test_fn=bench_memory_fused_linear_orpo_loss, | |
187 | kernel_operation_modes=["full"], | |
188 | metric_name="memory", | |
189 | metric_unit="MB", | |
190 | **common_configs | |
191 | ) | |
192 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_qwen2vl_mrope.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import triton | |
3 | from transformers.models.qwen2_vl.modeling_qwen2_vl import ( | |
4 | Qwen2VLRotaryEmbedding, | |
5 | apply_multimodal_rotary_pos_emb, | |
6 | ) | |
7 | from utils import ( | |
8 | QUANTILES, | |
9 | SingleBenchmarkRunInput, | |
10 | SingleBenchmarkRunOutput, | |
11 | _test_memory, | |
12 | parse_benchmark_script_args, | |
13 | run_benchmarks, | |
14 | ) | |
15 | | |
16 | from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb | |
17 | from liger_kernel.utils import infer_device | |
18 | | |
19 | device = infer_device() | |
20 | | |
21 | | |
22 | def bench_speed_qwen2vl_mrope( | |
23 | input: SingleBenchmarkRunInput, | |
24 | ) -> SingleBenchmarkRunOutput: | |
25 | provider = input.kernel_provider | |
26 | mode = input.kernel_operation_mode | |
27 | | |
28 | extra_benchmark_config = input.extra_benchmark_config | |
29 | num_q_heads = extra_benchmark_config["num_q_heads"] | |
30 | num_kv_heads = extra_benchmark_config["num_kv_heads"] | |
31 | dtype = extra_benchmark_config["dtype"] | |
32 | | |
33 | # x can be either hidden_size or seq_len | |
34 | hidden_size = ( | |
35 | extra_benchmark_config["hidden_size"] | |
36 | if "hidden_size" in extra_benchmark_config | |
37 | else input.x | |
38 | ) | |
39 | seq_len = ( | |
40 | extra_benchmark_config["seq_len"] | |
41 | if "seq_len" in extra_benchmark_config | |
42 | else input.x | |
43 | ) | |
44 | | |
45 | head_dim = hidden_size // num_q_heads | |
46 | rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) | |
47 | q = torch.randn( | |
48 | (1, seq_len, num_q_heads, head_dim), | |
49 | device=device, | |
50 | requires_grad=True, | |
51 | dtype=dtype, | |
52 | ).transpose(1, 2) | |
53 | k = torch.randn( | |
54 | (1, seq_len, num_kv_heads, head_dim), | |
55 | device=device, | |
56 | requires_grad=True, | |
57 | dtype=dtype, | |
58 | ).transpose(1, 2) | |
59 | dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( | |
60 | k, device=device | |
61 | ) | |
62 | pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) | |
63 | cos, sin = rotary_emb(k, pos_ids) | |
64 | | |
65 | mrope_section_hw = head_dim * 3 // 16 | |
66 | mrope_section = [ | |
67 | head_dim // 2 - 2 * mrope_section_hw, | |
68 | mrope_section_hw, | |
69 | mrope_section_hw, | |
70 | ] | |
71 | | |
72 | def fwd(): | |
73 | if provider == "liger": | |
74 | return liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) | |
75 | elif provider == "huggingface": | |
76 | return apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) | |
77 | else: | |
78 | raise ValueError(f"Invalid provider: {provider} for M-RoPE embedding") | |
79 | | |
80 | if mode == "forward": | |
81 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
82 | fwd, | |
83 | grad_to_none=[q, k], | |
84 | rep=400, | |
85 | quantiles=QUANTILES, | |
86 | ) | |
87 | elif mode == "backward": | |
88 | q_out, k_out = fwd() | |
89 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
90 | lambda: torch.autograd.grad( | |
91 | (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True | |
92 | ), | |
93 | grad_to_none=[q, k], | |
94 | rep=400, | |
95 | quantiles=QUANTILES, | |
96 | ) | |
97 | elif mode == "full": | |
98 | | |
99 | def full(): | |
100 | q_out, k_out = fwd() | |
101 | torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True) | |
102 | | |
103 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
104 | full, | |
105 | grad_to_none=[q, k], | |
106 | rep=400, | |
107 | quantiles=QUANTILES, | |
108 | ) | |
109 | return SingleBenchmarkRunOutput( | |
110 | y_20=ms_20, | |
111 | y_50=ms_50, | |
112 | y_80=ms_80, | |
113 | ) | |
114 | | |
115 | | |
116 | def bench_memory_qwen2vl_mrope( | |
117 | input: SingleBenchmarkRunInput, | |
118 | ) -> SingleBenchmarkRunOutput: | |
119 | provider = input.kernel_provider | |
120 | | |
121 | extra_benchmark_config = input.extra_benchmark_config | |
122 | num_q_heads = extra_benchmark_config["num_q_heads"] | |
123 | num_kv_heads = extra_benchmark_config["num_kv_heads"] | |
124 | dtype = extra_benchmark_config["dtype"] | |
125 | | |
126 | # x can be either hidden_size or seq_len | |
127 | hidden_size = ( | |
128 | extra_benchmark_config["hidden_size"] | |
129 | if "hidden_size" in extra_benchmark_config | |
130 | else input.x | |
131 | ) | |
132 | seq_len = ( | |
133 | extra_benchmark_config["seq_len"] | |
134 | if "seq_len" in extra_benchmark_config | |
135 | else input.x | |
136 | ) | |
137 | | |
138 | head_dim = hidden_size // num_q_heads | |
139 | rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) | |
140 | q = torch.randn( | |
141 | (1, seq_len, num_q_heads, head_dim), | |
142 | device=device, | |
143 | requires_grad=True, | |
144 | dtype=dtype, | |
145 | ).transpose(1, 2) | |
146 | k = torch.randn( | |
147 | (1, seq_len, num_kv_heads, head_dim), | |
148 | device=device, | |
149 | requires_grad=True, | |
150 | dtype=dtype, | |
151 | ).transpose(1, 2) | |
152 | dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( | |
153 | k, device=device | |
154 | ) | |
155 | pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) | |
156 | cos, sin = rotary_emb(k, pos_ids) | |
157 | | |
158 | mrope_section_hw = head_dim * 3 // 16 | |
159 | mrope_section = [ | |
160 | head_dim // 2 - 2 * mrope_section_hw, | |
161 | mrope_section_hw, | |
162 | mrope_section_hw, | |
163 | ] | |
164 | | |
165 | def full(): | |
166 | if provider == "liger": | |
167 | q_out, k_out = liger_multimodal_rotary_pos_emb( | |
168 | q, k, cos, sin, mrope_section | |
169 | ) | |
170 | else: | |
171 | q_out, k_out = apply_multimodal_rotary_pos_emb( | |
172 | q, k, cos, sin, mrope_section | |
173 | ) | |
174 | torch.autograd.grad( | |
175 | (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True | |
176 | ) | |
177 | | |
178 | mem_50, mem_20, mem_80 = _test_memory( | |
179 | full, | |
180 | quantiles=QUANTILES, | |
181 | ) | |
182 | return SingleBenchmarkRunOutput( | |
183 | y_20=mem_20, | |
184 | y_50=mem_50, | |
185 | y_80=mem_80, | |
186 | ) | |
187 | | |
188 | | |
189 | if __name__ == "__main__": | |
190 | args = parse_benchmark_script_args() | |
191 | | |
192 | common_configs_varying_hidden_size = { | |
193 | "kernel_name": "qwen2vl_mrope", | |
194 | "x_name": "H", | |
195 | "x_label": "hidden size", | |
196 | "x_values": [32 * (2**i) for i in range(4, 10, 2)], | |
197 | "kernel_providers": ["liger", "huggingface"], | |
198 | "extra_benchmark_configs": [ | |
199 | { | |
200 | "dtype": torch.bfloat16, | |
201 | "seq_len": 2048, | |
202 | "num_q_heads": 32, | |
203 | "num_kv_heads": 8, | |
204 | } | |
205 | ], | |
206 | "overwrite": args.overwrite, | |
207 | } | |
208 | run_benchmarks( | |
209 | bench_test_fn=bench_speed_qwen2vl_mrope, | |
210 | kernel_operation_modes=["forward", "backward", "full"], | |
211 | metric_name="speed", | |
212 | metric_unit="ms", | |
213 | **common_configs_varying_hidden_size, | |
214 | ) | |
215 | run_benchmarks( | |
216 | bench_test_fn=bench_memory_qwen2vl_mrope, | |
217 | kernel_operation_modes=["full"], | |
218 | metric_name="memory", | |
219 | metric_unit="MB", | |
220 | **common_configs_varying_hidden_size, | |
221 | ) | |
222 | | |
223 | common_configs_varying_seq_len = { | |
224 | "kernel_name": "qwen2vl_mrope", | |
225 | "x_name": "T", | |
226 | "x_label": "sequence length", | |
227 | "x_values": [2**i for i in range(10, 15)], | |
228 | "kernel_providers": ["liger", "huggingface"], | |
229 | "extra_benchmark_configs": [ | |
230 | { | |
231 | "dtype": torch.bfloat16, | |
232 | "hidden_size": 8192, | |
233 | "num_q_heads": 32, | |
234 | "num_kv_heads": 8, | |
235 | } | |
236 | ], | |
237 | "overwrite": args.overwrite, | |
238 | } | |
239 | run_benchmarks( | |
240 | bench_test_fn=bench_speed_qwen2vl_mrope, | |
241 | kernel_operation_modes=["forward", "backward", "full"], | |
242 | metric_name="speed", | |
243 | metric_unit="ms", | |
244 | **common_configs_varying_seq_len, | |
245 | ) | |
246 | run_benchmarks( | |
247 | bench_test_fn=bench_memory_qwen2vl_mrope, | |
248 | kernel_operation_modes=["full"], | |
249 | metric_name="memory", | |
250 | metric_unit="MB", | |
251 | **common_configs_varying_seq_len, | |
252 | ) | |
253 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_rms_norm.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import torch.nn as nn | |
3 | import triton | |
4 | from utils import ( | |
5 | QUANTILES, | |
6 | SingleBenchmarkRunInput, | |
7 | SingleBenchmarkRunOutput, | |
8 | _test_memory, | |
9 | parse_benchmark_script_args, | |
10 | run_benchmarks, | |
11 | ) | |
12 | | |
13 | from liger_kernel.transformers.rms_norm import LigerRMSNorm | |
14 | from liger_kernel.utils import infer_device | |
15 | | |
16 | device = infer_device() | |
17 | | |
18 | | |
19 | class LlamaRMSNorm(nn.Module): | |
20 | def __init__(self, hidden_size, eps=1e-6): | |
21 | """ | |
22 | LlamaRMSNorm is equivalent to T5LayerNorm | |
23 | """ | |
24 | super().__init__() | |
25 | self.weight = nn.Parameter(torch.ones(hidden_size)) | |
26 | self.variance_epsilon = eps | |
27 | | |
28 | def forward(self, hidden_states): | |
29 | input_dtype = hidden_states.dtype | |
30 | hidden_states = hidden_states.to(torch.float32) | |
31 | variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
32 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
33 | return self.weight * hidden_states.to(input_dtype) | |
34 | | |
35 | | |
36 | def bench_speed_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
37 | N = input.x | |
38 | provider = input.kernel_provider | |
39 | mode = input.kernel_operation_mode | |
40 | | |
41 | extra_benchmark_config = input.extra_benchmark_config | |
42 | M = extra_benchmark_config["M"] | |
43 | eps = extra_benchmark_config["eps"] | |
44 | dtype = extra_benchmark_config["dtype"] | |
45 | | |
46 | x_shape = (M, N) | |
47 | | |
48 | triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to(device) | |
49 | llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to(device) | |
50 | | |
51 | x = torch.randn(x_shape, dtype=dtype, device=device) | |
52 | dy = torch.randn_like(x) | |
53 | x.requires_grad_(True) | |
54 | | |
55 | # utility functions | |
56 | | |
57 | def y_fwd(): | |
58 | if provider == "liger": | |
59 | return triton_rms(x) | |
60 | | |
61 | if provider == "huggingface": | |
62 | return llama_rms(x) | |
63 | | |
64 | if mode == "forward": | |
65 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
66 | y_fwd, | |
67 | grad_to_none=[x], | |
68 | rep=500, | |
69 | quantiles=QUANTILES, | |
70 | ) | |
71 | elif mode == "backward": | |
72 | y = y_fwd() | |
73 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
74 | lambda: y.backward(dy, retain_graph=True), | |
75 | grad_to_none=[x], | |
76 | rep=500, | |
77 | quantiles=QUANTILES, | |
78 | ) | |
79 | elif mode == "full": | |
80 | | |
81 | def full(): | |
82 | y = y_fwd() | |
83 | y.backward(dy, retain_graph=True) | |
84 | | |
85 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
86 | full, | |
87 | grad_to_none=[x], | |
88 | rep=500, | |
89 | quantiles=QUANTILES, | |
90 | ) | |
91 | | |
92 | return SingleBenchmarkRunOutput( | |
93 | y_20=ms_20, | |
94 | y_50=ms_50, | |
95 | y_80=ms_80, | |
96 | ) | |
97 | | |
98 | | |
99 | def bench_memory_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
100 | N = input.x | |
101 | provider = input.kernel_provider | |
102 | | |
103 | extra_benchmark_config = input.extra_benchmark_config | |
104 | M = extra_benchmark_config["M"] | |
105 | eps = extra_benchmark_config["eps"] | |
106 | dtype = extra_benchmark_config["dtype"] | |
107 | | |
108 | x_shape = (M, N) | |
109 | | |
110 | triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to(device) | |
111 | llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to(device) | |
112 | | |
113 | x = torch.randn(x_shape, dtype=dtype, device=device) | |
114 | dy = torch.randn_like(x) | |
115 | x.requires_grad_(True) | |
116 | | |
117 | # utility functions | |
118 | def y_fwd(): | |
119 | if provider == "liger": | |
120 | return triton_rms(x) | |
121 | if provider == "huggingface": | |
122 | return llama_rms(x) | |
123 | | |
124 | def full(): | |
125 | y = y_fwd() | |
126 | y.backward(dy, retain_graph=True) | |
127 | | |
128 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) | |
129 | | |
130 | return SingleBenchmarkRunOutput( | |
131 | y_20=mem_20, | |
132 | y_50=mem_50, | |
133 | y_80=mem_80, | |
134 | ) | |
135 | | |
136 | | |
137 | if __name__ == "__main__": | |
138 | args = parse_benchmark_script_args() | |
139 | | |
140 | common_configs = { | |
141 | "kernel_name": "rms_norm", | |
142 | "x_name": "H", | |
143 | "x_label": "hidden size", | |
144 | "x_values": [2**i for i in range(10, 16)], | |
145 | "kernel_providers": ["liger", "huggingface"], | |
146 | "extra_benchmark_configs": [{"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6}], | |
147 | "overwrite": args.overwrite, | |
148 | } | |
149 | | |
150 | run_benchmarks( | |
151 | bench_test_fn=bench_speed_rms_norm, | |
152 | kernel_operation_modes=["forward", "full", "backward"], | |
153 | metric_name="speed", | |
154 | metric_unit="ms", | |
155 | **common_configs | |
156 | ) | |
157 | run_benchmarks( | |
158 | bench_test_fn=bench_memory_rms_norm, | |
159 | kernel_operation_modes=["full"], | |
160 | metric_name="memory", | |
161 | metric_unit="MB", | |
162 | **common_configs | |
163 | ) | |
164 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_rope.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import triton | |
3 | from transformers.models.llama.modeling_llama import ( | |
4 | LlamaRotaryEmbedding, | |
5 | apply_rotary_pos_emb, | |
6 | ) | |
7 | from utils import ( | |
8 | QUANTILES, | |
9 | SingleBenchmarkRunInput, | |
10 | SingleBenchmarkRunOutput, | |
11 | _test_memory, | |
12 | parse_benchmark_script_args, | |
13 | run_benchmarks, | |
14 | ) | |
15 | | |
16 | from liger_kernel.transformers.rope import liger_rotary_pos_emb | |
17 | from liger_kernel.utils import infer_device | |
18 | | |
19 | device = infer_device() | |
20 | | |
21 | | |
22 | def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
23 | provider = input.kernel_provider | |
24 | mode = input.kernel_operation_mode | |
25 | | |
26 | extra_benchmark_config = input.extra_benchmark_config | |
27 | num_q_heads = extra_benchmark_config["num_q_heads"] | |
28 | num_kv_heads = extra_benchmark_config["num_kv_heads"] | |
29 | dtype = extra_benchmark_config["dtype"] | |
30 | | |
31 | # x can be either hidden_size or seq_len | |
32 | hidden_size = ( | |
33 | extra_benchmark_config["hidden_size"] | |
34 | if "hidden_size" in extra_benchmark_config | |
35 | else input.x | |
36 | ) | |
37 | seq_len = ( | |
38 | extra_benchmark_config["seq_len"] | |
39 | if "seq_len" in extra_benchmark_config | |
40 | else input.x | |
41 | ) | |
42 | | |
43 | head_dim = hidden_size // num_q_heads | |
44 | rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) | |
45 | q = torch.randn( | |
46 | (1, seq_len, num_q_heads, head_dim), | |
47 | device=device, | |
48 | requires_grad=True, | |
49 | dtype=dtype, | |
50 | ).transpose(1, 2) | |
51 | k = torch.randn( | |
52 | (1, seq_len, num_kv_heads, head_dim), | |
53 | device=device, | |
54 | requires_grad=True, | |
55 | dtype=dtype, | |
56 | ).transpose(1, 2) | |
57 | dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( | |
58 | k, device=device | |
59 | ) | |
60 | pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) | |
61 | cos, sin = rotary_emb(k, pos_ids) | |
62 | | |
63 | def fwd(): | |
64 | if provider == "liger": | |
65 | return liger_rotary_pos_emb(q, k, cos, sin, pos_ids) | |
66 | elif provider == "huggingface": | |
67 | return apply_rotary_pos_emb(q, k, cos, sin, pos_ids) | |
68 | else: | |
69 | raise ValueError(f"Invalid provider: {provider} for RoPE embedding") | |
70 | | |
71 | if mode == "forward": | |
72 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
73 | fwd, | |
74 | grad_to_none=[q, k], | |
75 | rep=400, | |
76 | quantiles=QUANTILES, | |
77 | ) | |
78 | elif mode == "backward": | |
79 | q_out, k_out = fwd() | |
80 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
81 | lambda: torch.autograd.grad( | |
82 | (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True | |
83 | ), | |
84 | grad_to_none=[q, k], | |
85 | rep=400, | |
86 | quantiles=QUANTILES, | |
87 | ) | |
88 | elif mode == "full": | |
89 | | |
90 | def full(): | |
91 | q_out, k_out = fwd() | |
92 | torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True) | |
93 | | |
94 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
95 | full, | |
96 | grad_to_none=[q, k], | |
97 | rep=400, | |
98 | quantiles=QUANTILES, | |
99 | ) | |
100 | return SingleBenchmarkRunOutput( | |
101 | y_20=ms_20, | |
102 | y_50=ms_50, | |
103 | y_80=ms_80, | |
104 | ) | |
105 | | |
106 | | |
107 | def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
108 | provider = input.kernel_provider | |
109 | | |
110 | extra_benchmark_config = input.extra_benchmark_config | |
111 | num_q_heads = extra_benchmark_config["num_q_heads"] | |
112 | num_kv_heads = extra_benchmark_config["num_kv_heads"] | |
113 | dtype = extra_benchmark_config["dtype"] | |
114 | | |
115 | # x can be either hidden_size or seq_len | |
116 | hidden_size = ( | |
117 | extra_benchmark_config["hidden_size"] | |
118 | if "hidden_size" in extra_benchmark_config | |
119 | else input.x | |
120 | ) | |
121 | seq_len = ( | |
122 | extra_benchmark_config["seq_len"] | |
123 | if "seq_len" in extra_benchmark_config | |
124 | else input.x | |
125 | ) | |
126 | | |
127 | head_dim = hidden_size // num_q_heads | |
128 | rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) | |
129 | q = torch.randn( | |
130 | (1, seq_len, num_q_heads, head_dim), | |
131 | device=device, | |
132 | requires_grad=True, | |
133 | dtype=dtype, | |
134 | ).transpose(1, 2) | |
135 | k = torch.randn( | |
136 | (1, seq_len, num_kv_heads, head_dim), | |
137 | device=device, | |
138 | requires_grad=True, | |
139 | dtype=dtype, | |
140 | ).transpose(1, 2) | |
141 | dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( | |
142 | k, device=device | |
143 | ) | |
144 | pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) | |
145 | cos, sin = rotary_emb(k, pos_ids) | |
146 | | |
147 | def full(): | |
148 | if provider == "liger": | |
149 | q_out, k_out = liger_rotary_pos_emb(q, k, cos, sin, pos_ids) | |
150 | else: | |
151 | q_out, k_out = apply_rotary_pos_emb(q, k, cos, sin, pos_ids) | |
152 | torch.autograd.grad( | |
153 | (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True | |
154 | ) | |
155 | | |
156 | mem_50, mem_20, mem_80 = _test_memory( | |
157 | full, | |
158 | quantiles=QUANTILES, | |
159 | ) | |
160 | return SingleBenchmarkRunOutput( | |
161 | y_20=mem_20, | |
162 | y_50=mem_50, | |
163 | y_80=mem_80, | |
164 | ) | |
165 | | |
166 | | |
167 | if __name__ == "__main__": | |
168 | args = parse_benchmark_script_args() | |
169 | | |
170 | common_configs_varying_hidden_size = { | |
171 | "kernel_name": "rope", | |
172 | "x_name": "H", | |
173 | "x_label": "hidden size", | |
174 | "x_values": [32 * (2**i) for i in range(4, 10, 2)], | |
175 | "kernel_providers": ["liger", "huggingface"], | |
176 | "extra_benchmark_configs": [ | |
177 | { | |
178 | "dtype": torch.bfloat16, | |
179 | "seq_len": 2048, | |
180 | "num_q_heads": 32, | |
181 | "num_kv_heads": 8, | |
182 | } | |
183 | ], | |
184 | "overwrite": args.overwrite, | |
185 | } | |
186 | run_benchmarks( | |
187 | bench_test_fn=bench_speed_rope, | |
188 | kernel_operation_modes=["forward", "backward", "full"], | |
189 | metric_name="speed", | |
190 | metric_unit="ms", | |
191 | **common_configs_varying_hidden_size, | |
192 | ) | |
193 | run_benchmarks( | |
194 | bench_test_fn=bench_memory_rope, | |
195 | kernel_operation_modes=["full"], | |
196 | metric_name="memory", | |
197 | metric_unit="MB", | |
198 | **common_configs_varying_hidden_size, | |
199 | ) | |
200 | | |
201 | common_configs_varying_seq_len = { | |
202 | "kernel_name": "rope", | |
203 | "x_name": "T", | |
204 | "x_label": "sequence length", | |
205 | "x_values": [2**i for i in range(10, 15)], | |
206 | "kernel_providers": ["liger", "huggingface"], | |
207 | "extra_benchmark_configs": [ | |
208 | { | |
209 | "dtype": torch.bfloat16, | |
210 | "hidden_size": 8192, | |
211 | "num_q_heads": 32, | |
212 | "num_kv_heads": 8, | |
213 | } | |
214 | ], | |
215 | "overwrite": args.overwrite, | |
216 | } | |
217 | run_benchmarks( | |
218 | bench_test_fn=bench_speed_rope, | |
219 | kernel_operation_modes=["forward", "backward", "full"], | |
220 | metric_name="speed", | |
221 | metric_unit="ms", | |
222 | **common_configs_varying_seq_len, | |
223 | ) | |
224 | run_benchmarks( | |
225 | bench_test_fn=bench_memory_rope, | |
226 | kernel_operation_modes=["full"], | |
227 | metric_name="memory", | |
228 | metric_unit="MB", | |
229 | **common_configs_varying_seq_len, | |
230 | ) | |
231 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_simpo_loss.py: | |
-------------------------------------------------------------------------------- | |
1 | import os | |
2 | import sys | |
3 | | |
4 | import torch | |
5 | import triton | |
6 | from utils import ( | |
7 | QUANTILES, | |
8 | SingleBenchmarkRunInput, | |
9 | SingleBenchmarkRunOutput, | |
10 | _test_memory, | |
11 | parse_benchmark_script_args, | |
12 | run_benchmarks, | |
13 | ) | |
14 | | |
15 | from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction | |
16 | from liger_kernel.utils import infer_device | |
17 | | |
18 | device = infer_device() | |
19 | | |
20 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) | |
21 | | |
22 | | |
23 | class TorchLMHeadSimPO(torch.nn.Module): | |
24 | """Ground truth implementation of the linear fused with torch based cross entropy loss. | |
25 | | |
26 | :param H: hidden size | |
27 | :param V: vocab size | |
28 | :param ignore_index: index to ignore | |
29 | :param reduction: reduction method | |
30 | """ | |
31 | | |
32 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | |
33 | from test.chunked_loss.test_cpo_loss import HFCPOLoss | |
34 | | |
35 | super().__init__() | |
36 | self.lin = torch.nn.Linear( | |
37 | in_features=H, out_features=V, bias=False, dtype=dtype | |
38 | ) | |
39 | self.simpo_loss = HFCPOLoss(loss_type="simpo").get_batch_loss_metrics | |
40 | | |
41 | def forward(self, x, y): | |
42 | return self.simpo_loss(x, self.lin.weight, y) | |
43 | | |
44 | | |
45 | class LigerLMHeadSimPO(torch.nn.Module): | |
46 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | |
47 | super().__init__() | |
48 | self.lin = torch.nn.Linear( | |
49 | in_features=H, out_features=V, bias=False, dtype=dtype | |
50 | ) | |
51 | self.simpo_loss = LigerFusedLinearSimPOFunction.apply | |
52 | | |
53 | def forward(self, x, y): | |
54 | return self.simpo_loss(x, self.lin.weight, y) | |
55 | | |
56 | | |
57 | ############################################################################# | |
58 | # Test the memory consumption of the linear fused cross entropy loss | |
59 | ############################################################################# | |
60 | | |
61 | | |
62 | def bench_memory_fused_linear_simpo_loss( | |
63 | input: SingleBenchmarkRunInput, | |
64 | ) -> SingleBenchmarkRunOutput: | |
65 | B = input.x | |
66 | T = input.extra_benchmark_config["T"] | |
67 | H = input.extra_benchmark_config["H"] | |
68 | V = input.extra_benchmark_config["V"] | |
69 | dtype = input.extra_benchmark_config["dtype"] | |
70 | provider = input.kernel_provider | |
71 | | |
72 | torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) | |
73 | liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) | |
74 | | |
75 | _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) | |
76 | target = torch.randint(V, (B, T), dtype=torch.long, device=device) | |
77 | | |
78 | def fwd(): | |
79 | if provider == "liger": | |
80 | return liger_lm_head_simpo(_input, target) | |
81 | elif provider == "huggingface": | |
82 | return torch_lm_head_simpo(_input, target) | |
83 | | |
84 | def full(): | |
85 | y = fwd() | |
86 | y.backward() | |
87 | | |
88 | mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) | |
89 | return SingleBenchmarkRunOutput( | |
90 | y_20=mem_20, | |
91 | y_50=mem_50, | |
92 | y_80=mem_80, | |
93 | ) | |
94 | | |
95 | | |
96 | # ############################################################################# | |
97 | # # Test the speed of the fused linear cross entropy loss | |
98 | # ############################################################################# | |
99 | | |
100 | | |
101 | def bench_speed_fused_linear_simpo_loss( | |
102 | input: SingleBenchmarkRunInput, | |
103 | ) -> SingleBenchmarkRunOutput: | |
104 | B = input.x | |
105 | T = input.extra_benchmark_config["T"] | |
106 | H = input.extra_benchmark_config["H"] | |
107 | V = input.extra_benchmark_config["V"] | |
108 | dtype = input.extra_benchmark_config["dtype"] | |
109 | provider = input.kernel_provider | |
110 | mode = input.kernel_operation_mode | |
111 | | |
112 | torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) | |
113 | liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) | |
114 | | |
115 | _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) | |
116 | target = torch.randint(V, (B, T), dtype=torch.long, device=device) | |
117 | | |
118 | def fwd(): | |
119 | if provider == "liger": | |
120 | return liger_lm_head_simpo(_input, target) | |
121 | elif provider == "huggingface": | |
122 | return torch_lm_head_simpo(_input, target) | |
123 | | |
124 | if mode == "forward": | |
125 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
126 | fwd, | |
127 | rep=100, | |
128 | quantiles=QUANTILES, | |
129 | ) | |
130 | elif mode == "backward": | |
131 | y = fwd() | |
132 | | |
133 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
134 | lambda: y.backward(retain_graph=True), | |
135 | grad_to_none=[_input], | |
136 | rep=100, | |
137 | quantiles=QUANTILES, | |
138 | ) | |
139 | elif mode == "full": | |
140 | | |
141 | def full(): | |
142 | y = fwd() | |
143 | y.backward() | |
144 | | |
145 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
146 | full, | |
147 | rep=100, | |
148 | quantiles=QUANTILES, | |
149 | ) | |
150 | return SingleBenchmarkRunOutput( | |
151 | y_20=ms_20, | |
152 | y_50=ms_50, | |
153 | y_80=ms_80, | |
154 | ) | |
155 | | |
156 | | |
157 | if __name__ == "__main__": | |
158 | args = parse_benchmark_script_args() | |
159 | | |
160 | common_configs = { | |
161 | "kernel_name": "fused_linear_simpo_loss", | |
162 | "x_name": "B", | |
163 | "x_label": "B", | |
164 | "x_values": [2**i for i in range(1, 5)], | |
165 | "kernel_providers": ["liger", "huggingface"], | |
166 | "extra_benchmark_configs": [ | |
167 | { | |
168 | "T": 1024, | |
169 | "H": 4096, | |
170 | "V": 128256, | |
171 | "mode": "forward", | |
172 | "dtype": torch.bfloat16, | |
173 | } | |
174 | ], | |
175 | "overwrite": args.overwrite, | |
176 | } | |
177 | | |
178 | run_benchmarks( | |
179 | bench_test_fn=bench_speed_fused_linear_simpo_loss, | |
180 | kernel_operation_modes=["forward", "full"], | |
181 | metric_name="speed", | |
182 | metric_unit="ms", | |
183 | **common_configs | |
184 | ) | |
185 | run_benchmarks( | |
186 | bench_test_fn=bench_memory_fused_linear_simpo_loss, | |
187 | kernel_operation_modes=["full"], | |
188 | metric_name="memory", | |
189 | metric_unit="MB", | |
190 | **common_configs | |
191 | ) | |
192 | | |
-------------------------------------------------------------------------------- | |
/benchmark/scripts/benchmark_swiglu.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import triton | |
3 | from transformers.models.llama.configuration_llama import LlamaConfig | |
4 | from transformers.models.llama.modeling_llama import LlamaMLP | |
5 | from utils import ( | |
6 | QUANTILES, | |
7 | SingleBenchmarkRunInput, | |
8 | SingleBenchmarkRunOutput, | |
9 | _test_memory, | |
10 | parse_benchmark_script_args, | |
11 | run_benchmarks, | |
12 | ) | |
13 | | |
14 | from liger_kernel.transformers.swiglu import LigerSwiGLUMLP | |
15 | from liger_kernel.utils import infer_device | |
16 | | |
17 | device = infer_device() | |
18 | | |
19 | | |
20 | def bench_speed_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
21 | seq_len = input.x | |
22 | provider = input.kernel_provider | |
23 | mode = input.kernel_operation_mode | |
24 | | |
25 | extra_benchmark_config = input.extra_benchmark_config | |
26 | bsz = extra_benchmark_config["B"] | |
27 | hidden_size = extra_benchmark_config["hidden_size"] | |
28 | dtype = extra_benchmark_config["dtype"] | |
29 | intermediate_size = extra_benchmark_config["intermediate_size"] | |
30 | hidden_act = extra_benchmark_config["hidden_act"] | |
31 | | |
32 | llama_config = LlamaConfig( | |
33 | hidden_size=hidden_size, | |
34 | intermediate_size=intermediate_size, | |
35 | hidden_act=hidden_act, | |
36 | ) | |
37 | | |
38 | x_shape = (bsz, seq_len, hidden_size) | |
39 | | |
40 | # initialize input | |
41 | x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) | |
42 | | |
43 | if provider == "liger": | |
44 | layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype) | |
45 | elif provider == "huggingface": | |
46 | layer = LlamaMLP(config=llama_config).to(device).to(dtype) | |
47 | else: | |
48 | raise ValueError(f"Invalid provider: {provider} for SwiGLU") | |
49 | | |
50 | def fwd(): | |
51 | return layer(x) | |
52 | | |
53 | if mode == "forward": | |
54 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
55 | fwd, | |
56 | grad_to_none=[x], | |
57 | quantiles=QUANTILES, | |
58 | rep=10, | |
59 | ) | |
60 | elif mode == "backward": | |
61 | do = torch.randn_like(x) | |
62 | y = fwd() | |
63 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
64 | lambda: y.backward(do, retain_graph=True), | |
65 | grad_to_none=[x], | |
66 | quantiles=QUANTILES, | |
67 | rep=10, | |
68 | ) | |
69 | else: | |
70 | | |
71 | def full(): | |
72 | y = fwd() | |
73 | y.backward(torch.randn_like(y), retain_graph=True) | |
74 | | |
75 | ms_50, ms_20, ms_80 = triton.testing.do_bench( | |
76 | full, | |
77 | grad_to_none=[x], | |
78 | quantiles=QUANTILES, | |
79 | rep=10, | |
80 | ) | |
81 | | |
82 | return SingleBenchmarkRunOutput( | |
83 | y_20=ms_20, | |
84 | y_50=ms_50, | |
85 | y_80=ms_80, | |
86 | ) | |
87 | | |
88 | | |
89 | def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | |
90 | seq_len = input.x | |
91 | provider = input.kernel_provider | |
92 | mode = input.kernel_operation_mode | |
93 | | |
94 | extra_benchmark_config = input.extra_benchmark_config | |
95 | bsz = extra_benchmark_config["B"] | |
96 | hidden_size = extra_benchmark_config["hidden_size"] | |
97 | dtype = extra_benchmark_config["dtype"] | |
98 | intermediate_size = extra_benchmark_config["intermediate_size"] | |
99 | hidden_act = extra_benchmark_config["hidden_act"] | |
100 | | |
101 | llama_config = LlamaConfig( | |
102 | hidden_size=hidden_size, | |
103 | intermediate_size=intermediate_size, | |
104 | hidden_act=hidden_act, | |
105 | ) | |
106 | | |
107 | x_shape = (bsz, seq_len, hidden_size) | |
108 | | |
109 | # initialize input | |
110 | x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) | |
111 | | |
112 | if provider == "liger": | |
113 | layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype) | |
114 | elif provider == "huggingface": | |
115 | layer = LlamaMLP(config=llama_config).to(device).to(dtype) | |
116 | else: | |
117 | raise ValueError(f"Invalid provider: {provider} for SwiGLU") | |
118 | | |
119 | def fwd(): | |
120 | return layer(x) | |
121 | | |
122 | def full(): | |
123 | y = fwd() | |
124 | y.backward(torch.randn_like(y), retain_graph=True) | |
125 | | |
126 | if mode == "forward": | |
127 | mem_50, mem_20, mem_80 = _test_memory(fwd, quantiles=QUANTILES) | |
128 | elif mode == "backward": | |
129 | do = torch.randn_like(x) | |
130 | y = fwd() | |
131 | mem_50, mem_20, mem_80 = _test_memory( | |
132 | lambda: y.backward(do, retain_graph=True), quantiles=QUANTILES | |
133 | ) | |
134 | else: | |
135 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) | |
136 | | |
137 | return SingleBenchmarkRunOutput( | |
138 | y_20=mem_20, | |
139 | y_50=mem_50, | |
140 | y_80=mem_80, | |
141 | ) | |
142 | | |
143 | | |
144 | if __name__ == "__main__": | |
145 | args = parse_benchmark_script_args() | |
146 | | |
147 | common_configs = { | |
148 | "kernel_name": "swiglu", | |
149 | "x_name": "T", | |
150 | "x_label": "sequence length", | |
151 | "x_values": [2**i for i in range(10, 14)], | |
152 | "kernel_providers": ["liger", "huggingface"], | |
153 | "extra_benchmark_configs": [ | |
154 | { | |
155 | "B": 4, | |
156 | "hidden_size": 4096, | |
157 | "dtype": torch.bfloat16, | |
158 | "intermediate_size": 11008, | |
159 | "hidden_act": "silu", | |
160 | } | |
161 | ], | |
162 | "overwrite": args.overwrite, | |
163 | } | |
164 | | |
165 | run_benchmarks( | |
166 | bench_test_fn=bench_speed_swiglu, | |
167 | kernel_operation_modes=["forward"], | |
168 | metric_name="speed", | |
169 | metric_unit="ms", | |
170 | **common_configs, | |
171 | ) | |
172 | run_benchmarks( | |
173 | bench_test_fn=bench_memory_swiglu, | |
174 | kernel_operation_modes=["full"], | |
175 | metric_name="memory", | |
176 | metric_unit="MB", | |
177 | **common_configs, | |
178 | ) | |
179 | | |
-------------------------------------------------------------------------------- | |
/dev/modal/tests.py: | |
-------------------------------------------------------------------------------- | |
1 | from pathlib import Path | |
2 | | |
3 | import modal | |
4 | | |
5 | ROOT_PATH = Path(__file__).parent.parent.parent | |
6 | REMOTE_ROOT_PATH = "/root/liger-kernel" | |
7 | PYTHON_VERSION = "3.12" | |
8 | | |
9 | image = modal.Image.debian_slim(python_version=PYTHON_VERSION).pip_install("uv") | |
10 | | |
11 | app = modal.App("liger_tests", image=image) | |
12 | | |
13 | # mount: add local files to the remote container | |
14 | repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH) | |
15 | | |
16 | | |
17 | @app.function(gpu="A10G", mounts=[repo], timeout=60 * 15) | |
18 | def liger_tests(): | |
19 | import subprocess | |
20 | | |
21 | subprocess.run( | |
22 | ["uv pip install -e '.[dev]' --system"], | |
23 | check=True, | |
24 | shell=True, | |
25 | cwd=REMOTE_ROOT_PATH, | |
26 | ) | |
27 | subprocess.run(["make test"], check=True, shell=True, cwd=REMOTE_ROOT_PATH) | |
28 | subprocess.run( | |
29 | ["make test-convergence"], check=True, shell=True, cwd=REMOTE_ROOT_PATH | |
30 | ) | |
31 | | |
-------------------------------------------------------------------------------- | |
/dev/modal/tests_bwd.py: | |
-------------------------------------------------------------------------------- | |
1 | from pathlib import Path | |
2 | | |
3 | import modal | |
4 | | |
5 | ROOT_PATH = Path(__file__).parent.parent.parent | |
6 | REMOTE_ROOT_PATH = "/root/liger-kernel" | |
7 | PYTHON_VERSION = "3.12" | |
8 | | |
9 | image = modal.Image.debian_slim(python_version=PYTHON_VERSION).pip_install("uv") | |
10 | | |
11 | app = modal.App("liger_tests_bwd", image=image) | |
12 | | |
13 | # mount: add local files to the remote container | |
14 | repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH) | |
15 | | |
16 | | |
17 | @app.function(gpu="A10G", mounts=[repo], timeout=60 * 15) | |
18 | def liger_bwd_tests(): | |
19 | import subprocess | |
20 | | |
21 | subprocess.run( | |
22 | ["uv pip install -e '.[dev]' --system"], | |
23 | check=True, | |
24 | shell=True, | |
25 | cwd=REMOTE_ROOT_PATH, | |
26 | ) | |
27 | # force install transformers==4.44.2 | |
28 | subprocess.run( | |
29 | ["uv pip install transformers==4.44.2 --system"], | |
30 | check=True, | |
31 | shell=True, | |
32 | cwd=REMOTE_ROOT_PATH, | |
33 | ) | |
34 | subprocess.run(["make test"], check=True, shell=True, cwd=REMOTE_ROOT_PATH) | |
35 | subprocess.run( | |
36 | ["make test-convergence"], check=True, shell=True, cwd=REMOTE_ROOT_PATH | |
37 | ) | |
38 | | |
-------------------------------------------------------------------------------- | |
/examples/alignment/run_orpo.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | from datasets import load_dataset | |
3 | from transformers import AutoModelForCausalLM, AutoTokenizer | |
4 | from trl import ORPOConfig # noqa: F401 | |
5 | | |
6 | from liger_kernel.transformers.trainer import LigerORPOTrainer # noqa: F401 | |
7 | | |
8 | model = AutoModelForCausalLM.from_pretrained( | |
9 | "meta-llama/Llama-3.2-1B-Instruct", | |
10 | torch_dtype=torch.bfloat16, | |
11 | ) | |
12 | | |
13 | tokenizer = AutoTokenizer.from_pretrained( | |
14 | "meta-llama/Llama-3.2-1B-Instruct", | |
15 | max_length=512, | |
16 | padding="max_length", | |
17 | ) | |
18 | tokenizer.pad_token = tokenizer.eos_token | |
19 | | |
20 | train_dataset = load_dataset("trl-lib/tldr-preference", split="train") | |
21 | | |
22 | training_args = ORPOConfig( | |
23 | output_dir="Llama3.2_1B_Instruct", | |
24 | beta=0.1, | |
25 | max_length=128, | |
26 | per_device_train_batch_size=32, | |
27 | max_steps=100, | |
28 | save_strategy="no", | |
29 | ) | |
30 | | |
31 | trainer = LigerORPOTrainer( | |
32 | model=model, args=training_args, tokenizer=tokenizer, train_dataset=train_dataset | |
33 | ) | |
34 | | |
35 | trainer.train() | |
36 | | |
-------------------------------------------------------------------------------- | |
/examples/huggingface/launch_on_modal.py: | |
-------------------------------------------------------------------------------- | |
1 | """ | |
2 | launch_on_modal.py | |
3 | | |
4 | This tool is designed to launch scripts using Modal. | |
5 | | |
6 | It sets up the necessary environment, including GPU resources and python dependencies, | |
7 | and executes the specified training script remotely. | |
8 | | |
9 | ### Setup and Usage | |
10 | ```bash | |
11 | pip install modal | |
12 | modal setup # authenticate with Modal | |
13 | export HF_TOKEN="your_huggingface_token" # if using a gated model such as llama3 | |
14 | modal run launch_on_modal.py --script "run_qwen2_vl.sh" | |
15 | ``` | |
16 | | |
17 | ### Caveats | |
18 | This tool is intended as an easy on-ramp to using Liger-Kernel for fine-tuning LLMs and | |
19 | VLMs - it is a reproducible way to run benchmarks and example scripts. However, it is not | |
20 | the best way to develop a model on Modal, as it re-downloads the model and dataset each | |
21 | time it is run. For iterative development, consider using `modal.Volume` to cache the | |
22 | model and dataset between runs. | |
23 | """ | |
24 | | |
25 | import os | |
26 | | |
27 | import modal | |
28 | from modal import gpu | |
29 | | |
30 | TWO_HOURS = 2 * 60 * 60 | |
31 | SIXTEEN_GB = 16 * 1024 | |
32 | | |
33 | app = modal.App("liger-example") | |
34 | | |
35 | image = ( | |
36 | modal.Image.debian_slim() | |
37 | .pip_install_from_requirements("requirements.txt") | |
38 | .copy_local_dir(".", "/root") | |
39 | ) | |
40 | | |
41 | if "HF_TOKEN" not in os.environ: | |
42 | print("HF_TOKEN not found in environment variables, using an empty token.") | |
43 | hf_token_secret = modal.Secret.from_dict({"HF_TOKEN": os.environ.get("HF_TOKEN", "")}) | |
44 | | |
45 | | |
46 | @app.function( | |
47 | gpu=gpu.A100(count=4, size="80GB"), | |
48 | image=image, | |
49 | timeout=TWO_HOURS, | |
50 | memory=SIXTEEN_GB, | |
51 | secrets=[hf_token_secret], | |
52 | ) | |
53 | def launch_script(script: str): | |
54 | import subprocess | |
55 | | |
56 | script_path = f"/root/{script}" | |
57 | os.chmod(script_path, 0o755) # make script executable | |
58 | | |
59 | print(f"Running script: {script_path}") | |
60 | subprocess.run([script_path], check=True, cwd="/root", env=os.environ.copy()) | |
61 | | |
62 | | |
63 | @app.local_entrypoint() | |
64 | def main(script: str): | |
65 | """ | |
66 | Launch a script remotely on modal. | |
67 | ```bash | |
68 | export HF_TOKEN="your_huggingface_token" # if using a gated model such as llama3 | |
69 | modal run --detach launch_on_modal.py --script "run_qwen2_vl.sh" | |
70 | ``` | |
71 | """ | |
72 | launch_script.remote(script=script) | |
73 | | |
-------------------------------------------------------------------------------- | |
/examples/huggingface/training.py: | |
-------------------------------------------------------------------------------- | |
1 | from dataclasses import dataclass | |
2 | | |
3 | import datasets | |
4 | import torch | |
5 | import transformers | |
6 | from callback import EfficiencyCallback | |
7 | from trl import DataCollatorForCompletionOnlyLM, SFTTrainer | |
8 | | |
9 | from liger_kernel.transformers import AutoLigerKernelForCausalLM | |
10 | | |
11 | | |
12 | @dataclass | |
13 | class CustomArguments: | |
14 | model_name: str = "meta-llama/Meta-Llama-3-8B" | |
15 | dataset: str = "tatsu-lab/alpaca" | |
16 | max_seq_length: int = 512 | |
17 | use_liger: bool = False | |
18 | | |
19 | | |
20 | def formatting_prompts_func(example): | |
21 | return example["text"] | |
22 | | |
23 | | |
24 | def train(): | |
25 | parser = transformers.HfArgumentParser( | |
26 | (transformers.TrainingArguments, CustomArguments) | |
27 | ) | |
28 | training_args, custom_args = parser.parse_args_into_dataclasses() | |
29 | tokenizer = transformers.AutoTokenizer.from_pretrained( | |
30 | custom_args.model_name, | |
31 | padding_side="left", | |
32 | truncation_side="left", | |
33 | ) | |
34 | tokenizer.pad_token = tokenizer.eos_token | |
35 | | |
36 | dataset = datasets.load_dataset(custom_args.dataset)["train"].train_test_split( | |
37 | test_size=0.1 | |
38 | ) | |
39 | train_dataset = dataset["train"] | |
40 | eval_dataset = dataset["test"] | |
41 | response_prompt = tokenizer.encode("### Response:\n", add_special_tokens=False) | |
42 | collator = DataCollatorForCompletionOnlyLM( | |
43 | tokenizer=tokenizer, | |
44 | response_template=response_prompt, | |
45 | pad_to_multiple_of=16, | |
46 | ) | |
47 | | |
48 | if custom_args.use_liger: | |
49 | model = AutoLigerKernelForCausalLM.from_pretrained( | |
50 | custom_args.model_name, | |
51 | trust_remote_code=True, | |
52 | use_cache=False, | |
53 | torch_dtype=torch.bfloat16, | |
54 | # These args will get passed to the appropriate apply_liger_kernel_to_* function | |
55 | # to override the default settings | |
56 | # cross_entropy=True, | |
57 | # fused_linear_cross_entropy=False, | |
58 | ) | |
59 | else: | |
60 | model = transformers.AutoModelForCausalLM.from_pretrained( | |
61 | custom_args.model_name, | |
62 | trust_remote_code=True, | |
63 | use_cache=False, | |
64 | torch_dtype=torch.bfloat16, | |
65 | ) | |
66 | | |
67 | trainer = SFTTrainer( | |
68 | model=model, | |
69 | args=training_args, | |
70 | data_collator=collator, | |
71 | max_seq_length=custom_args.max_seq_length, | |
72 | train_dataset=train_dataset, | |
73 | eval_dataset=eval_dataset, | |
74 | formatting_func=formatting_prompts_func, | |
75 | callbacks=[EfficiencyCallback()], | |
76 | ) | |
77 | trainer.train() | |
78 | | |
79 | | |
80 | if __name__ == "__main__": | |
81 | train() | |
82 | | |
-------------------------------------------------------------------------------- | |
/examples/huggingface/training_multimodal.py: | |
-------------------------------------------------------------------------------- | |
1 | import os | |
2 | from dataclasses import dataclass | |
3 | | |
4 | import datasets | |
5 | import torch | |
6 | import transformers | |
7 | from callback import EfficiencyCallback | |
8 | from datasets import Image as ImageFeature | |
9 | from trl import SFTTrainer | |
10 | | |
11 | from liger_kernel.transformers import monkey_patch | |
12 | | |
13 | | |
14 | @dataclass | |
15 | class CustomArguments: | |
16 | model_name: str = "Qwen/Qwen2-VL-2B-Instruct" | |
17 | dataset: str = "HuggingFaceM4/the_cauldron" | |
18 | dataset_subset: str = "ai2d" | |
19 | dataset_split: str = "train" | |
20 | max_seq_length: int = 512 | |
21 | dataset_text_field: str = "texts" | |
22 | use_liger: bool = False | |
23 | | |
24 | | |
25 | def construct_model_and_processor(model_name: str, use_liger: bool) -> torch.nn.Module: | |
26 | if "Qwen2-VL" in model_name: | |
27 | from transformers import Qwen2VLForConditionalGeneration | |
28 | | |
29 | # These settings are used to reduce the memory footprint of the Qwen2-VL model, | |
30 | # which supports training/inferences on images in their native resolution. Large | |
31 | # images -> many visual tokens (a max of 16384) -> large memory consumption. | |
32 | # If fine-tuning for a real-world application, consider these values carefully. | |
33 | min_visual_tokens_per_image = 256 | |
34 | max_visual_tokens_per_image = 256 | |
35 | | |
36 | processor = transformers.AutoProcessor.from_pretrained( | |
37 | model_name, | |
38 | padding_side="left", | |
39 | truncation_side="left", | |
40 | min_pixels=min_visual_tokens_per_image * 28 * 28, # patch size is 14x14 | |
41 | max_pixels=max_visual_tokens_per_image * 28 * 28, # 4 patches / token | |
42 | ) | |
43 | processor.tokenizer.pad_token = processor.tokenizer.eos_token | |
44 | image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") | |
45 | | |
46 | if use_liger: | |
47 | print("Applying Liger Kernel to Qwen2-VL model") | |
48 | monkey_patch.apply_liger_kernel_to_qwen2_vl( | |
49 | # These args can be used to override the default Liger settings | |
50 | # cross_entropy=True, | |
51 | # fused_linear_cross_entropy=False, | |
52 | ) | |
53 | | |
54 | model = Qwen2VLForConditionalGeneration.from_pretrained( | |
55 | pretrained_model_name_or_path=model_name, | |
56 | use_cache=False, | |
57 | torch_dtype=torch.bfloat16, | |
58 | low_cpu_mem_usage=True, | |
59 | attn_implementation="sdpa", | |
60 | ) | |
61 | return model, processor, image_token_id | |
62 | | |
63 | raise NotImplementedError(f"Model {model_name} not supported") | |
64 | | |
65 | | |
66 | def _validate_and_extract_the_cauldron(examples) -> dict[str, list]: | |
67 | batch_texts = [] | |
68 | batch_images = [] | |
69 | for images, texts in zip(examples["images"], examples["texts"]): | |
70 | if not images: | |
71 | raise ValueError("No image found in example from the_cauldron dataset") | |
72 | if len(images) > 1: | |
73 | raise ValueError("Only one image per example is supported") | |
74 | batch_texts.extend(texts) | |
75 | batch_images.extend([images[0]] * len(texts)) | |
76 | return {"texts": batch_texts, "images": batch_images} | |
77 | | |
78 | | |
79 | def _format_for_convo(example, tokenizer): | |
80 | # cauldron data is already in message format {"user": ..., "assistant": ...} | |
81 | text = example["texts"] | |
82 | messages = [ | |
83 | { | |
84 | "role": "user", | |
85 | "content": [{"type": "image"}, {"type": "text", "text": text["user"]}], | |
86 | }, | |
87 | {"role": "assistant", "content": [{"type": "text", "text": text["assistant"]}]}, | |
88 | ] | |
89 | text = tokenizer.apply_chat_template(messages, tokenize=False) | |
90 | return {"texts": text} | |
91 | | |
92 | | |
93 | def train(): | |
94 | parser = transformers.HfArgumentParser( | |
95 | (transformers.TrainingArguments, CustomArguments) | |
96 | ) | |
97 | training_args, custom_args = parser.parse_args_into_dataclasses() | |
98 | training_args.remove_unused_columns = False # required to not drop the image column | |
99 | training_args.dataset_kwargs = {"skip_prepare_dataset": True} | |
100 | | |
101 | model, processor, image_token_id = construct_model_and_processor( | |
102 | custom_args.model_name, custom_args.use_liger | |
103 | ) | |
104 | | |
105 | dataset = ( | |
106 | datasets.load_dataset( | |
107 | custom_args.dataset, | |
108 | custom_args.dataset_subset, | |
109 | split=custom_args.dataset_split, | |
110 | ) | |
111 | .map( | |
112 | _validate_and_extract_the_cauldron, | |
113 | batched=True, | |
114 | num_proc=min(os.cpu_count(), 16), | |
115 | desc="Extracting text and images", | |
116 | ) | |
117 | .map( | |
118 | _format_for_convo, | |
119 | fn_kwargs={"tokenizer": processor.tokenizer}, | |
120 | desc="Formatting for convo", | |
121 | ) | |
122 | .cast_column("images", ImageFeature()) | |
123 | .train_test_split(test_size=0.1) | |
124 | ) | |
125 | | |
126 | train_dataset = dataset["train"] | |
127 | eval_dataset = dataset["test"] | |
128 | | |
129 | def collate_fn(examples): | |
130 | """ | |
131 | Taken directly from the TRL documentation with minor modifications: | |
132 | https://huggingface.co/docs/trl/en/sft_trainer#a-custom-collator-for-processing-multi-modal-data | |
133 | | |
134 | Modifications: | |
135 | 1. `apply_chat_template` is used to preprocess the texts before training begins (see above) | |
136 | 2. `example["messages"]` -> `example["texts"]` to conform with the_cauldron dataset schema | |
137 | 3. Ignoring image tokens in the loss computation | |
138 | """ | |
139 | # Get the texts and images | |
140 | texts = [example["texts"] for example in examples] | |
141 | images = [example["images"] for example in examples] | |
142 | | |
143 | # Tokenize the texts and process the images | |
144 | batch = processor(text=texts, images=images, return_tensors="pt", padding=True) | |
145 | | |
146 | # The labels are the input_ids, and we mask the padding tokens in the loss computation | |
147 | labels = batch["input_ids"].clone() | |
148 | labels[labels == processor.tokenizer.pad_token_id] = -100 | |
149 | | |
150 | # Ignore the image token index in the loss computation | |
151 | labels[labels == image_token_id] = -100 | |
152 | batch["labels"] = labels | |
153 | | |
154 | return batch | |
155 | | |
156 | trainer = SFTTrainer( | |
157 | model=model, | |
158 | args=training_args, | |
159 | data_collator=collate_fn, | |
160 | max_seq_length=custom_args.max_seq_length, | |
161 | dataset_text_field=custom_args.dataset_text_field, | |
162 | train_dataset=train_dataset, | |
163 | eval_dataset=eval_dataset, | |
164 | tokenizer=processor.tokenizer, | |
165 | callbacks=[EfficiencyCallback()], | |
166 | ) | |
167 | trainer.train() | |
168 | | |
169 | | |
170 | if __name__ == "__main__": | |
171 | train() | |
172 | | |
-------------------------------------------------------------------------------- | |
/setup.py: | |
-------------------------------------------------------------------------------- | |
1 | # setup.py | |
2 | | |
3 | import subprocess | |
4 | from typing import Literal | |
5 | | |
6 | from setuptools import setup | |
7 | | |
8 | | |
9 | def get_default_dependencies(): | |
10 | """Determine the appropriate dependencies based on detected hardware.""" | |
11 | platform = get_platform() | |
12 | | |
13 | if platform in ["cuda", "cpu"]: | |
14 | return [ | |
15 | "torch>=2.1.2", | |
16 | "triton>=2.3.1", | |
17 | ] | |
18 | elif platform == "rocm": | |
19 | return [ | |
20 | "torch>=2.6.0.dev", | |
21 | "triton>=3.0.0", | |
22 | ] | |
23 | | |
24 | | |
25 | def get_optional_dependencies(): | |
26 | """Get optional dependency groups.""" | |
27 | return { | |
28 | "dev": [ | |
29 | "transformers>=4.44.2", | |
30 | "matplotlib>=3.7.2", | |
31 | "flake8>=4.0.1.1", | |
32 | "black>=24.4.2", | |
33 | "isort>=5.13.2", | |
34 | "pytest>=7.1.2", | |
35 | "pytest-xdist", | |
36 | "pytest-rerunfailures", | |
37 | "datasets>=2.19.2", | |
38 | "seaborn", | |
39 | ] | |
40 | } | |
41 | | |
42 | | |
43 | # TODO: add intel XPU | |
44 | def get_platform() -> Literal["cuda", "rocm", "cpu"]: | |
45 | """ | |
46 | Detect whether the system has NVIDIA or AMD GPU without torch dependency. | |
47 | """ | |
48 | # Try nvidia-smi first | |
49 | try: | |
50 | subprocess.run(["nvidia-smi"], check=True) | |
51 | print("NVIDIA GPU detected") | |
52 | return "cuda" | |
53 | except (subprocess.SubprocessError, FileNotFoundError): | |
54 | # If nvidia-smi fails, check for ROCm | |
55 | try: | |
56 | subprocess.run(["rocm-smi"], check=True) | |
57 | print("ROCm GPU detected") | |
58 | return "rocm" | |
59 | except (subprocess.SubprocessError, FileNotFoundError): | |
60 | print("No GPU detected") | |
61 | return "cpu" | |
62 | | |
63 | | |
64 | setup( | |
65 | name="liger_kernel", | |
66 | package_dir={"": "src"}, | |
67 | packages=["liger_kernel"], | |
68 | install_requires=get_default_dependencies(), | |
69 | extras_require=get_optional_dependencies(), | |
70 | ) | |
71 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/__init__.py: | |
-------------------------------------------------------------------------------- | |
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/src/liger_kernel/__init__.py | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/chunked_loss/__init__.py: | |
-------------------------------------------------------------------------------- | |
1 | from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401 | |
2 | from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401 | |
3 | from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401 | |
4 | from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401 | |
5 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/chunked_loss/cpo_loss.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import torch.nn.functional as F | |
3 | | |
4 | from liger_kernel.chunked_loss.fused_linear_preference import ( | |
5 | LigerFusedLinearPreferenceBase, | |
6 | ) | |
7 | | |
8 | | |
9 | class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase): | |
10 | | |
11 | @staticmethod | |
12 | def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): | |
13 | """ | |
14 | Paper: https://arxiv.org/pdf/2401.08417 | |
15 | | |
16 | Formula: | |
17 | L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))] | |
18 | | |
19 | Where: | |
20 | - π_θ(y|x): Policy (model) probability | |
21 | - y_w: Chosen sequence | |
22 | - y_l: Rejected sequence | |
23 | - σ: Sigmoid function | |
24 | - β: Temperature parameter | |
25 | - E: Expected value over the dataset D | |
26 | - D: Dataset of preferences | |
27 | | |
28 | Args: | |
29 | chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). | |
30 | rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). | |
31 | full_target (torch.Tensor): Non chunked full target tensor | |
32 | beta (float): Weight for the CPO loss | |
33 | """ | |
34 | logits = beta * (chosen_logps - rejected_logps) | |
35 | loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2) | |
36 | return loss | |
37 | | |
38 | @staticmethod | |
39 | def forward( | |
40 | ctx, | |
41 | _input, | |
42 | weight, | |
43 | target, | |
44 | bias=None, | |
45 | ignore_index=-100, | |
46 | beta=0.1, | |
47 | alpha=1.0, | |
48 | compute_nll_loss=True, | |
49 | compiled=True, | |
50 | ): | |
51 | return LigerFusedLinearPreferenceBase.forward( | |
52 | ctx, | |
53 | _input, | |
54 | weight, | |
55 | target, | |
56 | bias, | |
57 | loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn, | |
58 | ignore_index=ignore_index, | |
59 | alpha=alpha, | |
60 | beta=beta, | |
61 | compute_nll_loss=compute_nll_loss, | |
62 | compiled=compiled, | |
63 | ) | |
64 | | |
65 | @staticmethod | |
66 | def backward(ctx, *grad_output): | |
67 | grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] | |
68 | return *grads, None, None, None, None, None | |
69 | | |
70 | | |
71 | class LigerFusedLinearCPOLoss(torch.nn.Module): | |
72 | """ | |
73 | Fused linear layer with CPO loss. | |
74 | """ | |
75 | | |
76 | def __init__( | |
77 | self, | |
78 | ignore_index: int = -100, | |
79 | beta: float = 0.1, | |
80 | alpha: float = 1.0, | |
81 | compute_nll_loss: bool = True, | |
82 | compiled: bool = True, | |
83 | ): | |
84 | """ | |
85 | Args: | |
86 | ignore_index (int): Index to ignore in the loss. | |
87 | beta (float): Weight for the odds ratio loss. | |
88 | """ | |
89 | super().__init__() | |
90 | self.ignore_index = ignore_index | |
91 | self.beta = beta | |
92 | self.alpha = alpha | |
93 | self.compute_nll_loss = compute_nll_loss | |
94 | self.compiled = compiled | |
95 | | |
96 | def forward(self, lin_weight, _input, target, bias=None): | |
97 | return LigerFusedLinearCPOFunction.apply( | |
98 | _input, | |
99 | lin_weight, | |
100 | target, | |
101 | bias, | |
102 | self.ignore_index, | |
103 | self.beta, | |
104 | self.alpha, | |
105 | self.compute_nll_loss, | |
106 | self.compiled, | |
107 | ) | |
108 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/chunked_loss/dpo_loss.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import torch.nn.functional as F | |
3 | | |
4 | from liger_kernel.chunked_loss.fused_linear_preference import ( | |
5 | LigerFusedLinearPreferenceBase, | |
6 | ) | |
7 | | |
8 | | |
9 | class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase): | |
10 | | |
11 | @staticmethod | |
12 | def preference_loss_fn( | |
13 | chosen_logps, | |
14 | rejected_logps, | |
15 | full_target, | |
16 | ref_chosen_logps=None, | |
17 | ref_rejected_logps=None, | |
18 | beta=0.1, | |
19 | ): | |
20 | """ | |
21 | Paper: https://arxiv.org/pdf/2305.18290 | |
22 | | |
23 | Formula: | |
24 | L_DPO = -E[ log_sigmoid( β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x))) ) ] | |
25 | | |
26 | Where: | |
27 | - π(y|x): Policy (model) probability | |
28 | - π_ref(y|x): Reference model probability | |
29 | - y_w: Chosen sequence | |
30 | - y_l: Rejected sequence | |
31 | - β: Weight for the direct preference loss | |
32 | - E: Expected value over the dataset | |
33 | | |
34 | Args: | |
35 | chosen_logps: Log probabilities of chosen tokens (batch_size,) | |
36 | rejected_logps: Log probabilities of rejected tokens (batch_size,) | |
37 | full_target: Non chunked full target tensor | |
38 | ref_chosen_logps: Reference log probs of chosen tokens (batch_size,) | |
39 | ref_rejected_logps: Reference log probs of rejected tokens (batch_size,) | |
40 | beta: Weight for the direct preference loss | |
41 | """ | |
42 | | |
43 | if ref_chosen_logps is None: | |
44 | ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device) | |
45 | if ref_rejected_logps is None: | |
46 | ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device) | |
47 | | |
48 | chosen_logratios = chosen_logps - ref_chosen_logps | |
49 | rejected_logratios = rejected_logps - ref_rejected_logps | |
50 | | |
51 | logits_diff = beta * (chosen_logratios - rejected_logratios) | |
52 | loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2) | |
53 | return loss | |
54 | | |
55 | @staticmethod | |
56 | def forward( | |
57 | ctx, | |
58 | _input, | |
59 | weight, | |
60 | target, | |
61 | bias=None, | |
62 | ref_input=None, | |
63 | ref_weight=None, | |
64 | ref_bias=None, | |
65 | ignore_index=-100, | |
66 | beta=0.1, | |
67 | compute_nll_loss=True, | |
68 | compiled=True, | |
69 | use_ref_model=True, | |
70 | ): | |
71 | return LigerFusedLinearPreferenceBase.forward( | |
72 | ctx=ctx, | |
73 | _input=_input, | |
74 | weight=weight, | |
75 | target=target, | |
76 | bias=bias, | |
77 | loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn, | |
78 | ignore_index=ignore_index, | |
79 | beta=beta, | |
80 | compute_nll_loss=compute_nll_loss, | |
81 | compiled=compiled, | |
82 | use_ref_model=use_ref_model, | |
83 | ref_input=ref_input, | |
84 | ref_weight=ref_weight, | |
85 | ref_bias=ref_bias, | |
86 | ) | |
87 | | |
88 | @staticmethod | |
89 | def backward(ctx, *grad_output): | |
90 | grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] | |
91 | return *grads, None, None, None, None, None, None, None, None | |
92 | | |
93 | | |
94 | class LigerFusedLinearDPOLoss(torch.nn.Module): | |
95 | """ | |
96 | Fused linear layer with DPO loss. | |
97 | """ | |
98 | | |
99 | def __init__( | |
100 | self, | |
101 | ignore_index: int = -100, | |
102 | beta: float = 0.1, | |
103 | compute_nll_loss: bool = True, | |
104 | compiled: bool = True, | |
105 | use_ref_model: bool = False, | |
106 | ): | |
107 | """ | |
108 | Args: | |
109 | ignore_index (int): Index to ignore in the loss. | |
110 | beta (float): Weight for the odds ratio loss. | |
111 | compute_nll_loss (bool): Whether to compute the NLL loss. | |
112 | compiled (bool): Whether to use the torch compiled kernel. | |
113 | use_ref_model (bool): Whether to use a reference model for the DPO loss. | |
114 | """ | |
115 | super().__init__() | |
116 | self.ignore_index = ignore_index | |
117 | self.beta = beta | |
118 | self.compute_nll_loss = compute_nll_loss | |
119 | self.compiled = compiled | |
120 | self.use_ref_model = use_ref_model | |
121 | | |
122 | def forward( | |
123 | self, | |
124 | lin_weight, | |
125 | _input, | |
126 | target, | |
127 | bias=None, | |
128 | ref_input=None, | |
129 | ref_weight=None, | |
130 | ref_bias=None, | |
131 | ): | |
132 | return LigerFusedLinearDPOFunction.apply( | |
133 | _input, | |
134 | lin_weight, | |
135 | target, | |
136 | bias, | |
137 | ref_input, | |
138 | ref_weight, | |
139 | ref_bias, | |
140 | self.ignore_index, | |
141 | self.beta, | |
142 | self.compute_nll_loss, | |
143 | self.compiled, | |
144 | self.use_ref_model, | |
145 | ) | |
146 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/chunked_loss/functional.py: | |
-------------------------------------------------------------------------------- | |
1 | from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction | |
2 | from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction | |
3 | from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction | |
4 | from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction | |
5 | | |
6 | liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply | |
7 | liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply | |
8 | liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply | |
9 | liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply | |
10 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/chunked_loss/orpo_loss.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import torch.nn.functional as F | |
3 | | |
4 | from liger_kernel.chunked_loss.fused_linear_preference import ( | |
5 | LigerFusedLinearPreferenceBase, | |
6 | ) | |
7 | | |
8 | | |
9 | class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase): | |
10 | | |
11 | @staticmethod | |
12 | def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): | |
13 | """ | |
14 | Paper: https://arxiv.org/pdf/2403.07691 | |
15 | | |
16 | Formula: | |
17 | Compute odds-ratio loss: L_OR = -log(σ(log(odds_θ(y_w|x) / odds_θ(y_l|x)))) | |
18 | where odds_θ(y|x) = P_θ(y|x) / (1 - P_θ(y|x)) | |
19 | | |
20 | Where: | |
21 | - P_θ(y|x): Policy (model) probability | |
22 | - y_w: Chosen sequence | |
23 | - y_l: Rejected sequence | |
24 | - σ: Sigmoid function | |
25 | - β: Weight for the odds ratio loss | |
26 | - odds_θ: Odds function for the policy | |
27 | | |
28 | Args: | |
29 | chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). | |
30 | rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). | |
31 | full_target (torch.Tensor): Non chunked full target tensor | |
32 | beta (float): Weight for the odds ratio loss. | |
33 | """ | |
34 | log_odds = (chosen_logps - rejected_logps) - ( | |
35 | torch.log1p(-torch.exp(chosen_logps)) | |
36 | - torch.log1p(-torch.exp(rejected_logps)) | |
37 | ) | |
38 | ratio = F.logsigmoid(log_odds) | |
39 | loss = beta * ratio.sum() / (full_target.shape[0] // 2) | |
40 | | |
41 | chosen_rewards = beta * chosen_logps | |
42 | rejected_rewards = beta * rejected_logps | |
43 | | |
44 | log_odds_ratio = torch.sum(ratio) / (full_target.shape[0] // 2) | |
45 | log_odds_chosen = torch.sum(log_odds) / (full_target.shape[0] // 2) | |
46 | | |
47 | return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen | |
48 | | |
49 | @staticmethod | |
50 | def forward( | |
51 | ctx, | |
52 | _input, | |
53 | weight, | |
54 | target, | |
55 | bias=None, | |
56 | ignore_index=-100, | |
57 | beta=0.1, | |
58 | compute_nll_loss=True, | |
59 | compiled=True, | |
60 | ): | |
61 | return LigerFusedLinearPreferenceBase.forward( | |
62 | ctx=ctx, | |
63 | _input=_input, | |
64 | weight=weight, | |
65 | target=target, | |
66 | bias=bias, | |
67 | loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn, | |
68 | ignore_index=ignore_index, | |
69 | beta=beta, | |
70 | compute_nll_loss=compute_nll_loss, | |
71 | compiled=compiled, | |
72 | ) | |
73 | | |
74 | @staticmethod | |
75 | def backward(ctx, *grad_output): | |
76 | grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] | |
77 | return *grads, None, None, None, None | |
78 | | |
79 | | |
80 | class LigerFusedLinearORPOLoss(torch.nn.Module): | |
81 | """ | |
82 | Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss. | |
83 | """ | |
84 | | |
85 | def __init__( | |
86 | self, | |
87 | ignore_index: int = -100, | |
88 | beta: float = 0.1, | |
89 | compute_nll_loss: bool = True, | |
90 | compiled: bool = True, | |
91 | ): | |
92 | """ | |
93 | Args: | |
94 | ignore_index (int): Index to ignore in the loss. | |
95 | beta (float): Weight for the odds ratio loss. | |
96 | """ | |
97 | super().__init__() | |
98 | self.ignore_index = ignore_index | |
99 | self.beta = beta | |
100 | self.compute_nll_loss = compute_nll_loss | |
101 | self.compiled = compiled | |
102 | | |
103 | def forward(self, lin_weight, _input, target, bias=None): | |
104 | return LigerFusedLinearORPOFunction.apply( | |
105 | _input, | |
106 | lin_weight, | |
107 | target, | |
108 | bias, | |
109 | self.ignore_index, | |
110 | self.beta, | |
111 | self.compute_nll_loss, | |
112 | self.compiled, | |
113 | ) | |
114 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/chunked_loss/simpo_loss.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import torch.nn.functional as F | |
3 | | |
4 | from liger_kernel.chunked_loss.fused_linear_preference import ( | |
5 | LigerFusedLinearPreferenceBase, | |
6 | ) | |
7 | | |
8 | | |
9 | class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase): | |
10 | | |
11 | @staticmethod | |
12 | def preference_loss_fn( | |
13 | chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5 | |
14 | ): | |
15 | """ | |
16 | Paper: https://arxiv.org/pdf/2405.14734 | |
17 | | |
18 | Formula: | |
19 | L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)] | |
20 | | |
21 | Where: | |
22 | - π_θ(y|x): Policy (model) probability | |
23 | - y_w: Chosen sequence | |
24 | - y_l: Rejected sequence | |
25 | - |y_w|, |y_l|: Sequence lengths | |
26 | - σ: Sigmoid function | |
27 | - β: beta weight | |
28 | - γ: gemma margin term | |
29 | | |
30 | Args: | |
31 | chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). | |
32 | rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). | |
33 | full_target: Non chunked full target tensor | |
34 | beta (float): beta weight | |
35 | gamma (float): gemma margin term | |
36 | """ | |
37 | logits = beta * (chosen_logps - rejected_logps) - gamma | |
38 | loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2) | |
39 | return loss | |
40 | | |
41 | @staticmethod | |
42 | def forward( | |
43 | ctx, | |
44 | _input, | |
45 | weight, | |
46 | target, | |
47 | bias=None, | |
48 | ignore_index=-100, | |
49 | beta=0.1, | |
50 | alpha=1.0, | |
51 | compute_nll_loss=False, | |
52 | compiled=True, | |
53 | gamma=0.5, | |
54 | ): | |
55 | return LigerFusedLinearPreferenceBase.forward( | |
56 | ctx, | |
57 | _input, | |
58 | weight, | |
59 | target, | |
60 | bias, | |
61 | loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn, | |
62 | compute_nll_loss=compute_nll_loss, | |
63 | ignore_index=ignore_index, | |
64 | alpha=alpha, | |
65 | beta=beta, | |
66 | compiled=compiled, | |
67 | gamma=gamma, | |
68 | ) | |
69 | | |
70 | @staticmethod | |
71 | def backward(ctx, *grad_output): | |
72 | grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] | |
73 | return *grads, None, None, None, None, None, None | |
74 | | |
75 | | |
76 | class LigerFusedLinearSimPOLoss(torch.nn.Module): | |
77 | """ | |
78 | Fused linear layer with SimPO loss. | |
79 | """ | |
80 | | |
81 | def __init__( | |
82 | self, | |
83 | ignore_index: int = -100, | |
84 | beta: float = 0.1, | |
85 | alpha: float = 1.0, | |
86 | compute_nll_loss: bool = True, | |
87 | compiled: bool = True, | |
88 | gamma: float = 0.5, | |
89 | ): | |
90 | """ | |
91 | Args: | |
92 | ignore_index (int): Index to ignore in the loss. | |
93 | beta (float): Weight for the odds ratio loss. | |
94 | """ | |
95 | super().__init__() | |
96 | self.ignore_index = ignore_index | |
97 | self.beta = beta | |
98 | self.alpha = alpha | |
99 | self.compute_nll_loss = compute_nll_loss | |
100 | self.compiled = compiled | |
101 | self.gamma = gamma | |
102 | | |
103 | def forward(self, lin_weight, _input, target, bias=None): | |
104 | return LigerFusedLinearSimPOFunction.apply( | |
105 | _input, | |
106 | lin_weight, | |
107 | target, | |
108 | bias, | |
109 | self.ignore_index, | |
110 | self.beta, | |
111 | self.alpha, | |
112 | self.compute_nll_loss, | |
113 | self.compiled, | |
114 | self.gamma, | |
115 | ) | |
116 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/env_report.py: | |
-------------------------------------------------------------------------------- | |
1 | import platform | |
2 | import sys | |
3 | from importlib.metadata import version | |
4 | | |
5 | | |
6 | def print_env_report(): | |
7 | """ | |
8 | | |
9 | Prints a report of the environment. Useful for debugging and reproducibility. | |
10 | Usage: | |
11 | ``` | |
12 | python -m liger_kernel.env_report | |
13 | ``` | |
14 | | |
15 | """ | |
16 | print("Environment Report:") | |
17 | print("-------------------") | |
18 | print(f"Operating System: {platform.platform()}") | |
19 | print(f"Python version: {sys.version.split()[0]}") | |
20 | | |
21 | try: | |
22 | print(f"Liger Kernel version: {version('liger-kernel')}") | |
23 | except ImportError: | |
24 | print("Liger Kernel: Not installed") | |
25 | | |
26 | try: | |
27 | import torch | |
28 | | |
29 | print(f"PyTorch version: {torch.__version__}") | |
30 | cuda_version = ( | |
31 | torch.version.cuda if torch.cuda.is_available() else "Not available" | |
32 | ) | |
33 | print(f"CUDA version: {cuda_version}") | |
34 | hip_version = ( | |
35 | torch.version.hip | |
36 | if torch.cuda.is_available() and torch.version.hip | |
37 | else "Not available" | |
38 | ) | |
39 | print(f"HIP(ROCm) version: {hip_version}") | |
40 | | |
41 | except ImportError: | |
42 | print("PyTorch: Not installed") | |
43 | print("CUDA version: Unable to query") | |
44 | print("HIP(ROCm) version: Unable to query") | |
45 | | |
46 | try: | |
47 | import triton | |
48 | | |
49 | print(f"Triton version: {triton.__version__}") | |
50 | except ImportError: | |
51 | print("Triton: Not installed") | |
52 | | |
53 | try: | |
54 | import transformers | |
55 | | |
56 | print(f"Transformers version: {transformers.__version__}") | |
57 | except ImportError: | |
58 | print("Transformers: Not installed") | |
59 | | |
60 | try: | |
61 | xpu_version = ( | |
62 | torch.version.xpu if torch.xpu.is_available() else "XPU Not Available" | |
63 | ) | |
64 | print(f"XPU version: {xpu_version}") | |
65 | except ImportError: | |
66 | print("XPU version: Unable to query") | |
67 | | |
68 | | |
69 | if __name__ == "__main__": | |
70 | print_env_report() | |
71 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/ops/__init__.py: | |
-------------------------------------------------------------------------------- | |
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/src/liger_kernel/ops/__init__.py | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/ops/experimental/embedding.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import triton | |
3 | import triton.language as tl | |
4 | | |
5 | from liger_kernel.ops.utils import ensure_contiguous | |
6 | | |
7 | | |
8 | @triton.jit | |
9 | def embedding_forward_kernel( | |
10 | embeddings_ptr, | |
11 | indices_ptr, | |
12 | output_ptr, | |
13 | n_elements, | |
14 | embedding_dim: tl.constexpr, | |
15 | BLOCK_SIZE_M: tl.constexpr, | |
16 | BLOCK_SIZE_N: tl.constexpr, | |
17 | ): | |
18 | pid_m = tl.program_id(0) | |
19 | pid_n = tl.program_id(1) | |
20 | | |
21 | start_m = pid_m * BLOCK_SIZE_M | |
22 | start_n = pid_n * BLOCK_SIZE_N | |
23 | offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M) | |
24 | mask_m = offsets_m < n_elements | |
25 | indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0) | |
26 | offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N) | |
27 | mask_n = offsets_n < embedding_dim | |
28 | | |
29 | embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :] | |
30 | embeddings = tl.load( | |
31 | embeddings_ptr + embedding_offsets, | |
32 | mask=mask_m[:, None] & mask_n[None, :], | |
33 | other=0.0, | |
34 | ) | |
35 | | |
36 | output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :] | |
37 | tl.store( | |
38 | output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :] | |
39 | ) | |
40 | | |
41 | | |
42 | @triton.jit | |
43 | def embedding_backward_kernel( | |
44 | grad_output_ptr, | |
45 | grad_weight_ptr, | |
46 | indices_ptr, | |
47 | n_elements, | |
48 | embedding_dim: tl.constexpr, | |
49 | BLOCK_SIZE_M: tl.constexpr, | |
50 | BLOCK_SIZE_N: tl.constexpr, | |
51 | ): | |
52 | pid_m = tl.program_id(0) | |
53 | pid_n = tl.program_id(1) | |
54 | | |
55 | start_m = pid_m * BLOCK_SIZE_M | |
56 | start_n = pid_n * BLOCK_SIZE_N | |
57 | offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M) | |
58 | mask_m = offsets_m < n_elements | |
59 | indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0) | |
60 | offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N) | |
61 | mask_n = offsets_n < embedding_dim | |
62 | | |
63 | grad_output = tl.load( | |
64 | grad_output_ptr + offsets_m[:, None] * embedding_dim + offsets_n[None, :], | |
65 | mask=mask_m[:, None] & mask_n[None, :], | |
66 | other=0.0, | |
67 | ) | |
68 | | |
69 | grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :] | |
70 | | |
71 | tl.atomic_add( | |
72 | grad_weight_ptr + grad_weight_offsets, | |
73 | grad_output, | |
74 | mask=mask_m[:, None] & mask_n[None, :], | |
75 | ) | |
76 | | |
77 | | |
78 | class LigerEmbeddingFunction(torch.autograd.Function): | |
79 | @staticmethod | |
80 | @ensure_contiguous | |
81 | def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor): | |
82 | ori_shape = indices.shape | |
83 | indices = indices.view(-1) | |
84 | output = torch.empty( | |
85 | indices.shape[0], | |
86 | embeddings.shape[1], | |
87 | device=indices.device, | |
88 | dtype=embeddings.dtype, | |
89 | ) | |
90 | | |
91 | n_elements = indices.numel() | |
92 | embedding_dim = embeddings.shape[1] | |
93 | | |
94 | BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim)) | |
95 | BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim)) | |
96 | grid = ( | |
97 | triton.cdiv(n_elements, BLOCK_SIZE_M), | |
98 | triton.cdiv(embedding_dim, BLOCK_SIZE_N), | |
99 | ) | |
100 | | |
101 | embedding_forward_kernel[grid]( | |
102 | embeddings, | |
103 | indices, | |
104 | output, | |
105 | n_elements, | |
106 | embedding_dim=embedding_dim, | |
107 | BLOCK_SIZE_M=BLOCK_SIZE_M, | |
108 | BLOCK_SIZE_N=BLOCK_SIZE_N, | |
109 | ) | |
110 | | |
111 | ctx.save_for_backward(indices, embeddings) | |
112 | | |
113 | return output.view(*ori_shape, -1) | |
114 | | |
115 | @staticmethod | |
116 | @ensure_contiguous | |
117 | def backward(ctx, grad_output: torch.Tensor): | |
118 | indices, embedding_table = ctx.saved_tensors | |
119 | grad_output = grad_output.contiguous().view(-1, embedding_table.shape[1]) | |
120 | | |
121 | grad_weight = torch.zeros_like(embedding_table) | |
122 | | |
123 | n_elements = indices.numel() | |
124 | embedding_dim = embedding_table.shape[1] | |
125 | | |
126 | BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim)) | |
127 | BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim)) | |
128 | grid = ( | |
129 | triton.cdiv(n_elements, BLOCK_SIZE_M), | |
130 | triton.cdiv(embedding_dim, BLOCK_SIZE_N), | |
131 | ) | |
132 | | |
133 | embedding_backward_kernel[grid]( | |
134 | grad_output, | |
135 | grad_weight, | |
136 | indices, | |
137 | n_elements, | |
138 | embedding_dim=embedding_dim, | |
139 | BLOCK_SIZE_M=BLOCK_SIZE_M, | |
140 | BLOCK_SIZE_N=BLOCK_SIZE_N, | |
141 | ) | |
142 | | |
143 | return grad_weight, None | |
144 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/ops/geglu.py: | |
-------------------------------------------------------------------------------- | |
1 | import operator | |
2 | | |
3 | import torch | |
4 | import triton | |
5 | import triton.language as tl | |
6 | | |
7 | from liger_kernel.ops.utils import ( | |
8 | calculate_settings, | |
9 | compare_version, | |
10 | ensure_contiguous, | |
11 | ) | |
12 | | |
13 | if compare_version("triton", operator.ge, "3.0.0"): | |
14 | try: | |
15 | # typical import path with dispatch available | |
16 | from triton.language.extra.libdevice import tanh | |
17 | except ModuleNotFoundError: | |
18 | # for working with NGC containers | |
19 | from triton.language.extra.cuda.libdevice import tanh | |
20 | else: | |
21 | from triton.language.math import tanh | |
22 | | |
23 | | |
24 | @triton.jit | |
25 | def _geglu_tanh_forward_kernel( | |
26 | a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr | |
27 | ): | |
28 | program_id = tl.program_id(0).to(tl.int64) | |
29 | | |
30 | # locate start index | |
31 | a += program_id * stride | |
32 | b += program_id * stride | |
33 | c += program_id * stride | |
34 | | |
35 | col_offsets = tl.arange(0, BLOCK_SIZE) | |
36 | mask = col_offsets < n_cols | |
37 | a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) | |
38 | b_row = tl.load(b + col_offsets, mask=mask, other=0) | |
39 | | |
40 | # tanh approximation form of GELU is computed with: | |
41 | # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3))) | |
42 | sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) | |
43 | a_cubed = a_row * a_row * a_row | |
44 | tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed) | |
45 | tanh_result = tanh(tanh_arg) | |
46 | geglu_a = 0.5 * a_row * (1 + tanh_result) | |
47 | c_row = geglu_a * b_row | |
48 | tl.store(c + col_offsets, c_row, mask=mask) | |
49 | | |
50 | | |
51 | @triton.jit | |
52 | def _geglu_tanh_backward_kernel( | |
53 | dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr | |
54 | ): | |
55 | program_id = tl.program_id(0).to(tl.int64) | |
56 | | |
57 | # locate start index | |
58 | dc += program_id * stride | |
59 | a += program_id * stride | |
60 | b += program_id * stride | |
61 | | |
62 | col_offsets = tl.arange(0, BLOCK_SIZE) | |
63 | mask = col_offsets < n_cols | |
64 | | |
65 | dc_row = tl.load(dc + col_offsets, mask=mask, other=0) | |
66 | a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) | |
67 | b_row = tl.load(b + col_offsets, mask=mask, other=0) | |
68 | | |
69 | # recomputation to save memory | |
70 | sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) | |
71 | a_cubed = a_row * a_row * a_row | |
72 | tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed) | |
73 | tanh_result = tanh(tanh_arg) | |
74 | geglu_a = 0.5 * a_row * (1 + tanh_result) | |
75 | | |
76 | db_row = dc_row * geglu_a | |
77 | | |
78 | # Gradient w.r.t. a can be computed with: | |
79 | # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2))) | |
80 | # where z = sqrt(2/pi) * (a + 0.044715 * a^3) | |
81 | term1 = 0.5 * (1 + tanh_result) | |
82 | tanh_sq = tanh_result * tanh_result | |
83 | term2 = ( | |
84 | 0.5 | |
85 | * a_row | |
86 | * (1 - tanh_sq) | |
87 | * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row)) | |
88 | ) | |
89 | da_row = dc_row * b_row * (term1 + term2) | |
90 | | |
91 | tl.store(a + col_offsets, da_row, mask=mask) | |
92 | tl.store(b + col_offsets, db_row, mask=mask) | |
93 | | |
94 | | |
95 | def geglu_forward(a, b): | |
96 | ori_shape = a.shape | |
97 | | |
98 | n_cols = ori_shape[-1] | |
99 | a = a.view(-1, n_cols) | |
100 | b = b.view(-1, n_cols) | |
101 | c = torch.empty_like(a) | |
102 | n_rows = a.shape[0] | |
103 | | |
104 | BLOCK_SIZE, num_warps = calculate_settings(n_cols) | |
105 | | |
106 | _geglu_tanh_forward_kernel[(n_rows,)]( | |
107 | a, | |
108 | b, | |
109 | c, | |
110 | c.stride(-2), | |
111 | n_cols=n_cols, | |
112 | BLOCK_SIZE=BLOCK_SIZE, | |
113 | num_warps=num_warps, | |
114 | ) | |
115 | return a, b, c.view(*ori_shape) | |
116 | | |
117 | | |
118 | def geglu_backward(a, b, dc): | |
119 | ori_shape = dc.shape | |
120 | n_cols = ori_shape[-1] | |
121 | dc = dc.view(-1, n_cols) | |
122 | n_rows = dc.shape[0] | |
123 | | |
124 | BLOCK_SIZE, num_warps = calculate_settings(n_cols) | |
125 | | |
126 | _geglu_tanh_backward_kernel[(n_rows,)]( | |
127 | dc, | |
128 | a, | |
129 | b, | |
130 | dc.stride(-2), | |
131 | n_cols=n_cols, | |
132 | BLOCK_SIZE=BLOCK_SIZE, | |
133 | num_warps=num_warps, | |
134 | ) | |
135 | | |
136 | return a.view(*ori_shape), b.view(*ori_shape) | |
137 | | |
138 | | |
139 | class LigerGELUMulFunction(torch.autograd.Function): | |
140 | @staticmethod | |
141 | @ensure_contiguous | |
142 | def forward(ctx, a, b): | |
143 | a, b, c = geglu_forward(a, b) | |
144 | ctx.save_for_backward(a, b) | |
145 | return c | |
146 | | |
147 | @staticmethod | |
148 | @ensure_contiguous | |
149 | def backward(ctx, dc): | |
150 | a, b = ctx.saved_tensors | |
151 | a, b = geglu_backward(a, b, dc) | |
152 | return a, b | |
153 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/ops/jsd.py: | |
-------------------------------------------------------------------------------- | |
1 | from typing import Optional | |
2 | | |
3 | import torch | |
4 | import triton | |
5 | import triton.language as tl | |
6 | | |
7 | from liger_kernel.ops.utils import ensure_contiguous | |
8 | | |
9 | | |
10 | @triton.jit | |
11 | def _jsd_kernel( | |
12 | X_ptr, # input in logspace, X = log Q | |
13 | X_stride, | |
14 | Y_ptr, # ground truth in logspace, Y = log P | |
15 | Y_stride, | |
16 | loss_ptr, | |
17 | loss_stride, | |
18 | dX_ptr, | |
19 | dX_stride, | |
20 | label_ptr, | |
21 | beta: tl.constexpr, | |
22 | n_non_ignore: int, | |
23 | ignore_index: tl.constexpr, | |
24 | n_cols, | |
25 | BLOCK_SIZE: tl.constexpr, | |
26 | HAS_LABEL: tl.constexpr, | |
27 | ): | |
28 | # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X) | |
29 | # = sum(P * log P + Q * log Q - 2 * M * log M) / 2 | |
30 | # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2 | |
31 | # grad_x_i = 0.5 * Q * (X - log_M) | |
32 | pid = tl.program_id(0).to(tl.int64) | |
33 | X_ptr += pid * X_stride | |
34 | dX_ptr += pid * dX_stride | |
35 | Y_ptr += pid * Y_stride | |
36 | loss_ptr += pid * loss_stride | |
37 | label_ptr += pid | |
38 | | |
39 | if HAS_LABEL: | |
40 | label = tl.load(label_ptr) | |
41 | if label == ignore_index: | |
42 | for i in range(0, n_cols, BLOCK_SIZE): | |
43 | offsets = i + tl.arange(0, BLOCK_SIZE) | |
44 | tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols) | |
45 | return | |
46 | | |
47 | for i in range(0, n_cols, BLOCK_SIZE): | |
48 | offsets = i + tl.arange(0, BLOCK_SIZE) | |
49 | mask = offsets < n_cols | |
50 | X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) | |
51 | Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) | |
52 | | |
53 | if beta == 0.0: # forward KL | |
54 | Y_prob = tl.exp(Y) | |
55 | loss = Y_prob * (Y - X) | |
56 | dX = -Y_prob | |
57 | elif beta == 1.0: | |
58 | X_prob = tl.exp(X) | |
59 | loss = X_prob * (X - Y) | |
60 | dX = loss + X_prob | |
61 | else: | |
62 | Q = tl.exp(X) | |
63 | P = tl.exp(Y) | |
64 | M = beta * P + (1 - beta) * Q | |
65 | log_M = tl.log(M) | |
66 | | |
67 | loss = beta * P * Y + (1 - beta) * Q * X - M * log_M | |
68 | dX = (1 - beta) * Q * (X - log_M) | |
69 | | |
70 | loss = loss / n_non_ignore | |
71 | dX = dX / n_non_ignore | |
72 | tl.store(loss_ptr + offsets, loss, mask=mask) | |
73 | tl.store(dX_ptr + offsets, dX, mask=mask) | |
74 | | |
75 | | |
76 | MAX_FUSED_SIZE = 65536 | |
77 | | |
78 | | |
79 | def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label): | |
80 | BT, V = _input.shape | |
81 | n_rows = BT | |
82 | BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) | |
83 | # non reduction loss | |
84 | loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device) | |
85 | dX = torch.empty_like(_input) | |
86 | | |
87 | if has_label: | |
88 | n_non_ignore = (shift_labels != ignore_index).sum().item() | |
89 | else: | |
90 | n_non_ignore = BT | |
91 | | |
92 | _jsd_kernel[(n_rows,)]( | |
93 | X_ptr=_input, # input in logspace, X = log Q | |
94 | X_stride=_input.stride(-2), | |
95 | Y_ptr=target, # ground truth in logspace, Y = log P | |
96 | Y_stride=target.stride(-2), | |
97 | loss_ptr=loss, | |
98 | loss_stride=loss.stride(-2), | |
99 | dX_ptr=dX, | |
100 | dX_stride=dX.stride(-2), | |
101 | label_ptr=( | |
102 | shift_labels if has_label else torch.empty(1, device=_input.device) | |
103 | ), # dummy ptr if no label | |
104 | beta=beta, | |
105 | n_non_ignore=n_non_ignore, | |
106 | ignore_index=ignore_index, | |
107 | n_cols=V, | |
108 | BLOCK_SIZE=BLOCK_SIZE, | |
109 | HAS_LABEL=has_label, | |
110 | ) | |
111 | | |
112 | loss = torch.sum(loss) | |
113 | return loss.to(_input.dtype), dX | |
114 | | |
115 | | |
116 | def jsd_backward(dX, grad_output): | |
117 | # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time | |
118 | if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): | |
119 | return dX | |
120 | else: | |
121 | return grad_output * dX | |
122 | | |
123 | | |
124 | class LigerJSDFunction(torch.autograd.Function): | |
125 | r""" | |
126 | This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence. | |
127 | .. math:: | |
128 | JSD(\beta)(P || Q) | |
129 | = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q)) | |
130 | | |
131 | .. note:: | |
132 | As all the other losses in PyTorch, this function expects the first argument, | |
133 | :attr:`_input`, to be the predictions, the output of the student model, in log-space | |
134 | and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space. | |
135 | This differs from the standard mathematical notation :math:`JSD(P || Q)` where | |
136 | :math:`P` denotes the teacher model and :math:`Q` denotes the student model. | |
137 | """ | |
138 | | |
139 | @staticmethod | |
140 | @ensure_contiguous | |
141 | def forward( | |
142 | ctx, | |
143 | _input: torch.Tensor, | |
144 | target: torch.Tensor, | |
145 | shift_labels: Optional[torch.Tensor] = None, | |
146 | beta: float = 0.5, | |
147 | ignore_index: int = -100, | |
148 | ) -> torch.Tensor: | |
149 | """ | |
150 | Args: | |
151 | _input (torch.Tensor): predict values with shape (BT, V) in logspace | |
152 | target (torch.Tensor): ground truth values with shape (BT, V) in logspace | |
153 | shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1]. | |
154 | beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` | |
155 | ignore_index (int): the index to ignore. Default: -100 | |
156 | | |
157 | Returns: | |
158 | loss (torch.Tensor): generalized JSD | |
159 | """ | |
160 | has_label = False | |
161 | if shift_labels is not None: | |
162 | assert shift_labels.shape == ( | |
163 | _input.shape[0], | |
164 | ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" | |
165 | shift_labels = shift_labels.contiguous() | |
166 | has_label = True | |
167 | | |
168 | loss, dX = jsd_forward( | |
169 | _input, target, shift_labels, beta, ignore_index, has_label | |
170 | ) | |
171 | ctx.save_for_backward(dX) | |
172 | return loss | |
173 | | |
174 | @staticmethod | |
175 | @ensure_contiguous | |
176 | def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: | |
177 | (dX,) = ctx.saved_tensors | |
178 | dX = jsd_backward(dX, grad_output) | |
179 | return ( | |
180 | dX, | |
181 | None, | |
182 | None, | |
183 | None, | |
184 | None, | |
185 | ) | |
186 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/ops/swiglu.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import triton | |
3 | import triton.language as tl | |
4 | | |
5 | from liger_kernel.ops.utils import calculate_settings, ensure_contiguous | |
6 | | |
7 | | |
8 | @triton.jit | |
9 | def silu(x): | |
10 | return x * tl.sigmoid(x) | |
11 | | |
12 | | |
13 | @triton.jit | |
14 | def _swiglu_forward_kernel( | |
15 | a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr | |
16 | ): | |
17 | program_id = tl.program_id(0).to(tl.int64) | |
18 | | |
19 | # locate start index | |
20 | a_ptr += program_id * stride | |
21 | b_ptr += program_id * stride | |
22 | c_ptr += program_id * stride | |
23 | | |
24 | col_offsets = tl.arange(0, BLOCK_SIZE) | |
25 | mask = col_offsets < n_cols | |
26 | | |
27 | # sigmoid requires type float32 | |
28 | a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) | |
29 | b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0) | |
30 | c_row = silu(a_row) * b_row | |
31 | tl.store(c_ptr + col_offsets, c_row, mask=mask) | |
32 | | |
33 | | |
34 | @triton.jit | |
35 | def _swiglu_backward_kernel( | |
36 | dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr | |
37 | ): | |
38 | program_id = tl.program_id(0).to(tl.int64) | |
39 | | |
40 | # locate start index | |
41 | dc_ptr += program_id * stride | |
42 | a_ptr += program_id * stride | |
43 | b_ptr += program_id * stride | |
44 | | |
45 | col_offsets = tl.arange(0, BLOCK_SIZE) | |
46 | mask = col_offsets < n_cols | |
47 | | |
48 | dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0) | |
49 | # sigmoid requires type float32 | |
50 | a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) | |
51 | b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0) | |
52 | | |
53 | # recomputation to save memory | |
54 | sig_a = tl.sigmoid(a_row) | |
55 | silu_a = a_row * sig_a | |
56 | db_row = dc_row * silu_a | |
57 | da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row | |
58 | | |
59 | tl.store(a_ptr + col_offsets, da_row, mask=mask) | |
60 | tl.store(b_ptr + col_offsets, db_row, mask=mask) | |
61 | | |
62 | | |
63 | def swiglu_forward(a, b): | |
64 | ori_shape = a.shape | |
65 | | |
66 | n_cols = ori_shape[-1] | |
67 | a = a.view(-1, n_cols) | |
68 | b = b.view(-1, n_cols) | |
69 | c = torch.empty_like(a) | |
70 | n_rows = a.shape[0] | |
71 | | |
72 | BLOCK_SIZE, num_warps = calculate_settings(n_cols) | |
73 | | |
74 | _swiglu_forward_kernel[(n_rows,)]( | |
75 | a, | |
76 | b, | |
77 | c, | |
78 | c.stride(-2), | |
79 | n_cols=n_cols, | |
80 | BLOCK_SIZE=BLOCK_SIZE, | |
81 | num_warps=num_warps, | |
82 | ) | |
83 | return a, b, c.view(*ori_shape) | |
84 | | |
85 | | |
86 | def swiglu_backward(a, b, dc): | |
87 | | |
88 | ori_shape = dc.shape | |
89 | n_cols = ori_shape[-1] | |
90 | dc = dc.view(-1, n_cols) | |
91 | n_rows = dc.shape[0] | |
92 | | |
93 | BLOCK_SIZE, num_warps = calculate_settings(n_cols) | |
94 | | |
95 | _swiglu_backward_kernel[(n_rows,)]( | |
96 | dc, | |
97 | a, | |
98 | b, | |
99 | dc.stride(-2), | |
100 | n_cols=n_cols, | |
101 | BLOCK_SIZE=BLOCK_SIZE, | |
102 | num_warps=num_warps, | |
103 | ) | |
104 | return a.view(*ori_shape), b.view(*ori_shape) | |
105 | | |
106 | | |
107 | class LigerSiLUMulFunction(torch.autograd.Function): | |
108 | @staticmethod | |
109 | @ensure_contiguous | |
110 | def forward(ctx, a, b): | |
111 | a, b, c = swiglu_forward(a, b) | |
112 | ctx.save_for_backward(a, b) | |
113 | return c | |
114 | | |
115 | @staticmethod | |
116 | @ensure_contiguous | |
117 | def backward(ctx, dc): | |
118 | a, b = ctx.saved_tensors | |
119 | a, b = swiglu_backward(a, b, dc) | |
120 | return a, b | |
121 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/ops/utils.py: | |
-------------------------------------------------------------------------------- | |
1 | """ | |
2 | This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. | |
3 | See the original Unsloth repository at https://github.com/unslothai/unsloth. | |
4 | | |
5 | The following line | |
6 | https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23 | |
7 | is based on code from Unsloth, located at: | |
8 | https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 | |
9 | | |
10 | Modifications made by Yanning Chen, 2024. | |
11 | """ | |
12 | | |
13 | import functools | |
14 | import importlib | |
15 | import operator | |
16 | from typing import Callable | |
17 | | |
18 | import torch | |
19 | import triton | |
20 | import triton.language as tl | |
21 | from packaging.version import Version | |
22 | | |
23 | from liger_kernel.utils import infer_device | |
24 | | |
25 | | |
26 | def is_hip() -> bool: | |
27 | return torch.version.hip is not None | |
28 | | |
29 | | |
30 | def ensure_contiguous(fn): | |
31 | @functools.wraps(fn) | |
32 | def wrapper(ctx, *args, **kwargs): | |
33 | def maybe_to_contiguous(x): | |
34 | return x.contiguous() if isinstance(x, torch.Tensor) else x | |
35 | | |
36 | args = [maybe_to_contiguous(arg) for arg in args] | |
37 | kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()} | |
38 | return fn(ctx, *args, **kwargs) | |
39 | | |
40 | return wrapper | |
41 | | |
42 | | |
43 | def calculate_settings(n): | |
44 | # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 | |
45 | | |
46 | MAX_FUSED_SIZE = 65536 | |
47 | BLOCK_SIZE = triton.next_power_of_2(n) | |
48 | if BLOCK_SIZE > MAX_FUSED_SIZE: | |
49 | raise RuntimeError( | |
50 | f"Cannot launch Triton kernel since n = {n} exceeds " | |
51 | f"the recommended Triton blocksize = {MAX_FUSED_SIZE}." | |
52 | ) | |
53 | | |
54 | num_warps = 4 | |
55 | if BLOCK_SIZE >= 32768: | |
56 | num_warps = 32 if not is_hip() else 16 | |
57 | elif BLOCK_SIZE >= 8192: | |
58 | num_warps = 16 | |
59 | elif BLOCK_SIZE >= 2048: | |
60 | num_warps = 8 | |
61 | return BLOCK_SIZE, num_warps | |
62 | | |
63 | | |
64 | def compare_version(package: str, operator: Callable, target: str): | |
65 | try: | |
66 | pkg = importlib.import_module(package) | |
67 | except ImportError: | |
68 | return False | |
69 | pkg_version = Version(pkg.__version__) | |
70 | return operator(pkg_version, Version(target)) | |
71 | | |
72 | | |
73 | def get_amp_custom_fwd_bwd() -> Callable: | |
74 | device = infer_device() | |
75 | if compare_version("torch", operator.ge, "2.4.0"): | |
76 | return ( | |
77 | functools.partial(torch.amp.custom_fwd, device_type=device), | |
78 | functools.partial(torch.amp.custom_bwd, device_type=device), | |
79 | ) | |
80 | return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd | |
81 | | |
82 | | |
83 | amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd() | |
84 | | |
85 | | |
86 | torch_to_triton_dtype = { | |
87 | torch.float32: tl.float32, | |
88 | torch.float16: tl.float16, | |
89 | torch.bfloat16: tl.bfloat16, | |
90 | } | |
91 | | |
92 | | |
93 | @triton.jit | |
94 | def element_mul_kernel( | |
95 | X_ptr, | |
96 | X_stride, | |
97 | grad_output_ptr, | |
98 | n_cols, | |
99 | BLOCK_SIZE: tl.constexpr, | |
100 | ): | |
101 | """ | |
102 | This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. | |
103 | The multiplication is performed in-place on the tensor pointed by X_ptr. | |
104 | | |
105 | Parameters: | |
106 | X_ptr: Pointer to the input tensor. | |
107 | X_stride (int): The stride of the input tensor. | |
108 | grad_output_ptr: Pointer to the gradient output value. | |
109 | n_cols (int): The number of columns in the input tensor. | |
110 | BLOCK_SIZE (int): The block size for Triton operations. | |
111 | """ | |
112 | | |
113 | # Get the program ID and convert it to int64 to avoid overflow | |
114 | program_id = tl.program_id(0).to(tl.int64) | |
115 | | |
116 | # Locate the start index | |
117 | X_ptr += program_id * X_stride | |
118 | | |
119 | # Load the gradient output value | |
120 | grad_output = tl.load(grad_output_ptr) | |
121 | | |
122 | # Perform the element-wise multiplication | |
123 | for i in range(0, n_cols, BLOCK_SIZE): | |
124 | X_offsets = i + tl.arange(0, BLOCK_SIZE) | |
125 | X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) | |
126 | tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) | |
127 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/__init__.py: | |
-------------------------------------------------------------------------------- | |
1 | from liger_kernel.transformers.auto_model import ( # noqa: F401 | |
2 | AutoLigerKernelForCausalLM, | |
3 | ) | |
4 | from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401 | |
5 | from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401 | |
6 | LigerFusedLinearCrossEntropyLoss, | |
7 | ) | |
8 | from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401 | |
9 | from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401 | |
10 | from liger_kernel.transformers.jsd import LigerJSD # noqa: F401 | |
11 | from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401 | |
12 | from liger_kernel.transformers.monkey_patch import ( # noqa: F401 | |
13 | _apply_liger_kernel, | |
14 | _apply_liger_kernel_to_instance, | |
15 | apply_liger_kernel_to_gemma, | |
16 | apply_liger_kernel_to_gemma2, | |
17 | apply_liger_kernel_to_llama, | |
18 | apply_liger_kernel_to_mistral, | |
19 | apply_liger_kernel_to_mixtral, | |
20 | apply_liger_kernel_to_mllama, | |
21 | apply_liger_kernel_to_phi3, | |
22 | apply_liger_kernel_to_qwen2, | |
23 | apply_liger_kernel_to_qwen2_vl, | |
24 | ) | |
25 | from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401 | |
26 | from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401 | |
27 | from liger_kernel.transformers.swiglu import ( # noqa: F401 | |
28 | LigerBlockSparseTop2MLP, | |
29 | LigerPhi3SwiGLUMLP, | |
30 | LigerSwiGLUMLP, | |
31 | ) | |
32 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/auto_model.py: | |
-------------------------------------------------------------------------------- | |
1 | import inspect | |
2 | | |
3 | from transformers import AutoConfig, AutoModelForCausalLM | |
4 | | |
5 | from liger_kernel.transformers.monkey_patch import ( | |
6 | MODEL_TYPE_TO_APPLY_LIGER_FN, | |
7 | _apply_liger_kernel, | |
8 | ) | |
9 | | |
10 | | |
11 | def _get_model_config(model_dir, **model_init_kwargs): | |
12 | config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs) | |
13 | return config | |
14 | | |
15 | | |
16 | class AutoLigerKernelForCausalLM(AutoModelForCausalLM): | |
17 | """ | |
18 | This class is a drop-in replacement for AutoModelForCausalLM that applies the Liger Kernel to the model | |
19 | if applicable. | |
20 | """ | |
21 | | |
22 | @classmethod | |
23 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
24 | model_config = _get_model_config(pretrained_model_name_or_path, **kwargs) | |
25 | | |
26 | # Determine the model type and apply the Liger Kernel if applicable | |
27 | # Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function | |
28 | model_type = model_config.model_type | |
29 | | |
30 | _apply_liger_kernel(model_type, **kwargs) | |
31 | | |
32 | # Filter out kwargs that were passed to the apply_liger_* function, which will cause | |
33 | # model initialization errors otherwise | |
34 | apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type] | |
35 | apply_fn_signature = inspect.signature(apply_fn) | |
36 | | |
37 | applicable_kwargs = { | |
38 | key: value | |
39 | for key, value in kwargs.items() | |
40 | if key not in apply_fn_signature.parameters | |
41 | } | |
42 | | |
43 | return super().from_pretrained( | |
44 | pretrained_model_name_or_path, *model_args, **applicable_kwargs | |
45 | ) | |
46 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/cross_entropy.py: | |
-------------------------------------------------------------------------------- | |
1 | from typing import Optional | |
2 | | |
3 | import torch | |
4 | | |
5 | from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction | |
6 | | |
7 | | |
8 | class LigerCrossEntropyLoss(torch.nn.Module): | |
9 | def __init__( | |
10 | self, | |
11 | ignore_index: int = -100, | |
12 | lse_square_scale: float = 0.0, | |
13 | label_smoothing: float = 0.0, | |
14 | reduction: str = "mean", | |
15 | softcap: Optional[float] = None, | |
16 | return_z_loss: bool = False, | |
17 | ): | |
18 | super().__init__() | |
19 | assert (label_smoothing >= 0) and ( | |
20 | label_smoothing <= 1 | |
21 | ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" | |
22 | assert (label_smoothing >= 0) and ( | |
23 | label_smoothing <= 1 | |
24 | ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" | |
25 | assert reduction in { | |
26 | "mean", | |
27 | "sum", | |
28 | "none", | |
29 | }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}" | |
30 | assert ( | |
31 | softcap is None or softcap > 0 | |
32 | ), f"softcap must greater than 0.0 or None. Got: {softcap}" | |
33 | self.ignore_index = ignore_index | |
34 | self.lse_square_scale = lse_square_scale | |
35 | self.label_smoothing = label_smoothing | |
36 | self.reduction = reduction | |
37 | self.softcap = softcap | |
38 | self.return_z_loss = return_z_loss | |
39 | | |
40 | def forward(self, _input: torch.Tensor, target: torch.Tensor): | |
41 | loss, z_loss = LigerCrossEntropyFunction.apply( | |
42 | _input, | |
43 | target, | |
44 | self.ignore_index, | |
45 | self.lse_square_scale, | |
46 | self.label_smoothing, | |
47 | self.reduction, | |
48 | self.softcap, | |
49 | self.return_z_loss, | |
50 | ) | |
51 | if not self.return_z_loss: | |
52 | return loss | |
53 | return loss, z_loss | |
54 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/experimental/embedding.py: | |
-------------------------------------------------------------------------------- | |
1 | from typing import Optional | |
2 | | |
3 | import torch | |
4 | import torch.nn as nn | |
5 | | |
6 | from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction | |
7 | | |
8 | | |
9 | class LigerEmbedding(nn.Module): | |
10 | def __init__( | |
11 | self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None | |
12 | ): | |
13 | super().__init__() | |
14 | self.num_embeddings = num_embeddings | |
15 | self.embedding_dim = embedding_dim | |
16 | self.padding_idx = padding_idx | |
17 | self.weight = nn.Parameter(torch.randn(num_embeddings, embedding_dim)) | |
18 | | |
19 | if padding_idx is not None: | |
20 | with torch.no_grad(): | |
21 | self.weight[padding_idx].fill_(0) | |
22 | | |
23 | def forward(self, indices): | |
24 | embedded = LigerEmbeddingFunction.apply(self.weight, indices) | |
25 | if self.padding_idx is not None: | |
26 | embedded = embedded.clone() | |
27 | embedded[indices == self.padding_idx] = 0 | |
28 | return embedded | |
29 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/functional.py: | |
-------------------------------------------------------------------------------- | |
1 | from typing import Optional | |
2 | | |
3 | from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction | |
4 | from liger_kernel.ops.fused_linear_cross_entropy import ( | |
5 | LigerFusedLinearCrossEntropyFunction, | |
6 | ) | |
7 | from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction | |
8 | from liger_kernel.ops.geglu import LigerGELUMulFunction | |
9 | from liger_kernel.ops.group_norm import LigerGroupNormFunction | |
10 | from liger_kernel.ops.jsd import LigerJSDFunction | |
11 | from liger_kernel.ops.kl_div import LigerKLDivLossFunction | |
12 | from liger_kernel.ops.layer_norm import LigerLayerNormFunction | |
13 | from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction | |
14 | from liger_kernel.ops.rms_norm import LigerRMSNormFunction | |
15 | from liger_kernel.ops.rope import LigerRopeFunction | |
16 | from liger_kernel.ops.swiglu import LigerSiLUMulFunction | |
17 | | |
18 | | |
19 | # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html | |
20 | # `weight` and `size_average` are placeholders and not implemented yet | |
21 | def liger_cross_entropy( | |
22 | input, | |
23 | target, | |
24 | weight=None, | |
25 | size_average=None, | |
26 | ignore_index: int = -100, | |
27 | reduce=None, | |
28 | reduction: str = "mean", | |
29 | label_smoothing: float = 0.0, | |
30 | lse_square_scale: float = 0.0, | |
31 | softcap: Optional[float] = None, | |
32 | return_z_loss: bool = False, | |
33 | ): | |
34 | loss, z_loss = LigerCrossEntropyFunction.apply( | |
35 | input, | |
36 | target, | |
37 | ignore_index, | |
38 | lse_square_scale, | |
39 | label_smoothing, | |
40 | reduction, | |
41 | softcap, | |
42 | return_z_loss, | |
43 | ) | |
44 | if not return_z_loss: | |
45 | return loss | |
46 | return loss, z_loss | |
47 | | |
48 | | |
49 | def liger_fused_linear_cross_entropy( | |
50 | input, | |
51 | weight, | |
52 | target, | |
53 | bias=None, | |
54 | ignore_index: int = -100, | |
55 | lse_square_scale: float = 0.0, | |
56 | label_smoothing: float = 0.0, | |
57 | reduction: str = "mean", | |
58 | softcap: Optional[float] = None, | |
59 | ): | |
60 | return LigerFusedLinearCrossEntropyFunction.apply( | |
61 | input, | |
62 | weight, | |
63 | target, | |
64 | bias, | |
65 | ignore_index, | |
66 | lse_square_scale, | |
67 | label_smoothing, | |
68 | reduction, | |
69 | softcap, | |
70 | ) | |
71 | | |
72 | | |
73 | def liger_fused_linear_jsd( | |
74 | student_input, | |
75 | student_weight, | |
76 | teacher_input, | |
77 | teacher_weight, | |
78 | shift_labels=None, | |
79 | jsd_beta: float = 0.5, | |
80 | ignore_index: int = -100, | |
81 | temperature: float = 1.0, | |
82 | ): | |
83 | return LigerFusedLinearJSDFunction.apply( | |
84 | student_input, | |
85 | student_weight, | |
86 | teacher_input, | |
87 | teacher_weight, | |
88 | shift_labels, | |
89 | jsd_beta, | |
90 | ignore_index, | |
91 | temperature, | |
92 | ) | |
93 | | |
94 | | |
95 | def liger_geglu(a, b): | |
96 | return LigerGELUMulFunction.apply(a, b) | |
97 | | |
98 | | |
99 | def liger_group_norm( | |
100 | X, | |
101 | affine_scaling_weight, | |
102 | affine_shifting_bias, | |
103 | num_channels, | |
104 | num_groups, | |
105 | eps, | |
106 | ): | |
107 | return LigerGroupNormFunction.apply( | |
108 | X, | |
109 | affine_scaling_weight, | |
110 | affine_shifting_bias, | |
111 | num_channels, | |
112 | num_groups, | |
113 | eps, | |
114 | ) | |
115 | | |
116 | | |
117 | def liger_jsd( | |
118 | input, | |
119 | target, | |
120 | shift_labels=None, | |
121 | beta: float = 0.5, | |
122 | ignore_index: int = -100, | |
123 | ): | |
124 | return LigerJSDFunction.apply( | |
125 | input, | |
126 | target, | |
127 | shift_labels, | |
128 | beta, | |
129 | ignore_index, | |
130 | ) | |
131 | | |
132 | | |
133 | # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div | |
134 | # `size_average` and `mean` are being deprecated in torch API and are placeholders here | |
135 | def liger_kl_div( | |
136 | input, | |
137 | target, | |
138 | size_average: bool = True, | |
139 | reduce: bool = True, | |
140 | reduction: str = "mean", | |
141 | log_target: bool = False, | |
142 | eps: float = 1e-10, | |
143 | ): | |
144 | # Note: the default reduction in torch is `mean`, but being `batchmean` in Liger | |
145 | return LigerKLDivLossFunction.apply( | |
146 | input, | |
147 | target, | |
148 | reduction, | |
149 | log_target, | |
150 | eps, | |
151 | ) | |
152 | | |
153 | | |
154 | def liger_layer_norm(X, W, B, eps): | |
155 | return LigerLayerNormFunction.apply(X, W, B, eps) | |
156 | | |
157 | | |
158 | def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1): | |
159 | return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim) | |
160 | | |
161 | | |
162 | def liger_rms_norm( | |
163 | X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True | |
164 | ): | |
165 | return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place) | |
166 | | |
167 | | |
168 | def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | |
169 | return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) | |
170 | | |
171 | | |
172 | def liger_swiglu(a, b): | |
173 | return LigerSiLUMulFunction.apply(a, b) | |
174 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/fused_linear_cross_entropy.py: | |
-------------------------------------------------------------------------------- | |
1 | from typing import Optional | |
2 | | |
3 | import torch | |
4 | | |
5 | from liger_kernel.ops.fused_linear_cross_entropy import ( | |
6 | LigerFusedLinearCrossEntropyFunction, | |
7 | ) | |
8 | | |
9 | | |
10 | class LigerFusedLinearCrossEntropyLoss(torch.nn.Module): | |
11 | def __init__( | |
12 | self, | |
13 | ignore_index: int = -100, | |
14 | lse_square_scale: float = 0.0, | |
15 | label_smoothing: float = 0.0, | |
16 | reduction: str = "mean", | |
17 | softcap: Optional[float] = None, | |
18 | ): | |
19 | super().__init__() | |
20 | assert (label_smoothing >= 0) and ( | |
21 | label_smoothing <= 1 | |
22 | ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" | |
23 | assert reduction in { | |
24 | "mean", | |
25 | "sum", | |
26 | "none", | |
27 | }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}" | |
28 | assert ( | |
29 | softcap is None or softcap > 0 | |
30 | ), f"softcap must greater than 0.0 or None. Got: {softcap}" | |
31 | self.ignore_index = ignore_index | |
32 | self.lse_square_scale = lse_square_scale | |
33 | self.label_smoothing = label_smoothing | |
34 | self.reduction = reduction | |
35 | self.softcap = softcap | |
36 | | |
37 | def forward(self, lin_weight, _input, target, bias=None): | |
38 | return LigerFusedLinearCrossEntropyFunction.apply( | |
39 | _input, | |
40 | lin_weight, | |
41 | target, | |
42 | bias, | |
43 | self.ignore_index, | |
44 | self.lse_square_scale, | |
45 | self.label_smoothing, | |
46 | self.reduction, | |
47 | self.softcap, | |
48 | ) | |
49 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/fused_linear_jsd.py: | |
-------------------------------------------------------------------------------- | |
1 | from typing import Optional | |
2 | | |
3 | import torch | |
4 | | |
5 | from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction | |
6 | | |
7 | | |
8 | class LigerFusedLinearJSD(torch.nn.Module): | |
9 | r"""Fusing the last linear layer with generalized JSD | |
10 | | |
11 | Handle the forward and backward pass of the final linear layer via JSD by avoiding | |
12 | the materialization of the large logits tensor. | |
13 | | |
14 | Args: | |
15 | jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` | |
16 | ignore_index (int): The index to ignore in the target. Default: `-100` | |
17 | temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0` | |
18 | | |
19 | Shape: | |
20 | - student_input: :math:`(BT, H)`, where B is batch size, T is sequence length, H is hidden dimension. | |
21 | - student_weight: :math:`(V, H)`, where V is vocab size. | |
22 | - teacher_input: :math:`(BT, H')`, where H' is hidden dimension of the teacher model. | |
23 | - teacher_weight: :math:`(V, H')`, where hidden size H and H' can be different. | |
24 | - shift_labels: :math:`(BT,)` | |
25 | - Output: a scalar. | |
26 | | |
27 | Examples: | |
28 | ```python | |
29 | >>> (B, T, H_s, H_t, V) = (2, 2, 3, 5, 10) | |
30 | >>> fused_jsd = LigerFusedLinearJSD(jsd_beta=0.1, temperature=2.0) | |
31 | >>> # generate inputs and weights | |
32 | >>> student_input = torch.rand(B * T, H_s, device="cuda", requires_grad=True) | |
33 | >>> student_lin = torch.nn.Linear(H_s, V, bias=False, device="cuda") | |
34 | >>> # teacher input doesn't require grad, hidden_dim can be different from student's | |
35 | >>> teacher_input = torch.rand(B * T, H_t, device="cuda") | |
36 | >>> teacher_lin = torch.nn.Linear(H_t, V, bias=False, device="cuda") | |
37 | >>> output = fused_jsd(student_input, student_lin.weight, teacher_input, teacher_lin.weight) | |
38 | >>> output.backward() | |
39 | >>> | |
40 | >>> # Example with labels for supervised fine-tuning (SFT) context: | |
41 | >>> | |
42 | >>> # Assume hidden_states, lm_heads and corresponding labels are given | |
43 | >>> student_lm_head = torch.nn.Linear(H_s, V, bias=False) | |
44 | >>> student_hidden_states = torch.randn(B * T, H_s, requires_grad=True).log_softmax(dim=-1) | |
45 | >>> teacher_lm_head = torch.nn.Linear(H_t, V, bias=False) | |
46 | >>> teacher_hidden_states = torch.randn(B * T, H_t).log_softmax(dim=-1) | |
47 | >>> labels = torch.randint(0, V, (B * T,), torch.long) | |
48 | >>> | |
49 | >>> # Shift so that tokens < n predict n | |
50 | >>> shift_student_hidden_states = student_hidden_states[..., :-1, :].contiguous() | |
51 | >>> shift_teacher_hidden_states = teacher_hidden_states[..., :-1, :].contiguous() | |
52 | >>> shift_labels = labels[..., 1:].contiguous() | |
53 | >>> | |
54 | >>> # Flatten tokens | |
55 | >>> shift_student_hidden_states = shift_student_hidden_states.view(-1, V) | |
56 | >>> shift_teacher_hidden_states = shift_teacher_hidden_states.view(-1, V) | |
57 | >>> shift_labels = shift_labels.view(-1) | |
58 | >>> | |
59 | >>> # Calculate loss | |
60 | >>> loss_fct = LigerJSD(beta=0.1) | |
61 | >>> loss = loss_fct( | |
62 | >>> shift_studetn_hidden_states, | |
63 | >>> student_lm_head.weight, | |
64 | >>> shift_teacher_hidden_states, | |
65 | >>> teacher_lm_head.weight, | |
66 | >>> shift_labels | |
67 | >>> ) | |
68 | ``` | |
69 | """ | |
70 | | |
71 | def __init__(self, jsd_beta=0.5, ignore_index=-100, temperature=1.0): | |
72 | super().__init__() | |
73 | assert temperature != 0, "temperature cannot be 0." | |
74 | self.jsd_beta = jsd_beta | |
75 | self.temperature = temperature | |
76 | self.ignore_index = ignore_index | |
77 | | |
78 | def forward( | |
79 | self, | |
80 | student_input: torch.Tensor, | |
81 | student_weight: torch.Tensor, | |
82 | teacher_input: torch.Tensor, | |
83 | teacher_weight: torch.Tensor, | |
84 | shift_labels: Optional[torch.LongTensor], | |
85 | ): | |
86 | return LigerFusedLinearJSDFunction.apply( | |
87 | student_input, | |
88 | student_weight, | |
89 | teacher_input, | |
90 | teacher_weight, | |
91 | shift_labels, | |
92 | self.jsd_beta, | |
93 | self.ignore_index, | |
94 | self.temperature, | |
95 | ) | |
96 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/geglu.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch.nn as nn | |
2 | | |
3 | from liger_kernel.ops.geglu import LigerGELUMulFunction | |
4 | | |
5 | | |
6 | class LigerGEGLUMLP(nn.Module): | |
7 | def __init__(self, config): | |
8 | super().__init__() | |
9 | self.config = config | |
10 | self.hidden_size = config.hidden_size | |
11 | self.intermediate_size = config.intermediate_size | |
12 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
13 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
14 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | |
15 | # TODO: support exact GELU | |
16 | # Right now Gemma 1, 1.1 and 2 models are all using `gelu_pytorch_tanh` | |
17 | # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175 | |
18 | # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/activations.py#L46 | |
19 | # So we can safely assume we use tanh approximation form all the time | |
20 | | |
21 | def forward(self, x): | |
22 | | |
23 | return self.down_proj( | |
24 | LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x)) | |
25 | ) | |
26 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/group_norm.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import torch.nn as nn | |
3 | | |
4 | from liger_kernel.ops.group_norm import LigerGroupNormFunction | |
5 | | |
6 | | |
7 | class LigerGroupNorm(nn.Module): | |
8 | def __init__(self, num_channels, num_groups, eps=1e-6, bias=False, init_fn="ones"): | |
9 | """ | |
10 | A Group Normalization layer. | |
11 | Args: | |
12 | num_channels (int): Number of channels in the input tensor. | |
13 | num_groups (int): Number of groups to divide the channels into. | |
14 | eps (float, optional): A value added to the denominator for numerical stability. Default: 1e-6. | |
15 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False``. | |
16 | init_fn (str, optional): Initialization function for the learnable parameters. Default: "ones". | |
17 | """ | |
18 | super().__init__() | |
19 | assert init_fn in [ | |
20 | "ones", | |
21 | "zeros", | |
22 | ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}" | |
23 | | |
24 | assert ( | |
25 | num_channels % num_groups == 0 | |
26 | ), f"Number of channels {num_channels} must be divisible by num_groups {num_groups}" | |
27 | self.num_channels = num_channels | |
28 | self.num_groups = num_groups | |
29 | self.eps = eps | |
30 | self.weight = nn.Parameter( | |
31 | torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels) | |
32 | ) | |
33 | self.bias = nn.Parameter( | |
34 | torch.randn(num_channels) if bias else torch.zeros(num_channels) | |
35 | ) | |
36 | self.variance_epsilon = eps | |
37 | | |
38 | def forward(self, hidden_states): | |
39 | # hidden_states: (batch_size, num_channels, *) | |
40 | assert ( | |
41 | hidden_states.dim() >= 3 | |
42 | ), f"Input must have atleast 3 dimensions, got {hidden_states.dim()}" | |
43 | assert ( | |
44 | hidden_states.size(1) == self.num_channels | |
45 | ), f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}" | |
46 | return LigerGroupNormFunction.apply( | |
47 | hidden_states, | |
48 | self.weight, | |
49 | self.bias, | |
50 | self.num_channels, | |
51 | self.num_groups, | |
52 | self.variance_epsilon, | |
53 | ) | |
54 | | |
55 | def extra_repr(self): | |
56 | return f"{self.hidden_size}, num_channels={self.num_channels}, num_groups={self.num_groups}, eps={self.eps}" | |
57 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/jsd.py: | |
-------------------------------------------------------------------------------- | |
1 | from typing import Optional | |
2 | | |
3 | import torch | |
4 | | |
5 | from liger_kernel.ops.jsd import LigerJSDFunction | |
6 | | |
7 | | |
8 | class LigerJSD(torch.nn.Module): | |
9 | r"""The generalized Jensen-Shannon Divergence. | |
10 | .. math:: | |
11 | JSD(\beta)(P || Q) | |
12 | = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q)) | |
13 | .. note:: | |
14 | As all the other losses in PyTorch, this function expects the first argument, | |
15 | :attr:`log_q`, to be the predictions, the output of the student model in log-space, | |
16 | and the second, :attr:`log_p`, to be the observations, the output of the teacher model in log-space. | |
17 | This differs from the standard mathematical notation :math:`JSD(P || Q)` where | |
18 | :math:`P` denotes the teacher model and :math:`Q` denotes the student model. | |
19 | | |
20 | Args: | |
21 | beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` | |
22 | ignore_index (int): The index to ignore in the target. Default: `-100` | |
23 | | |
24 | Shape: | |
25 | - Input: :math:`(BT, V)`, where B is batch size, T is sequence length, V is vocab size. | |
26 | - Target: :math:`(BT, V)`, same shape as the input. | |
27 | - shift_labels (Optional): :math:`(BT,)` | |
28 | - Output: a scalar. | |
29 | | |
30 | Examples: | |
31 | ```python | |
32 | >>> (B, T, V) = (2, 2, 5) | |
33 | >>> jsd = LigerJSD(beta=0.1) | |
34 | >>> # input should be a distribution in the log space | |
35 | >>> input = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1) | |
36 | >>> target = torch.randn(B * T, V).log_softmax(dim=-1) | |
37 | >>> output = jsd(input, target) | |
38 | >>> | |
39 | >>> # Example with labels for supervised fine-tuning (SFT) context | |
40 | >>> # Assume logits and corresponding labels are given | |
41 | >>> student_logits = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1) | |
42 | >>> teacher_logits = torch.randn(B * T, V).log_softmax(dim=-1) | |
43 | >>> labels = torch.randint(0, V, (B * T,), torch.long) | |
44 | >>> # Shift so that tokens < n predict n | |
45 | >>> shift_student_logits = student_logits[..., :-1, :].contiguous() | |
46 | >>> shift_teacher_logits = teacher_logits[..., :-1, :].contiguous() | |
47 | >>> shift_labels = labels[..., 1:].contiguous() | |
48 | >>> # Flatten tokens | |
49 | >>> shift_student_logits = shift_student_logits.view(-1, V) | |
50 | >>> shift_teacher_logits = shift_teacher_logits.view(-1, V) | |
51 | >>> shift_labels = shift_labels.view(-1) | |
52 | >>> # Calculate loss | |
53 | >>> loss_fct = LigerJSD(beta=0.1) | |
54 | >>> loss = loss_fct(shift_studetn_logits, shift_teacher_logits, shift_labels) | |
55 | | |
56 | ``` | |
57 | """ | |
58 | | |
59 | def __init__(self, beta: float = 0.5, ignore_index: int = -100): | |
60 | super().__init__() | |
61 | self.beta = beta | |
62 | self.ignore_index = ignore_index | |
63 | | |
64 | def forward( | |
65 | self, | |
66 | log_q: torch.Tensor, | |
67 | log_p: torch.Tensor, | |
68 | shift_labels: Optional[torch.LongTensor] = None, | |
69 | ): | |
70 | return LigerJSDFunction.apply( | |
71 | log_q, log_p, shift_labels, self.beta, self.ignore_index | |
72 | ) | |
73 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/kl_div.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch.nn as nn | |
2 | | |
3 | from liger_kernel.ops.kl_div import LigerKLDivLossFunction | |
4 | | |
5 | | |
6 | class LigerKLDIVLoss(nn.KLDivLoss): | |
7 | def __init__(self, eps: float = 1e-10, *args, **kwargs): | |
8 | super(LigerKLDIVLoss, self).__init__(*args, **kwargs) | |
9 | self.eps = eps | |
10 | | |
11 | def forward(self, y_pred, y_true): | |
12 | return LigerKLDivLossFunction.apply( | |
13 | y_pred, y_true, self.reduction, self.log_target, self.eps | |
14 | ) | |
15 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/layer_norm.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import torch.nn as nn | |
3 | | |
4 | from liger_kernel.ops.layer_norm import LigerLayerNormFunction | |
5 | | |
6 | | |
7 | class LigerLayerNorm(nn.Module): | |
8 | def __init__(self, hidden_size, eps=1e-6, bias=False, init_fn="ones"): | |
9 | super().__init__() | |
10 | assert init_fn in [ | |
11 | "ones", | |
12 | "zeros", | |
13 | ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}" | |
14 | self.hidden_size = hidden_size | |
15 | self.eps = eps | |
16 | self.weight = nn.Parameter( | |
17 | torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size) | |
18 | ) | |
19 | self.bias = nn.Parameter( | |
20 | torch.randn(hidden_size) if bias else torch.zeros(hidden_size) | |
21 | ) | |
22 | self.variance_epsilon = eps | |
23 | | |
24 | def forward(self, hidden_states): | |
25 | return LigerLayerNormFunction.apply( | |
26 | hidden_states, self.weight, self.bias, self.variance_epsilon | |
27 | ) | |
28 | | |
29 | def extra_repr(self): | |
30 | return f"{self.hidden_size}, eps={self.eps}" | |
31 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/model/__init__.py: | |
-------------------------------------------------------------------------------- | |
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/src/liger_kernel/transformers/model/__init__.py | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/model/mistral.py: | |
-------------------------------------------------------------------------------- | |
1 | from typing import List, Optional, Tuple, Union | |
2 | | |
3 | import torch | |
4 | from torch.nn import CrossEntropyLoss | |
5 | from transformers.cache_utils import Cache | |
6 | from transformers.modeling_outputs import CausalLMOutputWithPast | |
7 | from transformers.models.mistral.modeling_mistral import ( | |
8 | _CONFIG_FOR_DOC, | |
9 | MISTRAL_INPUTS_DOCSTRING, | |
10 | ) | |
11 | from transformers.utils import ( | |
12 | add_start_docstrings_to_model_forward, | |
13 | replace_return_docstrings, | |
14 | ) | |
15 | | |
16 | from liger_kernel.transformers.fused_linear_cross_entropy import ( | |
17 | LigerFusedLinearCrossEntropyLoss, | |
18 | ) | |
19 | | |
20 | | |
21 | @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) | |
22 | @replace_return_docstrings( | |
23 | output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC | |
24 | ) | |
25 | def lce_forward( | |
26 | self, | |
27 | input_ids: torch.LongTensor = None, | |
28 | attention_mask: Optional[torch.Tensor] = None, | |
29 | position_ids: Optional[torch.LongTensor] = None, | |
30 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
31 | inputs_embeds: Optional[torch.FloatTensor] = None, | |
32 | labels: Optional[torch.LongTensor] = None, | |
33 | use_cache: Optional[bool] = None, | |
34 | output_attentions: Optional[bool] = None, | |
35 | output_hidden_states: Optional[bool] = None, | |
36 | return_dict: Optional[bool] = None, | |
37 | cache_position: Optional[torch.LongTensor] = None, | |
38 | ) -> Union[Tuple, CausalLMOutputWithPast]: | |
39 | r""" | |
40 | Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy | |
41 | | |
42 | | |
43 | Args: | |
44 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
45 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
46 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
47 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
48 | | |
49 | Returns: | |
50 | | |
51 | Example: | |
52 | | |
53 | ```python | |
54 | >>> from transformers import AutoTokenizer, MistralForCausalLM | |
55 | | |
56 | >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") | |
57 | >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") | |
58 | | |
59 | >>> prompt = "Hey, are you conscious? Can you talk to me?" | |
60 | >>> inputs = tokenizer(prompt, return_tensors="pt") | |
61 | | |
62 | >>> # Generate | |
63 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) | |
64 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
65 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." | |
66 | ```""" | |
67 | | |
68 | output_attentions = ( | |
69 | output_attentions | |
70 | if output_attentions is not None | |
71 | else self.config.output_attentions | |
72 | ) | |
73 | output_hidden_states = ( | |
74 | output_hidden_states | |
75 | if output_hidden_states is not None | |
76 | else self.config.output_hidden_states | |
77 | ) | |
78 | return_dict = ( | |
79 | return_dict if return_dict is not None else self.config.use_return_dict | |
80 | ) | |
81 | | |
82 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
83 | outputs = self.model( | |
84 | input_ids=input_ids, | |
85 | attention_mask=attention_mask, | |
86 | position_ids=position_ids, | |
87 | past_key_values=past_key_values, | |
88 | inputs_embeds=inputs_embeds, | |
89 | use_cache=use_cache, | |
90 | output_attentions=output_attentions, | |
91 | output_hidden_states=output_hidden_states, | |
92 | return_dict=return_dict, | |
93 | cache_position=cache_position, | |
94 | ) | |
95 | | |
96 | hidden_states = outputs[0] | |
97 | | |
98 | loss = None | |
99 | logits = None | |
100 | | |
101 | if self.training and (labels is not None): | |
102 | shift_hidden_states = hidden_states[..., :-1, :].contiguous() | |
103 | shift_labels = labels[..., 1:].contiguous() | |
104 | | |
105 | # flatten tokens | |
106 | shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) | |
107 | shift_labels = shift_labels.view(-1) | |
108 | | |
109 | lce = LigerFusedLinearCrossEntropyLoss() | |
110 | loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) | |
111 | | |
112 | else: | |
113 | logits = self.lm_head(hidden_states) | |
114 | if labels is not None: | |
115 | # Upcast to float if we need to compute the loss to avoid potential precision issues | |
116 | logits = logits.float() | |
117 | # Shift so that tokens < n predict n | |
118 | shift_logits = logits[..., :-1, :].contiguous() | |
119 | shift_labels = labels[..., 1:].contiguous() | |
120 | # Flatten the tokens | |
121 | shift_logits = shift_logits.view(-1, self.config.vocab_size) | |
122 | shift_labels = shift_labels.view(-1) | |
123 | # Ensure tensors are on the same device | |
124 | shift_labels = shift_labels.to(shift_logits.device) | |
125 | loss_fct = CrossEntropyLoss() | |
126 | loss = loss_fct(shift_logits, shift_labels) | |
127 | | |
128 | if not return_dict: | |
129 | output = (logits,) + outputs[1:] | |
130 | return (loss,) + output if loss is not None else output | |
131 | | |
132 | return CausalLMOutputWithPast( | |
133 | loss=loss, | |
134 | logits=logits, | |
135 | past_key_values=outputs.past_key_values, | |
136 | hidden_states=outputs.hidden_states, | |
137 | attentions=outputs.attentions, | |
138 | ) | |
139 | | |
140 | | |
141 | # Note: Grad Acc is not fixed in mistral at transformer 4.46.1 | |
142 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/qwen2vl_mrope.py: | |
-------------------------------------------------------------------------------- | |
1 | from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction | |
2 | | |
3 | | |
4 | def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): | |
5 | """ | |
6 | Applies Multimodal Rotary Positional Embedding (M-RoPE) operation to query and key states. | |
7 | | |
8 | Args: | |
9 | q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim). | |
10 | k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim). | |
11 | cos (torch.Tensor): The cosine tensor of shape (3, bsz, seq_len, head_dim). | |
12 | sin (torch.Tensor): The sine tensor of shape (3, bsz, seq_len, head_dim). | |
13 | mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation. | |
14 | unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1. | |
15 | | |
16 | Returns: | |
17 | Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the M-RoPE operation. | |
18 | """ | |
19 | | |
20 | return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim) | |
21 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/rms_norm.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | import torch.nn as nn | |
3 | | |
4 | from liger_kernel.ops.rms_norm import LigerRMSNormFunction | |
5 | | |
6 | | |
7 | class LigerRMSNorm(nn.Module): | |
8 | def __init__( | |
9 | self, | |
10 | hidden_size, | |
11 | eps=1e-6, | |
12 | offset=0.0, | |
13 | casting_mode="llama", | |
14 | init_fn="ones", | |
15 | in_place=True, | |
16 | ): | |
17 | super().__init__() | |
18 | assert init_fn in [ | |
19 | "ones", | |
20 | "zeros", | |
21 | ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}" | |
22 | self.weight = nn.Parameter( | |
23 | torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size) | |
24 | ) | |
25 | self.variance_epsilon, self.offset, self.casting_mode, self.in_place = ( | |
26 | eps, | |
27 | offset, | |
28 | casting_mode, | |
29 | in_place, | |
30 | ) | |
31 | | |
32 | def forward(self, hidden_states): | |
33 | return LigerRMSNormFunction.apply( | |
34 | hidden_states, | |
35 | self.weight, | |
36 | self.variance_epsilon, | |
37 | self.offset, | |
38 | self.casting_mode, | |
39 | self.in_place, | |
40 | ) | |
41 | | |
42 | def extra_repr(self): | |
43 | return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}" | |
44 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/rope.py: | |
-------------------------------------------------------------------------------- | |
1 | from liger_kernel.ops.rope import LigerRopeFunction | |
2 | | |
3 | | |
4 | def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | |
5 | """ | |
6 | Applies Rotary Positional Embedding (RoPE) operation to query and key states. | |
7 | | |
8 | Args: | |
9 | q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim). | |
10 | k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim). | |
11 | cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim). | |
12 | sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim). | |
13 | position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None. | |
14 | unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1. | |
15 | | |
16 | Returns: | |
17 | Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the RoPE operation. | |
18 | """ | |
19 | | |
20 | return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) | |
21 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/swiglu.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch.nn as nn | |
2 | | |
3 | from liger_kernel.ops.swiglu import LigerSiLUMulFunction | |
4 | | |
5 | | |
6 | class LigerSwiGLUMLP(nn.Module): | |
7 | def __init__(self, config): | |
8 | super().__init__() | |
9 | self.config = config | |
10 | self.hidden_size = config.hidden_size | |
11 | self.intermediate_size = config.intermediate_size | |
12 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
13 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
14 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | |
15 | if config.hidden_act not in ["silu", "swish"]: | |
16 | raise ValueError(f"Activation function {config.hidden_act} not supported.") | |
17 | | |
18 | def forward(self, x): | |
19 | | |
20 | return self.down_proj( | |
21 | LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)) | |
22 | ) | |
23 | | |
24 | | |
25 | class LigerBlockSparseTop2MLP(nn.Module): | |
26 | def __init__(self, config): | |
27 | super().__init__() | |
28 | self.ffn_dim = config.intermediate_size | |
29 | self.hidden_dim = config.hidden_size | |
30 | | |
31 | self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) | |
32 | self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) | |
33 | self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) | |
34 | | |
35 | if config.hidden_act not in ["silu", "swish"]: | |
36 | raise ValueError(f"Activation function {config.hidden_act} not supported.") | |
37 | | |
38 | def forward(self, x): | |
39 | | |
40 | return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x))) | |
41 | | |
42 | | |
43 | class LigerPhi3SwiGLUMLP(nn.Module): | |
44 | """ | |
45 | Patch Phi3MLP to use LigerSiLUMulFunction | |
46 | https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/phi3/modeling_phi3.py#L241 | |
47 | """ | |
48 | | |
49 | def __init__(self, config): | |
50 | super().__init__() | |
51 | self.config = config | |
52 | self.hidden_size = config.hidden_size | |
53 | self.intermediate_size = config.intermediate_size | |
54 | self.gate_up_proj = nn.Linear( | |
55 | self.hidden_size, 2 * self.intermediate_size, bias=False | |
56 | ) | |
57 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | |
58 | if config.hidden_act not in ["silu", "swish"]: | |
59 | raise ValueError(f"Activation function {config.hidden_act} not supported.") | |
60 | | |
61 | def forward(self, x): | |
62 | up_states = self.gate_up_proj(x) | |
63 | gate, up_states = up_states.chunk(2, dim=-1) | |
64 | return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states)) | |
65 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/trainer/__init__.py: | |
-------------------------------------------------------------------------------- | |
1 | try: | |
2 | from liger_kernel.transformers.trainer.orpo_trainer import ( # noqa: F401 | |
3 | LigerORPOTrainer, | |
4 | ) | |
5 | except ImportError: | |
6 | raise ImportError("Please `pip install trl` to use LigerORPOTrainer") | |
7 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/trainer/orpo_trainer.py: | |
-------------------------------------------------------------------------------- | |
1 | from typing import Any, Callable, Dict, List, Literal, Tuple, Union | |
2 | | |
3 | import torch | |
4 | import torch.nn as nn | |
5 | from torch.distributed.fsdp import FullyShardedDataParallel | |
6 | from trl.trainer import ORPOTrainer | |
7 | | |
8 | from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss | |
9 | | |
10 | | |
11 | class _FSDPForwardRedirection: | |
12 | """ | |
13 | Modified based on | |
14 | https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648 | |
15 | Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and | |
16 | post-forward can be properly executed around the method call. | |
17 | This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only | |
18 | the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving | |
19 | GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`) | |
20 | will not work because the first `nn.Emebedding` layer is not independently wrapped as a FSDP module (because of | |
21 | the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather | |
22 | its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just | |
23 | the `lm_head` part of a model, we need this trick too to properly get its params all-gathered. | |
24 | """ | |
25 | | |
26 | def __call__( | |
27 | self, | |
28 | wrapper_module: FullyShardedDataParallel, | |
29 | method: Callable, | |
30 | *args: Any, | |
31 | **kwargs: Any, | |
32 | ): | |
33 | """Reroutes a method call through the `wrapper_module`'s `forward` method. | |
34 | Args: | |
35 | wrapper_module: The module that has `original_module` wrapped. | |
36 | original_module: The module that was wrapped inside `wrapper_module`. | |
37 | method_name: The name of the method that should be called on the `original_module` after inputs get | |
38 | redirected through the `wrapper_module`'s `forward` method. | |
39 | *args: The positional arguments to the method `method_name`. They will get passed to a patched | |
40 | `forward` method instead. | |
41 | **kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched | |
42 | `forward` method instead. | |
43 | """ | |
44 | assert isinstance(wrapper_module, FullyShardedDataParallel) | |
45 | original_module = wrapper_module._fsdp_wrapped_module | |
46 | original_forward = original_module.forward | |
47 | | |
48 | def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any: | |
49 | # Unpatch ourselves immediately before calling the method `method_name` | |
50 | # because itself may want to call the real `forward` | |
51 | original_module.forward = original_forward # type: ignore[method-assign] | |
52 | # Call the actual method e.g. `.training_step(...)` | |
53 | out = method(*_args, **_kwargs) | |
54 | return out | |
55 | | |
56 | # Patch the original_module's forward so we can redirect the arguments back to the real method | |
57 | original_module.forward = wrapped_forward # type: ignore[method-assign] | |
58 | wrapper_output = wrapper_module(*args, **kwargs) | |
59 | return wrapper_output | |
60 | | |
61 | | |
62 | class LigerORPOTrainer(ORPOTrainer): | |
63 | def concatenated_forward( | |
64 | self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] | |
65 | ) -> Tuple[ | |
66 | torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor | |
67 | ]: | |
68 | """ | |
69 | Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. | |
70 | We do this to avoid doing two forward passes, because it's faster for FSDP. | |
71 | """ | |
72 | concatenated_batch = self.concatenated_inputs( | |
73 | batch, | |
74 | is_encoder_decoder=self.is_encoder_decoder, | |
75 | label_pad_token_id=self.label_pad_token_id, | |
76 | padding_value=self.padding_value, | |
77 | device=self.accelerator.device, | |
78 | ) | |
79 | | |
80 | model_kwargs = ( | |
81 | { | |
82 | "decoder_input_ids": self._shift_right( | |
83 | concatenated_batch["concatenated_labels"] | |
84 | ), | |
85 | } | |
86 | if self.is_encoder_decoder | |
87 | else {} | |
88 | ) | |
89 | | |
90 | if self.aux_loss_enabled: | |
91 | model_kwargs["output_router_logits"] = True | |
92 | | |
93 | if isinstance(model, FullyShardedDataParallel): | |
94 | outputs = _FSDPForwardRedirection()( | |
95 | model, | |
96 | model._fsdp_wrapped_module.model, | |
97 | concatenated_batch["concatenated_input_ids"], | |
98 | attention_mask=concatenated_batch["concatenated_attention_mask"], | |
99 | use_cache=False, | |
100 | **model_kwargs, | |
101 | ) | |
102 | else: | |
103 | if isinstance(model, torch.nn.DataParallel): | |
104 | model = model.module | |
105 | outputs = model.model( | |
106 | concatenated_batch["concatenated_input_ids"], | |
107 | attention_mask=concatenated_batch["concatenated_attention_mask"], | |
108 | use_cache=False, | |
109 | **model_kwargs, | |
110 | ) | |
111 | | |
112 | orpo_loss_fn = LigerFusedLinearORPOLoss( | |
113 | ignore_index=self.label_pad_token_id, beta=self.beta | |
114 | ) | |
115 | | |
116 | def orpo_partial(lm_head, last_hidden_state, concatenated_labels): | |
117 | return orpo_loss_fn( | |
118 | lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias | |
119 | ) | |
120 | | |
121 | orpo_loss, aux_outputs = _FSDPForwardRedirection()( | |
122 | model, | |
123 | orpo_partial, | |
124 | model.lm_head, | |
125 | outputs.last_hidden_state, | |
126 | concatenated_batch["concatenated_labels"], | |
127 | ) | |
128 | return orpo_loss, aux_outputs | |
129 | | |
130 | def get_batch_loss_metrics( | |
131 | self, | |
132 | model, | |
133 | batch: Dict[str, Union[List, torch.LongTensor]], | |
134 | train_eval: Literal["train", "eval"] = "train", | |
135 | ): | |
136 | """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" | |
137 | metrics = {} | |
138 | loss, aux_outputs = self.concatenated_forward(model, batch) | |
139 | ( | |
140 | policy_chosen_logps, | |
141 | policy_rejected_logps, | |
142 | policy_chosen_logits, | |
143 | policy_rejected_logits, | |
144 | policy_nll_loss, | |
145 | ) = aux_outputs[:5] | |
146 | | |
147 | # return loss, metrics | |
148 | chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[ | |
149 | 5: | |
150 | ] | |
151 | | |
152 | reward_accuracies = (chosen_rewards > rejected_rewards).float() | |
153 | | |
154 | prefix = "eval_" if train_eval == "eval" else "" | |
155 | metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean() | |
156 | metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean() | |
157 | metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean() | |
158 | metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean() | |
159 | metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean() | |
160 | metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean() | |
161 | metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean() | |
162 | metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean() | |
163 | metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean() | |
164 | metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio | |
165 | metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen | |
166 | for k, v in metrics.items(): | |
167 | metrics[k] = v.item() | |
168 | | |
169 | return loss, metrics | |
170 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/transformers/trainer_integration.py: | |
-------------------------------------------------------------------------------- | |
1 | # To not break HF Trainer integration | |
2 | from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401 | |
3 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/triton/__init__.py: | |
-------------------------------------------------------------------------------- | |
1 | from liger_kernel.triton.monkey_patch import ( # noqa: F401 | |
2 | apply_liger_triton_cache_manager, | |
3 | ) | |
4 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/triton/monkey_patch.py: | |
-------------------------------------------------------------------------------- | |
1 | import os | |
2 | import random | |
3 | | |
4 | from triton.runtime.cache import FileCacheManager | |
5 | | |
6 | | |
7 | class LigerTritonFileCacheManager(FileCacheManager): | |
8 | def put(self, data, filename, binary=True) -> str: | |
9 | if not self.cache_dir: | |
10 | raise RuntimeError("Could not create or locate cache dir") | |
11 | binary = isinstance(data, bytes) | |
12 | if not binary: | |
13 | data = str(data) | |
14 | assert self.lock_path is not None | |
15 | filepath = self._make_path(filename) | |
16 | # Random ID to avoid any collisions | |
17 | rnd_id = random.randint(0, 1000000) | |
18 | # we use the PID incase a bunch of these around so we can see what PID made it | |
19 | pid = os.getpid() | |
20 | # use temp dir to be robust against program interruptions | |
21 | temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}") | |
22 | os.makedirs(temp_dir, exist_ok=True) | |
23 | temp_path = os.path.join(temp_dir, filename) | |
24 | | |
25 | mode = "wb" if binary else "w" | |
26 | with open(temp_path, mode) as f: | |
27 | f.write(data) | |
28 | # Replace is guaranteed to be atomic on POSIX systems if it succeeds | |
29 | # so filepath cannot see a partial write | |
30 | os.replace(temp_path, filepath) | |
31 | os.removedirs(temp_dir) | |
32 | return filepath | |
33 | | |
34 | | |
35 | def apply_liger_triton_cache_manager(): | |
36 | """ | |
37 | Experimental feature to get around transient FileNotFoundError in triton compilation. | |
38 | For more details please see https://github.com/triton-lang/triton/pull/4295 | |
39 | """ | |
40 | os.environ["TRITON_CACHE_MANAGER"] = ( | |
41 | "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager" | |
42 | ) | |
43 | | |
-------------------------------------------------------------------------------- | |
/src/liger_kernel/utils.py: | |
-------------------------------------------------------------------------------- | |
1 | import torch | |
2 | | |
3 | | |
4 | def infer_device(): | |
5 | """ | |
6 | Get current device name based on available devices | |
7 | """ | |
8 | if torch.cuda.is_available(): | |
9 | return "cuda" | |
10 | elif torch.xpu.is_available(): | |
11 | return "xpu" | |
12 | else: | |
13 | return "cpu" | |
14 | | |
-------------------------------------------------------------------------------- | |
/test/__init__.py: | |
-------------------------------------------------------------------------------- | |
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/test/__init__.py | |
-------------------------------------------------------------------------------- | |
/test/chunked_loss/__init__.py: | |
-------------------------------------------------------------------------------- | |
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/test/chunked_loss/__init__.py | |
-------------------------------------------------------------------------------- | |
/test/chunked_loss/test_simpo_loss.py: | |
-------------------------------------------------------------------------------- | |
1 | from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO | |
2 | from test.utils import assert_verbose_allclose, set_seed | |
3 | | |
4 | import pytest | |
5 | import torch | |
6 | | |
7 | from liger_kernel.chunked_loss import LigerFusedLinearSimPOLoss | |
8 | from liger_kernel.chunked_loss.functional import liger_fused_linear_simpo | |
9 | from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction | |
10 | from liger_kernel.utils import infer_device | |
11 | | |
12 | device = infer_device() | |
13 | | |
14 | # set random seed globally | |
15 | set_seed() | |
16 | | |
17 | | |
18 | class LigerLMHeadSimPO(torch.nn.Module): | |
19 | def __init__( | |
20 | self, | |
21 | H: int, | |
22 | V: int, | |
23 | dtype: torch.dtype, | |
24 | bias: bool = False, | |
25 | ignore_index: int = -100, | |
26 | beta: float = 0.1, | |
27 | alpha: float = 1.0, | |
28 | gamma: float = 0.5, | |
29 | ): | |
30 | super().__init__() | |
31 | self.lin = torch.nn.Linear( | |
32 | in_features=H, out_features=V, bias=bias, dtype=dtype | |
33 | ) | |
34 | self.simpo_loss = LigerFusedLinearSimPOLoss( | |
35 | ignore_index=ignore_index, beta=beta, alpha=alpha, gamma=gamma | |
36 | ) | |
37 | | |
38 | def forward(self, x, y): | |
39 | return self.simpo_loss(self.lin.weight, x, y, self.lin.bias) | |
40 | | |
41 | | |
42 | @pytest.mark.parametrize( | |
43 | "B, T, H, V", | |
44 | [ | |
45 | (8, 128, 1024, 4096), | |
46 | (3, 47, 31, 123), # random shape | |
47 | ], | |
48 | ) | |
49 | @pytest.mark.parametrize( | |
50 | "scalar, dtype, atol, rtol", | |
51 | [ | |
52 | (1.0, torch.bfloat16, 5e-3, 5e-3), | |
53 | (1.0, torch.float32, 1e-5, 5e-4), | |
54 | ], | |
55 | ) | |
56 | @pytest.mark.parametrize("bias", [True, False]) | |
57 | @pytest.mark.parametrize( | |
58 | "ignore_index, beta, gamma", [(-100, 0.1, 0.5), (42, 0.2, 0.85)] | |
59 | ) | |
60 | def test_correctness( | |
61 | B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, gamma | |
62 | ): | |
63 | B = 2 * B # SimPO loss requires B to be even | |
64 | | |
65 | torch_lm_head_simpo = TorchLMHeadCPO( | |
66 | H=H, | |
67 | V=V, | |
68 | dtype=dtype, | |
69 | bias=bias, | |
70 | ignore_index=ignore_index, | |
71 | beta=beta, | |
72 | loss_type="simpo", | |
73 | simpo_gamma=gamma, | |
74 | ) | |
75 | liger_lm_head_simpo = LigerLMHeadSimPO( | |
76 | H=H, | |
77 | V=V, | |
78 | dtype=dtype, | |
79 | bias=bias, | |
80 | ignore_index=ignore_index, | |
81 | beta=beta, | |
82 | gamma=gamma, | |
83 | ) | |
84 | | |
85 | torch_lm_head_simpo.lin.weight.data = liger_lm_head_simpo.lin.weight.data = ( | |
86 | torch.randn(V, H, device=device, dtype=dtype) | |
87 | ) | |
88 | | |
89 | if bias: | |
90 | torch_lm_head_simpo.lin.bias.data = liger_lm_head_simpo.lin.bias.data = ( | |
91 | torch.randn(V, device=device, dtype=dtype) | |
92 | ) | |
93 | | |
94 | _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar | |
95 | input1 = _input.detach().clone().requires_grad_(True) | |
96 | input2 = _input.detach().clone().requires_grad_(True) | |
97 | | |
98 | target = torch.randint( | |
99 | 0, | |
100 | V, | |
101 | ( | |
102 | B, | |
103 | T, | |
104 | ), | |
105 | device=device, | |
106 | dtype=torch.long, | |
107 | ) | |
108 | # Assign some random number of elements as ignore_index | |
109 | num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() | |
110 | indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] | |
111 | target.view(-1)[indices_to_assign] = ignore_index | |
112 | | |
113 | loss1, aggregated_aux_outputs1 = torch_lm_head_simpo(input1, target) | |
114 | loss2, aggregated_aux_outputs2 = liger_lm_head_simpo(input2, target) | |
115 | | |
116 | assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) | |
117 | | |
118 | assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) | |
119 | | |
120 | for i in range(len(aggregated_aux_outputs1)): | |
121 | assert_verbose_allclose( | |
122 | aggregated_aux_outputs1[i], | |
123 | aggregated_aux_outputs2[i], | |
124 | atol=atol, | |
125 | rtol=rtol, | |
126 | ) | |
127 | | |
128 | loss1.backward() | |
129 | loss2.backward() | |
130 | | |
131 | assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) | |
132 | assert_verbose_allclose( | |
133 | torch_lm_head_simpo.lin.weight.grad, | |
134 | liger_lm_head_simpo.lin.weight.grad, | |
135 | atol=atol, | |
136 | rtol=rtol, | |
137 | ) | |
138 | if bias: | |
139 | assert_verbose_allclose( | |
140 | torch_lm_head_simpo.lin.bias.grad, | |
141 | liger_lm_head_simpo.lin.bias.grad, | |
142 | atol=atol, | |
143 | rtol=rtol, | |
144 | ) | |
145 | | |
146 | | |
147 | @pytest.mark.parametrize( | |
148 | "B, T, H, V", | |
149 | [ | |
150 | (2, 2, 8, 8), | |
151 | (3, 47, 31, 123), # random shape | |
152 | ], | |
153 | ) | |
154 | @pytest.mark.parametrize( | |
155 | "scalar, dtype, atol, rtol", | |
156 | [ | |
157 | (1.0, torch.bfloat16, 5e-2, 5e-1), | |
158 | (1.0, torch.float32, 1e-5, 5e-4), | |
159 | ], | |
160 | ) | |
161 | @pytest.mark.parametrize("bias", [True, False]) | |
162 | def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): | |
163 | B = 2 * B | |
164 | | |
165 | _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar | |
166 | input1 = _input.detach().clone().requires_grad_(True) | |
167 | input2 = _input.detach().clone().requires_grad_(True) | |
168 | | |
169 | target = torch.randint( | |
170 | 0, | |
171 | V, | |
172 | ( | |
173 | B, | |
174 | T, | |
175 | ), | |
176 | device=device, | |
177 | dtype=torch.long, | |
178 | ) | |
179 | | |
180 | _weight = torch.randn(V, H, device=device, dtype=dtype) | |
181 | weight1 = _weight.detach().clone().requires_grad_(True) | |
182 | weight2 = _weight.detach().clone().requires_grad_(True) | |
183 | | |
184 | _bias = torch.randn(V, device=device, dtype=dtype) if bias else None | |
185 | bias1 = _bias.detach().clone().requires_grad_(True) if bias else None | |
186 | bias2 = _bias.detach().clone().requires_grad_(True) if bias else None | |
187 | | |
188 | loss1, aggregated_aux_outputs1 = LigerFusedLinearSimPOFunction.apply( | |
189 | input1, weight1, target, bias1 | |
190 | ) | |
191 | loss2, aggregated_aux_outputs2 = liger_fused_linear_simpo( | |
192 | input2, weight2, target, bias2 | |
193 | ) | |
194 | | |
195 | assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) | |
196 | | |
197 | loss1.backward() | |
198 | loss2.backward() | |
199 | | |
200 | assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) | |
201 | assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) | |
202 | if bias: | |
203 | assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol) | |
204 | | |
-------------------------------------------------------------------------------- | |
/test/conftest.py: | |
-------------------------------------------------------------------------------- | |
1 | import pytest | |
2 | import torch | |
3 | | |
4 | | |
5 | @pytest.fixture(autouse=True) | |
6 | def clear_cuda_cache(): | |
7 | yield | |
8 | torch.cuda.empty_cache() | |
9 | | |
-------------------------------------------------------------------------------- | |
/test/convergence/__init__.py: | |
-------------------------------------------------------------------------------- | |
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/test/convergence/__init__.py | |
-------------------------------------------------------------------------------- | |
/test/resources/scripts/generate_tokenized_dataset.py: | |
-------------------------------------------------------------------------------- | |
1 | import argparse | |
2 | | |
3 | from datasets import load_dataset | |
4 | from transformers import AutoTokenizer | |
5 | | |
6 | | |
7 | def prepare_dataset(tokenizer, text_file_path: str): | |
8 | """ | |
9 | Tokenizes a text file where each line is a different example. | |
10 | Padding is applied to each example. | |
11 | """ | |
12 | # Each line is a different example | |
13 | dataset = load_dataset("text", data_files={"train": text_file_path}) | |
14 | | |
15 | def tokenize_function(examples): | |
16 | return tokenizer( | |
17 | examples["text"], padding="max_length", truncation=True, max_length=128 | |
18 | ) | |
19 | | |
20 | tokenized_dataset = dataset.map( | |
21 | tokenize_function, batched=True, remove_columns=["text"] | |
22 | ) | |
23 | return tokenized_dataset["train"] | |
24 | | |
25 | | |
26 | def generate_tokenized_dataset( | |
27 | tokenizer_path: str, text_file_path: str, output_dir: str | |
28 | ) -> None: | |
29 | """ | |
30 | Generate tokenized dataset from a text file, where each line is a different example. | |
31 | | |
32 | Args: | |
33 | tokenizer_path (str): Path to the directory containing the tokenizer files. | |
34 | text_file_path (str): Path to the text file to tokenize. | |
35 | output_dir (str): Directory where the tokenized dataset will be saved | |
36 | """ | |
37 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
38 | tokenizer.pad_token = tokenizer.eos_token | |
39 | | |
40 | train_dataset = prepare_dataset(tokenizer, text_file_path) | |
41 | train_dataset.save_to_disk(output_dir) | |
42 | | |
43 | | |
44 | if __name__ == "__main__": | |
45 | # Example usage: | |
46 | # python generate_tokenized_dataset.py --tokenizer_path /shared/public/models/Mistral-7B --text_file_path ./../../resources/tiny_shakespeare.txt --output_dir ./../../resources/tiny_shakespeare_tokenized | |
47 | parser = argparse.ArgumentParser( | |
48 | description="Generate tokenized dataset from a text file." | |
49 | ) | |
50 | | |
51 | # Add arguments | |
52 | parser.add_argument( | |
53 | "--tokenizer_path", | |
54 | type=str, | |
55 | required=True, | |
56 | help="Path to the directory containing the tokenizer files.", | |
57 | ) | |
58 | parser.add_argument( | |
59 | "--text_file_path", | |
60 | type=str, | |
61 | required=True, | |
62 | help="Path to the text file to tokenize.", | |
63 | ) | |
64 | parser.add_argument( | |
65 | "--output_dir", | |
66 | type=str, | |
67 | required=True, | |
68 | help="Directory where the tokenized dataset will be saved.", | |
69 | ) | |
70 | | |
71 | # Parse the arguments | |
72 | args = parser.parse_args() | |
73 | | |
74 | # Call the function with parsed arguments | |
75 | generate_tokenized_dataset( | |
76 | tokenizer_path=args.tokenizer_path, | |
77 | text_file_path=args.text_file_path, | |
78 | output_dir=args.output_dir, | |
79 | ) | |
80 | | |
-------------------------------------------------------------------------------- | |
/test/transformers/test_auto_model.py: | |
-------------------------------------------------------------------------------- | |
1 | from inspect import signature | |
2 | from unittest import mock | |
3 | from unittest.mock import MagicMock, patch | |
4 | | |
5 | from transformers import AutoConfig, AutoModelForCausalLM | |
6 | | |
7 | from liger_kernel.transformers import AutoLigerKernelForCausalLM | |
8 | from liger_kernel.transformers.monkey_patch import ( | |
9 | MODEL_TYPE_TO_APPLY_LIGER_FN, | |
10 | apply_liger_kernel_to_llama, | |
11 | ) | |
12 | | |
13 | | |
14 | def test_auto_liger_kernel_for_causal_lm_from_pretrained(): | |
15 | pretrained_model_name_or_path = "/path/to/llama/model" | |
16 | model_args = ("model_arg1", "model_arg2") | |
17 | | |
18 | original_kwargs = { | |
19 | "valid_arg_1": "some_value_1", | |
20 | "valid_arg_2": 10, | |
21 | } | |
22 | | |
23 | # These args should be passed through to apply_liger_kernel_to_llama fn | |
24 | apply_liger_kernel_kwargs = { | |
25 | "rope": False, | |
26 | "swiglu": True, | |
27 | } | |
28 | | |
29 | kwargs = {**original_kwargs, **apply_liger_kernel_kwargs} | |
30 | | |
31 | # Mock the model config instance returned from AutoConfig.from_pretrained() | |
32 | mock_model_config = MagicMock() | |
33 | mock_model_config.model_type = "llama" | |
34 | mock_llama = mock.Mock() | |
35 | | |
36 | with patch.dict( | |
37 | MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama} | |
38 | ), mock.patch.object( | |
39 | AutoConfig, "from_pretrained", return_value=mock_model_config | |
40 | ), mock.patch.object( | |
41 | AutoModelForCausalLM, "from_pretrained", return_value="mock_model" | |
42 | ) as mock_super_from_pretrained: | |
43 | | |
44 | # Mock the function signature of apply_liger_kernel_to_llama | |
45 | mock_llama.__signature__ = signature(apply_liger_kernel_to_llama) | |
46 | | |
47 | model = AutoLigerKernelForCausalLM.from_pretrained( | |
48 | pretrained_model_name_or_path, *model_args, **kwargs | |
49 | ) | |
50 | | |
51 | # Check that the apply_liger_kernel_to_llama mock was called with the correct kwargs | |
52 | mock_llama.assert_called_once_with(rope=False, swiglu=True) | |
53 | # Check that the original kwargs are passed to super().from_pretrained | |
54 | mock_super_from_pretrained.assert_called_once_with( | |
55 | pretrained_model_name_or_path, *model_args, **original_kwargs | |
56 | ) | |
57 | assert model == "mock_model" | |
58 | | |
-------------------------------------------------------------------------------- | |
/test/transformers/test_embedding.py: | |
-------------------------------------------------------------------------------- | |
1 | import pytest | |
2 | import torch | |
3 | from torch.nn import Embedding | |
4 | | |
5 | from liger_kernel.transformers.experimental.embedding import LigerEmbedding | |
6 | from liger_kernel.utils import infer_device | |
7 | | |
8 | device = infer_device() | |
9 | | |
10 | SLEEP_SECONDS = 0.1 | |
11 | | |
12 | | |
13 | @pytest.mark.skip(reason="LigerEmbedding is under experimentation") | |
14 | @pytest.mark.parametrize( | |
15 | "num_embeddings, embedding_dim, padding_idx", | |
16 | [ | |
17 | (100, 64, None), | |
18 | (100, 64, None), | |
19 | (1000, 128, None), | |
20 | (100, 60, None), | |
21 | (100, 60, None), | |
22 | (1000, 120, None), | |
23 | (1000, 500, None), | |
24 | (30522, 768, None), | |
25 | (100, 64, 0), | |
26 | (1000, 128, 50), | |
27 | (30522, 768, 1), | |
28 | ], | |
29 | ) | |
30 | @pytest.mark.parametrize( | |
31 | "dtype, atol, rtol, device", | |
32 | [ | |
33 | (torch.float32, 1e-6, 1e-5, device), | |
34 | ], | |
35 | ) | |
36 | def test_embedding_correctness( | |
37 | num_embeddings, embedding_dim, padding_idx, dtype, atol, rtol, device | |
38 | ): | |
39 | print( | |
40 | f"\nTesting embedding with size: ({num_embeddings}, {embedding_dim}), padding_idx: {padding_idx}" | |
41 | ) | |
42 | torch.manual_seed(42) | |
43 | | |
44 | torch_embedding = ( | |
45 | Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) | |
46 | .to(dtype) | |
47 | .to(device) | |
48 | ) | |
49 | liger_embedding = ( | |
50 | LigerEmbedding(num_embeddings, embedding_dim, padding_idx=padding_idx) | |
51 | .to(dtype) | |
52 | .to(device) | |
53 | ) | |
54 | liger_embedding.weight.data.copy_(torch_embedding.weight.data) | |
55 | | |
56 | if padding_idx is not None: | |
57 | input_ids = torch.randint(0, num_embeddings, (32 * 10,), device=device) | |
58 | input_ids[torch.randint(0, 32 * 10, (32 * 10 // 10,))] = padding_idx | |
59 | else: | |
60 | input_ids = torch.randint(0, num_embeddings, (32 * 10,), device=device) | |
61 | | |
62 | torch_output = torch_embedding(input_ids).view(32, 10, -1) | |
63 | liger_output = liger_embedding(input_ids).view(32, 10, -1) | |
64 | | |
65 | assert torch.allclose(torch_output, liger_output, atol=atol, rtol=rtol) | |
66 | | |
67 | grad_output = torch.randn_like(torch_output) | |
68 | | |
69 | torch_output.backward(grad_output) | |
70 | liger_output.backward(grad_output) | |
71 | | |
72 | assert torch.allclose( | |
73 | torch_embedding.weight.grad, liger_embedding.weight.grad, atol=atol, rtol=rtol | |
74 | ) | |
75 | | |
-------------------------------------------------------------------------------- | |
/test/transformers/test_geglu.py: | |
-------------------------------------------------------------------------------- | |
1 | from test.utils import supports_bfloat16 | |
2 | | |
3 | import pytest | |
4 | import torch | |
5 | from transformers.models.llama.configuration_llama import LlamaConfig | |
6 | from transformers.models.llama.modeling_llama import LlamaMLP | |
7 | | |
8 | from liger_kernel.ops.geglu import LigerGELUMulFunction | |
9 | from liger_kernel.transformers.functional import liger_geglu | |
10 | from liger_kernel.transformers.geglu import LigerGEGLUMLP | |
11 | from liger_kernel.utils import infer_device | |
12 | | |
13 | device = infer_device() | |
14 | | |
15 | LLAMA_CONFIG = LlamaConfig( | |
16 | hidden_size=4096, | |
17 | intermediate_size=11008, | |
18 | hidden_act="gelu_pytorch_tanh", | |
19 | ) | |
20 | SLEEP_SECONDS = 0.1 | |
21 | | |
22 | | |
23 | @pytest.mark.parametrize( | |
24 | "bsz, seq_len, hidden_size, intermediate_size", | |
25 | [ | |
26 | (2, 2048, 2048, 4096), | |
27 | # weird shapes | |
28 | (9, 41, 341, 4231), | |
29 | ], | |
30 | ) | |
31 | @pytest.mark.parametrize( | |
32 | "dtype, atol, rtol", | |
33 | [ | |
34 | # atol is for small values: they have more difference, so set atol higher | |
35 | # rtol is for larger values: they are very close, so set rtol lower | |
36 | (torch.float32, 1e-0, 2e-6), | |
37 | pytest.param( | |
38 | torch.bfloat16, | |
39 | 1e4, | |
40 | 6e-3, | |
41 | marks=pytest.mark.skipif( | |
42 | not supports_bfloat16(), reason="bfloat16 not supported on this GPU" | |
43 | ), | |
44 | ), | |
45 | ], | |
46 | ) | |
47 | def test_correctness(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol): | |
48 | _input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype) | |
49 | | |
50 | x1 = _input.clone().requires_grad_(True) | |
51 | x2 = _input.clone().requires_grad_(True) | |
52 | | |
53 | # initialize weights | |
54 | G = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype) | |
55 | U = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype) | |
56 | D = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype) | |
57 | | |
58 | llama_mlp = LlamaMLP(config=LLAMA_CONFIG).to(device).to(dtype) | |
59 | llama_mlp.gate_proj.weight.data = G.T | |
60 | llama_mlp.up_proj.weight.data = U.T | |
61 | llama_mlp.down_proj.weight.data = D.T | |
62 | | |
63 | liger_mlp = LigerGEGLUMLP(config=LLAMA_CONFIG).to(device).to(dtype) | |
64 | liger_mlp.gate_proj.weight.data = G.T | |
65 | liger_mlp.up_proj.weight.data = U.T | |
66 | liger_mlp.down_proj.weight.data = D.T | |
67 | | |
68 | y1 = llama_mlp(x1) | |
69 | y2 = liger_mlp(x2) | |
70 | | |
71 | assert torch.allclose(y1, y2, atol=atol, rtol=rtol) is True | |
72 | | |
73 | dy = torch.randn_like(y1) | |
74 | | |
75 | y1.backward(dy.clone(), retain_graph=True) | |
76 | y2.backward(dy.clone(), retain_graph=True) | |
77 | | |
78 | assert ( | |
79 | torch.allclose( | |
80 | llama_mlp.gate_proj.weight.grad, | |
81 | liger_mlp.gate_proj.weight.grad, | |
82 | atol=atol, | |
83 | rtol=rtol, | |
84 | ) | |
85 | is True | |
86 | ) | |
87 | assert ( | |
88 | torch.allclose( | |
89 | llama_mlp.up_proj.weight.grad, | |
90 | liger_mlp.up_proj.weight.grad, | |
91 | atol=atol, | |
92 | rtol=rtol, | |
93 | ) | |
94 | is True | |
95 | ) | |
96 | assert ( | |
97 | torch.allclose( | |
98 | llama_mlp.down_proj.weight.grad, | |
99 | liger_mlp.down_proj.weight.grad, | |
100 | atol=atol, | |
101 | rtol=rtol, | |
102 | ) | |
103 | is True | |
104 | ) | |
105 | | |
106 | assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) is True | |
107 | | |
108 | | |
109 | @pytest.mark.parametrize( | |
110 | "bsz, seq_len, size", | |
111 | [ | |
112 | (2, 2, 8), | |
113 | # weird shapes | |
114 | (9, 7, 41), | |
115 | ], | |
116 | ) | |
117 | @pytest.mark.parametrize( | |
118 | "dtype, atol, rtol", | |
119 | [ | |
120 | # atol is for small values: they have more difference, so set atol higher | |
121 | # rtol is for larger values: they are very close, so set rtol lower | |
122 | (torch.float32, 1e-0, 2e-6), | |
123 | (torch.bfloat16, 1e4, 6e-3), | |
124 | ], | |
125 | ) | |
126 | def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol): | |
127 | _input = torch.randn(bsz, seq_len, size, device=device, dtype=dtype) | |
128 | _b = torch.randn(bsz, seq_len, size, device=device, dtype=dtype) | |
129 | | |
130 | x1 = _input.clone().requires_grad_(True) | |
131 | x2 = _input.clone().requires_grad_(True) | |
132 | | |
133 | b1 = _b.clone().requires_grad_(True) | |
134 | b2 = _b.clone().requires_grad_(True) | |
135 | | |
136 | y1 = liger_geglu(a=x1, b=b1) | |
137 | y2 = LigerGELUMulFunction.apply(x2, b2) | |
138 | | |
139 | assert torch.allclose(y1, y2, atol=atol, rtol=rtol) | |
140 | | |
141 | grad_output = torch.randn_like(y1) | |
142 | | |
143 | y1.backward(grad_output) | |
144 | y2.backward(grad_output) | |
145 | | |
146 | assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) | |
147 | assert torch.allclose(b1.grad, b2.grad, atol=atol, rtol=rtol) | |
148 | | |
-------------------------------------------------------------------------------- | |
/test/transformers/test_group_norm.py: | |
-------------------------------------------------------------------------------- | |
1 | import random | |
2 | | |
3 | import pytest | |
4 | import torch | |
5 | | |
6 | from liger_kernel.transformers.group_norm import LigerGroupNorm | |
7 | from liger_kernel.utils import infer_device | |
8 | | |
9 | device = infer_device() | |
10 | | |
11 | random_batch_size = random.randint(1, 16) | |
12 | random_num_groups = random.randint(1, 32) | |
13 | random_num_channels = random_num_groups * random.randint(1, 16) | |
14 | random_hidden_size = random.randint(1, 8192) | |
15 | | |
16 | | |
17 | @pytest.mark.parametrize( | |
18 | "batch_size, num_channels, num_groups, hidden_size", | |
19 | [ | |
20 | (1, 1, 1, 3), | |
21 | (1, 4, 2, 4), | |
22 | (16, 12, 3, 4096), | |
23 | (random_batch_size, random_num_channels, random_num_groups, random_hidden_size), | |
24 | ], | |
25 | ) | |
26 | @pytest.mark.parametrize( | |
27 | "dtype, atol, rtol", | |
28 | [ | |
29 | (torch.float32, 1e-4, 1e-4), | |
30 | ], | |
31 | ) | |
32 | def test_liger_group_norm( | |
33 | batch_size, num_channels, num_groups, hidden_size, dtype, atol, rtol | |
34 | ): | |
35 | torch.manual_seed(0) | |
36 | | |
37 | _tensor = torch.randn( | |
38 | batch_size, num_channels, hidden_size, dtype=dtype, device=device | |
39 | ) | |
40 | | |
41 | liger_x = _tensor.clone().detach().requires_grad_(True) | |
42 | torch_x = _tensor.clone().detach().requires_grad_(True) | |
43 | | |
44 | liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).to(device) | |
45 | torch_ln = ( | |
46 | torch.nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, eps=1e-6) | |
47 | .to(dtype) | |
48 | .to(device) | |
49 | ) | |
50 | | |
51 | with torch.no_grad(): | |
52 | torch_ln.weight.copy_(liger_ln.weight) | |
53 | torch_ln.bias.copy_(liger_ln.bias) | |
54 | | |
55 | liger_output = liger_ln( | |
56 | liger_x, | |
57 | ) | |
58 | torch_output = torch_ln(torch_x) | |
59 | | |
60 | assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol) | |
61 | grad_output = torch.randn_like(torch_x) | |
62 | liger_output.backward(grad_output, retain_graph=True) | |
63 | torch_output.backward(grad_output, retain_graph=True) | |
64 | assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) | |
65 | assert torch.allclose( | |
66 | liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol | |
67 | ), "Bias grads different" | |
68 | assert torch.allclose( | |
69 | liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol | |
70 | ), "Weight grads different" | |
71 | | |
-------------------------------------------------------------------------------- | |
/test/transformers/test_kl_div.py: | |
-------------------------------------------------------------------------------- | |
1 | from test.utils import supports_bfloat16 | |
2 | | |
3 | import pytest | |
4 | import torch | |
5 | from torch.nn import KLDivLoss | |
6 | | |
7 | from liger_kernel.transformers.kl_div import LigerKLDIVLoss | |
8 | from liger_kernel.utils import infer_device | |
9 | | |
10 | device = infer_device() | |
11 | | |
12 | _SHAPE_PARAMS = ( | |
13 | "B, T, V", | |
14 | [ | |
15 | (1, 4096, 32000), | |
16 | # weird shape | |
17 | (41, 401, 1271), | |
18 | ], | |
19 | ) | |
20 | | |
21 | _DTYPE_PARAMS = ( | |
22 | "dtype, atol, rtol", | |
23 | [ | |
24 | pytest.param( | |
25 | torch.bfloat16, | |
26 | 1e-8, | |
27 | 5e-2, | |
28 | marks=pytest.mark.skipif( | |
29 | not supports_bfloat16(), reason="bfloat16 not supported on this GPU" | |
30 | ), | |
31 | ), | |
32 | (torch.float32, 1e-8, 1e-6), | |
33 | (torch.float16, 1e-3, 1e-3), | |
34 | ], | |
35 | ) | |
36 | | |
37 | | |
38 | def _test_correctness_once( | |
39 | target_kldiv, | |
40 | B, | |
41 | T, | |
42 | V, | |
43 | dtype, | |
44 | atol, | |
45 | rtol, | |
46 | reduction, | |
47 | log_target, | |
48 | is_last_layer=True, | |
49 | device=device, | |
50 | ): | |
51 | torch.manual_seed(0) | |
52 | torch_kldiv = KLDivLoss(reduction=reduction, log_target=log_target) | |
53 | | |
54 | input = torch.randn( | |
55 | B * T, V, device=device, dtype=dtype, requires_grad=True | |
56 | ).log_softmax(dim=-1) | |
57 | | |
58 | x1 = input.detach().clone().requires_grad_(True) | |
59 | x2 = input.detach().clone().requires_grad_(True) | |
60 | | |
61 | with torch.no_grad(): | |
62 | target = torch.randn(B * T, V, device=device).softmax(dim=-1) | |
63 | | |
64 | output = torch_kldiv(x1, target) | |
65 | output2 = target_kldiv(x2, target) | |
66 | assert torch.allclose(output, output2, atol=atol, rtol=rtol) | |
67 | | |
68 | if ( | |
69 | not is_last_layer | |
70 | ): # if the loss is the last layer, grad_output is 1.0 and mul op is skipped, testing for that reason | |
71 | output = output * 2.0 | |
72 | output2 = output2 * 2.0 | |
73 | | |
74 | if reduction == "none": | |
75 | return | |
76 | | |
77 | output.backward() | |
78 | output2.backward() | |
79 | assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) | |
80 | | |
81 | | |
82 | @pytest.mark.parametrize(*_SHAPE_PARAMS) | |
83 | @pytest.mark.parametrize("log_target", [True, False]) | |
84 | @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) | |
85 | @pytest.mark.parametrize(*_DTYPE_PARAMS) | |
86 | def test_correctness(B, T, V, log_target, reduction, dtype, atol, rtol): | |
87 | liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target) | |
88 | _test_correctness_once( | |
89 | liger_kldiv, B, T, V, dtype, atol, rtol, reduction, log_target | |
90 | ) | |
91 | | |
92 | | |
93 | @pytest.mark.parametrize(*_SHAPE_PARAMS) | |
94 | @pytest.mark.parametrize("log_target", [True, False]) | |
95 | @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) | |
96 | @pytest.mark.parametrize(*_DTYPE_PARAMS) | |
97 | def test_correctness_not_last(B, T, V, log_target, reduction, dtype, atol, rtol): | |
98 | liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target) | |
99 | _test_correctness_once( | |
100 | liger_kldiv, | |
101 | B, | |
102 | T, | |
103 | V, | |
104 | dtype, | |
105 | atol, | |
106 | rtol, | |
107 | reduction, | |
108 | log_target, | |
109 | is_last_layer=False, | |
110 | ) | |
111 | | |
-------------------------------------------------------------------------------- | |
/test/transformers/test_layer_norm.py: | |
-------------------------------------------------------------------------------- | |
1 | import pytest | |
2 | import torch | |
3 | | |
4 | from liger_kernel.ops.layer_norm import LigerLayerNormFunction | |
5 | from liger_kernel.transformers.functional import liger_layer_norm | |
6 | from liger_kernel.transformers.layer_norm import LigerLayerNorm | |
7 | from liger_kernel.utils import infer_device | |
8 | | |
9 | device = infer_device() | |
10 | | |
11 | | |
12 | @pytest.mark.parametrize( | |
13 | "batch_size, seq_len, hidden_size", | |
14 | [ | |
15 | (2, 8, 64), | |
16 | (4, 16, 128), | |
17 | ], | |
18 | ) | |
19 | @pytest.mark.parametrize( | |
20 | "dtype, atol, rtol", | |
21 | [ | |
22 | (torch.float32, 1e-5, 1e-5), | |
23 | ], | |
24 | ) | |
25 | def test_liger_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol): | |
26 | torch.manual_seed(0) | |
27 | | |
28 | x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device) | |
29 | | |
30 | liger_x = x.clone().requires_grad_(True) | |
31 | torch_x = x.clone().requires_grad_(True) | |
32 | | |
33 | liger_ln = LigerLayerNorm(hidden_size, eps=1e-6).to(dtype).to(device) | |
34 | torch_ln = torch.nn.LayerNorm(hidden_size, eps=1e-6).to(dtype).to(device) | |
35 | | |
36 | with torch.no_grad(): | |
37 | torch_ln.weight.copy_(liger_ln.weight) | |
38 | torch_ln.bias.copy_(liger_ln.bias) | |
39 | | |
40 | liger_output = liger_ln(liger_x) | |
41 | torch_output = torch_ln(torch_x) | |
42 | | |
43 | assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol) | |
44 | | |
45 | grad_output = torch.randn_like(x) | |
46 | liger_output.backward(grad_output, retain_graph=True) | |
47 | torch_output.backward(grad_output, retain_graph=True) | |
48 | | |
49 | assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) | |
50 | assert torch.allclose( | |
51 | liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol | |
52 | ) | |
53 | assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol) | |
54 | | |
55 | | |
56 | @pytest.mark.parametrize( | |
57 | "batch_size, seq_len, hidden_size", | |
58 | [ | |
59 | (2, 8, 64), | |
60 | (4, 16, 128), | |
61 | ], | |
62 | ) | |
63 | @pytest.mark.parametrize( | |
64 | "dtype, atol, rtol", | |
65 | [ | |
66 | (torch.float32, 1e-5, 1e-5), | |
67 | ], | |
68 | ) | |
69 | def test_liger_layer_norm_functional( | |
70 | hidden_size, batch_size, seq_len, dtype, atol, rtol | |
71 | ): | |
72 | torch.manual_seed(0) | |
73 | | |
74 | input = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device) | |
75 | | |
76 | x1 = input.clone().requires_grad_(True) | |
77 | x2 = input.clone().requires_grad_(True) | |
78 | | |
79 | w = torch.randn(hidden_size, device=device, dtype=dtype) | |
80 | | |
81 | w1 = w.clone().requires_grad_(True) | |
82 | w2 = w.clone().requires_grad_(True) | |
83 | | |
84 | b = torch.randn(hidden_size, device=device, dtype=dtype) | |
85 | | |
86 | b1 = b.clone().requires_grad_(True) | |
87 | b2 = b.clone().requires_grad_(True) | |
88 | | |
89 | y1 = liger_layer_norm(X=x1, W=w1, B=b1, eps=1e-6) | |
90 | y2 = LigerLayerNormFunction.apply(x2, w2, b2, 1e-6) | |
91 | | |
92 | assert torch.allclose(y1, y2, atol=atol, rtol=rtol) | |
93 | | |
94 | grad_output = torch.randn_like(y2) | |
95 | | |
96 | y1.backward(grad_output, retain_graph=True) | |
97 | y2.backward(grad_output, retain_graph=True) | |
98 | | |
99 | assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) | |
100 | assert torch.allclose(w1.grad, w2.grad, atol=atol, rtol=rtol) | |
101 | assert torch.allclose(b1.grad, b2.grad, atol=atol, rtol=rtol) | |
102 | | |
-------------------------------------------------------------------------------- | |
/test/transformers/test_mm_int8int2.py: | |
-------------------------------------------------------------------------------- | |
1 | import pytest | |
2 | import torch | |
3 | | |
4 | from liger_kernel.ops.experimental.mm_int8int2 import ( | |
5 | matmul, | |
6 | pack_weights, | |
7 | unpack_weights, | |
8 | ) | |
9 | from liger_kernel.utils import infer_device | |
10 | | |
11 | device = infer_device() | |
12 | | |
13 | | |
14 | # input_features = size*4 when the weight matrix is unpacked | |
15 | @pytest.mark.skip(reason="mm_int8int2 is under experimentation") | |
16 | @pytest.mark.parametrize( | |
17 | "size", | |
18 | [ | |
19 | 2048, | |
20 | 1024, | |
21 | 512, | |
22 | ], | |
23 | ) | |
24 | @pytest.mark.parametrize( | |
25 | "batch_size", | |
26 | [1, 2, 3, 8], | |
27 | ) | |
28 | @pytest.mark.parametrize( | |
29 | "seq_len", | |
30 | [1, 7, 16, 2048], | |
31 | ) | |
32 | @pytest.mark.parametrize( | |
33 | "out_features", | |
34 | [ | |
35 | 1024, | |
36 | 2048, | |
37 | 4096, | |
38 | 10000, | |
39 | ], | |
40 | ) | |
41 | @pytest.mark.parametrize( | |
42 | "atol, rtol, device", | |
43 | [ | |
44 | (1e-2, 1e-2, device), | |
45 | ], | |
46 | ) | |
47 | def test_kernel_correctness( | |
48 | batch_size, seq_len, out_features, size, atol, rtol, device | |
49 | ): | |
50 | print(f"\nTesting kernel with size: {size}, atol: {atol}, rtol: {rtol}") | |
51 | | |
52 | # Generate the random tensors | |
53 | ht = torch.randint( | |
54 | -127, 127, (batch_size, seq_len, size * 4), device=device, dtype=torch.int8 | |
55 | ) | |
56 | u = torch.randint(0, 255, (out_features, size), device=device, dtype=torch.uint8) | |
57 | | |
58 | # Calculate dimensions | |
59 | B, M, N = ht.size() | |
60 | | |
61 | # Compute triton output | |
62 | triton_output = matmul(ht.view(B * M, N), u.T.contiguous()).view(B, M, -1) | |
63 | | |
64 | # Unpack weights and compute torch output | |
65 | unpacked = unpack_weights(u.T, bits=2).T | |
66 | torch_output = torch.matmul( | |
67 | ht.to(torch.float32), unpacked.T.contiguous().to(torch.float32) | |
68 | ) | |
69 | | |
70 | # Print the results (optional, can be commented out) | |
71 | print("triton_output =", triton_output) | |
72 | print("torch_output =", torch_output) | |
73 | | |
74 | # Check if outputs are close within the given tolerances | |
75 | assert torch.allclose( | |
76 | triton_output, torch_output.to(torch.int32), atol=atol, rtol=rtol | |
77 | ), "Results differ" | |
78 | | |
79 | | |
80 | @pytest.mark.skip(reason="mm_int8int2 is under experimentation") | |
81 | @pytest.mark.parametrize( | |
82 | "size", | |
83 | [ | |
84 | 2048, | |
85 | 1024, | |
86 | 512, | |
87 | ], | |
88 | ) | |
89 | @pytest.mark.parametrize( | |
90 | "out_features", | |
91 | [ | |
92 | 1024, | |
93 | 2048, | |
94 | 4096, | |
95 | 10000, | |
96 | ], | |
97 | ) | |
98 | @pytest.mark.parametrize( | |
99 | "device", | |
100 | [ | |
101 | device, | |
102 | ], | |
103 | ) | |
104 | def test_unpack_pack_correctness(out_features, size, device): | |
105 | u = torch.randint(0, 255, (out_features, size), device=device, dtype=torch.uint8) | |
106 | | |
107 | assert ( | |
108 | pack_weights(unpack_weights(u.T), 2) == u.T | |
109 | ).all(), "Packed weights do not match original weights." | |
110 | | |
-------------------------------------------------------------------------------- | |
/test/transformers/test_qwen2vl_mrope.py: | |
-------------------------------------------------------------------------------- | |
1 | from test.utils import supports_bfloat16 | |
2 | | |
3 | import pytest | |
4 | import torch | |
5 | | |
6 | try: | |
7 | from transformers.models.qwen2_vl.modeling_qwen2_vl import ( | |
8 | Qwen2VLRotaryEmbedding, | |
9 | apply_multimodal_rotary_pos_emb, | |
10 | ) | |
11 | | |
12 | IS_QWEN_AVAILABLE = True | |
13 | except Exception: | |
14 | IS_QWEN_AVAILABLE = False | |
15 | | |
16 | from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction | |
17 | from liger_kernel.transformers.functional import liger_qwen2vl_mrope | |
18 | from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb | |
19 | from liger_kernel.utils import infer_device | |
20 | | |
21 | device = infer_device() | |
22 | | |
23 | | |
24 | @pytest.mark.skipif( | |
25 | not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers." | |
26 | ) | |
27 | @pytest.mark.parametrize("bsz", [1, 2]) | |
28 | @pytest.mark.parametrize("seq_len", [128, 131]) | |
29 | @pytest.mark.parametrize("num_q_heads, num_kv_heads", [(64, 8), (28, 4), (12, 2)]) | |
30 | @pytest.mark.parametrize( | |
31 | "head_dim, mrope_section", | |
32 | [ | |
33 | (128, [16, 24, 24]), | |
34 | (96, [16, 16, 16]), | |
35 | (64, [8, 12, 12]), | |
36 | ], | |
37 | ) | |
38 | @pytest.mark.parametrize( | |
39 | "dtype, atol, rtol", | |
40 | [ | |
41 | (torch.float32, 1e-5, 1e-5), | |
42 | pytest.param( | |
43 | torch.bfloat16, | |
44 | 1e-1, | |
45 | 1e-5, | |
46 | marks=pytest.mark.skipif( | |
47 | not supports_bfloat16(), reason="bfloat16 not supported on this GPU" | |
48 | ), | |
49 | ), | |
50 | ], | |
51 | ) | |
52 | def test_correctness( | |
53 | bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section, dtype, atol, rtol | |
54 | ): | |
55 | rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) | |
56 | | |
57 | _tensor_q = ( | |
58 | torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device) | |
59 | .transpose(1, 2) | |
60 | .to(dtype) | |
61 | ) | |
62 | | |
63 | _tensor_k = ( | |
64 | torch.randn((bsz, seq_len, num_kv_heads, head_dim), device=device) | |
65 | .transpose(1, 2) | |
66 | .to(dtype) | |
67 | ) | |
68 | | |
69 | q1 = _tensor_q.clone().requires_grad_(True) | |
70 | k1 = _tensor_k.clone().requires_grad_(True) | |
71 | | |
72 | q2 = _tensor_q.clone().requires_grad_(True) | |
73 | k2 = _tensor_k.clone().requires_grad_(True) | |
74 | | |
75 | # NOTE: this position ids distribution is different from the real one, just to test op correctness | |
76 | pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view( | |
77 | 3, bsz, seq_len | |
78 | ) | |
79 | cos, sin = rotary_emb(k1, pos_ids) | |
80 | | |
81 | # validate forward pass | |
82 | hf_q, hf_k = apply_multimodal_rotary_pos_emb(q1, k1, cos, sin, mrope_section) | |
83 | tt_q, tt_k = liger_multimodal_rotary_pos_emb(q2, k2, cos, sin, mrope_section) | |
84 | torch.testing.assert_close(hf_q, tt_q, atol=atol, rtol=rtol) | |
85 | torch.testing.assert_close(hf_k, tt_k, atol=atol, rtol=rtol) | |
86 | | |
87 | # validate backward pass | |
88 | dq, dk = ( | |
89 | torch.randn_like(hf_q, device=device), | |
90 | torch.randn_like(hf_k, device=device).to(dtype), | |
91 | ) | |
92 | | |
93 | q1_grad, k1_grad = torch.autograd.grad( | |
94 | (hf_q, hf_k), (q1, k1), (dq, dk), allow_unused=True | |
95 | ) | |
96 | q2_grad, k2_grad = torch.autograd.grad( | |
97 | (tt_q, tt_k), (q2, k2), (dq.clone(), dk.clone()), allow_unused=True | |
98 | ) | |
99 | | |
100 | torch.testing.assert_close(q1_grad, q2_grad, atol=atol, rtol=rtol) | |
101 | torch.testing.assert_close(k1_grad, k2_grad, atol=atol, rtol=rtol) | |
102 | | |
103 | | |
104 | @pytest.mark.skipif( | |
105 | not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers." | |
106 | ) | |
107 | @pytest.mark.parametrize( | |
108 | "bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section", | |
109 | [ | |
110 | (1, 2, 2, 2, 8, [2, 1, 1]), | |
111 | (1, 2, 1, 2, 8, [2, 1, 1]), | |
112 | ], | |
113 | ) | |
114 | @pytest.mark.parametrize( | |
115 | "dtype, atol, rtol", | |
116 | [ | |
117 | (torch.float32, 1e-5, 1e-5), | |
118 | (torch.bfloat16, 1e-1, 1e-5), | |
119 | ], | |
120 | ) | |
121 | def test_functional_correctness( | |
122 | bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section, dtype, atol, rtol | |
123 | ): | |
124 | _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype) | |
125 | _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype) | |
126 | | |
127 | q1 = _q.clone().requires_grad_(True) | |
128 | q2 = _q.clone().requires_grad_(True) | |
129 | | |
130 | k1 = _k.clone().requires_grad_(True) | |
131 | k2 = _k.clone().requires_grad_(True) | |
132 | | |
133 | rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) | |
134 | | |
135 | pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view( | |
136 | 3, bsz, seq_len | |
137 | ) | |
138 | cos, sin = rotary_emb(k1, pos_ids) | |
139 | | |
140 | functional_q, functional_k = liger_qwen2vl_mrope(q1, k1, cos, sin, mrope_section) | |
141 | class_q, class_k = LigerQwen2VLMRopeFunction.apply(q2, k2, cos, sin, mrope_section) | |
142 | | |
143 | torch.testing.assert_close(functional_q, class_q, atol=atol, rtol=rtol) | |
144 | torch.testing.assert_close(functional_k, class_k, atol=atol, rtol=rtol) | |
145 | | |
146 | dq, dk = torch.randn_like(functional_q), torch.randn_like(functional_k) | |
147 | | |
148 | dq1, dk1 = dq.clone(), dk.clone() | |
149 | dq2, dk2 = dq.clone(), dk.clone() | |
150 | | |
151 | q1_grad, k1_grad = torch.autograd.grad( | |
152 | (functional_q, functional_k), | |
153 | (q1, k1), | |
154 | (dq1, dk1), | |
155 | allow_unused=True, | |
156 | ) | |
157 | | |
158 | q2_grad, k2_grad = torch.autograd.grad( | |
159 | (class_q, class_k), | |
160 | (q2, k2), | |
161 | (dq2, dk2), | |
162 | allow_unused=True, | |
163 | ) | |
164 | | |
165 | torch.testing.assert_close(q1_grad, q2_grad, atol=atol, rtol=rtol) | |
166 | torch.testing.assert_close(k1_grad, k2_grad, atol=atol, rtol=rtol) | |
167 | | |
-------------------------------------------------------------------------------- | |
/test/transformers/test_rms_norm.py: | |
-------------------------------------------------------------------------------- | |
1 | import os | |
2 | from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 | |
3 | | |
4 | import pytest | |
5 | import torch | |
6 | import torch.nn as nn | |
7 | | |
8 | from liger_kernel.ops.rms_norm import LigerRMSNormFunction | |
9 | from liger_kernel.transformers.functional import liger_rms_norm | |
10 | from liger_kernel.transformers.rms_norm import LigerRMSNorm | |
11 | from liger_kernel.utils import infer_device | |
12 | | |
13 | device = infer_device() | |
14 | | |
15 | set_seed(42) | |
16 | torch.use_deterministic_algorithms(True) | |
17 | | |
18 | # Only setting torch.use_deterministic_algorithms(True) might throw the following error: | |
19 | # RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, | |
20 | # but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an | |
21 | # environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, | |
22 | # go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility | |
23 | | |
24 | if device == "cuda": | |
25 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | |
26 | | |
27 | SLEEP_SECONDS = 0.1 | |
28 | | |
29 | | |
30 | class BaseRMSNorm(nn.Module): | |
31 | def __init__(self, hidden_size, eps=1e-6): | |
32 | super().__init__() | |
33 | self.weight = nn.Parameter(torch.ones(hidden_size)) | |
34 | self.variance_epsilon = eps | |
35 | | |
36 | def forward(self, hidden_states): | |
37 | input_dtype = hidden_states.dtype | |
38 | variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
39 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
40 | return self.weight * hidden_states.to(input_dtype) | |
41 | | |
42 | | |
43 | # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L112 | |
44 | class LlamaRMSNorm(nn.Module): | |
45 | def __init__(self, hidden_size, eps=1e-6): | |
46 | """ | |
47 | LlamaRMSNorm is equivalent to T5LayerNorm | |
48 | """ | |
49 | super().__init__() | |
50 | self.weight = nn.Parameter(torch.ones(hidden_size)) | |
51 | self.variance_epsilon = eps | |
52 | | |
53 | def forward(self, hidden_states): | |
54 | input_dtype = hidden_states.dtype | |
55 | hidden_states = hidden_states.to(torch.float32) | |
56 | variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
57 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
58 | return self.weight * hidden_states.to(input_dtype) | |
59 | | |
60 | | |
61 | # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L122 | |
62 | class GemmaRMSNorm(nn.Module): | |
63 | def __init__(self, hidden_size: int, eps: float = 1e-6): | |
64 | super().__init__() | |
65 | self.eps = eps | |
66 | self.weight = nn.Parameter(torch.ones(hidden_size)) | |
67 | | |
68 | def _norm(self, x): | |
69 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
70 | | |
71 | def forward(self, x): | |
72 | output = self._norm(x.float()) | |
73 | output = output * (1.0 + self.weight.float()) | |
74 | return output.type_as(x) | |
75 | | |
76 | | |
77 | @pytest.mark.flaky(reruns=3, reruns_delay=2) | |
78 | @pytest.mark.parametrize( | |
79 | "bs, sl, hd", | |
80 | [ | |
81 | (2, 128, 512), | |
82 | # weird shapes | |
83 | (5, 123, 123), | |
84 | ], | |
85 | ) | |
86 | @pytest.mark.parametrize( | |
87 | "dtype, atol, rtol", | |
88 | [ | |
89 | (torch.float32, 1e-4, 1e-6), | |
90 | pytest.param( | |
91 | torch.bfloat16, | |
92 | 2e-1, | |
93 | 2e-2, | |
94 | marks=pytest.mark.skipif( | |
95 | not supports_bfloat16(), reason="bfloat16 not supported on this GPU" | |
96 | ), | |
97 | ), | |
98 | ], | |
99 | ) | |
100 | @pytest.mark.parametrize( | |
101 | "reference, offset, casting_mode", | |
102 | [ | |
103 | (LlamaRMSNorm, 0.0, "llama"), | |
104 | (GemmaRMSNorm, 1.0, "gemma"), | |
105 | (BaseRMSNorm, 0.0, "none"), | |
106 | ], | |
107 | ) | |
108 | @pytest.mark.parametrize( | |
109 | "in_place", | |
110 | [ | |
111 | True, | |
112 | False, | |
113 | ], | |
114 | ) | |
115 | def test_correctness( | |
116 | bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode, in_place | |
117 | ): | |
118 | _tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype) | |
119 | | |
120 | h1 = _tensor.clone().requires_grad_(True) | |
121 | h2 = _tensor.clone().requires_grad_(True) | |
122 | | |
123 | # do | |
124 | do = torch.randn(bs, sl, hd, device=device, dtype=dtype) | |
125 | | |
126 | # reference (llama or gemma) | |
127 | ref_rms = reference(hidden_size=hd).to(device).to(dtype) | |
128 | ref_o = ref_rms(h1) | |
129 | ref_o.backward(do, retain_graph=True) | |
130 | | |
131 | # triton | |
132 | triton_rms = ( | |
133 | LigerRMSNorm( | |
134 | hidden_size=hd, offset=offset, casting_mode=casting_mode, in_place=in_place | |
135 | ) | |
136 | .to(device) | |
137 | .to(dtype) | |
138 | ) | |
139 | triton_o = triton_rms(h2) | |
140 | triton_o.backward(do, retain_graph=True) | |
141 | | |
142 | assert_verbose_allclose(ref_o, triton_o, atol=atol, rtol=rtol) | |
143 | assert_verbose_allclose( | |
144 | ref_rms.weight.grad, triton_rms.weight.grad, atol=atol, rtol=rtol | |
145 | ) | |
146 | print(f"{h1.grad=}") | |
147 | print(f"{h2.grad=}") | |
148 | assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol, max_print=20) | |
149 | | |
150 | | |
151 | @pytest.mark.parametrize( | |
152 | "bs, sl, hd", | |
153 | [ | |
154 | (2, 2, 8), | |
155 | # weird shapes | |
156 | (9, 7, 41), | |
157 | ], | |
158 | ) | |
159 | @pytest.mark.parametrize( | |
160 | "dtype, atol, rtol", | |
161 | [ | |
162 | (torch.float32, 1e-4, 1e-6), | |
163 | (torch.bfloat16, 2e-1, 2e-2), | |
164 | ], | |
165 | ) | |
166 | @pytest.mark.parametrize( | |
167 | "reference, offset, casting_mode", | |
168 | [ | |
169 | (LlamaRMSNorm, 0.0, "llama"), | |
170 | (GemmaRMSNorm, 1.0, "gemma"), | |
171 | ], | |
172 | ) | |
173 | def test_correctness_functional( | |
174 | bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode | |
175 | ): | |
176 | # h | |
177 | _tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype) | |
178 | | |
179 | h1 = _tensor.clone().requires_grad_(True) | |
180 | h2 = _tensor.clone().requires_grad_(True) | |
181 | | |
182 | w = torch.randn(hd, device=device, dtype=dtype) | |
183 | | |
184 | y1 = liger_rms_norm(X=h1, W=w, eps=1e-6, offset=offset, casting_mode=casting_mode) | |
185 | y2 = LigerRMSNormFunction.apply(h2, w, 1e-6, offset, casting_mode) | |
186 | | |
187 | assert torch.allclose(y1, y2, atol=atol, rtol=rtol) | |
188 | | |
189 | grad = torch.randn_like(y2) | |
190 | | |
191 | y1.backward(grad) | |
192 | y2.backward(grad) | |
193 | | |
194 | assert torch.allclose(h1.grad, h2.grad, atol=atol, rtol=rtol) | |
195 | | |
-------------------------------------------------------------------------------- | |
/test/transformers/test_rope.py: | |
-------------------------------------------------------------------------------- | |
1 | from test.utils import supports_bfloat16 | |
2 | | |
3 | import pytest | |
4 | import torch | |
5 | from transformers.models.llama.modeling_llama import ( | |
6 | LlamaRotaryEmbedding, | |
7 | apply_rotary_pos_emb, | |
8 | ) | |
9 | | |
10 | from liger_kernel.ops.rope import LigerRopeFunction | |
11 | from liger_kernel.transformers.functional import liger_rope | |
12 | from liger_kernel.transformers.rope import liger_rotary_pos_emb | |
13 | from liger_kernel.utils import infer_device | |
14 | | |
15 | device = infer_device() | |
16 | | |
17 | SLEEP_SECONDS = 0.1 | |
18 | | |
19 | | |
20 | @pytest.mark.parametrize( | |
21 | "bsz, seq_len, num_q_heads, num_kv_heads, head_dim", | |
22 | [ | |
23 | (1, 128, 32, 32, 64), | |
24 | (2, 128, 32, 32, 64), | |
25 | # different q/k heads | |
26 | (1, 128, 32, 8, 64), | |
27 | (2, 128, 32, 8, 64), | |
28 | # weird shapes | |
29 | # HuggingFace llama/mistral source code doesn't support odd head dimension | |
30 | # so we don't test it here | |
31 | (3, 423, 73, 213, 92), | |
32 | (3, 423, 73, 155, 92), | |
33 | ], | |
34 | ) | |
35 | @pytest.mark.parametrize( | |
36 | "dtype, atol, rtol", | |
37 | [ | |
38 | (torch.float32, 1e-5, 1e-5), | |
39 | pytest.param( | |
40 | torch.bfloat16, | |
41 | 1e-1, | |
42 | 1e-5, | |
43 | marks=pytest.mark.skipif( | |
44 | not supports_bfloat16(), reason="bfloat16 not supported on this GPU" | |
45 | ), | |
46 | ), | |
47 | ], | |
48 | ) | |
49 | def test_correctness( | |
50 | bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol | |
51 | ): | |
52 | rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) | |
53 | | |
54 | _tensor_q = ( | |
55 | torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device) | |
56 | .transpose(1, 2) | |
57 | .to(dtype) | |
58 | ) | |
59 | | |
60 | _tensor_k = ( | |
61 | torch.randn((bsz, seq_len, num_kv_heads, head_dim), device=device) | |
62 | .transpose(1, 2) | |
63 | .to(dtype) | |
64 | ) | |
65 | | |
66 | q1 = _tensor_q.clone().requires_grad_(True) | |
67 | k1 = _tensor_k.clone().requires_grad_(True) | |
68 | | |
69 | q2 = _tensor_q.clone().requires_grad_(True) | |
70 | k2 = _tensor_k.clone().requires_grad_(True) | |
71 | | |
72 | pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) | |
73 | cos, sin = rotary_emb(k1, pos_ids) | |
74 | | |
75 | # validate forward pass | |
76 | hf_q, hf_k = apply_rotary_pos_emb(q1, k1, cos, sin, pos_ids) | |
77 | tt_q, tt_k = liger_rotary_pos_emb(q2, k2, cos, sin) | |
78 | assert torch.allclose(hf_q, tt_q, atol=atol, rtol=rtol) | |
79 | assert torch.allclose(hf_k, tt_k, atol=atol, rtol=rtol) | |
80 | | |
81 | # validate backward pass | |
82 | dq, dk = ( | |
83 | torch.randn_like(hf_q, device=device), | |
84 | torch.randn_like(hf_k, device=device).to(dtype), | |
85 | ) | |
86 | | |
87 | q1_grad, k1_grad = torch.autograd.grad( | |
88 | (hf_q, hf_k), (q1, k1), (dq, dk), allow_unused=True | |
89 | ) | |
90 | q2_grad, k2_grad = torch.autograd.grad( | |
91 | (tt_q, tt_k), (q2, k2), (dq.clone(), dk.clone()), allow_unused=True | |
92 | ) | |
93 | | |
94 | assert torch.allclose(q1_grad, q2_grad, atol=atol, rtol=rtol) | |
95 | assert torch.allclose(k1_grad, k2_grad, atol=atol, rtol=rtol) | |
96 | | |
97 | | |
98 | @pytest.mark.parametrize( | |
99 | "bsz, seq_len, num_q_heads, num_kv_heads, head_dim", | |
100 | [ | |
101 | (1, 2, 2, 2, 8), | |
102 | (1, 2, 1, 2, 8), | |
103 | # weird shapes | |
104 | (9, 7, 41, 41, 41), | |
105 | ], | |
106 | ) | |
107 | @pytest.mark.parametrize( | |
108 | "dtype, atol, rtol", | |
109 | [ | |
110 | (torch.float32, 1e-5, 1e-5), | |
111 | (torch.bfloat16, 1e-1, 1e-5), | |
112 | ], | |
113 | ) | |
114 | def test_functional_correctness( | |
115 | bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol | |
116 | ): | |
117 | _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype) | |
118 | _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype) | |
119 | | |
120 | q1 = _q.clone().requires_grad_(True) | |
121 | q2 = _q.clone().requires_grad_(True) | |
122 | | |
123 | k1 = _k.clone().requires_grad_(True) | |
124 | k2 = _k.clone().requires_grad_(True) | |
125 | | |
126 | rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) | |
127 | | |
128 | pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) | |
129 | cos, sin = rotary_emb(k1, pos_ids) | |
130 | | |
131 | functional_q, functional_k = liger_rope(q=q1, k=k1, cos=cos, sin=sin) | |
132 | class_q, class_k = LigerRopeFunction.apply(q2, k2, cos, sin) | |
133 | | |
134 | assert torch.allclose(functional_q, class_q, atol=atol, rtol=rtol) | |
135 | assert torch.allclose(functional_k, class_k, atol=atol, rtol=rtol) | |
136 | | |
137 | dq, dk = torch.randn_like(functional_q), torch.randn_like(functional_k) | |
138 | | |
139 | dq1, dk1 = dq.clone(), dk.clone() | |
140 | dq2, dk2 = dq.clone(), dk.clone() | |
141 | | |
142 | q1_grad, k1_grad = torch.autograd.grad( | |
143 | (functional_q, functional_k), | |
144 | (q1, k1), | |
145 | (dq1, dk1), | |
146 | allow_unused=True, | |
147 | ) | |
148 | | |
149 | q2_grad, k2_grad = torch.autograd.grad( | |
150 | (class_q, class_k), | |
151 | (q2, k2), | |
152 | (dq2, dk2), | |
153 | allow_unused=True, | |
154 | ) | |
155 | | |
156 | assert torch.allclose(q1_grad, q2_grad, atol=atol, rtol=rtol) | |
157 | assert torch.allclose(k1_grad, k2_grad, atol=atol, rtol=rtol) | |
158 | | |
-------------------------------------------------------------------------------- | |
/test/transformers/test_swiglu.py: | |
-------------------------------------------------------------------------------- | |
1 | from test.utils import supports_bfloat16 | |
2 | | |
3 | import pytest | |
4 | import torch | |
5 | from transformers.models.llama.configuration_llama import LlamaConfig | |
6 | from transformers.models.llama.modeling_llama import LlamaMLP | |
7 | from transformers.models.phi3.configuration_phi3 import Phi3Config | |
8 | from transformers.models.phi3.modeling_phi3 import Phi3MLP | |
9 | | |
10 | from liger_kernel.ops.swiglu import LigerSiLUMulFunction | |
11 | from liger_kernel.transformers.functional import liger_swiglu | |
12 | from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP, LigerSwiGLUMLP | |
13 | from liger_kernel.utils import infer_device | |
14 | | |
15 | device = infer_device() | |
16 | | |
17 | LLAMA_CONFIG = LlamaConfig( | |
18 | hidden_size=4096, | |
19 | intermediate_size=11008, | |
20 | hidden_act="silu", | |
21 | ) | |
22 | PHI3_CONFIG = Phi3Config( | |
23 | hidden_size=4096, | |
24 | intermediate_size=11008, | |
25 | hidden_act="silu", | |
26 | ) | |
27 | SLEEP_SECONDS = 0.1 | |
28 | | |
29 | | |
30 | @pytest.mark.parametrize( | |
31 | "bsz, seq_len, hidden_size, intermediate_size", | |
32 | [ | |
33 | (2, 256, 256, 512), | |
34 | # weird shapes | |
35 | (6, 42, 123, 431), | |
36 | ], | |
37 | ) | |
38 | @pytest.mark.parametrize( | |
39 | "dtype, atol, rtol", | |
40 | [ | |
41 | # atol is for small values: they have more difference, so set atol higher | |
42 | # rtol is for larger values: they are very close, so set rtol lower | |
43 | (torch.float32, 1e-0, 1e-5), | |
44 | # TODO: we should find a better way to tune this. 1e4 is too large apparently | |
45 | pytest.param( | |
46 | torch.bfloat16, | |
47 | 1e4, | |
48 | 1e-2, | |
49 | marks=pytest.mark.skipif( | |
50 | not supports_bfloat16(), reason="bfloat16 not supported on this GPU" | |
51 | ), | |
52 | ), | |
53 | ], | |
54 | ) | |
55 | def test_correctness_llamamlp( | |
56 | bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol | |
57 | ): | |
58 | _input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype) | |
59 | | |
60 | x1 = _input.clone().requires_grad_(True) | |
61 | x2 = _input.clone().requires_grad_(True) | |
62 | | |
63 | # initialize weights | |
64 | G = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype) | |
65 | U = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype) | |
66 | D = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype) | |
67 | | |
68 | llama_mlp = LlamaMLP(config=LLAMA_CONFIG).to(device).to(dtype) | |
69 | llama_mlp.gate_proj.weight.data = G.T | |
70 | llama_mlp.up_proj.weight.data = U.T | |
71 | llama_mlp.down_proj.weight.data = D.T | |
72 | | |
73 | liger_mlp = LigerSwiGLUMLP(config=LLAMA_CONFIG).to(device).to(dtype) | |
74 | liger_mlp.gate_proj.weight.data = G.T | |
75 | liger_mlp.up_proj.weight.data = U.T | |
76 | liger_mlp.down_proj.weight.data = D.T | |
77 | | |
78 | y1 = llama_mlp(x1) | |
79 | y2 = liger_mlp(x2) | |
80 | | |
81 | assert torch.allclose(y1, y2, atol=atol, rtol=rtol) | |
82 | | |
83 | dy = torch.randn_like(y1) | |
84 | | |
85 | y1.backward(dy.clone(), retain_graph=True) | |
86 | y2.backward(dy.clone(), retain_graph=True) | |
87 | | |
88 | assert torch.allclose( | |
89 | llama_mlp.gate_proj.weight.grad, | |
90 | liger_mlp.gate_proj.weight.grad, | |
91 | atol=atol, | |
92 | rtol=rtol, | |
93 | ) | |
94 | assert torch.allclose( | |
95 | llama_mlp.up_proj.weight.grad, | |
96 | liger_mlp.up_proj.weight.grad, | |
97 | atol=atol, | |
98 | rtol=rtol, | |
99 | ) | |
100 | assert torch.allclose( | |
101 | llama_mlp.down_proj.weight.grad, | |
102 | liger_mlp.down_proj.weight.grad, | |
103 | atol=atol, | |
104 | rtol=rtol, | |
105 | ) | |
106 | | |
107 | assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) | |
108 | | |
109 | | |
110 | @pytest.mark.parametrize( | |
111 | "bsz, seq_len, hidden_size, intermediate_size", | |
112 | [ | |
113 | (2, 256, 256, 512), | |
114 | # weird shapes | |
115 | (6, 42, 123, 431), | |
116 | ], | |
117 | ) | |
118 | @pytest.mark.parametrize( | |
119 | "dtype, atol, rtol", | |
120 | [ | |
121 | # atol is for small values: they have more difference, so set atol higher | |
122 | # rtol is for larger values: they are very close, so set rtol lower | |
123 | (torch.float32, 1e-0, 1e-5), | |
124 | # TODO: we should find a better way to tune this. 1e4 is too large apparently | |
125 | pytest.param( | |
126 | torch.bfloat16, | |
127 | 1e4, | |
128 | 1e-2, | |
129 | marks=pytest.mark.skipif( | |
130 | not supports_bfloat16(), reason="bfloat16 not supported on this GPU" | |
131 | ), | |
132 | ), | |
133 | ], | |
134 | ) | |
135 | def test_correctness_phi3mlp( | |
136 | bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol | |
137 | ): | |
138 | _input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype) | |
139 | | |
140 | x1 = _input.clone().requires_grad_(True) | |
141 | x2 = _input.clone().requires_grad_(True) | |
142 | | |
143 | # initialize weights | |
144 | GU = torch.randn(hidden_size, intermediate_size * 2, device=device, dtype=dtype) | |
145 | D = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype) | |
146 | | |
147 | phi3_mlp = Phi3MLP(config=PHI3_CONFIG).to(device).to(dtype) | |
148 | phi3_mlp.gate_up_proj.weight.data = GU.T | |
149 | phi3_mlp.down_proj.weight.data = D.T | |
150 | | |
151 | liger_mlp = LigerPhi3SwiGLUMLP(config=PHI3_CONFIG).to(device).to(dtype) | |
152 | liger_mlp.gate_up_proj.weight.data = GU.T | |
153 | liger_mlp.down_proj.weight.data = D.T | |
154 | | |
155 | y1 = phi3_mlp(x1) | |
156 | y2 = liger_mlp(x2) | |
157 | | |
158 | assert torch.allclose(y1, y2, atol=atol, rtol=rtol) | |
159 | | |
160 | dy = torch.randn_like(y1) | |
161 | | |
162 | y1.backward(dy.clone(), retain_graph=True) | |
163 | y2.backward(dy.clone(), retain_graph=True) | |
164 | | |
165 | assert torch.allclose( | |
166 | phi3_mlp.gate_up_proj.weight.grad, | |
167 | liger_mlp.gate_up_proj.weight.grad, | |
168 | atol=atol, | |
169 | rtol=rtol, | |
170 | ) | |
171 | assert torch.allclose( | |
172 | phi3_mlp.down_proj.weight.grad, | |
173 | liger_mlp.down_proj.weight.grad, | |
174 | atol=atol, | |
175 | rtol=rtol, | |
176 | ) | |
177 | | |
178 | assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) | |
179 | | |
180 | | |
181 | @pytest.mark.parametrize( | |
182 | "bsz, seq_len, size", | |
183 | [ | |
184 | (2, 8, 8), | |
185 | (9, 7, 41), | |
186 | ], | |
187 | ) | |
188 | @pytest.mark.parametrize( | |
189 | "dtype, atol, rtol", | |
190 | [ | |
191 | # atol is for small values: they have more difference, so set atol higher | |
192 | # rtol is for larger values: they are very close, so set rtol lower | |
193 | (torch.float32, 1e-0, 1e-5), | |
194 | # TODO: we should find a better way to tune this. 1e4 is too large apparently | |
195 | (torch.bfloat16, 1e4, 1e-2), | |
196 | ], | |
197 | ) | |
198 | def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol): | |
199 | _input = torch.randn(bsz, seq_len, size, device=device, dtype=dtype) | |
200 | _b = torch.randn(bsz, seq_len, size, device=device, dtype=dtype) | |
201 | | |
202 | x1 = _input.clone().requires_grad_(True) | |
203 | x2 = _input.clone().requires_grad_(True) | |
204 | | |
205 | b1 = _b.clone().requires_grad_(True) | |
206 | b2 = _b.clone().requires_grad_(True) | |
207 | | |
208 | y1 = liger_swiglu(a=x1, b=b1) | |
209 | y2 = LigerSiLUMulFunction.apply(x2, b2) | |
210 | | |
211 | assert torch.allclose(y1, y2, atol=atol, rtol=rtol) | |
212 | | |
213 | # Test backward pass | |
214 | grad_output = torch.randn_like(y1) | |
215 | | |
216 | y1.backward(grad_output) | |
217 | y2.backward(grad_output) | |
218 | | |
219 | # Check if gradients are close for x | |
220 | assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) | |
221 | assert torch.allclose(b1.grad, b2.grad, atol=atol, rtol=rtol) | |
222 | | |
-------------------------------------------------------------------------------- | |
/test/transformers/test_trainer_integration.py: | |
-------------------------------------------------------------------------------- | |
1 | import pytest | |
2 | | |
3 | | |
4 | def test_import(): | |
5 | try: | |
6 | from liger_kernel.transformers.trainer_integration import ( # noqa: F401 | |
7 | _apply_liger_kernel, | |
8 | ) | |
9 | except Exception: | |
10 | pytest.fail("Import _apply_liger_kernel fails") | |
11 | | |
-------------------------------------------------------------------------------- | |
/test/transformers/test_transformers.py: | |
-------------------------------------------------------------------------------- | |
1 | import pytest | |
2 | | |
3 | | |
4 | def test_import_from_root(): | |
5 | try: | |
6 | from liger_kernel.transformers import ( # noqa: F401 | |
7 | LigerBlockSparseTop2MLP, | |
8 | LigerCrossEntropyLoss, | |
9 | LigerFusedLinearCrossEntropyLoss, | |
10 | LigerGEGLUMLP, | |
11 | LigerLayerNorm, | |
12 | LigerPhi3SwiGLUMLP, | |
13 | LigerRMSNorm, | |
14 | LigerSwiGLUMLP, | |
15 | liger_rotary_pos_emb, | |
16 | ) | |
17 | except Exception: | |
18 | pytest.fail("Import kernels from root fails") | |
19 | | |
-------------------------------------------------------------------------------- | |
/test/triton/test_triton_monkey_patch.py: | |
-------------------------------------------------------------------------------- | |
1 | import pytest | |
2 | | |
3 | | |
4 | def test_import_from_root(): | |
5 | try: | |
6 | from liger_kernel.triton import apply_liger_triton_cache_manager # noqa: F401 | |
7 | except Exception: | |
8 | pytest.fail("Import kernel patch from root fails") | |
9 | | |
10 | | |
11 | def test_import_custom_cache_manager(): | |
12 | from triton.runtime.cache import get_cache_manager | |
13 | | |
14 | from liger_kernel.triton import apply_liger_triton_cache_manager | |
15 | | |
16 | apply_liger_triton_cache_manager() | |
17 | cache_manager = get_cache_manager(key="test_hash") | |
18 | from liger_kernel.triton.monkey_patch import LigerTritonFileCacheManager | |
19 | | |
20 | assert isinstance( | |
21 | cache_manager, LigerTritonFileCacheManager | |
22 | ), "Cache manager should have been LigerTritonFileCacheManager" | |
23 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment