Skip to content

Instantly share code, notes, and snippets.

@kashif
Created December 15, 2024 09:47
Show Gist options
  • Save kashif/29e19d624aca5556b225f5c5692ce770 to your computer and use it in GitHub Desktop.
Save kashif/29e19d624aca5556b225f5c5692ce770 to your computer and use it in GitHub Desktop.
liger code
├── benchmark
├── __init__.py
├── benchmarks_visualizer.py
└── scripts
│ ├── __init__.py
│ ├── benchmark_cpo_loss.py
│ ├── benchmark_cross_entropy.py
│ ├── benchmark_dpo_loss.py
│ ├── benchmark_embedding.py
│ ├── benchmark_fused_linear_cross_entropy.py
│ ├── benchmark_fused_linear_jsd.py
│ ├── benchmark_geglu.py
│ ├── benchmark_group_norm.py
│ ├── benchmark_jsd.py
│ ├── benchmark_kl_div.py
│ ├── benchmark_layer_norm.py
│ ├── benchmark_orpo_loss.py
│ ├── benchmark_qwen2vl_mrope.py
│ ├── benchmark_rms_norm.py
│ ├── benchmark_rope.py
│ ├── benchmark_simpo_loss.py
│ ├── benchmark_swiglu.py
│ └── utils.py
├── dev
└── modal
│ ├── tests.py
│ └── tests_bwd.py
├── examples
├── alignment
│ └── run_orpo.py
├── huggingface
│ ├── callback.py
│ ├── launch_on_modal.py
│ ├── training.py
│ └── training_multimodal.py
├── lightning
│ └── training.py
└── medusa
│ ├── callback.py
│ ├── medusa_util.py
│ └── train.py
├── setup.py
├── src
└── liger_kernel
│ ├── __init__.py
│ ├── chunked_loss
│ ├── __init__.py
│ ├── cpo_loss.py
│ ├── dpo_loss.py
│ ├── functional.py
│ ├── fused_linear_distillation.py
│ ├── fused_linear_preference.py
│ ├── orpo_loss.py
│ └── simpo_loss.py
│ ├── env_report.py
│ ├── ops
│ ├── __init__.py
│ ├── cross_entropy.py
│ ├── experimental
│ │ ├── embedding.py
│ │ └── mm_int8int2.py
│ ├── fused_linear_cross_entropy.py
│ ├── fused_linear_jsd.py
│ ├── geglu.py
│ ├── group_norm.py
│ ├── jsd.py
│ ├── kl_div.py
│ ├── layer_norm.py
│ ├── qwen2vl_mrope.py
│ ├── rms_norm.py
│ ├── rope.py
│ ├── swiglu.py
│ └── utils.py
│ ├── transformers
│ ├── __init__.py
│ ├── auto_model.py
│ ├── cross_entropy.py
│ ├── experimental
│ │ └── embedding.py
│ ├── functional.py
│ ├── fused_linear_cross_entropy.py
│ ├── fused_linear_jsd.py
│ ├── geglu.py
│ ├── group_norm.py
│ ├── jsd.py
│ ├── kl_div.py
│ ├── layer_norm.py
│ ├── model
│ │ ├── __init__.py
│ │ ├── gemma.py
│ │ ├── gemma2.py
│ │ ├── llama.py
│ │ ├── mistral.py
│ │ ├── mixtral.py
│ │ ├── mllama.py
│ │ ├── phi3.py
│ │ ├── qwen2.py
│ │ └── qwen2_vl.py
│ ├── monkey_patch.py
│ ├── qwen2vl_mrope.py
│ ├── rms_norm.py
│ ├── rope.py
│ ├── swiglu.py
│ ├── trainer
│ │ ├── __init__.py
│ │ └── orpo_trainer.py
│ └── trainer_integration.py
│ ├── triton
│ ├── __init__.py
│ └── monkey_patch.py
│ └── utils.py
└── test
├── __init__.py
├── chunked_loss
├── __init__.py
├── test_cpo_loss.py
├── test_dpo_loss.py
├── test_orpo_loss.py
└── test_simpo_loss.py
├── conftest.py
├── convergence
├── __init__.py
├── test_mini_models.py
├── test_mini_models_multimodal.py
└── test_mini_models_with_logits.py
├── resources
└── scripts
│ └── generate_tokenized_dataset.py
├── transformers
├── test_auto_model.py
├── test_cross_entropy.py
├── test_embedding.py
├── test_fused_linear_cross_entropy.py
├── test_fused_linear_jsd.py
├── test_geglu.py
├── test_group_norm.py
├── test_jsd.py
├── test_kl_div.py
├── test_layer_norm.py
├── test_mm_int8int2.py
├── test_monkey_patch.py
├── test_qwen2vl_mrope.py
├── test_rms_norm.py
├── test_rope.py
├── test_swiglu.py
├── test_trainer_integration.py
└── test_transformers.py
├── triton
└── test_triton_monkey_patch.py
└── utils.py
/benchmark/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/benchmark/__init__.py
--------------------------------------------------------------------------------
/benchmark/benchmarks_visualizer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from argparse import ArgumentParser
4 | from dataclasses import dataclass
5 |
6 | import matplotlib.pyplot as plt
7 | import pandas as pd
8 | import seaborn as sns
9 |
10 | DATA_PATH = "data/all_benchmark_data.csv"
11 | VISUALIZATIONS_PATH = "visualizations/"
12 |
13 |
14 | @dataclass
15 | class VisualizationsConfig:
16 | """
17 | Configuration for the visualizations script.
18 |
19 | Args:
20 | kernel_name (str): Kernel name to benchmark. (Will run `scripts/benchmark_{kernel_name}.py`)
21 | metric_name (str): Metric name to visualize (speed/memory)
22 | kernel_operation_mode (str): Kernel operation mode to visualize (forward/backward/full). Defaults to "full"
23 | display (bool): Display the visualization. Defaults to False
24 | overwrite (bool): Overwrite existing visualization, if none exist this flag has no effect as ones are always created and saved. Defaults to False
25 |
26 | """
27 |
28 | kernel_name: str
29 | metric_name: str
30 | kernel_operation_mode: str = "full"
31 | display: bool = False
32 | overwrite: bool = False
33 |
34 |
35 | def parse_args() -> VisualizationsConfig:
36 | """Parse command line arguments into a configuration object.
37 |
38 | Returns:
39 | VisualizationsConfig: Configuration object for the visualizations script.
40 | """
41 | parser = ArgumentParser()
42 | parser.add_argument(
43 | "--kernel-name", type=str, required=True, help="Kernel name to benchmark"
44 | )
45 | parser.add_argument(
46 | "--metric-name",
47 | type=str,
48 | required=True,
49 | help="Metric name to visualize (speed/memory)",
50 | )
51 | parser.add_argument(
52 | "--kernel-operation-mode",
53 | type=str,
54 | required=True,
55 | help="Kernel operation mode to visualize (forward/backward/full)",
56 | )
57 | parser.add_argument(
58 | "--display", action="store_true", help="Display the visualization"
59 | )
60 | parser.add_argument(
61 | "--overwrite",
62 | action="store_true",
63 | help="Overwrite existing visualization, if none exist this flag has no effect as one are always created",
64 | )
65 |
66 | args = parser.parse_args()
67 |
68 | return VisualizationsConfig(**dict(args._get_kwargs()))
69 |
70 |
71 | def load_data(config: VisualizationsConfig) -> pd.DataFrame:
72 | """Loads the benchmark data from the CSV file and filters it based on the configuration.
73 |
74 | Args:
75 | config (VisualizationsConfig): Configuration object for the visualizations script.
76 |
77 | Raises:
78 | ValueError: If no data is found for the given filters.
79 |
80 | Returns:
81 | pd.DataFrame: Filtered benchmark dataframe.
82 | """
83 | df = pd.read_csv(DATA_PATH)
84 | df["extra_benchmark_config"] = df["extra_benchmark_config_str"].apply(json.loads)
85 |
86 | filtered_df = df[
87 | (df["kernel_name"] == config.kernel_name)
88 | & (df["metric_name"] == config.metric_name)
89 | & (df["kernel_operation_mode"] == config.kernel_operation_mode)
90 | # Use this to filter by extra benchmark configuration property
91 | # & (data['extra_benchmark_config'].apply(lambda x: x.get('H') == 4096))
92 | # FIXME: maybe add a way to filter using some configuration, except of hardcoding it
93 | ]
94 |
95 | if filtered_df.empty:
96 | raise ValueError("No data found for the given filters")
97 |
98 | return filtered_df
99 |
100 |
101 | def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
102 | """Plots the benchmark data, saving the result if needed.
103 |
104 | Args:
105 | df (pd.DataFrame): Filtered benchmark dataframe.
106 | config (VisualizationsConfig): Configuration object for the visualizations script.
107 | """
108 | xlabel = df["x_label"].iloc[0]
109 | ylabel = f"{config.metric_name} ({df['metric_unit'].iloc[0]})"
110 | # Sort by "kernel_provider" to ensure consistent color assignment
111 | df = df.sort_values(by="kernel_provider")
112 |
113 | plt.figure(figsize=(10, 6))
114 | sns.set(style="whitegrid")
115 | ax = sns.lineplot(
116 | data=df,
117 | x="x_value",
118 | y="y_value_50",
119 | hue="kernel_provider",
120 | marker="o",
121 | palette="tab10",
122 | errorbar=("ci", None),
123 | )
124 |
125 | # Seaborn can't plot pre-computed error bars, so we need to do it manually
126 | lines = ax.get_lines()
127 | colors = [line.get_color() for line in lines]
128 |
129 | for (_, group_data), color in zip(df.groupby("kernel_provider"), colors):
130 | # for i, row in group_data.iterrows():
131 | y_error_lower = group_data["y_value_50"] - group_data["y_value_20"]
132 | y_error_upper = group_data["y_value_80"] - group_data["y_value_50"]
133 | y_error = [y_error_lower, y_error_upper]
134 |
135 | plt.errorbar(
136 | group_data["x_value"],
137 | group_data["y_value_50"],
138 | yerr=y_error,
139 | fmt="o",
140 | color=color,
141 | capsize=5,
142 | )
143 | plt.legend(title="Kernel Provider")
144 | plt.xlabel(xlabel)
145 | plt.ylabel(ylabel)
146 | plt.tight_layout()
147 |
148 | out_path = os.path.join(
149 | VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png"
150 | )
151 |
152 | if config.display:
153 | plt.show()
154 | if config.overwrite or not os.path.exists(
155 | out_path
156 | ): # Save the plot if it doesn't exist or if we want to overwrite it
157 | os.makedirs(VISUALIZATIONS_PATH, exist_ok=True)
158 | plt.savefig(out_path)
159 | plt.close()
160 |
161 |
162 | def main():
163 | config = parse_args()
164 | df = load_data(config)
165 | plot_data(df, config)
166 |
167 |
168 | if __name__ == "__main__":
169 | main()
170 |
--------------------------------------------------------------------------------
/benchmark/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/benchmark/scripts/__init__.py
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_cpo_loss.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 | import triton
6 | from utils import (
7 | QUANTILES,
8 | SingleBenchmarkRunInput,
9 | SingleBenchmarkRunOutput,
10 | _test_memory,
11 | parse_benchmark_script_args,
12 | run_benchmarks,
13 | )
14 |
15 | from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
16 | from liger_kernel.utils import infer_device
17 |
18 | device = infer_device()
19 |
20 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
21 |
22 |
23 | class TorchLMHeadCPO(torch.nn.Module):
24 | """Ground truth implementation of the linear fused with torch based cross entropy loss.
25 |
26 | :param H: hidden size
27 | :param V: vocab size
28 | :param ignore_index: index to ignore
29 | :param reduction: reduction method
30 | """
31 |
32 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
33 | from test.chunked_loss.test_cpo_loss import HFCPOLoss
34 |
35 | super().__init__()
36 | self.lin = torch.nn.Linear(
37 | in_features=H, out_features=V, bias=False, dtype=dtype
38 | )
39 | self.cpo_loss = HFCPOLoss().get_batch_loss_metrics
40 |
41 | def forward(self, x, y):
42 | return self.cpo_loss(x, self.lin.weight, y)
43 |
44 |
45 | class LigerLMHeadCPO(torch.nn.Module):
46 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
47 | super().__init__()
48 | self.lin = torch.nn.Linear(
49 | in_features=H, out_features=V, bias=False, dtype=dtype
50 | )
51 | self.cpo_loss = LigerFusedLinearCPOFunction.apply
52 |
53 | def forward(self, x, y):
54 | return self.cpo_loss(x, self.lin.weight, y)
55 |
56 |
57 | #############################################################################
58 | # Test the memory consumption of the linear fused cross entropy loss
59 | #############################################################################
60 |
61 |
62 | def bench_memory_fused_linear_cpo_loss(
63 | input: SingleBenchmarkRunInput,
64 | ) -> SingleBenchmarkRunOutput:
65 | B = input.x
66 | T = input.extra_benchmark_config["T"]
67 | H = input.extra_benchmark_config["H"]
68 | V = input.extra_benchmark_config["V"]
69 | dtype = input.extra_benchmark_config["dtype"]
70 | provider = input.kernel_provider
71 |
72 | torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
73 | liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
74 |
75 | _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
76 | target = torch.randint(V, (B, T), dtype=torch.long, device=device)
77 |
78 | def fwd():
79 | if provider == "liger":
80 | return liger_lm_head_cpo(_input, target)
81 | elif provider == "huggingface":
82 | return torch_lm_head_cpo(_input, target)
83 |
84 | def full():
85 | y = fwd()
86 | y.backward()
87 |
88 | mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
89 | return SingleBenchmarkRunOutput(
90 | y_20=mem_20,
91 | y_50=mem_50,
92 | y_80=mem_80,
93 | )
94 |
95 |
96 | # #############################################################################
97 | # # Test the speed of the fused linear cross entropy loss
98 | # #############################################################################
99 |
100 |
101 | def bench_speed_fused_linear_cpo_loss(
102 | input: SingleBenchmarkRunInput,
103 | ) -> SingleBenchmarkRunOutput:
104 | B = input.x
105 | T = input.extra_benchmark_config["T"]
106 | H = input.extra_benchmark_config["H"]
107 | V = input.extra_benchmark_config["V"]
108 | dtype = input.extra_benchmark_config["dtype"]
109 | provider = input.kernel_provider
110 | mode = input.kernel_operation_mode
111 |
112 | torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
113 | liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
114 |
115 | _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
116 | target = torch.randint(V, (B, T), dtype=torch.long, device=device)
117 |
118 | def fwd():
119 | if provider == "liger":
120 | return liger_lm_head_cpo(_input, target)
121 | elif provider == "huggingface":
122 | return torch_lm_head_cpo(_input, target)
123 |
124 | if mode == "forward":
125 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
126 | fwd,
127 | rep=100,
128 | quantiles=QUANTILES,
129 | )
130 | elif mode == "backward":
131 | y = fwd()
132 |
133 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
134 | lambda: y.backward(retain_graph=True),
135 | grad_to_none=[_input],
136 | rep=100,
137 | quantiles=QUANTILES,
138 | )
139 | elif mode == "full":
140 |
141 | def full():
142 | y = fwd()
143 | y.backward()
144 |
145 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
146 | full,
147 | rep=100,
148 | quantiles=QUANTILES,
149 | )
150 | return SingleBenchmarkRunOutput(
151 | y_20=ms_20,
152 | y_50=ms_50,
153 | y_80=ms_80,
154 | )
155 |
156 |
157 | if __name__ == "__main__":
158 | args = parse_benchmark_script_args()
159 |
160 | common_configs = {
161 | "kernel_name": "fused_linear_cpo_loss",
162 | "x_name": "B",
163 | "x_label": "B",
164 | "x_values": [2**i for i in range(1, 5)],
165 | "kernel_providers": ["liger", "huggingface"],
166 | "extra_benchmark_configs": [
167 | {
168 | "T": 1024,
169 | "H": 4096,
170 | "V": 128256,
171 | "mode": "forward",
172 | "dtype": torch.bfloat16,
173 | }
174 | ],
175 | "overwrite": args.overwrite,
176 | }
177 |
178 | run_benchmarks(
179 | bench_test_fn=bench_speed_fused_linear_cpo_loss,
180 | kernel_operation_modes=["forward", "full"],
181 | metric_name="speed",
182 | metric_unit="ms",
183 | **common_configs
184 | )
185 | run_benchmarks(
186 | bench_test_fn=bench_memory_fused_linear_cpo_loss,
187 | kernel_operation_modes=["full"],
188 | metric_name="memory",
189 | metric_unit="MB",
190 | **common_configs
191 | )
192 |
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_cross_entropy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | from torch.nn import CrossEntropyLoss
4 | from utils import (
5 | QUANTILES,
6 | SingleBenchmarkRunInput,
7 | SingleBenchmarkRunOutput,
8 | _test_memory,
9 | parse_benchmark_script_args,
10 | run_benchmarks,
11 | )
12 |
13 | from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
14 | from liger_kernel.utils import infer_device
15 |
16 | device = infer_device()
17 |
18 |
19 | def bench_memory_cross_entropy(
20 | input: SingleBenchmarkRunInput,
21 | ) -> SingleBenchmarkRunOutput:
22 | torch_ce = CrossEntropyLoss()
23 | liger_ce = LigerCrossEntropyLoss()
24 |
25 | V = input.x
26 | provider = input.kernel_provider
27 | B = input.extra_benchmark_config["B"]
28 | T = input.extra_benchmark_config["T"]
29 |
30 | _input = torch.randn(B * T, V, requires_grad=True, device=device)
31 | target = torch.randint(V, (B * T, 1), device=device).squeeze(1)
32 |
33 | def fwd():
34 | if provider == "liger":
35 | return liger_ce(_input, target)
36 | else:
37 | return torch_ce(_input, target)
38 |
39 | def full():
40 | y = fwd()
41 | y.backward()
42 |
43 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
44 | return SingleBenchmarkRunOutput(
45 | y_20=mem_20,
46 | y_50=mem_50,
47 | y_80=mem_80,
48 | )
49 |
50 |
51 | def bench_speed_cross_entropy(
52 | input: SingleBenchmarkRunInput,
53 | ) -> SingleBenchmarkRunOutput:
54 | torch_ce = CrossEntropyLoss()
55 | liger_ce = LigerCrossEntropyLoss()
56 |
57 | V = input.x
58 | provider = input.kernel_provider
59 | mode = input.kernel_operation_mode
60 | B = input.extra_benchmark_config["B"]
61 | T = input.extra_benchmark_config["T"]
62 |
63 | _input = torch.randn(B * T, V, requires_grad=True, device=device)
64 | target = torch.randint(V, (B * T, 1), device=device).squeeze(1)
65 |
66 | def fwd():
67 | if provider == "liger":
68 | return liger_ce(_input, target)
69 | else:
70 | return torch_ce(_input, target)
71 |
72 | if mode == "forward":
73 | ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
74 | elif mode == "backward":
75 | y = fwd()
76 |
77 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
78 | lambda: y.backward(retain_graph=True),
79 | grad_to_none=[_input],
80 | rep=100,
81 | quantiles=QUANTILES,
82 | )
83 | elif mode == "full":
84 |
85 | def full():
86 | y = fwd()
87 | y.backward()
88 |
89 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
90 | full, rep=100, quantiles=QUANTILES
91 | )
92 |
93 | return SingleBenchmarkRunOutput(
94 | y_20=ms_20,
95 | y_50=ms_50,
96 | y_80=ms_80,
97 | )
98 |
99 |
100 | if __name__ == "__main__":
101 | args = parse_benchmark_script_args()
102 |
103 | common_configs = {
104 | "kernel_name": "cross_entropy",
105 | "x_name": "V",
106 | "x_label": "vocab size",
107 | "x_values": [2**i for i in range(12, 18)],
108 | "kernel_providers": ["liger", "huggingface"],
109 | "extra_benchmark_configs": [{"B": 8, "T": 2048}],
110 | "overwrite": args.overwrite,
111 | }
112 |
113 | run_benchmarks(
114 | bench_test_fn=bench_speed_cross_entropy,
115 | kernel_operation_modes=["forward", "full"],
116 | metric_name="speed",
117 | metric_unit="ms",
118 | **common_configs
119 | )
120 | run_benchmarks(
121 | bench_test_fn=bench_memory_cross_entropy,
122 | kernel_operation_modes=["full"],
123 | metric_name="memory",
124 | metric_unit="MB",
125 | **common_configs
126 | )
127 |
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_dpo_loss.py:
--------------------------------------------------------------------------------
1 | from test.chunked_loss.test_dpo_loss import HF_DPO_Loss
2 |
3 | import torch
4 | import triton
5 | from utils import (
6 | QUANTILES,
7 | SingleBenchmarkRunInput,
8 | SingleBenchmarkRunOutput,
9 | _test_memory,
10 | parse_benchmark_script_args,
11 | run_benchmarks,
12 | )
13 |
14 | from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
15 | from liger_kernel.utils import infer_device
16 |
17 | device = infer_device()
18 |
19 |
20 | class TorchDPOLoss(torch.nn.Module):
21 | def __init__(
22 | self,
23 | H: int,
24 | V: int,
25 | dtype: torch.dtype,
26 | beta: float = 0.1,
27 | ignore_index: int = -100,
28 | bias: bool = False,
29 | ):
30 | super().__init__()
31 | self.lin = torch.nn.Linear(
32 | in_features=H, out_features=V, bias=bias, dtype=dtype
33 | )
34 | self.dpo_loss = HF_DPO_Loss(beta=beta, ignore_index=ignore_index)
35 |
36 | def forward(self, x, target):
37 | return self.dpo_loss.get_batch_loss_metrics(
38 | x,
39 | self.lin.weight,
40 | target,
41 | self.lin.bias if hasattr(self.lin, "bias") else None,
42 | )
43 |
44 |
45 | class LigerDPOLoss(torch.nn.Module):
46 | def __init__(
47 | self,
48 | H: int,
49 | V: int,
50 | dtype: torch.dtype,
51 | beta: float = 0.1,
52 | ignore_index: int = -100,
53 | bias: bool = False,
54 | ):
55 | super().__init__()
56 | self.lin = torch.nn.Linear(
57 | in_features=H, out_features=V, bias=bias, dtype=dtype
58 | )
59 | self.beta = beta
60 | self.ignore_index = ignore_index
61 |
62 | def forward(self, x, target):
63 | return LigerFusedLinearDPOFunction.apply(
64 | x,
65 | self.lin.weight,
66 | target,
67 | self.lin.bias if hasattr(self.lin, "bias") else None,
68 | self.ignore_index,
69 | self.beta,
70 | True,
71 | )
72 |
73 |
74 | def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
75 | B = input.x
76 | T = input.extra_benchmark_config["T"]
77 | H = input.extra_benchmark_config["H"]
78 | V = input.extra_benchmark_config["V"]
79 | dtype = input.extra_benchmark_config["dtype"]
80 | bias = input.extra_benchmark_config["bias"]
81 | beta = input.extra_benchmark_config["beta"]
82 | ignore_index = input.extra_benchmark_config["ignore_index"]
83 | provider = input.kernel_provider
84 |
85 | torch_dpo_loss = TorchDPOLoss(
86 | H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
87 | ).to(device)
88 | liger_dpo_loss = LigerDPOLoss(
89 | H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
90 | ).to(device)
91 |
92 | # Input shape: [B, T, H]
93 | _input = torch.randn(B, T, H, device=device, dtype=dtype)
94 | # Target shape: [B, T]
95 | target = torch.randint(V, (B, T), dtype=torch.long, device=device)
96 |
97 | # Add ignore_index tokens to simulate padding
98 | num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
99 | indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
100 | target.view(-1)[indices_to_assign] = ignore_index
101 |
102 | def fwd():
103 | if provider == "liger":
104 | return liger_dpo_loss(_input, target)
105 | elif provider == "huggingface":
106 | return torch_dpo_loss(_input, target)
107 |
108 | def full():
109 | y = fwd()
110 | y.backward()
111 |
112 | mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
113 | return SingleBenchmarkRunOutput(
114 | y_20=mem_20,
115 | y_50=mem_50,
116 | y_80=mem_80,
117 | )
118 |
119 |
120 | def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
121 | B = input.x
122 | T = input.extra_benchmark_config["T"]
123 | H = input.extra_benchmark_config["H"]
124 | V = input.extra_benchmark_config["V"]
125 | dtype = input.extra_benchmark_config["dtype"]
126 | bias = input.extra_benchmark_config["bias"]
127 | beta = input.extra_benchmark_config["beta"]
128 | ignore_index = input.extra_benchmark_config["ignore_index"]
129 | provider = input.kernel_provider
130 | mode = input.kernel_operation_mode
131 |
132 | torch_dpo_loss = TorchDPOLoss(
133 | H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
134 | ).to(device)
135 | liger_dpo_loss = LigerDPOLoss(
136 | H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
137 | ).to(device)
138 |
139 | # Input shape: [B, T, H]
140 | _input = torch.randn(B, T, H, device=device, dtype=dtype)
141 |
142 | # Target shape: [B, T]
143 | target = torch.randint(V, (B, T), device=device, dtype=torch.long)
144 |
145 | # Add ignore_index tokens
146 | num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
147 | indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
148 | target.view(-1)[indices_to_assign] = ignore_index
149 |
150 | def fwd():
151 | if provider == "liger":
152 | return liger_dpo_loss(_input, target)
153 | elif provider == "huggingface":
154 | return torch_dpo_loss(_input, target)
155 |
156 | if mode == "forward":
157 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
158 | fwd,
159 | rep=100,
160 | quantiles=QUANTILES,
161 | )
162 | elif mode == "backward":
163 | y = fwd()
164 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
165 | lambda: y.backward(retain_graph=True),
166 | grad_to_none=[_input],
167 | rep=100,
168 | quantiles=QUANTILES,
169 | )
170 | elif mode == "full":
171 |
172 | def full():
173 | y = fwd()
174 | y.backward()
175 |
176 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
177 | full,
178 | rep=100,
179 | quantiles=QUANTILES,
180 | )
181 |
182 | return SingleBenchmarkRunOutput(
183 | y_20=ms_20,
184 | y_50=ms_50,
185 | y_80=ms_80,
186 | )
187 |
188 |
189 | if __name__ == "__main__":
190 | args = parse_benchmark_script_args()
191 |
192 | common_configs = {
193 | "kernel_name": "dpo_loss",
194 | "x_name": "B",
195 | "x_label": "Batch Size (B)",
196 | "x_values": [2**i for i in range(1, 6)],
197 | "kernel_providers": ["liger", "huggingface"],
198 | "extra_benchmark_configs": [
199 | {
200 | "T": 512,
201 | "H": 1024,
202 | "V": 128256,
203 | "mode": "forward",
204 | "dtype": torch.bfloat16,
205 | "bias": True,
206 | "beta": 0.1,
207 | "ignore_index": 42,
208 | }
209 | ],
210 | "overwrite": args.overwrite,
211 | }
212 |
213 | run_benchmarks(
214 | bench_test_fn=bench_speed_dpo_loss,
215 | kernel_operation_modes=["forward", "full"],
216 | metric_name="speed",
217 | metric_unit="ms",
218 | **common_configs
219 | )
220 |
221 | run_benchmarks(
222 | bench_test_fn=bench_memory_dpo_loss,
223 | kernel_operation_modes=["full"],
224 | metric_name="memory",
225 | metric_unit="MB",
226 | **common_configs
227 | )
228 |
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | from torch.nn import Embedding
4 | from utils import (
5 | QUANTILES,
6 | SingleBenchmarkRunInput,
7 | SingleBenchmarkRunOutput,
8 | _test_memory,
9 | parse_benchmark_script_args,
10 | run_benchmarks,
11 | )
12 |
13 | from liger_kernel.transformers.experimental.embedding import LigerEmbedding
14 | from liger_kernel.utils import infer_device
15 |
16 | device = infer_device()
17 |
18 | # NOTE: For torch compile, we will just use default inductor settings. No further customization
19 | # is needed.
20 |
21 |
22 | def bench_speed_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
23 | V = input.x
24 | provider = input.kernel_provider
25 | mode = input.kernel_operation_mode
26 |
27 | B = input.extra_benchmark_config["B"]
28 | T = input.extra_benchmark_config["T"]
29 | D = input.extra_benchmark_config["D"]
30 | dtype = input.extra_benchmark_config["dtype"]
31 |
32 | torch_emb = Embedding(V, D).to(device).to(dtype)
33 | liger_emb = LigerEmbedding(V, D).to(device).to(dtype)
34 | torch_compile_emb = torch.compile(torch_emb)
35 |
36 | input_ids = torch.randint(0, V, (B, T), device=device)
37 |
38 | def fwd():
39 | if provider == "liger":
40 | return liger_emb(input_ids)
41 | elif provider == "torch_compile":
42 | return torch_compile_emb(input_ids)
43 | else:
44 | return torch_emb(input_ids)
45 |
46 | def full():
47 | output = fwd()
48 | output.backward(torch.randn_like(output))
49 |
50 | if mode == "forward":
51 | ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
52 | elif mode == "full":
53 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
54 | full, quantiles=QUANTILES, rep=100
55 | )
56 | return SingleBenchmarkRunOutput(
57 | y_20=ms_20,
58 | y_50=ms_50,
59 | y_80=ms_80,
60 | )
61 |
62 |
63 | def bench_memory_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
64 | V = input.x
65 | provider = input.kernel_provider
66 |
67 | B = input.extra_benchmark_config["B"]
68 | T = input.extra_benchmark_config["T"]
69 | D = input.extra_benchmark_config["D"]
70 | dtype = input.extra_benchmark_config["dtype"]
71 |
72 | torch_emb = Embedding(V, D).to(device).to(dtype)
73 | liger_emb = LigerEmbedding(V, D).to(device).to(dtype)
74 | torch_compile_emb = torch.compile(torch_emb)
75 |
76 | input_ids = torch.randint(0, V, (B, T), device=device)
77 |
78 | def fwd():
79 | if provider == "liger":
80 | return liger_emb(input_ids)
81 | elif provider == "torch_compile":
82 | return torch_compile_emb(input_ids)
83 | else:
84 | return torch_emb(input_ids)
85 |
86 | def full():
87 | output = fwd()
88 | output.backward(torch.randn_like(output))
89 |
90 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
91 | return SingleBenchmarkRunOutput(
92 | y_20=mem_20,
93 | y_50=mem_50,
94 | y_80=mem_80,
95 | )
96 |
97 |
98 | if __name__ == "__main__":
99 | args = parse_benchmark_script_args()
100 |
101 | common_configs = {
102 | "kernel_name": "embedding",
103 | "x_name": "V",
104 | "x_label": "embedding dimension",
105 | "x_values": [2**i for i in range(10, 18)],
106 | "kernel_providers": ["liger", "huggingface", "torch_compile"],
107 | "extra_benchmark_configs": [
108 | # BERT
109 | {"B": 32, "T": 512, "D": 768, "dtype": torch.float32},
110 | # Llama
111 | {"B": 8, "T": 2048, "D": 4096, "dtype": torch.float32},
112 | ],
113 | "overwrite": args.overwrite,
114 | }
115 |
116 | run_benchmarks(
117 | bench_test_fn=bench_speed_embedding,
118 | kernel_operation_modes=["forward", "full"],
119 | metric_name="speed",
120 | metric_unit="ms",
121 | **common_configs
122 | )
123 | run_benchmarks(
124 | bench_test_fn=bench_memory_embedding,
125 | kernel_operation_modes=["full"],
126 | metric_name="memory",
127 | metric_unit="MB",
128 | **common_configs
129 | )
130 |
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_fused_linear_cross_entropy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | from utils import (
4 | QUANTILES,
5 | SingleBenchmarkRunInput,
6 | SingleBenchmarkRunOutput,
7 | _test_memory,
8 | parse_benchmark_script_args,
9 | run_benchmarks,
10 | )
11 |
12 | from liger_kernel.transformers.fused_linear_cross_entropy import (
13 | LigerFusedLinearCrossEntropyLoss,
14 | )
15 | from liger_kernel.utils import infer_device
16 |
17 | device = infer_device()
18 |
19 |
20 | class TorchLMHeadCE(torch.nn.Module):
21 | """Ground truth implementation of the linear fused with torch based cross entropy loss.
22 |
23 | :param H: hidden size
24 | :param V: vocab size
25 | :param ignore_index: index to ignore
26 | :param reduction: reduction method
27 | """
28 |
29 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
30 | super().__init__()
31 | self.lin = torch.nn.Linear(
32 | in_features=H, out_features=V, bias=False, dtype=dtype
33 | )
34 | self.ce_loss = torch.nn.CrossEntropyLoss(
35 | ignore_index=ignore_index, reduction="mean"
36 | )
37 |
38 | def forward(self, x, y):
39 | logits = self.lin(x)
40 | return self.ce_loss(logits, y)
41 |
42 |
43 | class LigerLMHeadCE(torch.nn.Module):
44 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
45 | super().__init__()
46 | self.lin = torch.nn.Linear(
47 | in_features=H, out_features=V, bias=False, dtype=dtype
48 | )
49 | self.ce_loss = LigerFusedLinearCrossEntropyLoss(
50 | ignore_index=ignore_index, reduction="mean"
51 | )
52 |
53 | def forward(self, x, y):
54 | return self.ce_loss(self.lin.weight, x, y)
55 |
56 |
57 | #############################################################################
58 | # Test the memory consumption of the linear fused cross entropy loss
59 | #############################################################################
60 |
61 |
62 | def bench_memory_fused_linear_cross_entropy(
63 | input: SingleBenchmarkRunInput,
64 | ) -> SingleBenchmarkRunOutput:
65 | BT = input.x
66 | H = input.extra_benchmark_config["H"]
67 | V = input.extra_benchmark_config["V"]
68 | dtype = input.extra_benchmark_config["dtype"]
69 | provider = input.kernel_provider
70 |
71 | torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
72 | liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
73 |
74 | _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
75 | target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)
76 |
77 | def fwd():
78 | if provider == "liger":
79 | return liger_lm_head_ce(_input, target)
80 | elif provider == "huggingface":
81 | return torch_lm_head_ce(_input, target)
82 |
83 | def full():
84 | y = fwd()
85 | y.backward()
86 |
87 | mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
88 | return SingleBenchmarkRunOutput(
89 | y_20=mem_20,
90 | y_50=mem_50,
91 | y_80=mem_80,
92 | )
93 |
94 |
95 | # #############################################################################
96 | # # Test the speed of the fused linear cross entropy loss
97 | # #############################################################################
98 |
99 |
100 | def bench_speed_fused_linear_cross_entropy(
101 | input: SingleBenchmarkRunInput,
102 | ) -> SingleBenchmarkRunOutput:
103 | BT = input.x
104 | H = input.extra_benchmark_config["H"]
105 | V = input.extra_benchmark_config["V"]
106 | dtype = input.extra_benchmark_config["dtype"]
107 | provider = input.kernel_provider
108 | mode = input.kernel_operation_mode
109 |
110 | torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
111 | liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
112 |
113 | _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
114 | target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)
115 |
116 | def fwd():
117 | if provider == "liger":
118 | return liger_lm_head_ce(_input, target)
119 | elif provider == "huggingface":
120 | return torch_lm_head_ce(_input, target)
121 |
122 | if mode == "forward":
123 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
124 | fwd,
125 | rep=100,
126 | quantiles=QUANTILES,
127 | )
128 | elif mode == "backward":
129 | y = fwd()
130 |
131 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
132 | lambda: y.backward(retain_graph=True),
133 | grad_to_none=[_input],
134 | rep=100,
135 | quantiles=QUANTILES,
136 | )
137 | elif mode == "full":
138 |
139 | def full():
140 | y = fwd()
141 | y.backward()
142 |
143 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
144 | full,
145 | rep=100,
146 | quantiles=QUANTILES,
147 | )
148 | return SingleBenchmarkRunOutput(
149 | y_20=ms_20,
150 | y_50=ms_50,
151 | y_80=ms_80,
152 | )
153 |
154 |
155 | if __name__ == "__main__":
156 | args = parse_benchmark_script_args()
157 |
158 | common_configs = {
159 | "kernel_name": "fused_linear_cross_entropy",
160 | "x_name": "BT",
161 | "x_label": "B x T",
162 | "x_values": [2**i for i in range(12, 16)],
163 | "kernel_providers": ["liger", "huggingface"],
164 | "extra_benchmark_configs": [
165 | {"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}
166 | ],
167 | "overwrite": args.overwrite,
168 | }
169 |
170 | run_benchmarks(
171 | bench_test_fn=bench_speed_fused_linear_cross_entropy,
172 | kernel_operation_modes=["forward", "full"],
173 | metric_name="speed",
174 | metric_unit="ms",
175 | **common_configs
176 | )
177 | run_benchmarks(
178 | bench_test_fn=bench_memory_fused_linear_cross_entropy,
179 | kernel_operation_modes=["full"],
180 | metric_name="memory",
181 | metric_unit="MB",
182 | **common_configs
183 | )
184 |
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_geglu.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | from transformers.models.llama.configuration_llama import LlamaConfig
4 | from transformers.models.llama.modeling_llama import LlamaMLP
5 | from utils import (
6 | QUANTILES,
7 | SingleBenchmarkRunInput,
8 | SingleBenchmarkRunOutput,
9 | _test_memory,
10 | parse_benchmark_script_args,
11 | run_benchmarks,
12 | )
13 |
14 | from liger_kernel.transformers.geglu import LigerGEGLUMLP
15 | from liger_kernel.utils import infer_device
16 |
17 | device = infer_device()
18 |
19 |
20 | def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
21 | seq_len = input.x
22 | bsz = input.extra_benchmark_config["bsz"]
23 | hidden_size = input.extra_benchmark_config["hidden_size"]
24 | intermediate_size = input.extra_benchmark_config["intermediate_size"]
25 | hidden_act = input.extra_benchmark_config["hidden_act"]
26 | dtype = input.extra_benchmark_config["dtype"]
27 | provider = input.kernel_provider
28 | mode = input.kernel_operation_mode
29 |
30 | llama_config = LlamaConfig(
31 | hidden_size=hidden_size,
32 | intermediate_size=intermediate_size,
33 | hidden_act=hidden_act,
34 | )
35 |
36 | x_shape = (bsz, seq_len, hidden_size)
37 |
38 | # initialize input
39 | x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)
40 |
41 | if provider == "liger":
42 | layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype)
43 | elif provider == "huggingface":
44 | layer = LlamaMLP(config=llama_config).to(device).to(dtype)
45 | else:
46 | raise ValueError(f"Invalid provider: {provider} for GEGLU")
47 |
48 | def fwd():
49 | return layer(x)
50 |
51 | if mode == "forward":
52 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
53 | fwd,
54 | grad_to_none=[x],
55 | rep=10,
56 | quantiles=QUANTILES,
57 | )
58 | elif mode == "backward":
59 | do = torch.randn_like(x)
60 | y = fwd()
61 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
62 | lambda: y.backward(do, retain_graph=True),
63 | grad_to_none=[x],
64 | rep=10,
65 | quantiles=QUANTILES,
66 | )
67 | else:
68 |
69 | def full():
70 | y = fwd()
71 | y.backward(torch.randn_like(y), retain_graph=True)
72 |
73 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
74 | full,
75 | grad_to_none=[x],
76 | rep=10,
77 | quantiles=QUANTILES,
78 | )
79 |
80 | return SingleBenchmarkRunOutput(
81 | y_20=ms_20,
82 | y_50=ms_50,
83 | y_80=ms_80,
84 | )
85 |
86 |
87 | def bench_memory_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
88 | seq_len = input.x
89 | bsz = input.extra_benchmark_config["bsz"]
90 | hidden_size = input.extra_benchmark_config["hidden_size"]
91 | intermediate_size = input.extra_benchmark_config["intermediate_size"]
92 | hidden_act = input.extra_benchmark_config["hidden_act"]
93 | dtype = input.extra_benchmark_config["dtype"]
94 | provider = input.kernel_provider
95 | mode = input.kernel_operation_mode
96 |
97 | llama_config = LlamaConfig(
98 | hidden_size=hidden_size,
99 | intermediate_size=intermediate_size,
100 | hidden_act=hidden_act,
101 | )
102 |
103 | x_shape = (bsz, seq_len, hidden_size)
104 | # initialize input
105 | x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)
106 |
107 | if provider == "liger":
108 | layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype)
109 | elif provider == "huggingface":
110 | layer = LlamaMLP(config=llama_config).to(device).to(dtype)
111 | else:
112 | raise ValueError(f"Invalid provider: {provider} for GEGLU")
113 |
114 | def fwd():
115 | return layer(x)
116 |
117 | def full():
118 | y = fwd()
119 | y.backward(torch.randn_like(y), retain_graph=True)
120 |
121 | if mode == "forward":
122 | mem_50, mem_20, mem_80 = _test_memory(
123 | fwd,
124 | quantiles=QUANTILES,
125 | )
126 | elif mode == "backward":
127 | do = torch.randn_like(x)
128 | y = fwd()
129 | mem_50, mem_20, mem_80 = _test_memory(
130 | lambda: y.backward(do, retain_graph=True),
131 | quantiles=QUANTILES,
132 | )
133 | else:
134 | mem_50, mem_20, mem_80 = _test_memory(
135 | full,
136 | quantiles=QUANTILES,
137 | )
138 |
139 | return SingleBenchmarkRunOutput(
140 | y_20=mem_20,
141 | y_50=mem_50,
142 | y_80=mem_80,
143 | )
144 |
145 |
146 | if __name__ == "__main__":
147 | args = parse_benchmark_script_args()
148 |
149 | common_configs = {
150 | "kernel_name": "geglu",
151 | "x_name": "T",
152 | "x_label": "sequence length",
153 | "x_values": [2**i for i in range(10, 14)],
154 | "kernel_providers": ["liger", "huggingface"],
155 | "extra_benchmark_configs": [
156 | {
157 | "bsz": 8,
158 | "hidden_size": 4096,
159 | "intermediate_size": 11008,
160 | "hidden_act": "gelu_pytorch_tanh",
161 | "dtype": torch.bfloat16,
162 | }
163 | ],
164 | "overwrite": args.overwrite,
165 | }
166 |
167 | run_benchmarks(
168 | bench_test_fn=bench_speed_geglu,
169 | kernel_operation_modes=["full", "forward", "backward"],
170 | metric_name="speed",
171 | metric_unit="ms",
172 | **common_configs,
173 | )
174 | run_benchmarks(
175 | bench_test_fn=bench_memory_geglu,
176 | kernel_operation_modes=["full", "forward", "backward"],
177 | metric_name="memory",
178 | metric_unit="MB",
179 | **common_configs,
180 | )
181 |
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_group_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | from utils import (
4 | QUANTILES,
5 | SingleBenchmarkRunInput,
6 | SingleBenchmarkRunOutput,
7 | _test_memory,
8 | parse_benchmark_script_args,
9 | run_benchmarks,
10 | )
11 |
12 | from liger_kernel.transformers.group_norm import LigerGroupNorm
13 | from liger_kernel.utils import infer_device
14 |
15 | device = infer_device()
16 |
17 |
18 | def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
19 | C = input.x
20 | provider = input.kernel_provider
21 | mode = input.kernel_operation_mode
22 | extra_benchmark_config = input.extra_benchmark_config
23 | M = extra_benchmark_config["M"]
24 | H = extra_benchmark_config["H"]
25 | channels_per_group = extra_benchmark_config["channels_per_group"]
26 | eps = extra_benchmark_config["eps"]
27 | dtype = extra_benchmark_config["dtype"]
28 |
29 | x_shape = (M, C, H)
30 | triton_ln = LigerGroupNorm(
31 | num_channels=C, num_groups=C // channels_per_group, eps=eps
32 | ).to(device)
33 | torch_ln = torch.nn.GroupNorm(
34 | num_groups=C // channels_per_group, num_channels=C, eps=eps
35 | ).to(device)
36 |
37 | x = torch.randn(x_shape, dtype=dtype, device=device)
38 | dy = torch.randn_like(x)
39 | x.requires_grad_(True)
40 |
41 | def y_fwd():
42 | if provider == "liger":
43 | return triton_ln(x)
44 | if provider == "huggingface":
45 | return torch_ln(x)
46 |
47 | if mode == "forward":
48 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
49 | y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500
50 | )
51 | elif mode == "backward":
52 | y = y_fwd()
53 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
54 | lambda: y.backward(dy, retain_graph=True),
55 | quantiles=QUANTILES,
56 | grad_to_none=[x],
57 | rep=500,
58 | )
59 | elif mode == "full":
60 |
61 | def full():
62 | y = y_fwd()
63 | y.backward(dy, retain_graph=True)
64 |
65 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
66 | full, quantiles=QUANTILES, grad_to_none=[x], rep=500
67 | )
68 |
69 | return SingleBenchmarkRunOutput(
70 | y_20=ms_20,
71 | y_50=ms_50,
72 | y_80=ms_80,
73 | )
74 |
75 |
76 | def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
77 | C = input.x
78 | provider = input.kernel_provider
79 | extra_benchmark_config = input.extra_benchmark_config
80 | M = extra_benchmark_config["M"]
81 | H = extra_benchmark_config["H"]
82 | channels_per_group = extra_benchmark_config["channels_per_group"]
83 | eps = extra_benchmark_config["eps"]
84 | dtype = extra_benchmark_config["dtype"]
85 |
86 | x_shape = (M, C, H)
87 | triton_ln = LigerGroupNorm(
88 | num_channels=C, num_groups=C // channels_per_group, eps=eps
89 | ).to(device)
90 | torch_ln = torch.nn.GroupNorm(
91 | num_groups=C // channels_per_group, num_channels=C, eps=eps
92 | ).to(device)
93 |
94 | x = torch.randn(x_shape, dtype=dtype, device=device)
95 | dy = torch.randn_like(x)
96 | x.requires_grad_(True)
97 |
98 | def y_fwd():
99 | if provider == "liger":
100 | return triton_ln(x)
101 | if provider == "huggingface":
102 | return torch_ln(x)
103 |
104 | def full():
105 | y = y_fwd()
106 | y.backward(dy, retain_graph=True)
107 |
108 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
109 | return SingleBenchmarkRunOutput(
110 | y_20=mem_20,
111 | y_50=mem_50,
112 | y_80=mem_80,
113 | )
114 |
115 |
116 | if __name__ == "__main__":
117 | args = parse_benchmark_script_args()
118 |
119 | common_configs = {
120 | "kernel_name": "group_norm",
121 | "x_name": "C",
122 | "x_label": "num_channels",
123 | "x_values": [2**i for i in range(5, 12)],
124 | "kernel_providers": ["liger", "huggingface"],
125 | "extra_benchmark_configs": [
126 | {
127 | "M": 128,
128 | "H": 512,
129 | "channels_per_group": 4,
130 | "dtype": torch.float32,
131 | "eps": 1e-6,
132 | }
133 | ],
134 | "overwrite": args.overwrite,
135 | }
136 |
137 | run_benchmarks(
138 | bench_test_fn=bench_speed_group_norm,
139 | kernel_operation_modes=["forward", "full", "backward"],
140 | metric_name="speed",
141 | metric_unit="ms",
142 | **common_configs
143 | )
144 | run_benchmarks(
145 | bench_test_fn=bench_memory_group_norm,
146 | kernel_operation_modes=["full", "forward", "backward"],
147 | metric_name="memory",
148 | metric_unit="MB",
149 | **common_configs
150 | )
151 |
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_jsd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | from utils import (
4 | QUANTILES,
5 | SingleBenchmarkRunInput,
6 | SingleBenchmarkRunOutput,
7 | _test_memory,
8 | parse_benchmark_script_args,
9 | run_benchmarks,
10 | )
11 |
12 | from liger_kernel.transformers.jsd import LigerJSD
13 | from liger_kernel.utils import infer_device
14 |
15 | device = infer_device()
16 |
17 |
18 | class TorchJSD(torch.nn.Module):
19 | def __init__(
20 | self,
21 | beta: float = 0.5,
22 | ignore_index: int = -100,
23 | dtype: torch.dtype = torch.float,
24 | ):
25 | super(TorchJSD, self).__init__()
26 | self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True)
27 | self.beta = beta
28 | self.ignore_index = ignore_index
29 | self.dtype = dtype
30 |
31 | def forward(
32 | self,
33 | log_q: torch.Tensor, # input
34 | log_p: torch.Tensor, # target
35 | label=None,
36 | ):
37 | log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
38 | log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
39 | m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
40 | loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (
41 | 1 - self.beta
42 | ) * self.kl(torch.log(m), log_q).sum(dim=-1)
43 |
44 | if label is not None:
45 | loss = torch.where(label != self.ignore_index, loss, 0.0)
46 | n_non_ignore = (label != self.ignore_index).sum().item()
47 | if n_non_ignore == 0:
48 | loss = 0.0
49 | else:
50 | loss = (loss / n_non_ignore).sum()
51 | else:
52 | loss = (loss / log_q.shape[0]).sum()
53 | return loss.to(self.dtype)
54 |
55 |
56 | def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
57 | V = input.x
58 | B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
59 | torch_jsd = TorchJSD()
60 | liger_jsd = LigerJSD()
61 |
62 | _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
63 | dim=-1
64 | )
65 | target = torch.randn(B * T, V, device=device).log_softmax(dim=-1)
66 |
67 | def fwd():
68 | if input.kernel_provider == "liger":
69 | return liger_jsd(_input, target)
70 | else:
71 | return torch_jsd(_input, target)
72 |
73 | if input.kernel_operation_mode == "forward":
74 | ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
75 | elif input.kernel_operation_mode == "backward":
76 | y = fwd()
77 |
78 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
79 | lambda: y.backward(retain_graph=True),
80 | quantiles=QUANTILES,
81 | grad_to_none=[_input],
82 | rep=100,
83 | )
84 | elif input.kernel_operation_mode == "full":
85 |
86 | def full():
87 | y = fwd()
88 | y.backward(retain_graph=True)
89 |
90 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
91 | full, quantiles=QUANTILES, rep=100
92 | )
93 | return SingleBenchmarkRunOutput(
94 | y_20=ms_20,
95 | y_50=ms_50,
96 | y_80=ms_80,
97 | )
98 |
99 |
100 | def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
101 | torch_jsd = TorchJSD()
102 | liger_jsd = LigerJSD()
103 |
104 | V = input.x
105 | B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
106 |
107 | _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
108 | dim=-1
109 | )
110 | target = torch.randn(B * T, V, device=device).log_softmax(dim=-1)
111 |
112 | def fwd():
113 | if input.kernel_provider == "liger":
114 | return liger_jsd(_input, target)
115 | else:
116 | return torch_jsd(_input, target)
117 |
118 | def full():
119 | y = fwd()
120 | y.backward(retain_graph=True)
121 |
122 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
123 |
124 | return SingleBenchmarkRunOutput(
125 | y_20=mem_20,
126 | y_50=mem_50,
127 | y_80=mem_80,
128 | )
129 |
130 |
131 | if __name__ == "__main__":
132 | args = parse_benchmark_script_args()
133 | common_args = {
134 | "kernel_name": "jsd",
135 | "x_name": "V",
136 | "x_label": "vocab size",
137 | "x_values": [2**i for i in range(12, 18)],
138 | "kernel_providers": ["liger", "torch"],
139 | "extra_benchmark_configs": [{"B": 4, "T": 2048}],
140 | "overwrite": args.overwrite,
141 | }
142 |
143 | run_benchmarks(
144 | bench_test_fn=bench_memory_jsd,
145 | kernel_operation_modes=["full"],
146 | metric_name="memory",
147 | metric_unit="MB",
148 | **common_args,
149 | )
150 |
151 | run_benchmarks(
152 | bench_test_fn=bench_speed_jsd,
153 | kernel_operation_modes=["forward", "full"],
154 | metric_name="speed",
155 | metric_unit="ms",
156 | **common_args,
157 | )
158 |
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_kl_div.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import triton
4 | from utils import (
5 | QUANTILES,
6 | SingleBenchmarkRunInput,
7 | SingleBenchmarkRunOutput,
8 | _test_memory,
9 | parse_benchmark_script_args,
10 | run_benchmarks,
11 | )
12 |
13 | from liger_kernel.transformers.kl_div import LigerKLDIVLoss
14 | from liger_kernel.utils import infer_device
15 |
16 | device = infer_device()
17 |
18 | S, E = 12, 18
19 |
20 |
21 | def bench_speed_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
22 | reduction = "batchmean"
23 | V = input.x
24 | B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
25 | torch_kl_div = nn.KLDivLoss(reduction=reduction)
26 | liger_kl_div = LigerKLDIVLoss(reduction=reduction)
27 |
28 | _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
29 | dim=-1
30 | )
31 | target = torch.randn(B * T, V, device=device).softmax(dim=-1)
32 |
33 | def fwd():
34 | if input.kernel_provider == "liger":
35 | return liger_kl_div(_input, target)
36 | else:
37 | return torch_kl_div(_input, target)
38 |
39 | if input.kernel_operation_mode == "forward":
40 | ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
41 | elif input.kernel_operation_mode == "backward":
42 | y = fwd()
43 |
44 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
45 | lambda: y.backward(retain_graph=True),
46 | quantiles=QUANTILES,
47 | grad_to_none=[_input],
48 | rep=100,
49 | )
50 | elif input.kernel_operation_mode == "full":
51 |
52 | def full():
53 | y = fwd()
54 | y.backward(retain_graph=True)
55 |
56 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
57 | full, quantiles=QUANTILES, rep=100
58 | )
59 | return SingleBenchmarkRunOutput(
60 | y_20=ms_20,
61 | y_50=ms_50,
62 | y_80=ms_80,
63 | )
64 |
65 |
66 | def bench_memory_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
67 | reduction = "batchmean"
68 | torch_kl_div = nn.KLDivLoss(reduction=reduction)
69 | liger_kl_div = LigerKLDIVLoss(reduction=reduction)
70 |
71 | V = input.x
72 | B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
73 |
74 | _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
75 | dim=-1
76 | )
77 | target = torch.randn(B * T, V, device=device).softmax(dim=-1)
78 |
79 | def fwd():
80 | if input.kernel_provider == "liger":
81 | return liger_kl_div(_input, target)
82 | else:
83 | return torch_kl_div(_input, target)
84 |
85 | def full():
86 | y = fwd()
87 | y.backward(retain_graph=True)
88 |
89 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
90 |
91 | return SingleBenchmarkRunOutput(
92 | y_20=mem_20,
93 | y_50=mem_50,
94 | y_80=mem_80,
95 | )
96 |
97 |
98 | if __name__ == "__main__":
99 | args = parse_benchmark_script_args()
100 | common_args = {
101 | "kernel_name": "kl_div",
102 | "x_name": "V",
103 | "x_label": "vocab size",
104 | "x_values": [2**i for i in range(12, 18)],
105 | "kernel_providers": ["liger", "torch"],
106 | "extra_benchmark_configs": [{"B": 8, "T": 512}],
107 | "overwrite": args.overwrite,
108 | }
109 |
110 | run_benchmarks(
111 | bench_test_fn=bench_memory_kldiv,
112 | kernel_operation_modes=["full"],
113 | metric_name="memory",
114 | metric_unit="MB",
115 | **common_args,
116 | )
117 |
118 | run_benchmarks(
119 | bench_test_fn=bench_speed_kldiv,
120 | kernel_operation_modes=["forward", "full"],
121 | metric_name="speed",
122 | metric_unit="ms",
123 | **common_args,
124 | )
125 |
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_layer_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | from utils import (
4 | QUANTILES,
5 | SingleBenchmarkRunInput,
6 | SingleBenchmarkRunOutput,
7 | _test_memory,
8 | parse_benchmark_script_args,
9 | run_benchmarks,
10 | )
11 |
12 | from liger_kernel.transformers.layer_norm import LigerLayerNorm
13 | from liger_kernel.utils import infer_device
14 |
15 | device = infer_device()
16 |
17 |
18 | def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
19 | N = input.x
20 | provider = input.kernel_provider
21 | mode = input.kernel_operation_mode
22 | extra_benchmark_config = input.extra_benchmark_config
23 | M = extra_benchmark_config["M"]
24 | eps = extra_benchmark_config["eps"]
25 | dtype = extra_benchmark_config["dtype"]
26 |
27 | x_shape = (M, N)
28 | triton_ln = LigerLayerNorm(hidden_size=N).to(device)
29 | torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device)
30 |
31 | x = torch.randn(x_shape, dtype=dtype, device=device)
32 | dy = torch.randn_like(x)
33 | x.requires_grad_(True)
34 |
35 | def y_fwd():
36 | if provider == "liger":
37 | return triton_ln(x)
38 | if provider == "huggingface":
39 | return torch_ln(x)
40 |
41 | if mode == "forward":
42 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
43 | y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500
44 | )
45 | elif mode == "backward":
46 | y = y_fwd()
47 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
48 | lambda: y.backward(dy, retain_graph=True),
49 | quantiles=QUANTILES,
50 | grad_to_none=[x],
51 | rep=500,
52 | )
53 | elif mode == "full":
54 |
55 | def full():
56 | y = y_fwd()
57 | y.backward(dy, retain_graph=True)
58 |
59 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
60 | full, quantiles=QUANTILES, grad_to_none=[x], rep=500
61 | )
62 |
63 | return SingleBenchmarkRunOutput(
64 | y_20=ms_20,
65 | y_50=ms_50,
66 | y_80=ms_80,
67 | )
68 |
69 |
70 | def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
71 | N = input.x
72 | provider = input.kernel_provider
73 | dtype = input.extra_benchmark_config["dtype"]
74 | M = input.extra_benchmark_config["M"]
75 | eps = input.extra_benchmark_config["eps"]
76 |
77 | x_shape = (M, N)
78 |
79 | triton_ln = LigerLayerNorm(hidden_size=N).to(device)
80 | torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device)
81 |
82 | x = torch.randn(x_shape, dtype=dtype, device=device)
83 | dy = torch.randn_like(x)
84 | x.requires_grad_(True)
85 |
86 | def y_fwd():
87 | if provider == "liger":
88 | return triton_ln(x)
89 | if provider == "huggingface":
90 | return torch_ln(x)
91 |
92 | def full():
93 | y = y_fwd()
94 | y.backward(dy, retain_graph=True)
95 |
96 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
97 | return SingleBenchmarkRunOutput(
98 | y_20=mem_20,
99 | y_50=mem_50,
100 | y_80=mem_80,
101 | )
102 |
103 |
104 | if __name__ == "__main__":
105 | args = parse_benchmark_script_args()
106 |
107 | common_configs = {
108 | "kernel_name": "layer_norm",
109 | "x_name": "N",
110 | "x_label": "hidden size",
111 | "x_values": [2**i for i in range(10, 15)],
112 | "kernel_providers": ["liger", "huggingface"],
113 | "extra_benchmark_configs": [{"M": 4096, "dtype": torch.float32, "eps": 1e-6}],
114 | "overwrite": args.overwrite,
115 | }
116 |
117 | run_benchmarks(
118 | bench_test_fn=bench_speed_layer_norm,
119 | kernel_operation_modes=["forward", "full"],
120 | metric_name="speed",
121 | metric_unit="ms",
122 | **common_configs
123 | )
124 | run_benchmarks(
125 | bench_test_fn=bench_memory_layer_norm,
126 | kernel_operation_modes=["full"],
127 | metric_name="memory",
128 | metric_unit="MB",
129 | **common_configs
130 | )
131 |
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_orpo_loss.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 | import triton
6 | from utils import (
7 | QUANTILES,
8 | SingleBenchmarkRunInput,
9 | SingleBenchmarkRunOutput,
10 | _test_memory,
11 | parse_benchmark_script_args,
12 | run_benchmarks,
13 | )
14 |
15 | from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
16 | from liger_kernel.utils import infer_device
17 |
18 | device = infer_device()
19 |
20 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
21 |
22 |
23 | class TorchLMHeadORPO(torch.nn.Module):
24 | """Ground truth implementation of the linear fused with torch based cross entropy loss.
25 |
26 | :param H: hidden size
27 | :param V: vocab size
28 | :param ignore_index: index to ignore
29 | :param reduction: reduction method
30 | """
31 |
32 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
33 | from test.chunked_loss.test_orpo_loss import HF_ORPO_Loss
34 |
35 | super().__init__()
36 | self.lin = torch.nn.Linear(
37 | in_features=H, out_features=V, bias=False, dtype=dtype
38 | )
39 | self.orpo_loss = HF_ORPO_Loss().get_batch_loss_metrics
40 |
41 | def forward(self, x, y):
42 | return self.orpo_loss(x, self.lin.weight, y)
43 |
44 |
45 | class LigerLMHeadORPO(torch.nn.Module):
46 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
47 | super().__init__()
48 | self.lin = torch.nn.Linear(
49 | in_features=H, out_features=V, bias=False, dtype=dtype
50 | )
51 | self.orpo_loss = LigerFusedLinearORPOFunction.apply
52 |
53 | def forward(self, x, y):
54 | return self.orpo_loss(x, self.lin.weight, y)
55 |
56 |
57 | #############################################################################
58 | # Test the memory consumption of the linear fused cross entropy loss
59 | #############################################################################
60 |
61 |
62 | def bench_memory_fused_linear_orpo_loss(
63 | input: SingleBenchmarkRunInput,
64 | ) -> SingleBenchmarkRunOutput:
65 | B = input.x
66 | T = input.extra_benchmark_config["T"]
67 | H = input.extra_benchmark_config["H"]
68 | V = input.extra_benchmark_config["V"]
69 | dtype = input.extra_benchmark_config["dtype"]
70 | provider = input.kernel_provider
71 |
72 | torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
73 | liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
74 |
75 | _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
76 | target = torch.randint(V, (B, T), dtype=torch.long, device=device)
77 |
78 | def fwd():
79 | if provider == "liger":
80 | return liger_lm_head_orpo(_input, target)
81 | elif provider == "huggingface":
82 | return torch_lm_head_orpo(_input, target)
83 |
84 | def full():
85 | y = fwd()
86 | y.backward()
87 |
88 | mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
89 | return SingleBenchmarkRunOutput(
90 | y_20=mem_20,
91 | y_50=mem_50,
92 | y_80=mem_80,
93 | )
94 |
95 |
96 | # #############################################################################
97 | # # Test the speed of the fused linear cross entropy loss
98 | # #############################################################################
99 |
100 |
101 | def bench_speed_fused_linear_orpo_loss(
102 | input: SingleBenchmarkRunInput,
103 | ) -> SingleBenchmarkRunOutput:
104 | B = input.x
105 | T = input.extra_benchmark_config["T"]
106 | H = input.extra_benchmark_config["H"]
107 | V = input.extra_benchmark_config["V"]
108 | dtype = input.extra_benchmark_config["dtype"]
109 | provider = input.kernel_provider
110 | mode = input.kernel_operation_mode
111 |
112 | torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
113 | liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
114 |
115 | _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
116 | target = torch.randint(V, (B, T), dtype=torch.long, device=device)
117 |
118 | def fwd():
119 | if provider == "liger":
120 | return liger_lm_head_orpo(_input, target)
121 | elif provider == "huggingface":
122 | return torch_lm_head_orpo(_input, target)
123 |
124 | if mode == "forward":
125 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
126 | fwd,
127 | rep=100,
128 | quantiles=QUANTILES,
129 | )
130 | elif mode == "backward":
131 | y = fwd()
132 |
133 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
134 | lambda: y.backward(retain_graph=True),
135 | grad_to_none=[_input],
136 | rep=100,
137 | quantiles=QUANTILES,
138 | )
139 | elif mode == "full":
140 |
141 | def full():
142 | y = fwd()
143 | y.backward()
144 |
145 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
146 | full,
147 | rep=100,
148 | quantiles=QUANTILES,
149 | )
150 | return SingleBenchmarkRunOutput(
151 | y_20=ms_20,
152 | y_50=ms_50,
153 | y_80=ms_80,
154 | )
155 |
156 |
157 | if __name__ == "__main__":
158 | args = parse_benchmark_script_args()
159 |
160 | common_configs = {
161 | "kernel_name": "fused_linear_orpo_loss",
162 | "x_name": "B",
163 | "x_label": "B",
164 | "x_values": [2**i for i in range(1, 5)],
165 | "kernel_providers": ["liger", "huggingface"],
166 | "extra_benchmark_configs": [
167 | {
168 | "T": 1024,
169 | "H": 4096,
170 | "V": 128256,
171 | "mode": "forward",
172 | "dtype": torch.bfloat16,
173 | }
174 | ],
175 | "overwrite": args.overwrite,
176 | }
177 |
178 | run_benchmarks(
179 | bench_test_fn=bench_speed_fused_linear_orpo_loss,
180 | kernel_operation_modes=["forward", "full"],
181 | metric_name="speed",
182 | metric_unit="ms",
183 | **common_configs
184 | )
185 | run_benchmarks(
186 | bench_test_fn=bench_memory_fused_linear_orpo_loss,
187 | kernel_operation_modes=["full"],
188 | metric_name="memory",
189 | metric_unit="MB",
190 | **common_configs
191 | )
192 |
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_qwen2vl_mrope.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | from transformers.models.qwen2_vl.modeling_qwen2_vl import (
4 | Qwen2VLRotaryEmbedding,
5 | apply_multimodal_rotary_pos_emb,
6 | )
7 | from utils import (
8 | QUANTILES,
9 | SingleBenchmarkRunInput,
10 | SingleBenchmarkRunOutput,
11 | _test_memory,
12 | parse_benchmark_script_args,
13 | run_benchmarks,
14 | )
15 |
16 | from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
17 | from liger_kernel.utils import infer_device
18 |
19 | device = infer_device()
20 |
21 |
22 | def bench_speed_qwen2vl_mrope(
23 | input: SingleBenchmarkRunInput,
24 | ) -> SingleBenchmarkRunOutput:
25 | provider = input.kernel_provider
26 | mode = input.kernel_operation_mode
27 |
28 | extra_benchmark_config = input.extra_benchmark_config
29 | num_q_heads = extra_benchmark_config["num_q_heads"]
30 | num_kv_heads = extra_benchmark_config["num_kv_heads"]
31 | dtype = extra_benchmark_config["dtype"]
32 |
33 | # x can be either hidden_size or seq_len
34 | hidden_size = (
35 | extra_benchmark_config["hidden_size"]
36 | if "hidden_size" in extra_benchmark_config
37 | else input.x
38 | )
39 | seq_len = (
40 | extra_benchmark_config["seq_len"]
41 | if "seq_len" in extra_benchmark_config
42 | else input.x
43 | )
44 |
45 | head_dim = hidden_size // num_q_heads
46 | rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device)
47 | q = torch.randn(
48 | (1, seq_len, num_q_heads, head_dim),
49 | device=device,
50 | requires_grad=True,
51 | dtype=dtype,
52 | ).transpose(1, 2)
53 | k = torch.randn(
54 | (1, seq_len, num_kv_heads, head_dim),
55 | device=device,
56 | requires_grad=True,
57 | dtype=dtype,
58 | ).transpose(1, 2)
59 | dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(
60 | k, device=device
61 | )
62 | pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1)
63 | cos, sin = rotary_emb(k, pos_ids)
64 |
65 | mrope_section_hw = head_dim * 3 // 16
66 | mrope_section = [
67 | head_dim // 2 - 2 * mrope_section_hw,
68 | mrope_section_hw,
69 | mrope_section_hw,
70 | ]
71 |
72 | def fwd():
73 | if provider == "liger":
74 | return liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section)
75 | elif provider == "huggingface":
76 | return apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section)
77 | else:
78 | raise ValueError(f"Invalid provider: {provider} for M-RoPE embedding")
79 |
80 | if mode == "forward":
81 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
82 | fwd,
83 | grad_to_none=[q, k],
84 | rep=400,
85 | quantiles=QUANTILES,
86 | )
87 | elif mode == "backward":
88 | q_out, k_out = fwd()
89 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
90 | lambda: torch.autograd.grad(
91 | (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True
92 | ),
93 | grad_to_none=[q, k],
94 | rep=400,
95 | quantiles=QUANTILES,
96 | )
97 | elif mode == "full":
98 |
99 | def full():
100 | q_out, k_out = fwd()
101 | torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True)
102 |
103 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
104 | full,
105 | grad_to_none=[q, k],
106 | rep=400,
107 | quantiles=QUANTILES,
108 | )
109 | return SingleBenchmarkRunOutput(
110 | y_20=ms_20,
111 | y_50=ms_50,
112 | y_80=ms_80,
113 | )
114 |
115 |
116 | def bench_memory_qwen2vl_mrope(
117 | input: SingleBenchmarkRunInput,
118 | ) -> SingleBenchmarkRunOutput:
119 | provider = input.kernel_provider
120 |
121 | extra_benchmark_config = input.extra_benchmark_config
122 | num_q_heads = extra_benchmark_config["num_q_heads"]
123 | num_kv_heads = extra_benchmark_config["num_kv_heads"]
124 | dtype = extra_benchmark_config["dtype"]
125 |
126 | # x can be either hidden_size or seq_len
127 | hidden_size = (
128 | extra_benchmark_config["hidden_size"]
129 | if "hidden_size" in extra_benchmark_config
130 | else input.x
131 | )
132 | seq_len = (
133 | extra_benchmark_config["seq_len"]
134 | if "seq_len" in extra_benchmark_config
135 | else input.x
136 | )
137 |
138 | head_dim = hidden_size // num_q_heads
139 | rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device)
140 | q = torch.randn(
141 | (1, seq_len, num_q_heads, head_dim),
142 | device=device,
143 | requires_grad=True,
144 | dtype=dtype,
145 | ).transpose(1, 2)
146 | k = torch.randn(
147 | (1, seq_len, num_kv_heads, head_dim),
148 | device=device,
149 | requires_grad=True,
150 | dtype=dtype,
151 | ).transpose(1, 2)
152 | dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(
153 | k, device=device
154 | )
155 | pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1)
156 | cos, sin = rotary_emb(k, pos_ids)
157 |
158 | mrope_section_hw = head_dim * 3 // 16
159 | mrope_section = [
160 | head_dim // 2 - 2 * mrope_section_hw,
161 | mrope_section_hw,
162 | mrope_section_hw,
163 | ]
164 |
165 | def full():
166 | if provider == "liger":
167 | q_out, k_out = liger_multimodal_rotary_pos_emb(
168 | q, k, cos, sin, mrope_section
169 | )
170 | else:
171 | q_out, k_out = apply_multimodal_rotary_pos_emb(
172 | q, k, cos, sin, mrope_section
173 | )
174 | torch.autograd.grad(
175 | (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True
176 | )
177 |
178 | mem_50, mem_20, mem_80 = _test_memory(
179 | full,
180 | quantiles=QUANTILES,
181 | )
182 | return SingleBenchmarkRunOutput(
183 | y_20=mem_20,
184 | y_50=mem_50,
185 | y_80=mem_80,
186 | )
187 |
188 |
189 | if __name__ == "__main__":
190 | args = parse_benchmark_script_args()
191 |
192 | common_configs_varying_hidden_size = {
193 | "kernel_name": "qwen2vl_mrope",
194 | "x_name": "H",
195 | "x_label": "hidden size",
196 | "x_values": [32 * (2**i) for i in range(4, 10, 2)],
197 | "kernel_providers": ["liger", "huggingface"],
198 | "extra_benchmark_configs": [
199 | {
200 | "dtype": torch.bfloat16,
201 | "seq_len": 2048,
202 | "num_q_heads": 32,
203 | "num_kv_heads": 8,
204 | }
205 | ],
206 | "overwrite": args.overwrite,
207 | }
208 | run_benchmarks(
209 | bench_test_fn=bench_speed_qwen2vl_mrope,
210 | kernel_operation_modes=["forward", "backward", "full"],
211 | metric_name="speed",
212 | metric_unit="ms",
213 | **common_configs_varying_hidden_size,
214 | )
215 | run_benchmarks(
216 | bench_test_fn=bench_memory_qwen2vl_mrope,
217 | kernel_operation_modes=["full"],
218 | metric_name="memory",
219 | metric_unit="MB",
220 | **common_configs_varying_hidden_size,
221 | )
222 |
223 | common_configs_varying_seq_len = {
224 | "kernel_name": "qwen2vl_mrope",
225 | "x_name": "T",
226 | "x_label": "sequence length",
227 | "x_values": [2**i for i in range(10, 15)],
228 | "kernel_providers": ["liger", "huggingface"],
229 | "extra_benchmark_configs": [
230 | {
231 | "dtype": torch.bfloat16,
232 | "hidden_size": 8192,
233 | "num_q_heads": 32,
234 | "num_kv_heads": 8,
235 | }
236 | ],
237 | "overwrite": args.overwrite,
238 | }
239 | run_benchmarks(
240 | bench_test_fn=bench_speed_qwen2vl_mrope,
241 | kernel_operation_modes=["forward", "backward", "full"],
242 | metric_name="speed",
243 | metric_unit="ms",
244 | **common_configs_varying_seq_len,
245 | )
246 | run_benchmarks(
247 | bench_test_fn=bench_memory_qwen2vl_mrope,
248 | kernel_operation_modes=["full"],
249 | metric_name="memory",
250 | metric_unit="MB",
251 | **common_configs_varying_seq_len,
252 | )
253 |
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_rms_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import triton
4 | from utils import (
5 | QUANTILES,
6 | SingleBenchmarkRunInput,
7 | SingleBenchmarkRunOutput,
8 | _test_memory,
9 | parse_benchmark_script_args,
10 | run_benchmarks,
11 | )
12 |
13 | from liger_kernel.transformers.rms_norm import LigerRMSNorm
14 | from liger_kernel.utils import infer_device
15 |
16 | device = infer_device()
17 |
18 |
19 | class LlamaRMSNorm(nn.Module):
20 | def __init__(self, hidden_size, eps=1e-6):
21 | """
22 | LlamaRMSNorm is equivalent to T5LayerNorm
23 | """
24 | super().__init__()
25 | self.weight = nn.Parameter(torch.ones(hidden_size))
26 | self.variance_epsilon = eps
27 |
28 | def forward(self, hidden_states):
29 | input_dtype = hidden_states.dtype
30 | hidden_states = hidden_states.to(torch.float32)
31 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
32 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
33 | return self.weight * hidden_states.to(input_dtype)
34 |
35 |
36 | def bench_speed_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
37 | N = input.x
38 | provider = input.kernel_provider
39 | mode = input.kernel_operation_mode
40 |
41 | extra_benchmark_config = input.extra_benchmark_config
42 | M = extra_benchmark_config["M"]
43 | eps = extra_benchmark_config["eps"]
44 | dtype = extra_benchmark_config["dtype"]
45 |
46 | x_shape = (M, N)
47 |
48 | triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to(device)
49 | llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to(device)
50 |
51 | x = torch.randn(x_shape, dtype=dtype, device=device)
52 | dy = torch.randn_like(x)
53 | x.requires_grad_(True)
54 |
55 | # utility functions
56 |
57 | def y_fwd():
58 | if provider == "liger":
59 | return triton_rms(x)
60 |
61 | if provider == "huggingface":
62 | return llama_rms(x)
63 |
64 | if mode == "forward":
65 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
66 | y_fwd,
67 | grad_to_none=[x],
68 | rep=500,
69 | quantiles=QUANTILES,
70 | )
71 | elif mode == "backward":
72 | y = y_fwd()
73 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
74 | lambda: y.backward(dy, retain_graph=True),
75 | grad_to_none=[x],
76 | rep=500,
77 | quantiles=QUANTILES,
78 | )
79 | elif mode == "full":
80 |
81 | def full():
82 | y = y_fwd()
83 | y.backward(dy, retain_graph=True)
84 |
85 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
86 | full,
87 | grad_to_none=[x],
88 | rep=500,
89 | quantiles=QUANTILES,
90 | )
91 |
92 | return SingleBenchmarkRunOutput(
93 | y_20=ms_20,
94 | y_50=ms_50,
95 | y_80=ms_80,
96 | )
97 |
98 |
99 | def bench_memory_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
100 | N = input.x
101 | provider = input.kernel_provider
102 |
103 | extra_benchmark_config = input.extra_benchmark_config
104 | M = extra_benchmark_config["M"]
105 | eps = extra_benchmark_config["eps"]
106 | dtype = extra_benchmark_config["dtype"]
107 |
108 | x_shape = (M, N)
109 |
110 | triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to(device)
111 | llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to(device)
112 |
113 | x = torch.randn(x_shape, dtype=dtype, device=device)
114 | dy = torch.randn_like(x)
115 | x.requires_grad_(True)
116 |
117 | # utility functions
118 | def y_fwd():
119 | if provider == "liger":
120 | return triton_rms(x)
121 | if provider == "huggingface":
122 | return llama_rms(x)
123 |
124 | def full():
125 | y = y_fwd()
126 | y.backward(dy, retain_graph=True)
127 |
128 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
129 |
130 | return SingleBenchmarkRunOutput(
131 | y_20=mem_20,
132 | y_50=mem_50,
133 | y_80=mem_80,
134 | )
135 |
136 |
137 | if __name__ == "__main__":
138 | args = parse_benchmark_script_args()
139 |
140 | common_configs = {
141 | "kernel_name": "rms_norm",
142 | "x_name": "H",
143 | "x_label": "hidden size",
144 | "x_values": [2**i for i in range(10, 16)],
145 | "kernel_providers": ["liger", "huggingface"],
146 | "extra_benchmark_configs": [{"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6}],
147 | "overwrite": args.overwrite,
148 | }
149 |
150 | run_benchmarks(
151 | bench_test_fn=bench_speed_rms_norm,
152 | kernel_operation_modes=["forward", "full", "backward"],
153 | metric_name="speed",
154 | metric_unit="ms",
155 | **common_configs
156 | )
157 | run_benchmarks(
158 | bench_test_fn=bench_memory_rms_norm,
159 | kernel_operation_modes=["full"],
160 | metric_name="memory",
161 | metric_unit="MB",
162 | **common_configs
163 | )
164 |
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_rope.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | from transformers.models.llama.modeling_llama import (
4 | LlamaRotaryEmbedding,
5 | apply_rotary_pos_emb,
6 | )
7 | from utils import (
8 | QUANTILES,
9 | SingleBenchmarkRunInput,
10 | SingleBenchmarkRunOutput,
11 | _test_memory,
12 | parse_benchmark_script_args,
13 | run_benchmarks,
14 | )
15 |
16 | from liger_kernel.transformers.rope import liger_rotary_pos_emb
17 | from liger_kernel.utils import infer_device
18 |
19 | device = infer_device()
20 |
21 |
22 | def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
23 | provider = input.kernel_provider
24 | mode = input.kernel_operation_mode
25 |
26 | extra_benchmark_config = input.extra_benchmark_config
27 | num_q_heads = extra_benchmark_config["num_q_heads"]
28 | num_kv_heads = extra_benchmark_config["num_kv_heads"]
29 | dtype = extra_benchmark_config["dtype"]
30 |
31 | # x can be either hidden_size or seq_len
32 | hidden_size = (
33 | extra_benchmark_config["hidden_size"]
34 | if "hidden_size" in extra_benchmark_config
35 | else input.x
36 | )
37 | seq_len = (
38 | extra_benchmark_config["seq_len"]
39 | if "seq_len" in extra_benchmark_config
40 | else input.x
41 | )
42 |
43 | head_dim = hidden_size // num_q_heads
44 | rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
45 | q = torch.randn(
46 | (1, seq_len, num_q_heads, head_dim),
47 | device=device,
48 | requires_grad=True,
49 | dtype=dtype,
50 | ).transpose(1, 2)
51 | k = torch.randn(
52 | (1, seq_len, num_kv_heads, head_dim),
53 | device=device,
54 | requires_grad=True,
55 | dtype=dtype,
56 | ).transpose(1, 2)
57 | dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(
58 | k, device=device
59 | )
60 | pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
61 | cos, sin = rotary_emb(k, pos_ids)
62 |
63 | def fwd():
64 | if provider == "liger":
65 | return liger_rotary_pos_emb(q, k, cos, sin, pos_ids)
66 | elif provider == "huggingface":
67 | return apply_rotary_pos_emb(q, k, cos, sin, pos_ids)
68 | else:
69 | raise ValueError(f"Invalid provider: {provider} for RoPE embedding")
70 |
71 | if mode == "forward":
72 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
73 | fwd,
74 | grad_to_none=[q, k],
75 | rep=400,
76 | quantiles=QUANTILES,
77 | )
78 | elif mode == "backward":
79 | q_out, k_out = fwd()
80 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
81 | lambda: torch.autograd.grad(
82 | (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True
83 | ),
84 | grad_to_none=[q, k],
85 | rep=400,
86 | quantiles=QUANTILES,
87 | )
88 | elif mode == "full":
89 |
90 | def full():
91 | q_out, k_out = fwd()
92 | torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True)
93 |
94 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
95 | full,
96 | grad_to_none=[q, k],
97 | rep=400,
98 | quantiles=QUANTILES,
99 | )
100 | return SingleBenchmarkRunOutput(
101 | y_20=ms_20,
102 | y_50=ms_50,
103 | y_80=ms_80,
104 | )
105 |
106 |
107 | def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
108 | provider = input.kernel_provider
109 |
110 | extra_benchmark_config = input.extra_benchmark_config
111 | num_q_heads = extra_benchmark_config["num_q_heads"]
112 | num_kv_heads = extra_benchmark_config["num_kv_heads"]
113 | dtype = extra_benchmark_config["dtype"]
114 |
115 | # x can be either hidden_size or seq_len
116 | hidden_size = (
117 | extra_benchmark_config["hidden_size"]
118 | if "hidden_size" in extra_benchmark_config
119 | else input.x
120 | )
121 | seq_len = (
122 | extra_benchmark_config["seq_len"]
123 | if "seq_len" in extra_benchmark_config
124 | else input.x
125 | )
126 |
127 | head_dim = hidden_size // num_q_heads
128 | rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
129 | q = torch.randn(
130 | (1, seq_len, num_q_heads, head_dim),
131 | device=device,
132 | requires_grad=True,
133 | dtype=dtype,
134 | ).transpose(1, 2)
135 | k = torch.randn(
136 | (1, seq_len, num_kv_heads, head_dim),
137 | device=device,
138 | requires_grad=True,
139 | dtype=dtype,
140 | ).transpose(1, 2)
141 | dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(
142 | k, device=device
143 | )
144 | pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
145 | cos, sin = rotary_emb(k, pos_ids)
146 |
147 | def full():
148 | if provider == "liger":
149 | q_out, k_out = liger_rotary_pos_emb(q, k, cos, sin, pos_ids)
150 | else:
151 | q_out, k_out = apply_rotary_pos_emb(q, k, cos, sin, pos_ids)
152 | torch.autograd.grad(
153 | (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True
154 | )
155 |
156 | mem_50, mem_20, mem_80 = _test_memory(
157 | full,
158 | quantiles=QUANTILES,
159 | )
160 | return SingleBenchmarkRunOutput(
161 | y_20=mem_20,
162 | y_50=mem_50,
163 | y_80=mem_80,
164 | )
165 |
166 |
167 | if __name__ == "__main__":
168 | args = parse_benchmark_script_args()
169 |
170 | common_configs_varying_hidden_size = {
171 | "kernel_name": "rope",
172 | "x_name": "H",
173 | "x_label": "hidden size",
174 | "x_values": [32 * (2**i) for i in range(4, 10, 2)],
175 | "kernel_providers": ["liger", "huggingface"],
176 | "extra_benchmark_configs": [
177 | {
178 | "dtype": torch.bfloat16,
179 | "seq_len": 2048,
180 | "num_q_heads": 32,
181 | "num_kv_heads": 8,
182 | }
183 | ],
184 | "overwrite": args.overwrite,
185 | }
186 | run_benchmarks(
187 | bench_test_fn=bench_speed_rope,
188 | kernel_operation_modes=["forward", "backward", "full"],
189 | metric_name="speed",
190 | metric_unit="ms",
191 | **common_configs_varying_hidden_size,
192 | )
193 | run_benchmarks(
194 | bench_test_fn=bench_memory_rope,
195 | kernel_operation_modes=["full"],
196 | metric_name="memory",
197 | metric_unit="MB",
198 | **common_configs_varying_hidden_size,
199 | )
200 |
201 | common_configs_varying_seq_len = {
202 | "kernel_name": "rope",
203 | "x_name": "T",
204 | "x_label": "sequence length",
205 | "x_values": [2**i for i in range(10, 15)],
206 | "kernel_providers": ["liger", "huggingface"],
207 | "extra_benchmark_configs": [
208 | {
209 | "dtype": torch.bfloat16,
210 | "hidden_size": 8192,
211 | "num_q_heads": 32,
212 | "num_kv_heads": 8,
213 | }
214 | ],
215 | "overwrite": args.overwrite,
216 | }
217 | run_benchmarks(
218 | bench_test_fn=bench_speed_rope,
219 | kernel_operation_modes=["forward", "backward", "full"],
220 | metric_name="speed",
221 | metric_unit="ms",
222 | **common_configs_varying_seq_len,
223 | )
224 | run_benchmarks(
225 | bench_test_fn=bench_memory_rope,
226 | kernel_operation_modes=["full"],
227 | metric_name="memory",
228 | metric_unit="MB",
229 | **common_configs_varying_seq_len,
230 | )
231 |
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_simpo_loss.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 | import triton
6 | from utils import (
7 | QUANTILES,
8 | SingleBenchmarkRunInput,
9 | SingleBenchmarkRunOutput,
10 | _test_memory,
11 | parse_benchmark_script_args,
12 | run_benchmarks,
13 | )
14 |
15 | from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
16 | from liger_kernel.utils import infer_device
17 |
18 | device = infer_device()
19 |
20 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
21 |
22 |
23 | class TorchLMHeadSimPO(torch.nn.Module):
24 | """Ground truth implementation of the linear fused with torch based cross entropy loss.
25 |
26 | :param H: hidden size
27 | :param V: vocab size
28 | :param ignore_index: index to ignore
29 | :param reduction: reduction method
30 | """
31 |
32 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
33 | from test.chunked_loss.test_cpo_loss import HFCPOLoss
34 |
35 | super().__init__()
36 | self.lin = torch.nn.Linear(
37 | in_features=H, out_features=V, bias=False, dtype=dtype
38 | )
39 | self.simpo_loss = HFCPOLoss(loss_type="simpo").get_batch_loss_metrics
40 |
41 | def forward(self, x, y):
42 | return self.simpo_loss(x, self.lin.weight, y)
43 |
44 |
45 | class LigerLMHeadSimPO(torch.nn.Module):
46 | def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
47 | super().__init__()
48 | self.lin = torch.nn.Linear(
49 | in_features=H, out_features=V, bias=False, dtype=dtype
50 | )
51 | self.simpo_loss = LigerFusedLinearSimPOFunction.apply
52 |
53 | def forward(self, x, y):
54 | return self.simpo_loss(x, self.lin.weight, y)
55 |
56 |
57 | #############################################################################
58 | # Test the memory consumption of the linear fused cross entropy loss
59 | #############################################################################
60 |
61 |
62 | def bench_memory_fused_linear_simpo_loss(
63 | input: SingleBenchmarkRunInput,
64 | ) -> SingleBenchmarkRunOutput:
65 | B = input.x
66 | T = input.extra_benchmark_config["T"]
67 | H = input.extra_benchmark_config["H"]
68 | V = input.extra_benchmark_config["V"]
69 | dtype = input.extra_benchmark_config["dtype"]
70 | provider = input.kernel_provider
71 |
72 | torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
73 | liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
74 |
75 | _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
76 | target = torch.randint(V, (B, T), dtype=torch.long, device=device)
77 |
78 | def fwd():
79 | if provider == "liger":
80 | return liger_lm_head_simpo(_input, target)
81 | elif provider == "huggingface":
82 | return torch_lm_head_simpo(_input, target)
83 |
84 | def full():
85 | y = fwd()
86 | y.backward()
87 |
88 | mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
89 | return SingleBenchmarkRunOutput(
90 | y_20=mem_20,
91 | y_50=mem_50,
92 | y_80=mem_80,
93 | )
94 |
95 |
96 | # #############################################################################
97 | # # Test the speed of the fused linear cross entropy loss
98 | # #############################################################################
99 |
100 |
101 | def bench_speed_fused_linear_simpo_loss(
102 | input: SingleBenchmarkRunInput,
103 | ) -> SingleBenchmarkRunOutput:
104 | B = input.x
105 | T = input.extra_benchmark_config["T"]
106 | H = input.extra_benchmark_config["H"]
107 | V = input.extra_benchmark_config["V"]
108 | dtype = input.extra_benchmark_config["dtype"]
109 | provider = input.kernel_provider
110 | mode = input.kernel_operation_mode
111 |
112 | torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
113 | liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
114 |
115 | _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
116 | target = torch.randint(V, (B, T), dtype=torch.long, device=device)
117 |
118 | def fwd():
119 | if provider == "liger":
120 | return liger_lm_head_simpo(_input, target)
121 | elif provider == "huggingface":
122 | return torch_lm_head_simpo(_input, target)
123 |
124 | if mode == "forward":
125 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
126 | fwd,
127 | rep=100,
128 | quantiles=QUANTILES,
129 | )
130 | elif mode == "backward":
131 | y = fwd()
132 |
133 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
134 | lambda: y.backward(retain_graph=True),
135 | grad_to_none=[_input],
136 | rep=100,
137 | quantiles=QUANTILES,
138 | )
139 | elif mode == "full":
140 |
141 | def full():
142 | y = fwd()
143 | y.backward()
144 |
145 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
146 | full,
147 | rep=100,
148 | quantiles=QUANTILES,
149 | )
150 | return SingleBenchmarkRunOutput(
151 | y_20=ms_20,
152 | y_50=ms_50,
153 | y_80=ms_80,
154 | )
155 |
156 |
157 | if __name__ == "__main__":
158 | args = parse_benchmark_script_args()
159 |
160 | common_configs = {
161 | "kernel_name": "fused_linear_simpo_loss",
162 | "x_name": "B",
163 | "x_label": "B",
164 | "x_values": [2**i for i in range(1, 5)],
165 | "kernel_providers": ["liger", "huggingface"],
166 | "extra_benchmark_configs": [
167 | {
168 | "T": 1024,
169 | "H": 4096,
170 | "V": 128256,
171 | "mode": "forward",
172 | "dtype": torch.bfloat16,
173 | }
174 | ],
175 | "overwrite": args.overwrite,
176 | }
177 |
178 | run_benchmarks(
179 | bench_test_fn=bench_speed_fused_linear_simpo_loss,
180 | kernel_operation_modes=["forward", "full"],
181 | metric_name="speed",
182 | metric_unit="ms",
183 | **common_configs
184 | )
185 | run_benchmarks(
186 | bench_test_fn=bench_memory_fused_linear_simpo_loss,
187 | kernel_operation_modes=["full"],
188 | metric_name="memory",
189 | metric_unit="MB",
190 | **common_configs
191 | )
192 |
--------------------------------------------------------------------------------
/benchmark/scripts/benchmark_swiglu.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | from transformers.models.llama.configuration_llama import LlamaConfig
4 | from transformers.models.llama.modeling_llama import LlamaMLP
5 | from utils import (
6 | QUANTILES,
7 | SingleBenchmarkRunInput,
8 | SingleBenchmarkRunOutput,
9 | _test_memory,
10 | parse_benchmark_script_args,
11 | run_benchmarks,
12 | )
13 |
14 | from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
15 | from liger_kernel.utils import infer_device
16 |
17 | device = infer_device()
18 |
19 |
20 | def bench_speed_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
21 | seq_len = input.x
22 | provider = input.kernel_provider
23 | mode = input.kernel_operation_mode
24 |
25 | extra_benchmark_config = input.extra_benchmark_config
26 | bsz = extra_benchmark_config["B"]
27 | hidden_size = extra_benchmark_config["hidden_size"]
28 | dtype = extra_benchmark_config["dtype"]
29 | intermediate_size = extra_benchmark_config["intermediate_size"]
30 | hidden_act = extra_benchmark_config["hidden_act"]
31 |
32 | llama_config = LlamaConfig(
33 | hidden_size=hidden_size,
34 | intermediate_size=intermediate_size,
35 | hidden_act=hidden_act,
36 | )
37 |
38 | x_shape = (bsz, seq_len, hidden_size)
39 |
40 | # initialize input
41 | x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)
42 |
43 | if provider == "liger":
44 | layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype)
45 | elif provider == "huggingface":
46 | layer = LlamaMLP(config=llama_config).to(device).to(dtype)
47 | else:
48 | raise ValueError(f"Invalid provider: {provider} for SwiGLU")
49 |
50 | def fwd():
51 | return layer(x)
52 |
53 | if mode == "forward":
54 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
55 | fwd,
56 | grad_to_none=[x],
57 | quantiles=QUANTILES,
58 | rep=10,
59 | )
60 | elif mode == "backward":
61 | do = torch.randn_like(x)
62 | y = fwd()
63 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
64 | lambda: y.backward(do, retain_graph=True),
65 | grad_to_none=[x],
66 | quantiles=QUANTILES,
67 | rep=10,
68 | )
69 | else:
70 |
71 | def full():
72 | y = fwd()
73 | y.backward(torch.randn_like(y), retain_graph=True)
74 |
75 | ms_50, ms_20, ms_80 = triton.testing.do_bench(
76 | full,
77 | grad_to_none=[x],
78 | quantiles=QUANTILES,
79 | rep=10,
80 | )
81 |
82 | return SingleBenchmarkRunOutput(
83 | y_20=ms_20,
84 | y_50=ms_50,
85 | y_80=ms_80,
86 | )
87 |
88 |
89 | def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
90 | seq_len = input.x
91 | provider = input.kernel_provider
92 | mode = input.kernel_operation_mode
93 |
94 | extra_benchmark_config = input.extra_benchmark_config
95 | bsz = extra_benchmark_config["B"]
96 | hidden_size = extra_benchmark_config["hidden_size"]
97 | dtype = extra_benchmark_config["dtype"]
98 | intermediate_size = extra_benchmark_config["intermediate_size"]
99 | hidden_act = extra_benchmark_config["hidden_act"]
100 |
101 | llama_config = LlamaConfig(
102 | hidden_size=hidden_size,
103 | intermediate_size=intermediate_size,
104 | hidden_act=hidden_act,
105 | )
106 |
107 | x_shape = (bsz, seq_len, hidden_size)
108 |
109 | # initialize input
110 | x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)
111 |
112 | if provider == "liger":
113 | layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype)
114 | elif provider == "huggingface":
115 | layer = LlamaMLP(config=llama_config).to(device).to(dtype)
116 | else:
117 | raise ValueError(f"Invalid provider: {provider} for SwiGLU")
118 |
119 | def fwd():
120 | return layer(x)
121 |
122 | def full():
123 | y = fwd()
124 | y.backward(torch.randn_like(y), retain_graph=True)
125 |
126 | if mode == "forward":
127 | mem_50, mem_20, mem_80 = _test_memory(fwd, quantiles=QUANTILES)
128 | elif mode == "backward":
129 | do = torch.randn_like(x)
130 | y = fwd()
131 | mem_50, mem_20, mem_80 = _test_memory(
132 | lambda: y.backward(do, retain_graph=True), quantiles=QUANTILES
133 | )
134 | else:
135 | mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
136 |
137 | return SingleBenchmarkRunOutput(
138 | y_20=mem_20,
139 | y_50=mem_50,
140 | y_80=mem_80,
141 | )
142 |
143 |
144 | if __name__ == "__main__":
145 | args = parse_benchmark_script_args()
146 |
147 | common_configs = {
148 | "kernel_name": "swiglu",
149 | "x_name": "T",
150 | "x_label": "sequence length",
151 | "x_values": [2**i for i in range(10, 14)],
152 | "kernel_providers": ["liger", "huggingface"],
153 | "extra_benchmark_configs": [
154 | {
155 | "B": 4,
156 | "hidden_size": 4096,
157 | "dtype": torch.bfloat16,
158 | "intermediate_size": 11008,
159 | "hidden_act": "silu",
160 | }
161 | ],
162 | "overwrite": args.overwrite,
163 | }
164 |
165 | run_benchmarks(
166 | bench_test_fn=bench_speed_swiglu,
167 | kernel_operation_modes=["forward"],
168 | metric_name="speed",
169 | metric_unit="ms",
170 | **common_configs,
171 | )
172 | run_benchmarks(
173 | bench_test_fn=bench_memory_swiglu,
174 | kernel_operation_modes=["full"],
175 | metric_name="memory",
176 | metric_unit="MB",
177 | **common_configs,
178 | )
179 |
--------------------------------------------------------------------------------
/dev/modal/tests.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import modal
4 |
5 | ROOT_PATH = Path(__file__).parent.parent.parent
6 | REMOTE_ROOT_PATH = "/root/liger-kernel"
7 | PYTHON_VERSION = "3.12"
8 |
9 | image = modal.Image.debian_slim(python_version=PYTHON_VERSION).pip_install("uv")
10 |
11 | app = modal.App("liger_tests", image=image)
12 |
13 | # mount: add local files to the remote container
14 | repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
15 |
16 |
17 | @app.function(gpu="A10G", mounts=[repo], timeout=60 * 15)
18 | def liger_tests():
19 | import subprocess
20 |
21 | subprocess.run(
22 | ["uv pip install -e '.[dev]' --system"],
23 | check=True,
24 | shell=True,
25 | cwd=REMOTE_ROOT_PATH,
26 | )
27 | subprocess.run(["make test"], check=True, shell=True, cwd=REMOTE_ROOT_PATH)
28 | subprocess.run(
29 | ["make test-convergence"], check=True, shell=True, cwd=REMOTE_ROOT_PATH
30 | )
31 |
--------------------------------------------------------------------------------
/dev/modal/tests_bwd.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import modal
4 |
5 | ROOT_PATH = Path(__file__).parent.parent.parent
6 | REMOTE_ROOT_PATH = "/root/liger-kernel"
7 | PYTHON_VERSION = "3.12"
8 |
9 | image = modal.Image.debian_slim(python_version=PYTHON_VERSION).pip_install("uv")
10 |
11 | app = modal.App("liger_tests_bwd", image=image)
12 |
13 | # mount: add local files to the remote container
14 | repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
15 |
16 |
17 | @app.function(gpu="A10G", mounts=[repo], timeout=60 * 15)
18 | def liger_bwd_tests():
19 | import subprocess
20 |
21 | subprocess.run(
22 | ["uv pip install -e '.[dev]' --system"],
23 | check=True,
24 | shell=True,
25 | cwd=REMOTE_ROOT_PATH,
26 | )
27 | # force install transformers==4.44.2
28 | subprocess.run(
29 | ["uv pip install transformers==4.44.2 --system"],
30 | check=True,
31 | shell=True,
32 | cwd=REMOTE_ROOT_PATH,
33 | )
34 | subprocess.run(["make test"], check=True, shell=True, cwd=REMOTE_ROOT_PATH)
35 | subprocess.run(
36 | ["make test-convergence"], check=True, shell=True, cwd=REMOTE_ROOT_PATH
37 | )
38 |
--------------------------------------------------------------------------------
/examples/alignment/run_orpo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datasets import load_dataset
3 | from transformers import AutoModelForCausalLM, AutoTokenizer
4 | from trl import ORPOConfig # noqa: F401
5 |
6 | from liger_kernel.transformers.trainer import LigerORPOTrainer # noqa: F401
7 |
8 | model = AutoModelForCausalLM.from_pretrained(
9 | "meta-llama/Llama-3.2-1B-Instruct",
10 | torch_dtype=torch.bfloat16,
11 | )
12 |
13 | tokenizer = AutoTokenizer.from_pretrained(
14 | "meta-llama/Llama-3.2-1B-Instruct",
15 | max_length=512,
16 | padding="max_length",
17 | )
18 | tokenizer.pad_token = tokenizer.eos_token
19 |
20 | train_dataset = load_dataset("trl-lib/tldr-preference", split="train")
21 |
22 | training_args = ORPOConfig(
23 | output_dir="Llama3.2_1B_Instruct",
24 | beta=0.1,
25 | max_length=128,
26 | per_device_train_batch_size=32,
27 | max_steps=100,
28 | save_strategy="no",
29 | )
30 |
31 | trainer = LigerORPOTrainer(
32 | model=model, args=training_args, tokenizer=tokenizer, train_dataset=train_dataset
33 | )
34 |
35 | trainer.train()
36 |
--------------------------------------------------------------------------------
/examples/huggingface/launch_on_modal.py:
--------------------------------------------------------------------------------
1 | """
2 | launch_on_modal.py
3 |
4 | This tool is designed to launch scripts using Modal.
5 |
6 | It sets up the necessary environment, including GPU resources and python dependencies,
7 | and executes the specified training script remotely.
8 |
9 | ### Setup and Usage
10 | ```bash
11 | pip install modal
12 | modal setup # authenticate with Modal
13 | export HF_TOKEN="your_huggingface_token" # if using a gated model such as llama3
14 | modal run launch_on_modal.py --script "run_qwen2_vl.sh"
15 | ```
16 |
17 | ### Caveats
18 | This tool is intended as an easy on-ramp to using Liger-Kernel for fine-tuning LLMs and
19 | VLMs - it is a reproducible way to run benchmarks and example scripts. However, it is not
20 | the best way to develop a model on Modal, as it re-downloads the model and dataset each
21 | time it is run. For iterative development, consider using `modal.Volume` to cache the
22 | model and dataset between runs.
23 | """
24 |
25 | import os
26 |
27 | import modal
28 | from modal import gpu
29 |
30 | TWO_HOURS = 2 * 60 * 60
31 | SIXTEEN_GB = 16 * 1024
32 |
33 | app = modal.App("liger-example")
34 |
35 | image = (
36 | modal.Image.debian_slim()
37 | .pip_install_from_requirements("requirements.txt")
38 | .copy_local_dir(".", "/root")
39 | )
40 |
41 | if "HF_TOKEN" not in os.environ:
42 | print("HF_TOKEN not found in environment variables, using an empty token.")
43 | hf_token_secret = modal.Secret.from_dict({"HF_TOKEN": os.environ.get("HF_TOKEN", "")})
44 |
45 |
46 | @app.function(
47 | gpu=gpu.A100(count=4, size="80GB"),
48 | image=image,
49 | timeout=TWO_HOURS,
50 | memory=SIXTEEN_GB,
51 | secrets=[hf_token_secret],
52 | )
53 | def launch_script(script: str):
54 | import subprocess
55 |
56 | script_path = f"/root/{script}"
57 | os.chmod(script_path, 0o755) # make script executable
58 |
59 | print(f"Running script: {script_path}")
60 | subprocess.run([script_path], check=True, cwd="/root", env=os.environ.copy())
61 |
62 |
63 | @app.local_entrypoint()
64 | def main(script: str):
65 | """
66 | Launch a script remotely on modal.
67 | ```bash
68 | export HF_TOKEN="your_huggingface_token" # if using a gated model such as llama3
69 | modal run --detach launch_on_modal.py --script "run_qwen2_vl.sh"
70 | ```
71 | """
72 | launch_script.remote(script=script)
73 |
--------------------------------------------------------------------------------
/examples/huggingface/training.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import datasets
4 | import torch
5 | import transformers
6 | from callback import EfficiencyCallback
7 | from trl import DataCollatorForCompletionOnlyLM, SFTTrainer
8 |
9 | from liger_kernel.transformers import AutoLigerKernelForCausalLM
10 |
11 |
12 | @dataclass
13 | class CustomArguments:
14 | model_name: str = "meta-llama/Meta-Llama-3-8B"
15 | dataset: str = "tatsu-lab/alpaca"
16 | max_seq_length: int = 512
17 | use_liger: bool = False
18 |
19 |
20 | def formatting_prompts_func(example):
21 | return example["text"]
22 |
23 |
24 | def train():
25 | parser = transformers.HfArgumentParser(
26 | (transformers.TrainingArguments, CustomArguments)
27 | )
28 | training_args, custom_args = parser.parse_args_into_dataclasses()
29 | tokenizer = transformers.AutoTokenizer.from_pretrained(
30 | custom_args.model_name,
31 | padding_side="left",
32 | truncation_side="left",
33 | )
34 | tokenizer.pad_token = tokenizer.eos_token
35 |
36 | dataset = datasets.load_dataset(custom_args.dataset)["train"].train_test_split(
37 | test_size=0.1
38 | )
39 | train_dataset = dataset["train"]
40 | eval_dataset = dataset["test"]
41 | response_prompt = tokenizer.encode("### Response:\n", add_special_tokens=False)
42 | collator = DataCollatorForCompletionOnlyLM(
43 | tokenizer=tokenizer,
44 | response_template=response_prompt,
45 | pad_to_multiple_of=16,
46 | )
47 |
48 | if custom_args.use_liger:
49 | model = AutoLigerKernelForCausalLM.from_pretrained(
50 | custom_args.model_name,
51 | trust_remote_code=True,
52 | use_cache=False,
53 | torch_dtype=torch.bfloat16,
54 | # These args will get passed to the appropriate apply_liger_kernel_to_* function
55 | # to override the default settings
56 | # cross_entropy=True,
57 | # fused_linear_cross_entropy=False,
58 | )
59 | else:
60 | model = transformers.AutoModelForCausalLM.from_pretrained(
61 | custom_args.model_name,
62 | trust_remote_code=True,
63 | use_cache=False,
64 | torch_dtype=torch.bfloat16,
65 | )
66 |
67 | trainer = SFTTrainer(
68 | model=model,
69 | args=training_args,
70 | data_collator=collator,
71 | max_seq_length=custom_args.max_seq_length,
72 | train_dataset=train_dataset,
73 | eval_dataset=eval_dataset,
74 | formatting_func=formatting_prompts_func,
75 | callbacks=[EfficiencyCallback()],
76 | )
77 | trainer.train()
78 |
79 |
80 | if __name__ == "__main__":
81 | train()
82 |
--------------------------------------------------------------------------------
/examples/huggingface/training_multimodal.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 |
4 | import datasets
5 | import torch
6 | import transformers
7 | from callback import EfficiencyCallback
8 | from datasets import Image as ImageFeature
9 | from trl import SFTTrainer
10 |
11 | from liger_kernel.transformers import monkey_patch
12 |
13 |
14 | @dataclass
15 | class CustomArguments:
16 | model_name: str = "Qwen/Qwen2-VL-2B-Instruct"
17 | dataset: str = "HuggingFaceM4/the_cauldron"
18 | dataset_subset: str = "ai2d"
19 | dataset_split: str = "train"
20 | max_seq_length: int = 512
21 | dataset_text_field: str = "texts"
22 | use_liger: bool = False
23 |
24 |
25 | def construct_model_and_processor(model_name: str, use_liger: bool) -> torch.nn.Module:
26 | if "Qwen2-VL" in model_name:
27 | from transformers import Qwen2VLForConditionalGeneration
28 |
29 | # These settings are used to reduce the memory footprint of the Qwen2-VL model,
30 | # which supports training/inferences on images in their native resolution. Large
31 | # images -> many visual tokens (a max of 16384) -> large memory consumption.
32 | # If fine-tuning for a real-world application, consider these values carefully.
33 | min_visual_tokens_per_image = 256
34 | max_visual_tokens_per_image = 256
35 |
36 | processor = transformers.AutoProcessor.from_pretrained(
37 | model_name,
38 | padding_side="left",
39 | truncation_side="left",
40 | min_pixels=min_visual_tokens_per_image * 28 * 28, # patch size is 14x14
41 | max_pixels=max_visual_tokens_per_image * 28 * 28, # 4 patches / token
42 | )
43 | processor.tokenizer.pad_token = processor.tokenizer.eos_token
44 | image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
45 |
46 | if use_liger:
47 | print("Applying Liger Kernel to Qwen2-VL model")
48 | monkey_patch.apply_liger_kernel_to_qwen2_vl(
49 | # These args can be used to override the default Liger settings
50 | # cross_entropy=True,
51 | # fused_linear_cross_entropy=False,
52 | )
53 |
54 | model = Qwen2VLForConditionalGeneration.from_pretrained(
55 | pretrained_model_name_or_path=model_name,
56 | use_cache=False,
57 | torch_dtype=torch.bfloat16,
58 | low_cpu_mem_usage=True,
59 | attn_implementation="sdpa",
60 | )
61 | return model, processor, image_token_id
62 |
63 | raise NotImplementedError(f"Model {model_name} not supported")
64 |
65 |
66 | def _validate_and_extract_the_cauldron(examples) -> dict[str, list]:
67 | batch_texts = []
68 | batch_images = []
69 | for images, texts in zip(examples["images"], examples["texts"]):
70 | if not images:
71 | raise ValueError("No image found in example from the_cauldron dataset")
72 | if len(images) > 1:
73 | raise ValueError("Only one image per example is supported")
74 | batch_texts.extend(texts)
75 | batch_images.extend([images[0]] * len(texts))
76 | return {"texts": batch_texts, "images": batch_images}
77 |
78 |
79 | def _format_for_convo(example, tokenizer):
80 | # cauldron data is already in message format {"user": ..., "assistant": ...}
81 | text = example["texts"]
82 | messages = [
83 | {
84 | "role": "user",
85 | "content": [{"type": "image"}, {"type": "text", "text": text["user"]}],
86 | },
87 | {"role": "assistant", "content": [{"type": "text", "text": text["assistant"]}]},
88 | ]
89 | text = tokenizer.apply_chat_template(messages, tokenize=False)
90 | return {"texts": text}
91 |
92 |
93 | def train():
94 | parser = transformers.HfArgumentParser(
95 | (transformers.TrainingArguments, CustomArguments)
96 | )
97 | training_args, custom_args = parser.parse_args_into_dataclasses()
98 | training_args.remove_unused_columns = False # required to not drop the image column
99 | training_args.dataset_kwargs = {"skip_prepare_dataset": True}
100 |
101 | model, processor, image_token_id = construct_model_and_processor(
102 | custom_args.model_name, custom_args.use_liger
103 | )
104 |
105 | dataset = (
106 | datasets.load_dataset(
107 | custom_args.dataset,
108 | custom_args.dataset_subset,
109 | split=custom_args.dataset_split,
110 | )
111 | .map(
112 | _validate_and_extract_the_cauldron,
113 | batched=True,
114 | num_proc=min(os.cpu_count(), 16),
115 | desc="Extracting text and images",
116 | )
117 | .map(
118 | _format_for_convo,
119 | fn_kwargs={"tokenizer": processor.tokenizer},
120 | desc="Formatting for convo",
121 | )
122 | .cast_column("images", ImageFeature())
123 | .train_test_split(test_size=0.1)
124 | )
125 |
126 | train_dataset = dataset["train"]
127 | eval_dataset = dataset["test"]
128 |
129 | def collate_fn(examples):
130 | """
131 | Taken directly from the TRL documentation with minor modifications:
132 | https://huggingface.co/docs/trl/en/sft_trainer#a-custom-collator-for-processing-multi-modal-data
133 |
134 | Modifications:
135 | 1. `apply_chat_template` is used to preprocess the texts before training begins (see above)
136 | 2. `example["messages"]` -> `example["texts"]` to conform with the_cauldron dataset schema
137 | 3. Ignoring image tokens in the loss computation
138 | """
139 | # Get the texts and images
140 | texts = [example["texts"] for example in examples]
141 | images = [example["images"] for example in examples]
142 |
143 | # Tokenize the texts and process the images
144 | batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
145 |
146 | # The labels are the input_ids, and we mask the padding tokens in the loss computation
147 | labels = batch["input_ids"].clone()
148 | labels[labels == processor.tokenizer.pad_token_id] = -100
149 |
150 | # Ignore the image token index in the loss computation
151 | labels[labels == image_token_id] = -100
152 | batch["labels"] = labels
153 |
154 | return batch
155 |
156 | trainer = SFTTrainer(
157 | model=model,
158 | args=training_args,
159 | data_collator=collate_fn,
160 | max_seq_length=custom_args.max_seq_length,
161 | dataset_text_field=custom_args.dataset_text_field,
162 | train_dataset=train_dataset,
163 | eval_dataset=eval_dataset,
164 | tokenizer=processor.tokenizer,
165 | callbacks=[EfficiencyCallback()],
166 | )
167 | trainer.train()
168 |
169 |
170 | if __name__ == "__main__":
171 | train()
172 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # setup.py
2 |
3 | import subprocess
4 | from typing import Literal
5 |
6 | from setuptools import setup
7 |
8 |
9 | def get_default_dependencies():
10 | """Determine the appropriate dependencies based on detected hardware."""
11 | platform = get_platform()
12 |
13 | if platform in ["cuda", "cpu"]:
14 | return [
15 | "torch>=2.1.2",
16 | "triton>=2.3.1",
17 | ]
18 | elif platform == "rocm":
19 | return [
20 | "torch>=2.6.0.dev",
21 | "triton>=3.0.0",
22 | ]
23 |
24 |
25 | def get_optional_dependencies():
26 | """Get optional dependency groups."""
27 | return {
28 | "dev": [
29 | "transformers>=4.44.2",
30 | "matplotlib>=3.7.2",
31 | "flake8>=4.0.1.1",
32 | "black>=24.4.2",
33 | "isort>=5.13.2",
34 | "pytest>=7.1.2",
35 | "pytest-xdist",
36 | "pytest-rerunfailures",
37 | "datasets>=2.19.2",
38 | "seaborn",
39 | ]
40 | }
41 |
42 |
43 | # TODO: add intel XPU
44 | def get_platform() -> Literal["cuda", "rocm", "cpu"]:
45 | """
46 | Detect whether the system has NVIDIA or AMD GPU without torch dependency.
47 | """
48 | # Try nvidia-smi first
49 | try:
50 | subprocess.run(["nvidia-smi"], check=True)
51 | print("NVIDIA GPU detected")
52 | return "cuda"
53 | except (subprocess.SubprocessError, FileNotFoundError):
54 | # If nvidia-smi fails, check for ROCm
55 | try:
56 | subprocess.run(["rocm-smi"], check=True)
57 | print("ROCm GPU detected")
58 | return "rocm"
59 | except (subprocess.SubprocessError, FileNotFoundError):
60 | print("No GPU detected")
61 | return "cpu"
62 |
63 |
64 | setup(
65 | name="liger_kernel",
66 | package_dir={"": "src"},
67 | packages=["liger_kernel"],
68 | install_requires=get_default_dependencies(),
69 | extras_require=get_optional_dependencies(),
70 | )
71 |
--------------------------------------------------------------------------------
/src/liger_kernel/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/src/liger_kernel/__init__.py
--------------------------------------------------------------------------------
/src/liger_kernel/chunked_loss/__init__.py:
--------------------------------------------------------------------------------
1 | from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2 | from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3 | from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
4 | from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
5 |
--------------------------------------------------------------------------------
/src/liger_kernel/chunked_loss/cpo_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from liger_kernel.chunked_loss.fused_linear_preference import (
5 | LigerFusedLinearPreferenceBase,
6 | )
7 |
8 |
9 | class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
10 |
11 | @staticmethod
12 | def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13 | """
14 | Paper: https://arxiv.org/pdf/2401.08417
15 |
16 | Formula:
17 | L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))]
18 |
19 | Where:
20 | - π_θ(y|x): Policy (model) probability
21 | - y_w: Chosen sequence
22 | - y_l: Rejected sequence
23 | - σ: Sigmoid function
24 | - β: Temperature parameter
25 | - E: Expected value over the dataset D
26 | - D: Dataset of preferences
27 |
28 | Args:
29 | chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
30 | rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
31 | full_target (torch.Tensor): Non chunked full target tensor
32 | beta (float): Weight for the CPO loss
33 | """
34 | logits = beta * (chosen_logps - rejected_logps)
35 | loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
36 | return loss
37 |
38 | @staticmethod
39 | def forward(
40 | ctx,
41 | _input,
42 | weight,
43 | target,
44 | bias=None,
45 | ignore_index=-100,
46 | beta=0.1,
47 | alpha=1.0,
48 | compute_nll_loss=True,
49 | compiled=True,
50 | ):
51 | return LigerFusedLinearPreferenceBase.forward(
52 | ctx,
53 | _input,
54 | weight,
55 | target,
56 | bias,
57 | loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
58 | ignore_index=ignore_index,
59 | alpha=alpha,
60 | beta=beta,
61 | compute_nll_loss=compute_nll_loss,
62 | compiled=compiled,
63 | )
64 |
65 | @staticmethod
66 | def backward(ctx, *grad_output):
67 | grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
68 | return *grads, None, None, None, None, None
69 |
70 |
71 | class LigerFusedLinearCPOLoss(torch.nn.Module):
72 | """
73 | Fused linear layer with CPO loss.
74 | """
75 |
76 | def __init__(
77 | self,
78 | ignore_index: int = -100,
79 | beta: float = 0.1,
80 | alpha: float = 1.0,
81 | compute_nll_loss: bool = True,
82 | compiled: bool = True,
83 | ):
84 | """
85 | Args:
86 | ignore_index (int): Index to ignore in the loss.
87 | beta (float): Weight for the odds ratio loss.
88 | """
89 | super().__init__()
90 | self.ignore_index = ignore_index
91 | self.beta = beta
92 | self.alpha = alpha
93 | self.compute_nll_loss = compute_nll_loss
94 | self.compiled = compiled
95 |
96 | def forward(self, lin_weight, _input, target, bias=None):
97 | return LigerFusedLinearCPOFunction.apply(
98 | _input,
99 | lin_weight,
100 | target,
101 | bias,
102 | self.ignore_index,
103 | self.beta,
104 | self.alpha,
105 | self.compute_nll_loss,
106 | self.compiled,
107 | )
108 |
--------------------------------------------------------------------------------
/src/liger_kernel/chunked_loss/dpo_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from liger_kernel.chunked_loss.fused_linear_preference import (
5 | LigerFusedLinearPreferenceBase,
6 | )
7 |
8 |
9 | class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
10 |
11 | @staticmethod
12 | def preference_loss_fn(
13 | chosen_logps,
14 | rejected_logps,
15 | full_target,
16 | ref_chosen_logps=None,
17 | ref_rejected_logps=None,
18 | beta=0.1,
19 | ):
20 | """
21 | Paper: https://arxiv.org/pdf/2305.18290
22 |
23 | Formula:
24 | L_DPO = -E[ log_sigmoid( β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x))) ) ]
25 |
26 | Where:
27 | - π(y|x): Policy (model) probability
28 | - π_ref(y|x): Reference model probability
29 | - y_w: Chosen sequence
30 | - y_l: Rejected sequence
31 | - β: Weight for the direct preference loss
32 | - E: Expected value over the dataset
33 |
34 | Args:
35 | chosen_logps: Log probabilities of chosen tokens (batch_size,)
36 | rejected_logps: Log probabilities of rejected tokens (batch_size,)
37 | full_target: Non chunked full target tensor
38 | ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
39 | ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
40 | beta: Weight for the direct preference loss
41 | """
42 |
43 | if ref_chosen_logps is None:
44 | ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
45 | if ref_rejected_logps is None:
46 | ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device)
47 |
48 | chosen_logratios = chosen_logps - ref_chosen_logps
49 | rejected_logratios = rejected_logps - ref_rejected_logps
50 |
51 | logits_diff = beta * (chosen_logratios - rejected_logratios)
52 | loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
53 | return loss
54 |
55 | @staticmethod
56 | def forward(
57 | ctx,
58 | _input,
59 | weight,
60 | target,
61 | bias=None,
62 | ref_input=None,
63 | ref_weight=None,
64 | ref_bias=None,
65 | ignore_index=-100,
66 | beta=0.1,
67 | compute_nll_loss=True,
68 | compiled=True,
69 | use_ref_model=True,
70 | ):
71 | return LigerFusedLinearPreferenceBase.forward(
72 | ctx=ctx,
73 | _input=_input,
74 | weight=weight,
75 | target=target,
76 | bias=bias,
77 | loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
78 | ignore_index=ignore_index,
79 | beta=beta,
80 | compute_nll_loss=compute_nll_loss,
81 | compiled=compiled,
82 | use_ref_model=use_ref_model,
83 | ref_input=ref_input,
84 | ref_weight=ref_weight,
85 | ref_bias=ref_bias,
86 | )
87 |
88 | @staticmethod
89 | def backward(ctx, *grad_output):
90 | grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
91 | return *grads, None, None, None, None, None, None, None, None
92 |
93 |
94 | class LigerFusedLinearDPOLoss(torch.nn.Module):
95 | """
96 | Fused linear layer with DPO loss.
97 | """
98 |
99 | def __init__(
100 | self,
101 | ignore_index: int = -100,
102 | beta: float = 0.1,
103 | compute_nll_loss: bool = True,
104 | compiled: bool = True,
105 | use_ref_model: bool = False,
106 | ):
107 | """
108 | Args:
109 | ignore_index (int): Index to ignore in the loss.
110 | beta (float): Weight for the odds ratio loss.
111 | compute_nll_loss (bool): Whether to compute the NLL loss.
112 | compiled (bool): Whether to use the torch compiled kernel.
113 | use_ref_model (bool): Whether to use a reference model for the DPO loss.
114 | """
115 | super().__init__()
116 | self.ignore_index = ignore_index
117 | self.beta = beta
118 | self.compute_nll_loss = compute_nll_loss
119 | self.compiled = compiled
120 | self.use_ref_model = use_ref_model
121 |
122 | def forward(
123 | self,
124 | lin_weight,
125 | _input,
126 | target,
127 | bias=None,
128 | ref_input=None,
129 | ref_weight=None,
130 | ref_bias=None,
131 | ):
132 | return LigerFusedLinearDPOFunction.apply(
133 | _input,
134 | lin_weight,
135 | target,
136 | bias,
137 | ref_input,
138 | ref_weight,
139 | ref_bias,
140 | self.ignore_index,
141 | self.beta,
142 | self.compute_nll_loss,
143 | self.compiled,
144 | self.use_ref_model,
145 | )
146 |
--------------------------------------------------------------------------------
/src/liger_kernel/chunked_loss/functional.py:
--------------------------------------------------------------------------------
1 | from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2 | from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3 | from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
4 | from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
5 |
6 | liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
7 | liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
8 | liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
9 | liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
10 |
--------------------------------------------------------------------------------
/src/liger_kernel/chunked_loss/orpo_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from liger_kernel.chunked_loss.fused_linear_preference import (
5 | LigerFusedLinearPreferenceBase,
6 | )
7 |
8 |
9 | class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
10 |
11 | @staticmethod
12 | def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13 | """
14 | Paper: https://arxiv.org/pdf/2403.07691
15 |
16 | Formula:
17 | Compute odds-ratio loss: L_OR = -log(σ(log(odds_θ(y_w|x) / odds_θ(y_l|x))))
18 | where odds_θ(y|x) = P_θ(y|x) / (1 - P_θ(y|x))
19 |
20 | Where:
21 | - P_θ(y|x): Policy (model) probability
22 | - y_w: Chosen sequence
23 | - y_l: Rejected sequence
24 | - σ: Sigmoid function
25 | - β: Weight for the odds ratio loss
26 | - odds_θ: Odds function for the policy
27 |
28 | Args:
29 | chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
30 | rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
31 | full_target (torch.Tensor): Non chunked full target tensor
32 | beta (float): Weight for the odds ratio loss.
33 | """
34 | log_odds = (chosen_logps - rejected_logps) - (
35 | torch.log1p(-torch.exp(chosen_logps))
36 | - torch.log1p(-torch.exp(rejected_logps))
37 | )
38 | ratio = F.logsigmoid(log_odds)
39 | loss = beta * ratio.sum() / (full_target.shape[0] // 2)
40 |
41 | chosen_rewards = beta * chosen_logps
42 | rejected_rewards = beta * rejected_logps
43 |
44 | log_odds_ratio = torch.sum(ratio) / (full_target.shape[0] // 2)
45 | log_odds_chosen = torch.sum(log_odds) / (full_target.shape[0] // 2)
46 |
47 | return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen
48 |
49 | @staticmethod
50 | def forward(
51 | ctx,
52 | _input,
53 | weight,
54 | target,
55 | bias=None,
56 | ignore_index=-100,
57 | beta=0.1,
58 | compute_nll_loss=True,
59 | compiled=True,
60 | ):
61 | return LigerFusedLinearPreferenceBase.forward(
62 | ctx=ctx,
63 | _input=_input,
64 | weight=weight,
65 | target=target,
66 | bias=bias,
67 | loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
68 | ignore_index=ignore_index,
69 | beta=beta,
70 | compute_nll_loss=compute_nll_loss,
71 | compiled=compiled,
72 | )
73 |
74 | @staticmethod
75 | def backward(ctx, *grad_output):
76 | grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
77 | return *grads, None, None, None, None
78 |
79 |
80 | class LigerFusedLinearORPOLoss(torch.nn.Module):
81 | """
82 | Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
83 | """
84 |
85 | def __init__(
86 | self,
87 | ignore_index: int = -100,
88 | beta: float = 0.1,
89 | compute_nll_loss: bool = True,
90 | compiled: bool = True,
91 | ):
92 | """
93 | Args:
94 | ignore_index (int): Index to ignore in the loss.
95 | beta (float): Weight for the odds ratio loss.
96 | """
97 | super().__init__()
98 | self.ignore_index = ignore_index
99 | self.beta = beta
100 | self.compute_nll_loss = compute_nll_loss
101 | self.compiled = compiled
102 |
103 | def forward(self, lin_weight, _input, target, bias=None):
104 | return LigerFusedLinearORPOFunction.apply(
105 | _input,
106 | lin_weight,
107 | target,
108 | bias,
109 | self.ignore_index,
110 | self.beta,
111 | self.compute_nll_loss,
112 | self.compiled,
113 | )
114 |
--------------------------------------------------------------------------------
/src/liger_kernel/chunked_loss/simpo_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from liger_kernel.chunked_loss.fused_linear_preference import (
5 | LigerFusedLinearPreferenceBase,
6 | )
7 |
8 |
9 | class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
10 |
11 | @staticmethod
12 | def preference_loss_fn(
13 | chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
14 | ):
15 | """
16 | Paper: https://arxiv.org/pdf/2405.14734
17 |
18 | Formula:
19 | L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)]
20 |
21 | Where:
22 | - π_θ(y|x): Policy (model) probability
23 | - y_w: Chosen sequence
24 | - y_l: Rejected sequence
25 | - |y_w|, |y_l|: Sequence lengths
26 | - σ: Sigmoid function
27 | - β: beta weight
28 | - γ: gemma margin term
29 |
30 | Args:
31 | chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
32 | rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
33 | full_target: Non chunked full target tensor
34 | beta (float): beta weight
35 | gamma (float): gemma margin term
36 | """
37 | logits = beta * (chosen_logps - rejected_logps) - gamma
38 | loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
39 | return loss
40 |
41 | @staticmethod
42 | def forward(
43 | ctx,
44 | _input,
45 | weight,
46 | target,
47 | bias=None,
48 | ignore_index=-100,
49 | beta=0.1,
50 | alpha=1.0,
51 | compute_nll_loss=False,
52 | compiled=True,
53 | gamma=0.5,
54 | ):
55 | return LigerFusedLinearPreferenceBase.forward(
56 | ctx,
57 | _input,
58 | weight,
59 | target,
60 | bias,
61 | loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn,
62 | compute_nll_loss=compute_nll_loss,
63 | ignore_index=ignore_index,
64 | alpha=alpha,
65 | beta=beta,
66 | compiled=compiled,
67 | gamma=gamma,
68 | )
69 |
70 | @staticmethod
71 | def backward(ctx, *grad_output):
72 | grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
73 | return *grads, None, None, None, None, None, None
74 |
75 |
76 | class LigerFusedLinearSimPOLoss(torch.nn.Module):
77 | """
78 | Fused linear layer with SimPO loss.
79 | """
80 |
81 | def __init__(
82 | self,
83 | ignore_index: int = -100,
84 | beta: float = 0.1,
85 | alpha: float = 1.0,
86 | compute_nll_loss: bool = True,
87 | compiled: bool = True,
88 | gamma: float = 0.5,
89 | ):
90 | """
91 | Args:
92 | ignore_index (int): Index to ignore in the loss.
93 | beta (float): Weight for the odds ratio loss.
94 | """
95 | super().__init__()
96 | self.ignore_index = ignore_index
97 | self.beta = beta
98 | self.alpha = alpha
99 | self.compute_nll_loss = compute_nll_loss
100 | self.compiled = compiled
101 | self.gamma = gamma
102 |
103 | def forward(self, lin_weight, _input, target, bias=None):
104 | return LigerFusedLinearSimPOFunction.apply(
105 | _input,
106 | lin_weight,
107 | target,
108 | bias,
109 | self.ignore_index,
110 | self.beta,
111 | self.alpha,
112 | self.compute_nll_loss,
113 | self.compiled,
114 | self.gamma,
115 | )
116 |
--------------------------------------------------------------------------------
/src/liger_kernel/env_report.py:
--------------------------------------------------------------------------------
1 | import platform
2 | import sys
3 | from importlib.metadata import version
4 |
5 |
6 | def print_env_report():
7 | """
8 |
9 | Prints a report of the environment. Useful for debugging and reproducibility.
10 | Usage:
11 | ```
12 | python -m liger_kernel.env_report
13 | ```
14 |
15 | """
16 | print("Environment Report:")
17 | print("-------------------")
18 | print(f"Operating System: {platform.platform()}")
19 | print(f"Python version: {sys.version.split()[0]}")
20 |
21 | try:
22 | print(f"Liger Kernel version: {version('liger-kernel')}")
23 | except ImportError:
24 | print("Liger Kernel: Not installed")
25 |
26 | try:
27 | import torch
28 |
29 | print(f"PyTorch version: {torch.__version__}")
30 | cuda_version = (
31 | torch.version.cuda if torch.cuda.is_available() else "Not available"
32 | )
33 | print(f"CUDA version: {cuda_version}")
34 | hip_version = (
35 | torch.version.hip
36 | if torch.cuda.is_available() and torch.version.hip
37 | else "Not available"
38 | )
39 | print(f"HIP(ROCm) version: {hip_version}")
40 |
41 | except ImportError:
42 | print("PyTorch: Not installed")
43 | print("CUDA version: Unable to query")
44 | print("HIP(ROCm) version: Unable to query")
45 |
46 | try:
47 | import triton
48 |
49 | print(f"Triton version: {triton.__version__}")
50 | except ImportError:
51 | print("Triton: Not installed")
52 |
53 | try:
54 | import transformers
55 |
56 | print(f"Transformers version: {transformers.__version__}")
57 | except ImportError:
58 | print("Transformers: Not installed")
59 |
60 | try:
61 | xpu_version = (
62 | torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
63 | )
64 | print(f"XPU version: {xpu_version}")
65 | except ImportError:
66 | print("XPU version: Unable to query")
67 |
68 |
69 | if __name__ == "__main__":
70 | print_env_report()
71 |
--------------------------------------------------------------------------------
/src/liger_kernel/ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/src/liger_kernel/ops/__init__.py
--------------------------------------------------------------------------------
/src/liger_kernel/ops/experimental/embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | import triton.language as tl
4 |
5 | from liger_kernel.ops.utils import ensure_contiguous
6 |
7 |
8 | @triton.jit
9 | def embedding_forward_kernel(
10 | embeddings_ptr,
11 | indices_ptr,
12 | output_ptr,
13 | n_elements,
14 | embedding_dim: tl.constexpr,
15 | BLOCK_SIZE_M: tl.constexpr,
16 | BLOCK_SIZE_N: tl.constexpr,
17 | ):
18 | pid_m = tl.program_id(0)
19 | pid_n = tl.program_id(1)
20 |
21 | start_m = pid_m * BLOCK_SIZE_M
22 | start_n = pid_n * BLOCK_SIZE_N
23 | offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
24 | mask_m = offsets_m < n_elements
25 | indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
26 | offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
27 | mask_n = offsets_n < embedding_dim
28 |
29 | embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
30 | embeddings = tl.load(
31 | embeddings_ptr + embedding_offsets,
32 | mask=mask_m[:, None] & mask_n[None, :],
33 | other=0.0,
34 | )
35 |
36 | output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
37 | tl.store(
38 | output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :]
39 | )
40 |
41 |
42 | @triton.jit
43 | def embedding_backward_kernel(
44 | grad_output_ptr,
45 | grad_weight_ptr,
46 | indices_ptr,
47 | n_elements,
48 | embedding_dim: tl.constexpr,
49 | BLOCK_SIZE_M: tl.constexpr,
50 | BLOCK_SIZE_N: tl.constexpr,
51 | ):
52 | pid_m = tl.program_id(0)
53 | pid_n = tl.program_id(1)
54 |
55 | start_m = pid_m * BLOCK_SIZE_M
56 | start_n = pid_n * BLOCK_SIZE_N
57 | offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
58 | mask_m = offsets_m < n_elements
59 | indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
60 | offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
61 | mask_n = offsets_n < embedding_dim
62 |
63 | grad_output = tl.load(
64 | grad_output_ptr + offsets_m[:, None] * embedding_dim + offsets_n[None, :],
65 | mask=mask_m[:, None] & mask_n[None, :],
66 | other=0.0,
67 | )
68 |
69 | grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
70 |
71 | tl.atomic_add(
72 | grad_weight_ptr + grad_weight_offsets,
73 | grad_output,
74 | mask=mask_m[:, None] & mask_n[None, :],
75 | )
76 |
77 |
78 | class LigerEmbeddingFunction(torch.autograd.Function):
79 | @staticmethod
80 | @ensure_contiguous
81 | def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor):
82 | ori_shape = indices.shape
83 | indices = indices.view(-1)
84 | output = torch.empty(
85 | indices.shape[0],
86 | embeddings.shape[1],
87 | device=indices.device,
88 | dtype=embeddings.dtype,
89 | )
90 |
91 | n_elements = indices.numel()
92 | embedding_dim = embeddings.shape[1]
93 |
94 | BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim))
95 | BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
96 | grid = (
97 | triton.cdiv(n_elements, BLOCK_SIZE_M),
98 | triton.cdiv(embedding_dim, BLOCK_SIZE_N),
99 | )
100 |
101 | embedding_forward_kernel[grid](
102 | embeddings,
103 | indices,
104 | output,
105 | n_elements,
106 | embedding_dim=embedding_dim,
107 | BLOCK_SIZE_M=BLOCK_SIZE_M,
108 | BLOCK_SIZE_N=BLOCK_SIZE_N,
109 | )
110 |
111 | ctx.save_for_backward(indices, embeddings)
112 |
113 | return output.view(*ori_shape, -1)
114 |
115 | @staticmethod
116 | @ensure_contiguous
117 | def backward(ctx, grad_output: torch.Tensor):
118 | indices, embedding_table = ctx.saved_tensors
119 | grad_output = grad_output.contiguous().view(-1, embedding_table.shape[1])
120 |
121 | grad_weight = torch.zeros_like(embedding_table)
122 |
123 | n_elements = indices.numel()
124 | embedding_dim = embedding_table.shape[1]
125 |
126 | BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim))
127 | BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
128 | grid = (
129 | triton.cdiv(n_elements, BLOCK_SIZE_M),
130 | triton.cdiv(embedding_dim, BLOCK_SIZE_N),
131 | )
132 |
133 | embedding_backward_kernel[grid](
134 | grad_output,
135 | grad_weight,
136 | indices,
137 | n_elements,
138 | embedding_dim=embedding_dim,
139 | BLOCK_SIZE_M=BLOCK_SIZE_M,
140 | BLOCK_SIZE_N=BLOCK_SIZE_N,
141 | )
142 |
143 | return grad_weight, None
144 |
--------------------------------------------------------------------------------
/src/liger_kernel/ops/geglu.py:
--------------------------------------------------------------------------------
1 | import operator
2 |
3 | import torch
4 | import triton
5 | import triton.language as tl
6 |
7 | from liger_kernel.ops.utils import (
8 | calculate_settings,
9 | compare_version,
10 | ensure_contiguous,
11 | )
12 |
13 | if compare_version("triton", operator.ge, "3.0.0"):
14 | try:
15 | # typical import path with dispatch available
16 | from triton.language.extra.libdevice import tanh
17 | except ModuleNotFoundError:
18 | # for working with NGC containers
19 | from triton.language.extra.cuda.libdevice import tanh
20 | else:
21 | from triton.language.math import tanh
22 |
23 |
24 | @triton.jit
25 | def _geglu_tanh_forward_kernel(
26 | a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
27 | ):
28 | program_id = tl.program_id(0).to(tl.int64)
29 |
30 | # locate start index
31 | a += program_id * stride
32 | b += program_id * stride
33 | c += program_id * stride
34 |
35 | col_offsets = tl.arange(0, BLOCK_SIZE)
36 | mask = col_offsets < n_cols
37 | a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
38 | b_row = tl.load(b + col_offsets, mask=mask, other=0)
39 |
40 | # tanh approximation form of GELU is computed with:
41 | # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
42 | sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
43 | a_cubed = a_row * a_row * a_row
44 | tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
45 | tanh_result = tanh(tanh_arg)
46 | geglu_a = 0.5 * a_row * (1 + tanh_result)
47 | c_row = geglu_a * b_row
48 | tl.store(c + col_offsets, c_row, mask=mask)
49 |
50 |
51 | @triton.jit
52 | def _geglu_tanh_backward_kernel(
53 | dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
54 | ):
55 | program_id = tl.program_id(0).to(tl.int64)
56 |
57 | # locate start index
58 | dc += program_id * stride
59 | a += program_id * stride
60 | b += program_id * stride
61 |
62 | col_offsets = tl.arange(0, BLOCK_SIZE)
63 | mask = col_offsets < n_cols
64 |
65 | dc_row = tl.load(dc + col_offsets, mask=mask, other=0)
66 | a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
67 | b_row = tl.load(b + col_offsets, mask=mask, other=0)
68 |
69 | # recomputation to save memory
70 | sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
71 | a_cubed = a_row * a_row * a_row
72 | tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
73 | tanh_result = tanh(tanh_arg)
74 | geglu_a = 0.5 * a_row * (1 + tanh_result)
75 |
76 | db_row = dc_row * geglu_a
77 |
78 | # Gradient w.r.t. a can be computed with:
79 | # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
80 | # where z = sqrt(2/pi) * (a + 0.044715 * a^3)
81 | term1 = 0.5 * (1 + tanh_result)
82 | tanh_sq = tanh_result * tanh_result
83 | term2 = (
84 | 0.5
85 | * a_row
86 | * (1 - tanh_sq)
87 | * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
88 | )
89 | da_row = dc_row * b_row * (term1 + term2)
90 |
91 | tl.store(a + col_offsets, da_row, mask=mask)
92 | tl.store(b + col_offsets, db_row, mask=mask)
93 |
94 |
95 | def geglu_forward(a, b):
96 | ori_shape = a.shape
97 |
98 | n_cols = ori_shape[-1]
99 | a = a.view(-1, n_cols)
100 | b = b.view(-1, n_cols)
101 | c = torch.empty_like(a)
102 | n_rows = a.shape[0]
103 |
104 | BLOCK_SIZE, num_warps = calculate_settings(n_cols)
105 |
106 | _geglu_tanh_forward_kernel[(n_rows,)](
107 | a,
108 | b,
109 | c,
110 | c.stride(-2),
111 | n_cols=n_cols,
112 | BLOCK_SIZE=BLOCK_SIZE,
113 | num_warps=num_warps,
114 | )
115 | return a, b, c.view(*ori_shape)
116 |
117 |
118 | def geglu_backward(a, b, dc):
119 | ori_shape = dc.shape
120 | n_cols = ori_shape[-1]
121 | dc = dc.view(-1, n_cols)
122 | n_rows = dc.shape[0]
123 |
124 | BLOCK_SIZE, num_warps = calculate_settings(n_cols)
125 |
126 | _geglu_tanh_backward_kernel[(n_rows,)](
127 | dc,
128 | a,
129 | b,
130 | dc.stride(-2),
131 | n_cols=n_cols,
132 | BLOCK_SIZE=BLOCK_SIZE,
133 | num_warps=num_warps,
134 | )
135 |
136 | return a.view(*ori_shape), b.view(*ori_shape)
137 |
138 |
139 | class LigerGELUMulFunction(torch.autograd.Function):
140 | @staticmethod
141 | @ensure_contiguous
142 | def forward(ctx, a, b):
143 | a, b, c = geglu_forward(a, b)
144 | ctx.save_for_backward(a, b)
145 | return c
146 |
147 | @staticmethod
148 | @ensure_contiguous
149 | def backward(ctx, dc):
150 | a, b = ctx.saved_tensors
151 | a, b = geglu_backward(a, b, dc)
152 | return a, b
153 |
--------------------------------------------------------------------------------
/src/liger_kernel/ops/jsd.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import triton
5 | import triton.language as tl
6 |
7 | from liger_kernel.ops.utils import ensure_contiguous
8 |
9 |
10 | @triton.jit
11 | def _jsd_kernel(
12 | X_ptr, # input in logspace, X = log Q
13 | X_stride,
14 | Y_ptr, # ground truth in logspace, Y = log P
15 | Y_stride,
16 | loss_ptr,
17 | loss_stride,
18 | dX_ptr,
19 | dX_stride,
20 | label_ptr,
21 | beta: tl.constexpr,
22 | n_non_ignore: int,
23 | ignore_index: tl.constexpr,
24 | n_cols,
25 | BLOCK_SIZE: tl.constexpr,
26 | HAS_LABEL: tl.constexpr,
27 | ):
28 | # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
29 | # = sum(P * log P + Q * log Q - 2 * M * log M) / 2
30 | # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
31 | # grad_x_i = 0.5 * Q * (X - log_M)
32 | pid = tl.program_id(0).to(tl.int64)
33 | X_ptr += pid * X_stride
34 | dX_ptr += pid * dX_stride
35 | Y_ptr += pid * Y_stride
36 | loss_ptr += pid * loss_stride
37 | label_ptr += pid
38 |
39 | if HAS_LABEL:
40 | label = tl.load(label_ptr)
41 | if label == ignore_index:
42 | for i in range(0, n_cols, BLOCK_SIZE):
43 | offsets = i + tl.arange(0, BLOCK_SIZE)
44 | tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols)
45 | return
46 |
47 | for i in range(0, n_cols, BLOCK_SIZE):
48 | offsets = i + tl.arange(0, BLOCK_SIZE)
49 | mask = offsets < n_cols
50 | X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
51 | Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
52 |
53 | if beta == 0.0: # forward KL
54 | Y_prob = tl.exp(Y)
55 | loss = Y_prob * (Y - X)
56 | dX = -Y_prob
57 | elif beta == 1.0:
58 | X_prob = tl.exp(X)
59 | loss = X_prob * (X - Y)
60 | dX = loss + X_prob
61 | else:
62 | Q = tl.exp(X)
63 | P = tl.exp(Y)
64 | M = beta * P + (1 - beta) * Q
65 | log_M = tl.log(M)
66 |
67 | loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
68 | dX = (1 - beta) * Q * (X - log_M)
69 |
70 | loss = loss / n_non_ignore
71 | dX = dX / n_non_ignore
72 | tl.store(loss_ptr + offsets, loss, mask=mask)
73 | tl.store(dX_ptr + offsets, dX, mask=mask)
74 |
75 |
76 | MAX_FUSED_SIZE = 65536
77 |
78 |
79 | def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
80 | BT, V = _input.shape
81 | n_rows = BT
82 | BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
83 | # non reduction loss
84 | loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
85 | dX = torch.empty_like(_input)
86 |
87 | if has_label:
88 | n_non_ignore = (shift_labels != ignore_index).sum().item()
89 | else:
90 | n_non_ignore = BT
91 |
92 | _jsd_kernel[(n_rows,)](
93 | X_ptr=_input, # input in logspace, X = log Q
94 | X_stride=_input.stride(-2),
95 | Y_ptr=target, # ground truth in logspace, Y = log P
96 | Y_stride=target.stride(-2),
97 | loss_ptr=loss,
98 | loss_stride=loss.stride(-2),
99 | dX_ptr=dX,
100 | dX_stride=dX.stride(-2),
101 | label_ptr=(
102 | shift_labels if has_label else torch.empty(1, device=_input.device)
103 | ), # dummy ptr if no label
104 | beta=beta,
105 | n_non_ignore=n_non_ignore,
106 | ignore_index=ignore_index,
107 | n_cols=V,
108 | BLOCK_SIZE=BLOCK_SIZE,
109 | HAS_LABEL=has_label,
110 | )
111 |
112 | loss = torch.sum(loss)
113 | return loss.to(_input.dtype), dX
114 |
115 |
116 | def jsd_backward(dX, grad_output):
117 | # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
118 | if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
119 | return dX
120 | else:
121 | return grad_output * dX
122 |
123 |
124 | class LigerJSDFunction(torch.autograd.Function):
125 | r"""
126 | This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
127 | .. math::
128 | JSD(\beta)(P || Q)
129 | = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
130 |
131 | .. note::
132 | As all the other losses in PyTorch, this function expects the first argument,
133 | :attr:`_input`, to be the predictions, the output of the student model, in log-space
134 | and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
135 | This differs from the standard mathematical notation :math:`JSD(P || Q)` where
136 | :math:`P` denotes the teacher model and :math:`Q` denotes the student model.
137 | """
138 |
139 | @staticmethod
140 | @ensure_contiguous
141 | def forward(
142 | ctx,
143 | _input: torch.Tensor,
144 | target: torch.Tensor,
145 | shift_labels: Optional[torch.Tensor] = None,
146 | beta: float = 0.5,
147 | ignore_index: int = -100,
148 | ) -> torch.Tensor:
149 | """
150 | Args:
151 | _input (torch.Tensor): predict values with shape (BT, V) in logspace
152 | target (torch.Tensor): ground truth values with shape (BT, V) in logspace
153 | shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
154 | beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
155 | ignore_index (int): the index to ignore. Default: -100
156 |
157 | Returns:
158 | loss (torch.Tensor): generalized JSD
159 | """
160 | has_label = False
161 | if shift_labels is not None:
162 | assert shift_labels.shape == (
163 | _input.shape[0],
164 | ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
165 | shift_labels = shift_labels.contiguous()
166 | has_label = True
167 |
168 | loss, dX = jsd_forward(
169 | _input, target, shift_labels, beta, ignore_index, has_label
170 | )
171 | ctx.save_for_backward(dX)
172 | return loss
173 |
174 | @staticmethod
175 | @ensure_contiguous
176 | def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
177 | (dX,) = ctx.saved_tensors
178 | dX = jsd_backward(dX, grad_output)
179 | return (
180 | dX,
181 | None,
182 | None,
183 | None,
184 | None,
185 | )
186 |
--------------------------------------------------------------------------------
/src/liger_kernel/ops/swiglu.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | import triton.language as tl
4 |
5 | from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
6 |
7 |
8 | @triton.jit
9 | def silu(x):
10 | return x * tl.sigmoid(x)
11 |
12 |
13 | @triton.jit
14 | def _swiglu_forward_kernel(
15 | a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
16 | ):
17 | program_id = tl.program_id(0).to(tl.int64)
18 |
19 | # locate start index
20 | a_ptr += program_id * stride
21 | b_ptr += program_id * stride
22 | c_ptr += program_id * stride
23 |
24 | col_offsets = tl.arange(0, BLOCK_SIZE)
25 | mask = col_offsets < n_cols
26 |
27 | # sigmoid requires type float32
28 | a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
29 | b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
30 | c_row = silu(a_row) * b_row
31 | tl.store(c_ptr + col_offsets, c_row, mask=mask)
32 |
33 |
34 | @triton.jit
35 | def _swiglu_backward_kernel(
36 | dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
37 | ):
38 | program_id = tl.program_id(0).to(tl.int64)
39 |
40 | # locate start index
41 | dc_ptr += program_id * stride
42 | a_ptr += program_id * stride
43 | b_ptr += program_id * stride
44 |
45 | col_offsets = tl.arange(0, BLOCK_SIZE)
46 | mask = col_offsets < n_cols
47 |
48 | dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
49 | # sigmoid requires type float32
50 | a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
51 | b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
52 |
53 | # recomputation to save memory
54 | sig_a = tl.sigmoid(a_row)
55 | silu_a = a_row * sig_a
56 | db_row = dc_row * silu_a
57 | da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row
58 |
59 | tl.store(a_ptr + col_offsets, da_row, mask=mask)
60 | tl.store(b_ptr + col_offsets, db_row, mask=mask)
61 |
62 |
63 | def swiglu_forward(a, b):
64 | ori_shape = a.shape
65 |
66 | n_cols = ori_shape[-1]
67 | a = a.view(-1, n_cols)
68 | b = b.view(-1, n_cols)
69 | c = torch.empty_like(a)
70 | n_rows = a.shape[0]
71 |
72 | BLOCK_SIZE, num_warps = calculate_settings(n_cols)
73 |
74 | _swiglu_forward_kernel[(n_rows,)](
75 | a,
76 | b,
77 | c,
78 | c.stride(-2),
79 | n_cols=n_cols,
80 | BLOCK_SIZE=BLOCK_SIZE,
81 | num_warps=num_warps,
82 | )
83 | return a, b, c.view(*ori_shape)
84 |
85 |
86 | def swiglu_backward(a, b, dc):
87 |
88 | ori_shape = dc.shape
89 | n_cols = ori_shape[-1]
90 | dc = dc.view(-1, n_cols)
91 | n_rows = dc.shape[0]
92 |
93 | BLOCK_SIZE, num_warps = calculate_settings(n_cols)
94 |
95 | _swiglu_backward_kernel[(n_rows,)](
96 | dc,
97 | a,
98 | b,
99 | dc.stride(-2),
100 | n_cols=n_cols,
101 | BLOCK_SIZE=BLOCK_SIZE,
102 | num_warps=num_warps,
103 | )
104 | return a.view(*ori_shape), b.view(*ori_shape)
105 |
106 |
107 | class LigerSiLUMulFunction(torch.autograd.Function):
108 | @staticmethod
109 | @ensure_contiguous
110 | def forward(ctx, a, b):
111 | a, b, c = swiglu_forward(a, b)
112 | ctx.save_for_backward(a, b)
113 | return c
114 |
115 | @staticmethod
116 | @ensure_contiguous
117 | def backward(ctx, dc):
118 | a, b = ctx.saved_tensors
119 | a, b = swiglu_backward(a, b, dc)
120 | return a, b
121 |
--------------------------------------------------------------------------------
/src/liger_kernel/ops/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
3 | See the original Unsloth repository at https://github.com/unslothai/unsloth.
4 |
5 | The following line
6 | https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23
7 | is based on code from Unsloth, located at:
8 | https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
9 |
10 | Modifications made by Yanning Chen, 2024.
11 | """
12 |
13 | import functools
14 | import importlib
15 | import operator
16 | from typing import Callable
17 |
18 | import torch
19 | import triton
20 | import triton.language as tl
21 | from packaging.version import Version
22 |
23 | from liger_kernel.utils import infer_device
24 |
25 |
26 | def is_hip() -> bool:
27 | return torch.version.hip is not None
28 |
29 |
30 | def ensure_contiguous(fn):
31 | @functools.wraps(fn)
32 | def wrapper(ctx, *args, **kwargs):
33 | def maybe_to_contiguous(x):
34 | return x.contiguous() if isinstance(x, torch.Tensor) else x
35 |
36 | args = [maybe_to_contiguous(arg) for arg in args]
37 | kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()}
38 | return fn(ctx, *args, **kwargs)
39 |
40 | return wrapper
41 |
42 |
43 | def calculate_settings(n):
44 | # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
45 |
46 | MAX_FUSED_SIZE = 65536
47 | BLOCK_SIZE = triton.next_power_of_2(n)
48 | if BLOCK_SIZE > MAX_FUSED_SIZE:
49 | raise RuntimeError(
50 | f"Cannot launch Triton kernel since n = {n} exceeds "
51 | f"the recommended Triton blocksize = {MAX_FUSED_SIZE}."
52 | )
53 |
54 | num_warps = 4
55 | if BLOCK_SIZE >= 32768:
56 | num_warps = 32 if not is_hip() else 16
57 | elif BLOCK_SIZE >= 8192:
58 | num_warps = 16
59 | elif BLOCK_SIZE >= 2048:
60 | num_warps = 8
61 | return BLOCK_SIZE, num_warps
62 |
63 |
64 | def compare_version(package: str, operator: Callable, target: str):
65 | try:
66 | pkg = importlib.import_module(package)
67 | except ImportError:
68 | return False
69 | pkg_version = Version(pkg.__version__)
70 | return operator(pkg_version, Version(target))
71 |
72 |
73 | def get_amp_custom_fwd_bwd() -> Callable:
74 | device = infer_device()
75 | if compare_version("torch", operator.ge, "2.4.0"):
76 | return (
77 | functools.partial(torch.amp.custom_fwd, device_type=device),
78 | functools.partial(torch.amp.custom_bwd, device_type=device),
79 | )
80 | return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
81 |
82 |
83 | amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
84 |
85 |
86 | torch_to_triton_dtype = {
87 | torch.float32: tl.float32,
88 | torch.float16: tl.float16,
89 | torch.bfloat16: tl.bfloat16,
90 | }
91 |
92 |
93 | @triton.jit
94 | def element_mul_kernel(
95 | X_ptr,
96 | X_stride,
97 | grad_output_ptr,
98 | n_cols,
99 | BLOCK_SIZE: tl.constexpr,
100 | ):
101 | """
102 | This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
103 | The multiplication is performed in-place on the tensor pointed by X_ptr.
104 |
105 | Parameters:
106 | X_ptr: Pointer to the input tensor.
107 | X_stride (int): The stride of the input tensor.
108 | grad_output_ptr: Pointer to the gradient output value.
109 | n_cols (int): The number of columns in the input tensor.
110 | BLOCK_SIZE (int): The block size for Triton operations.
111 | """
112 |
113 | # Get the program ID and convert it to int64 to avoid overflow
114 | program_id = tl.program_id(0).to(tl.int64)
115 |
116 | # Locate the start index
117 | X_ptr += program_id * X_stride
118 |
119 | # Load the gradient output value
120 | grad_output = tl.load(grad_output_ptr)
121 |
122 | # Perform the element-wise multiplication
123 | for i in range(0, n_cols, BLOCK_SIZE):
124 | X_offsets = i + tl.arange(0, BLOCK_SIZE)
125 | X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
126 | tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
127 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/__init__.py:
--------------------------------------------------------------------------------
1 | from liger_kernel.transformers.auto_model import ( # noqa: F401
2 | AutoLigerKernelForCausalLM,
3 | )
4 | from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
5 | from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401
6 | LigerFusedLinearCrossEntropyLoss,
7 | )
8 | from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
9 | from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
10 | from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
11 | from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
12 | from liger_kernel.transformers.monkey_patch import ( # noqa: F401
13 | _apply_liger_kernel,
14 | _apply_liger_kernel_to_instance,
15 | apply_liger_kernel_to_gemma,
16 | apply_liger_kernel_to_gemma2,
17 | apply_liger_kernel_to_llama,
18 | apply_liger_kernel_to_mistral,
19 | apply_liger_kernel_to_mixtral,
20 | apply_liger_kernel_to_mllama,
21 | apply_liger_kernel_to_phi3,
22 | apply_liger_kernel_to_qwen2,
23 | apply_liger_kernel_to_qwen2_vl,
24 | )
25 | from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
26 | from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
27 | from liger_kernel.transformers.swiglu import ( # noqa: F401
28 | LigerBlockSparseTop2MLP,
29 | LigerPhi3SwiGLUMLP,
30 | LigerSwiGLUMLP,
31 | )
32 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/auto_model.py:
--------------------------------------------------------------------------------
1 | import inspect
2 |
3 | from transformers import AutoConfig, AutoModelForCausalLM
4 |
5 | from liger_kernel.transformers.monkey_patch import (
6 | MODEL_TYPE_TO_APPLY_LIGER_FN,
7 | _apply_liger_kernel,
8 | )
9 |
10 |
11 | def _get_model_config(model_dir, **model_init_kwargs):
12 | config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
13 | return config
14 |
15 |
16 | class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
17 | """
18 | This class is a drop-in replacement for AutoModelForCausalLM that applies the Liger Kernel to the model
19 | if applicable.
20 | """
21 |
22 | @classmethod
23 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
24 | model_config = _get_model_config(pretrained_model_name_or_path, **kwargs)
25 |
26 | # Determine the model type and apply the Liger Kernel if applicable
27 | # Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function
28 | model_type = model_config.model_type
29 |
30 | _apply_liger_kernel(model_type, **kwargs)
31 |
32 | # Filter out kwargs that were passed to the apply_liger_* function, which will cause
33 | # model initialization errors otherwise
34 | apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
35 | apply_fn_signature = inspect.signature(apply_fn)
36 |
37 | applicable_kwargs = {
38 | key: value
39 | for key, value in kwargs.items()
40 | if key not in apply_fn_signature.parameters
41 | }
42 |
43 | return super().from_pretrained(
44 | pretrained_model_name_or_path, *model_args, **applicable_kwargs
45 | )
46 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/cross_entropy.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 |
5 | from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
6 |
7 |
8 | class LigerCrossEntropyLoss(torch.nn.Module):
9 | def __init__(
10 | self,
11 | ignore_index: int = -100,
12 | lse_square_scale: float = 0.0,
13 | label_smoothing: float = 0.0,
14 | reduction: str = "mean",
15 | softcap: Optional[float] = None,
16 | return_z_loss: bool = False,
17 | ):
18 | super().__init__()
19 | assert (label_smoothing >= 0) and (
20 | label_smoothing <= 1
21 | ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
22 | assert (label_smoothing >= 0) and (
23 | label_smoothing <= 1
24 | ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
25 | assert reduction in {
26 | "mean",
27 | "sum",
28 | "none",
29 | }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
30 | assert (
31 | softcap is None or softcap > 0
32 | ), f"softcap must greater than 0.0 or None. Got: {softcap}"
33 | self.ignore_index = ignore_index
34 | self.lse_square_scale = lse_square_scale
35 | self.label_smoothing = label_smoothing
36 | self.reduction = reduction
37 | self.softcap = softcap
38 | self.return_z_loss = return_z_loss
39 |
40 | def forward(self, _input: torch.Tensor, target: torch.Tensor):
41 | loss, z_loss = LigerCrossEntropyFunction.apply(
42 | _input,
43 | target,
44 | self.ignore_index,
45 | self.lse_square_scale,
46 | self.label_smoothing,
47 | self.reduction,
48 | self.softcap,
49 | self.return_z_loss,
50 | )
51 | if not self.return_z_loss:
52 | return loss
53 | return loss, z_loss
54 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/experimental/embedding.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
7 |
8 |
9 | class LigerEmbedding(nn.Module):
10 | def __init__(
11 | self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None
12 | ):
13 | super().__init__()
14 | self.num_embeddings = num_embeddings
15 | self.embedding_dim = embedding_dim
16 | self.padding_idx = padding_idx
17 | self.weight = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
18 |
19 | if padding_idx is not None:
20 | with torch.no_grad():
21 | self.weight[padding_idx].fill_(0)
22 |
23 | def forward(self, indices):
24 | embedded = LigerEmbeddingFunction.apply(self.weight, indices)
25 | if self.padding_idx is not None:
26 | embedded = embedded.clone()
27 | embedded[indices == self.padding_idx] = 0
28 | return embedded
29 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/functional.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4 | from liger_kernel.ops.fused_linear_cross_entropy import (
5 | LigerFusedLinearCrossEntropyFunction,
6 | )
7 | from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
8 | from liger_kernel.ops.geglu import LigerGELUMulFunction
9 | from liger_kernel.ops.group_norm import LigerGroupNormFunction
10 | from liger_kernel.ops.jsd import LigerJSDFunction
11 | from liger_kernel.ops.kl_div import LigerKLDivLossFunction
12 | from liger_kernel.ops.layer_norm import LigerLayerNormFunction
13 | from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
14 | from liger_kernel.ops.rms_norm import LigerRMSNormFunction
15 | from liger_kernel.ops.rope import LigerRopeFunction
16 | from liger_kernel.ops.swiglu import LigerSiLUMulFunction
17 |
18 |
19 | # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
20 | # `weight` and `size_average` are placeholders and not implemented yet
21 | def liger_cross_entropy(
22 | input,
23 | target,
24 | weight=None,
25 | size_average=None,
26 | ignore_index: int = -100,
27 | reduce=None,
28 | reduction: str = "mean",
29 | label_smoothing: float = 0.0,
30 | lse_square_scale: float = 0.0,
31 | softcap: Optional[float] = None,
32 | return_z_loss: bool = False,
33 | ):
34 | loss, z_loss = LigerCrossEntropyFunction.apply(
35 | input,
36 | target,
37 | ignore_index,
38 | lse_square_scale,
39 | label_smoothing,
40 | reduction,
41 | softcap,
42 | return_z_loss,
43 | )
44 | if not return_z_loss:
45 | return loss
46 | return loss, z_loss
47 |
48 |
49 | def liger_fused_linear_cross_entropy(
50 | input,
51 | weight,
52 | target,
53 | bias=None,
54 | ignore_index: int = -100,
55 | lse_square_scale: float = 0.0,
56 | label_smoothing: float = 0.0,
57 | reduction: str = "mean",
58 | softcap: Optional[float] = None,
59 | ):
60 | return LigerFusedLinearCrossEntropyFunction.apply(
61 | input,
62 | weight,
63 | target,
64 | bias,
65 | ignore_index,
66 | lse_square_scale,
67 | label_smoothing,
68 | reduction,
69 | softcap,
70 | )
71 |
72 |
73 | def liger_fused_linear_jsd(
74 | student_input,
75 | student_weight,
76 | teacher_input,
77 | teacher_weight,
78 | shift_labels=None,
79 | jsd_beta: float = 0.5,
80 | ignore_index: int = -100,
81 | temperature: float = 1.0,
82 | ):
83 | return LigerFusedLinearJSDFunction.apply(
84 | student_input,
85 | student_weight,
86 | teacher_input,
87 | teacher_weight,
88 | shift_labels,
89 | jsd_beta,
90 | ignore_index,
91 | temperature,
92 | )
93 |
94 |
95 | def liger_geglu(a, b):
96 | return LigerGELUMulFunction.apply(a, b)
97 |
98 |
99 | def liger_group_norm(
100 | X,
101 | affine_scaling_weight,
102 | affine_shifting_bias,
103 | num_channels,
104 | num_groups,
105 | eps,
106 | ):
107 | return LigerGroupNormFunction.apply(
108 | X,
109 | affine_scaling_weight,
110 | affine_shifting_bias,
111 | num_channels,
112 | num_groups,
113 | eps,
114 | )
115 |
116 |
117 | def liger_jsd(
118 | input,
119 | target,
120 | shift_labels=None,
121 | beta: float = 0.5,
122 | ignore_index: int = -100,
123 | ):
124 | return LigerJSDFunction.apply(
125 | input,
126 | target,
127 | shift_labels,
128 | beta,
129 | ignore_index,
130 | )
131 |
132 |
133 | # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div
134 | # `size_average` and `mean` are being deprecated in torch API and are placeholders here
135 | def liger_kl_div(
136 | input,
137 | target,
138 | size_average: bool = True,
139 | reduce: bool = True,
140 | reduction: str = "mean",
141 | log_target: bool = False,
142 | eps: float = 1e-10,
143 | ):
144 | # Note: the default reduction in torch is `mean`, but being `batchmean` in Liger
145 | return LigerKLDivLossFunction.apply(
146 | input,
147 | target,
148 | reduction,
149 | log_target,
150 | eps,
151 | )
152 |
153 |
154 | def liger_layer_norm(X, W, B, eps):
155 | return LigerLayerNormFunction.apply(X, W, B, eps)
156 |
157 |
158 | def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
159 | return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
160 |
161 |
162 | def liger_rms_norm(
163 | X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
164 | ):
165 | return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
166 |
167 |
168 | def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
169 | return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
170 |
171 |
172 | def liger_swiglu(a, b):
173 | return LigerSiLUMulFunction.apply(a, b)
174 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/fused_linear_cross_entropy.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 |
5 | from liger_kernel.ops.fused_linear_cross_entropy import (
6 | LigerFusedLinearCrossEntropyFunction,
7 | )
8 |
9 |
10 | class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
11 | def __init__(
12 | self,
13 | ignore_index: int = -100,
14 | lse_square_scale: float = 0.0,
15 | label_smoothing: float = 0.0,
16 | reduction: str = "mean",
17 | softcap: Optional[float] = None,
18 | ):
19 | super().__init__()
20 | assert (label_smoothing >= 0) and (
21 | label_smoothing <= 1
22 | ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
23 | assert reduction in {
24 | "mean",
25 | "sum",
26 | "none",
27 | }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
28 | assert (
29 | softcap is None or softcap > 0
30 | ), f"softcap must greater than 0.0 or None. Got: {softcap}"
31 | self.ignore_index = ignore_index
32 | self.lse_square_scale = lse_square_scale
33 | self.label_smoothing = label_smoothing
34 | self.reduction = reduction
35 | self.softcap = softcap
36 |
37 | def forward(self, lin_weight, _input, target, bias=None):
38 | return LigerFusedLinearCrossEntropyFunction.apply(
39 | _input,
40 | lin_weight,
41 | target,
42 | bias,
43 | self.ignore_index,
44 | self.lse_square_scale,
45 | self.label_smoothing,
46 | self.reduction,
47 | self.softcap,
48 | )
49 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/fused_linear_jsd.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 |
5 | from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
6 |
7 |
8 | class LigerFusedLinearJSD(torch.nn.Module):
9 | r"""Fusing the last linear layer with generalized JSD
10 |
11 | Handle the forward and backward pass of the final linear layer via JSD by avoiding
12 | the materialization of the large logits tensor.
13 |
14 | Args:
15 | jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
16 | ignore_index (int): The index to ignore in the target. Default: `-100`
17 | temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
18 |
19 | Shape:
20 | - student_input: :math:`(BT, H)`, where B is batch size, T is sequence length, H is hidden dimension.
21 | - student_weight: :math:`(V, H)`, where V is vocab size.
22 | - teacher_input: :math:`(BT, H')`, where H' is hidden dimension of the teacher model.
23 | - teacher_weight: :math:`(V, H')`, where hidden size H and H' can be different.
24 | - shift_labels: :math:`(BT,)`
25 | - Output: a scalar.
26 |
27 | Examples:
28 | ```python
29 | >>> (B, T, H_s, H_t, V) = (2, 2, 3, 5, 10)
30 | >>> fused_jsd = LigerFusedLinearJSD(jsd_beta=0.1, temperature=2.0)
31 | >>> # generate inputs and weights
32 | >>> student_input = torch.rand(B * T, H_s, device="cuda", requires_grad=True)
33 | >>> student_lin = torch.nn.Linear(H_s, V, bias=False, device="cuda")
34 | >>> # teacher input doesn't require grad, hidden_dim can be different from student's
35 | >>> teacher_input = torch.rand(B * T, H_t, device="cuda")
36 | >>> teacher_lin = torch.nn.Linear(H_t, V, bias=False, device="cuda")
37 | >>> output = fused_jsd(student_input, student_lin.weight, teacher_input, teacher_lin.weight)
38 | >>> output.backward()
39 | >>>
40 | >>> # Example with labels for supervised fine-tuning (SFT) context:
41 | >>>
42 | >>> # Assume hidden_states, lm_heads and corresponding labels are given
43 | >>> student_lm_head = torch.nn.Linear(H_s, V, bias=False)
44 | >>> student_hidden_states = torch.randn(B * T, H_s, requires_grad=True).log_softmax(dim=-1)
45 | >>> teacher_lm_head = torch.nn.Linear(H_t, V, bias=False)
46 | >>> teacher_hidden_states = torch.randn(B * T, H_t).log_softmax(dim=-1)
47 | >>> labels = torch.randint(0, V, (B * T,), torch.long)
48 | >>>
49 | >>> # Shift so that tokens < n predict n
50 | >>> shift_student_hidden_states = student_hidden_states[..., :-1, :].contiguous()
51 | >>> shift_teacher_hidden_states = teacher_hidden_states[..., :-1, :].contiguous()
52 | >>> shift_labels = labels[..., 1:].contiguous()
53 | >>>
54 | >>> # Flatten tokens
55 | >>> shift_student_hidden_states = shift_student_hidden_states.view(-1, V)
56 | >>> shift_teacher_hidden_states = shift_teacher_hidden_states.view(-1, V)
57 | >>> shift_labels = shift_labels.view(-1)
58 | >>>
59 | >>> # Calculate loss
60 | >>> loss_fct = LigerJSD(beta=0.1)
61 | >>> loss = loss_fct(
62 | >>> shift_studetn_hidden_states,
63 | >>> student_lm_head.weight,
64 | >>> shift_teacher_hidden_states,
65 | >>> teacher_lm_head.weight,
66 | >>> shift_labels
67 | >>> )
68 | ```
69 | """
70 |
71 | def __init__(self, jsd_beta=0.5, ignore_index=-100, temperature=1.0):
72 | super().__init__()
73 | assert temperature != 0, "temperature cannot be 0."
74 | self.jsd_beta = jsd_beta
75 | self.temperature = temperature
76 | self.ignore_index = ignore_index
77 |
78 | def forward(
79 | self,
80 | student_input: torch.Tensor,
81 | student_weight: torch.Tensor,
82 | teacher_input: torch.Tensor,
83 | teacher_weight: torch.Tensor,
84 | shift_labels: Optional[torch.LongTensor],
85 | ):
86 | return LigerFusedLinearJSDFunction.apply(
87 | student_input,
88 | student_weight,
89 | teacher_input,
90 | teacher_weight,
91 | shift_labels,
92 | self.jsd_beta,
93 | self.ignore_index,
94 | self.temperature,
95 | )
96 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/geglu.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from liger_kernel.ops.geglu import LigerGELUMulFunction
4 |
5 |
6 | class LigerGEGLUMLP(nn.Module):
7 | def __init__(self, config):
8 | super().__init__()
9 | self.config = config
10 | self.hidden_size = config.hidden_size
11 | self.intermediate_size = config.intermediate_size
12 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
13 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
14 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
15 | # TODO: support exact GELU
16 | # Right now Gemma 1, 1.1 and 2 models are all using `gelu_pytorch_tanh`
17 | # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175
18 | # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/activations.py#L46
19 | # So we can safely assume we use tanh approximation form all the time
20 |
21 | def forward(self, x):
22 |
23 | return self.down_proj(
24 | LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
25 | )
26 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/group_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from liger_kernel.ops.group_norm import LigerGroupNormFunction
5 |
6 |
7 | class LigerGroupNorm(nn.Module):
8 | def __init__(self, num_channels, num_groups, eps=1e-6, bias=False, init_fn="ones"):
9 | """
10 | A Group Normalization layer.
11 | Args:
12 | num_channels (int): Number of channels in the input tensor.
13 | num_groups (int): Number of groups to divide the channels into.
14 | eps (float, optional): A value added to the denominator for numerical stability. Default: 1e-6.
15 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False``.
16 | init_fn (str, optional): Initialization function for the learnable parameters. Default: "ones".
17 | """
18 | super().__init__()
19 | assert init_fn in [
20 | "ones",
21 | "zeros",
22 | ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
23 |
24 | assert (
25 | num_channels % num_groups == 0
26 | ), f"Number of channels {num_channels} must be divisible by num_groups {num_groups}"
27 | self.num_channels = num_channels
28 | self.num_groups = num_groups
29 | self.eps = eps
30 | self.weight = nn.Parameter(
31 | torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels)
32 | )
33 | self.bias = nn.Parameter(
34 | torch.randn(num_channels) if bias else torch.zeros(num_channels)
35 | )
36 | self.variance_epsilon = eps
37 |
38 | def forward(self, hidden_states):
39 | # hidden_states: (batch_size, num_channels, *)
40 | assert (
41 | hidden_states.dim() >= 3
42 | ), f"Input must have atleast 3 dimensions, got {hidden_states.dim()}"
43 | assert (
44 | hidden_states.size(1) == self.num_channels
45 | ), f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}"
46 | return LigerGroupNormFunction.apply(
47 | hidden_states,
48 | self.weight,
49 | self.bias,
50 | self.num_channels,
51 | self.num_groups,
52 | self.variance_epsilon,
53 | )
54 |
55 | def extra_repr(self):
56 | return f"{self.hidden_size}, num_channels={self.num_channels}, num_groups={self.num_groups}, eps={self.eps}"
57 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/jsd.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 |
5 | from liger_kernel.ops.jsd import LigerJSDFunction
6 |
7 |
8 | class LigerJSD(torch.nn.Module):
9 | r"""The generalized Jensen-Shannon Divergence.
10 | .. math::
11 | JSD(\beta)(P || Q)
12 | = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
13 | .. note::
14 | As all the other losses in PyTorch, this function expects the first argument,
15 | :attr:`log_q`, to be the predictions, the output of the student model in log-space,
16 | and the second, :attr:`log_p`, to be the observations, the output of the teacher model in log-space.
17 | This differs from the standard mathematical notation :math:`JSD(P || Q)` where
18 | :math:`P` denotes the teacher model and :math:`Q` denotes the student model.
19 |
20 | Args:
21 | beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
22 | ignore_index (int): The index to ignore in the target. Default: `-100`
23 |
24 | Shape:
25 | - Input: :math:`(BT, V)`, where B is batch size, T is sequence length, V is vocab size.
26 | - Target: :math:`(BT, V)`, same shape as the input.
27 | - shift_labels (Optional): :math:`(BT,)`
28 | - Output: a scalar.
29 |
30 | Examples:
31 | ```python
32 | >>> (B, T, V) = (2, 2, 5)
33 | >>> jsd = LigerJSD(beta=0.1)
34 | >>> # input should be a distribution in the log space
35 | >>> input = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1)
36 | >>> target = torch.randn(B * T, V).log_softmax(dim=-1)
37 | >>> output = jsd(input, target)
38 | >>>
39 | >>> # Example with labels for supervised fine-tuning (SFT) context
40 | >>> # Assume logits and corresponding labels are given
41 | >>> student_logits = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1)
42 | >>> teacher_logits = torch.randn(B * T, V).log_softmax(dim=-1)
43 | >>> labels = torch.randint(0, V, (B * T,), torch.long)
44 | >>> # Shift so that tokens < n predict n
45 | >>> shift_student_logits = student_logits[..., :-1, :].contiguous()
46 | >>> shift_teacher_logits = teacher_logits[..., :-1, :].contiguous()
47 | >>> shift_labels = labels[..., 1:].contiguous()
48 | >>> # Flatten tokens
49 | >>> shift_student_logits = shift_student_logits.view(-1, V)
50 | >>> shift_teacher_logits = shift_teacher_logits.view(-1, V)
51 | >>> shift_labels = shift_labels.view(-1)
52 | >>> # Calculate loss
53 | >>> loss_fct = LigerJSD(beta=0.1)
54 | >>> loss = loss_fct(shift_studetn_logits, shift_teacher_logits, shift_labels)
55 |
56 | ```
57 | """
58 |
59 | def __init__(self, beta: float = 0.5, ignore_index: int = -100):
60 | super().__init__()
61 | self.beta = beta
62 | self.ignore_index = ignore_index
63 |
64 | def forward(
65 | self,
66 | log_q: torch.Tensor,
67 | log_p: torch.Tensor,
68 | shift_labels: Optional[torch.LongTensor] = None,
69 | ):
70 | return LigerJSDFunction.apply(
71 | log_q, log_p, shift_labels, self.beta, self.ignore_index
72 | )
73 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/kl_div.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from liger_kernel.ops.kl_div import LigerKLDivLossFunction
4 |
5 |
6 | class LigerKLDIVLoss(nn.KLDivLoss):
7 | def __init__(self, eps: float = 1e-10, *args, **kwargs):
8 | super(LigerKLDIVLoss, self).__init__(*args, **kwargs)
9 | self.eps = eps
10 |
11 | def forward(self, y_pred, y_true):
12 | return LigerKLDivLossFunction.apply(
13 | y_pred, y_true, self.reduction, self.log_target, self.eps
14 | )
15 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/layer_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from liger_kernel.ops.layer_norm import LigerLayerNormFunction
5 |
6 |
7 | class LigerLayerNorm(nn.Module):
8 | def __init__(self, hidden_size, eps=1e-6, bias=False, init_fn="ones"):
9 | super().__init__()
10 | assert init_fn in [
11 | "ones",
12 | "zeros",
13 | ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
14 | self.hidden_size = hidden_size
15 | self.eps = eps
16 | self.weight = nn.Parameter(
17 | torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
18 | )
19 | self.bias = nn.Parameter(
20 | torch.randn(hidden_size) if bias else torch.zeros(hidden_size)
21 | )
22 | self.variance_epsilon = eps
23 |
24 | def forward(self, hidden_states):
25 | return LigerLayerNormFunction.apply(
26 | hidden_states, self.weight, self.bias, self.variance_epsilon
27 | )
28 |
29 | def extra_repr(self):
30 | return f"{self.hidden_size}, eps={self.eps}"
31 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/src/liger_kernel/transformers/model/__init__.py
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/model/mistral.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union
2 |
3 | import torch
4 | from torch.nn import CrossEntropyLoss
5 | from transformers.cache_utils import Cache
6 | from transformers.modeling_outputs import CausalLMOutputWithPast
7 | from transformers.models.mistral.modeling_mistral import (
8 | _CONFIG_FOR_DOC,
9 | MISTRAL_INPUTS_DOCSTRING,
10 | )
11 | from transformers.utils import (
12 | add_start_docstrings_to_model_forward,
13 | replace_return_docstrings,
14 | )
15 |
16 | from liger_kernel.transformers.fused_linear_cross_entropy import (
17 | LigerFusedLinearCrossEntropyLoss,
18 | )
19 |
20 |
21 | @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
22 | @replace_return_docstrings(
23 | output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24 | )
25 | def lce_forward(
26 | self,
27 | input_ids: torch.LongTensor = None,
28 | attention_mask: Optional[torch.Tensor] = None,
29 | position_ids: Optional[torch.LongTensor] = None,
30 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
31 | inputs_embeds: Optional[torch.FloatTensor] = None,
32 | labels: Optional[torch.LongTensor] = None,
33 | use_cache: Optional[bool] = None,
34 | output_attentions: Optional[bool] = None,
35 | output_hidden_states: Optional[bool] = None,
36 | return_dict: Optional[bool] = None,
37 | cache_position: Optional[torch.LongTensor] = None,
38 | ) -> Union[Tuple, CausalLMOutputWithPast]:
39 | r"""
40 | Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
41 |
42 |
43 | Args:
44 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48 |
49 | Returns:
50 |
51 | Example:
52 |
53 | ```python
54 | >>> from transformers import AutoTokenizer, MistralForCausalLM
55 |
56 | >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
57 | >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
58 |
59 | >>> prompt = "Hey, are you conscious? Can you talk to me?"
60 | >>> inputs = tokenizer(prompt, return_tensors="pt")
61 |
62 | >>> # Generate
63 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
64 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
65 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
66 | ```"""
67 |
68 | output_attentions = (
69 | output_attentions
70 | if output_attentions is not None
71 | else self.config.output_attentions
72 | )
73 | output_hidden_states = (
74 | output_hidden_states
75 | if output_hidden_states is not None
76 | else self.config.output_hidden_states
77 | )
78 | return_dict = (
79 | return_dict if return_dict is not None else self.config.use_return_dict
80 | )
81 |
82 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
83 | outputs = self.model(
84 | input_ids=input_ids,
85 | attention_mask=attention_mask,
86 | position_ids=position_ids,
87 | past_key_values=past_key_values,
88 | inputs_embeds=inputs_embeds,
89 | use_cache=use_cache,
90 | output_attentions=output_attentions,
91 | output_hidden_states=output_hidden_states,
92 | return_dict=return_dict,
93 | cache_position=cache_position,
94 | )
95 |
96 | hidden_states = outputs[0]
97 |
98 | loss = None
99 | logits = None
100 |
101 | if self.training and (labels is not None):
102 | shift_hidden_states = hidden_states[..., :-1, :].contiguous()
103 | shift_labels = labels[..., 1:].contiguous()
104 |
105 | # flatten tokens
106 | shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
107 | shift_labels = shift_labels.view(-1)
108 |
109 | lce = LigerFusedLinearCrossEntropyLoss()
110 | loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
111 |
112 | else:
113 | logits = self.lm_head(hidden_states)
114 | if labels is not None:
115 | # Upcast to float if we need to compute the loss to avoid potential precision issues
116 | logits = logits.float()
117 | # Shift so that tokens < n predict n
118 | shift_logits = logits[..., :-1, :].contiguous()
119 | shift_labels = labels[..., 1:].contiguous()
120 | # Flatten the tokens
121 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
122 | shift_labels = shift_labels.view(-1)
123 | # Ensure tensors are on the same device
124 | shift_labels = shift_labels.to(shift_logits.device)
125 | loss_fct = CrossEntropyLoss()
126 | loss = loss_fct(shift_logits, shift_labels)
127 |
128 | if not return_dict:
129 | output = (logits,) + outputs[1:]
130 | return (loss,) + output if loss is not None else output
131 |
132 | return CausalLMOutputWithPast(
133 | loss=loss,
134 | logits=logits,
135 | past_key_values=outputs.past_key_values,
136 | hidden_states=outputs.hidden_states,
137 | attentions=outputs.attentions,
138 | )
139 |
140 |
141 | # Note: Grad Acc is not fixed in mistral at transformer 4.46.1
142 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/qwen2vl_mrope.py:
--------------------------------------------------------------------------------
1 | from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
2 |
3 |
4 | def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
5 | """
6 | Applies Multimodal Rotary Positional Embedding (M-RoPE) operation to query and key states.
7 |
8 | Args:
9 | q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10 | k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11 | cos (torch.Tensor): The cosine tensor of shape (3, bsz, seq_len, head_dim).
12 | sin (torch.Tensor): The sine tensor of shape (3, bsz, seq_len, head_dim).
13 | mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
14 | unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15 |
16 | Returns:
17 | Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the M-RoPE operation.
18 | """
19 |
20 | return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
21 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/rms_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from liger_kernel.ops.rms_norm import LigerRMSNormFunction
5 |
6 |
7 | class LigerRMSNorm(nn.Module):
8 | def __init__(
9 | self,
10 | hidden_size,
11 | eps=1e-6,
12 | offset=0.0,
13 | casting_mode="llama",
14 | init_fn="ones",
15 | in_place=True,
16 | ):
17 | super().__init__()
18 | assert init_fn in [
19 | "ones",
20 | "zeros",
21 | ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
22 | self.weight = nn.Parameter(
23 | torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
24 | )
25 | self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
26 | eps,
27 | offset,
28 | casting_mode,
29 | in_place,
30 | )
31 |
32 | def forward(self, hidden_states):
33 | return LigerRMSNormFunction.apply(
34 | hidden_states,
35 | self.weight,
36 | self.variance_epsilon,
37 | self.offset,
38 | self.casting_mode,
39 | self.in_place,
40 | )
41 |
42 | def extra_repr(self):
43 | return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
44 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/rope.py:
--------------------------------------------------------------------------------
1 | from liger_kernel.ops.rope import LigerRopeFunction
2 |
3 |
4 | def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
5 | """
6 | Applies Rotary Positional Embedding (RoPE) operation to query and key states.
7 |
8 | Args:
9 | q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10 | k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11 | cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
12 | sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
13 | position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
14 | unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15 |
16 | Returns:
17 | Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the RoPE operation.
18 | """
19 |
20 | return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
21 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/swiglu.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from liger_kernel.ops.swiglu import LigerSiLUMulFunction
4 |
5 |
6 | class LigerSwiGLUMLP(nn.Module):
7 | def __init__(self, config):
8 | super().__init__()
9 | self.config = config
10 | self.hidden_size = config.hidden_size
11 | self.intermediate_size = config.intermediate_size
12 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
13 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
14 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
15 | if config.hidden_act not in ["silu", "swish"]:
16 | raise ValueError(f"Activation function {config.hidden_act} not supported.")
17 |
18 | def forward(self, x):
19 |
20 | return self.down_proj(
21 | LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
22 | )
23 |
24 |
25 | class LigerBlockSparseTop2MLP(nn.Module):
26 | def __init__(self, config):
27 | super().__init__()
28 | self.ffn_dim = config.intermediate_size
29 | self.hidden_dim = config.hidden_size
30 |
31 | self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
32 | self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
33 | self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
34 |
35 | if config.hidden_act not in ["silu", "swish"]:
36 | raise ValueError(f"Activation function {config.hidden_act} not supported.")
37 |
38 | def forward(self, x):
39 |
40 | return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
41 |
42 |
43 | class LigerPhi3SwiGLUMLP(nn.Module):
44 | """
45 | Patch Phi3MLP to use LigerSiLUMulFunction
46 | https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/phi3/modeling_phi3.py#L241
47 | """
48 |
49 | def __init__(self, config):
50 | super().__init__()
51 | self.config = config
52 | self.hidden_size = config.hidden_size
53 | self.intermediate_size = config.intermediate_size
54 | self.gate_up_proj = nn.Linear(
55 | self.hidden_size, 2 * self.intermediate_size, bias=False
56 | )
57 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
58 | if config.hidden_act not in ["silu", "swish"]:
59 | raise ValueError(f"Activation function {config.hidden_act} not supported.")
60 |
61 | def forward(self, x):
62 | up_states = self.gate_up_proj(x)
63 | gate, up_states = up_states.chunk(2, dim=-1)
64 | return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
65 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from liger_kernel.transformers.trainer.orpo_trainer import ( # noqa: F401
3 | LigerORPOTrainer,
4 | )
5 | except ImportError:
6 | raise ImportError("Please `pip install trl` to use LigerORPOTrainer")
7 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/trainer/orpo_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, List, Literal, Tuple, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.distributed.fsdp import FullyShardedDataParallel
6 | from trl.trainer import ORPOTrainer
7 |
8 | from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
9 |
10 |
11 | class _FSDPForwardRedirection:
12 | """
13 | Modified based on
14 | https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
15 | Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
16 | post-forward can be properly executed around the method call.
17 | This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
18 | the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
19 | GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
20 | will not work because the first `nn.Emebedding` layer is not independently wrapped as a FSDP module (because of
21 | the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
22 | its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
23 | the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
24 | """
25 |
26 | def __call__(
27 | self,
28 | wrapper_module: FullyShardedDataParallel,
29 | method: Callable,
30 | *args: Any,
31 | **kwargs: Any,
32 | ):
33 | """Reroutes a method call through the `wrapper_module`'s `forward` method.
34 | Args:
35 | wrapper_module: The module that has `original_module` wrapped.
36 | original_module: The module that was wrapped inside `wrapper_module`.
37 | method_name: The name of the method that should be called on the `original_module` after inputs get
38 | redirected through the `wrapper_module`'s `forward` method.
39 | *args: The positional arguments to the method `method_name`. They will get passed to a patched
40 | `forward` method instead.
41 | **kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
42 | `forward` method instead.
43 | """
44 | assert isinstance(wrapper_module, FullyShardedDataParallel)
45 | original_module = wrapper_module._fsdp_wrapped_module
46 | original_forward = original_module.forward
47 |
48 | def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
49 | # Unpatch ourselves immediately before calling the method `method_name`
50 | # because itself may want to call the real `forward`
51 | original_module.forward = original_forward # type: ignore[method-assign]
52 | # Call the actual method e.g. `.training_step(...)`
53 | out = method(*_args, **_kwargs)
54 | return out
55 |
56 | # Patch the original_module's forward so we can redirect the arguments back to the real method
57 | original_module.forward = wrapped_forward # type: ignore[method-assign]
58 | wrapper_output = wrapper_module(*args, **kwargs)
59 | return wrapper_output
60 |
61 |
62 | class LigerORPOTrainer(ORPOTrainer):
63 | def concatenated_forward(
64 | self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
65 | ) -> Tuple[
66 | torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
67 | ]:
68 | """
69 | Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
70 | We do this to avoid doing two forward passes, because it's faster for FSDP.
71 | """
72 | concatenated_batch = self.concatenated_inputs(
73 | batch,
74 | is_encoder_decoder=self.is_encoder_decoder,
75 | label_pad_token_id=self.label_pad_token_id,
76 | padding_value=self.padding_value,
77 | device=self.accelerator.device,
78 | )
79 |
80 | model_kwargs = (
81 | {
82 | "decoder_input_ids": self._shift_right(
83 | concatenated_batch["concatenated_labels"]
84 | ),
85 | }
86 | if self.is_encoder_decoder
87 | else {}
88 | )
89 |
90 | if self.aux_loss_enabled:
91 | model_kwargs["output_router_logits"] = True
92 |
93 | if isinstance(model, FullyShardedDataParallel):
94 | outputs = _FSDPForwardRedirection()(
95 | model,
96 | model._fsdp_wrapped_module.model,
97 | concatenated_batch["concatenated_input_ids"],
98 | attention_mask=concatenated_batch["concatenated_attention_mask"],
99 | use_cache=False,
100 | **model_kwargs,
101 | )
102 | else:
103 | if isinstance(model, torch.nn.DataParallel):
104 | model = model.module
105 | outputs = model.model(
106 | concatenated_batch["concatenated_input_ids"],
107 | attention_mask=concatenated_batch["concatenated_attention_mask"],
108 | use_cache=False,
109 | **model_kwargs,
110 | )
111 |
112 | orpo_loss_fn = LigerFusedLinearORPOLoss(
113 | ignore_index=self.label_pad_token_id, beta=self.beta
114 | )
115 |
116 | def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
117 | return orpo_loss_fn(
118 | lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias
119 | )
120 |
121 | orpo_loss, aux_outputs = _FSDPForwardRedirection()(
122 | model,
123 | orpo_partial,
124 | model.lm_head,
125 | outputs.last_hidden_state,
126 | concatenated_batch["concatenated_labels"],
127 | )
128 | return orpo_loss, aux_outputs
129 |
130 | def get_batch_loss_metrics(
131 | self,
132 | model,
133 | batch: Dict[str, Union[List, torch.LongTensor]],
134 | train_eval: Literal["train", "eval"] = "train",
135 | ):
136 | """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
137 | metrics = {}
138 | loss, aux_outputs = self.concatenated_forward(model, batch)
139 | (
140 | policy_chosen_logps,
141 | policy_rejected_logps,
142 | policy_chosen_logits,
143 | policy_rejected_logits,
144 | policy_nll_loss,
145 | ) = aux_outputs[:5]
146 |
147 | # return loss, metrics
148 | chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[
149 | 5:
150 | ]
151 |
152 | reward_accuracies = (chosen_rewards > rejected_rewards).float()
153 |
154 | prefix = "eval_" if train_eval == "eval" else ""
155 | metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
156 | metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
157 | metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
158 | metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
159 | metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
160 | metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
161 | metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
162 | metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
163 | metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
164 | metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
165 | metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
166 | for k, v in metrics.items():
167 | metrics[k] = v.item()
168 |
169 | return loss, metrics
170 |
--------------------------------------------------------------------------------
/src/liger_kernel/transformers/trainer_integration.py:
--------------------------------------------------------------------------------
1 | # To not break HF Trainer integration
2 | from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
3 |
--------------------------------------------------------------------------------
/src/liger_kernel/triton/__init__.py:
--------------------------------------------------------------------------------
1 | from liger_kernel.triton.monkey_patch import ( # noqa: F401
2 | apply_liger_triton_cache_manager,
3 | )
4 |
--------------------------------------------------------------------------------
/src/liger_kernel/triton/monkey_patch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | from triton.runtime.cache import FileCacheManager
5 |
6 |
7 | class LigerTritonFileCacheManager(FileCacheManager):
8 | def put(self, data, filename, binary=True) -> str:
9 | if not self.cache_dir:
10 | raise RuntimeError("Could not create or locate cache dir")
11 | binary = isinstance(data, bytes)
12 | if not binary:
13 | data = str(data)
14 | assert self.lock_path is not None
15 | filepath = self._make_path(filename)
16 | # Random ID to avoid any collisions
17 | rnd_id = random.randint(0, 1000000)
18 | # we use the PID incase a bunch of these around so we can see what PID made it
19 | pid = os.getpid()
20 | # use temp dir to be robust against program interruptions
21 | temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
22 | os.makedirs(temp_dir, exist_ok=True)
23 | temp_path = os.path.join(temp_dir, filename)
24 |
25 | mode = "wb" if binary else "w"
26 | with open(temp_path, mode) as f:
27 | f.write(data)
28 | # Replace is guaranteed to be atomic on POSIX systems if it succeeds
29 | # so filepath cannot see a partial write
30 | os.replace(temp_path, filepath)
31 | os.removedirs(temp_dir)
32 | return filepath
33 |
34 |
35 | def apply_liger_triton_cache_manager():
36 | """
37 | Experimental feature to get around transient FileNotFoundError in triton compilation.
38 | For more details please see https://github.com/triton-lang/triton/pull/4295
39 | """
40 | os.environ["TRITON_CACHE_MANAGER"] = (
41 | "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"
42 | )
43 |
--------------------------------------------------------------------------------
/src/liger_kernel/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def infer_device():
5 | """
6 | Get current device name based on available devices
7 | """
8 | if torch.cuda.is_available():
9 | return "cuda"
10 | elif torch.xpu.is_available():
11 | return "xpu"
12 | else:
13 | return "cpu"
14 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/test/__init__.py
--------------------------------------------------------------------------------
/test/chunked_loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/test/chunked_loss/__init__.py
--------------------------------------------------------------------------------
/test/chunked_loss/test_simpo_loss.py:
--------------------------------------------------------------------------------
1 | from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO
2 | from test.utils import assert_verbose_allclose, set_seed
3 |
4 | import pytest
5 | import torch
6 |
7 | from liger_kernel.chunked_loss import LigerFusedLinearSimPOLoss
8 | from liger_kernel.chunked_loss.functional import liger_fused_linear_simpo
9 | from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
10 | from liger_kernel.utils import infer_device
11 |
12 | device = infer_device()
13 |
14 | # set random seed globally
15 | set_seed()
16 |
17 |
18 | class LigerLMHeadSimPO(torch.nn.Module):
19 | def __init__(
20 | self,
21 | H: int,
22 | V: int,
23 | dtype: torch.dtype,
24 | bias: bool = False,
25 | ignore_index: int = -100,
26 | beta: float = 0.1,
27 | alpha: float = 1.0,
28 | gamma: float = 0.5,
29 | ):
30 | super().__init__()
31 | self.lin = torch.nn.Linear(
32 | in_features=H, out_features=V, bias=bias, dtype=dtype
33 | )
34 | self.simpo_loss = LigerFusedLinearSimPOLoss(
35 | ignore_index=ignore_index, beta=beta, alpha=alpha, gamma=gamma
36 | )
37 |
38 | def forward(self, x, y):
39 | return self.simpo_loss(self.lin.weight, x, y, self.lin.bias)
40 |
41 |
42 | @pytest.mark.parametrize(
43 | "B, T, H, V",
44 | [
45 | (8, 128, 1024, 4096),
46 | (3, 47, 31, 123), # random shape
47 | ],
48 | )
49 | @pytest.mark.parametrize(
50 | "scalar, dtype, atol, rtol",
51 | [
52 | (1.0, torch.bfloat16, 5e-3, 5e-3),
53 | (1.0, torch.float32, 1e-5, 5e-4),
54 | ],
55 | )
56 | @pytest.mark.parametrize("bias", [True, False])
57 | @pytest.mark.parametrize(
58 | "ignore_index, beta, gamma", [(-100, 0.1, 0.5), (42, 0.2, 0.85)]
59 | )
60 | def test_correctness(
61 | B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, gamma
62 | ):
63 | B = 2 * B # SimPO loss requires B to be even
64 |
65 | torch_lm_head_simpo = TorchLMHeadCPO(
66 | H=H,
67 | V=V,
68 | dtype=dtype,
69 | bias=bias,
70 | ignore_index=ignore_index,
71 | beta=beta,
72 | loss_type="simpo",
73 | simpo_gamma=gamma,
74 | )
75 | liger_lm_head_simpo = LigerLMHeadSimPO(
76 | H=H,
77 | V=V,
78 | dtype=dtype,
79 | bias=bias,
80 | ignore_index=ignore_index,
81 | beta=beta,
82 | gamma=gamma,
83 | )
84 |
85 | torch_lm_head_simpo.lin.weight.data = liger_lm_head_simpo.lin.weight.data = (
86 | torch.randn(V, H, device=device, dtype=dtype)
87 | )
88 |
89 | if bias:
90 | torch_lm_head_simpo.lin.bias.data = liger_lm_head_simpo.lin.bias.data = (
91 | torch.randn(V, device=device, dtype=dtype)
92 | )
93 |
94 | _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
95 | input1 = _input.detach().clone().requires_grad_(True)
96 | input2 = _input.detach().clone().requires_grad_(True)
97 |
98 | target = torch.randint(
99 | 0,
100 | V,
101 | (
102 | B,
103 | T,
104 | ),
105 | device=device,
106 | dtype=torch.long,
107 | )
108 | # Assign some random number of elements as ignore_index
109 | num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
110 | indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
111 | target.view(-1)[indices_to_assign] = ignore_index
112 |
113 | loss1, aggregated_aux_outputs1 = torch_lm_head_simpo(input1, target)
114 | loss2, aggregated_aux_outputs2 = liger_lm_head_simpo(input2, target)
115 |
116 | assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
117 |
118 | assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2)
119 |
120 | for i in range(len(aggregated_aux_outputs1)):
121 | assert_verbose_allclose(
122 | aggregated_aux_outputs1[i],
123 | aggregated_aux_outputs2[i],
124 | atol=atol,
125 | rtol=rtol,
126 | )
127 |
128 | loss1.backward()
129 | loss2.backward()
130 |
131 | assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol)
132 | assert_verbose_allclose(
133 | torch_lm_head_simpo.lin.weight.grad,
134 | liger_lm_head_simpo.lin.weight.grad,
135 | atol=atol,
136 | rtol=rtol,
137 | )
138 | if bias:
139 | assert_verbose_allclose(
140 | torch_lm_head_simpo.lin.bias.grad,
141 | liger_lm_head_simpo.lin.bias.grad,
142 | atol=atol,
143 | rtol=rtol,
144 | )
145 |
146 |
147 | @pytest.mark.parametrize(
148 | "B, T, H, V",
149 | [
150 | (2, 2, 8, 8),
151 | (3, 47, 31, 123), # random shape
152 | ],
153 | )
154 | @pytest.mark.parametrize(
155 | "scalar, dtype, atol, rtol",
156 | [
157 | (1.0, torch.bfloat16, 5e-2, 5e-1),
158 | (1.0, torch.float32, 1e-5, 5e-4),
159 | ],
160 | )
161 | @pytest.mark.parametrize("bias", [True, False])
162 | def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias):
163 | B = 2 * B
164 |
165 | _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
166 | input1 = _input.detach().clone().requires_grad_(True)
167 | input2 = _input.detach().clone().requires_grad_(True)
168 |
169 | target = torch.randint(
170 | 0,
171 | V,
172 | (
173 | B,
174 | T,
175 | ),
176 | device=device,
177 | dtype=torch.long,
178 | )
179 |
180 | _weight = torch.randn(V, H, device=device, dtype=dtype)
181 | weight1 = _weight.detach().clone().requires_grad_(True)
182 | weight2 = _weight.detach().clone().requires_grad_(True)
183 |
184 | _bias = torch.randn(V, device=device, dtype=dtype) if bias else None
185 | bias1 = _bias.detach().clone().requires_grad_(True) if bias else None
186 | bias2 = _bias.detach().clone().requires_grad_(True) if bias else None
187 |
188 | loss1, aggregated_aux_outputs1 = LigerFusedLinearSimPOFunction.apply(
189 | input1, weight1, target, bias1
190 | )
191 | loss2, aggregated_aux_outputs2 = liger_fused_linear_simpo(
192 | input2, weight2, target, bias2
193 | )
194 |
195 | assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
196 |
197 | loss1.backward()
198 | loss2.backward()
199 |
200 | assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol)
201 | assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol)
202 | if bias:
203 | assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol)
204 |
--------------------------------------------------------------------------------
/test/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 |
5 | @pytest.fixture(autouse=True)
6 | def clear_cuda_cache():
7 | yield
8 | torch.cuda.empty_cache()
9 |
--------------------------------------------------------------------------------
/test/convergence/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linkedin/Liger-Kernel/0bb6c72d03f0fc570b9e954c42b723ab7e0c315f/test/convergence/__init__.py
--------------------------------------------------------------------------------
/test/resources/scripts/generate_tokenized_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from datasets import load_dataset
4 | from transformers import AutoTokenizer
5 |
6 |
7 | def prepare_dataset(tokenizer, text_file_path: str):
8 | """
9 | Tokenizes a text file where each line is a different example.
10 | Padding is applied to each example.
11 | """
12 | # Each line is a different example
13 | dataset = load_dataset("text", data_files={"train": text_file_path})
14 |
15 | def tokenize_function(examples):
16 | return tokenizer(
17 | examples["text"], padding="max_length", truncation=True, max_length=128
18 | )
19 |
20 | tokenized_dataset = dataset.map(
21 | tokenize_function, batched=True, remove_columns=["text"]
22 | )
23 | return tokenized_dataset["train"]
24 |
25 |
26 | def generate_tokenized_dataset(
27 | tokenizer_path: str, text_file_path: str, output_dir: str
28 | ) -> None:
29 | """
30 | Generate tokenized dataset from a text file, where each line is a different example.
31 |
32 | Args:
33 | tokenizer_path (str): Path to the directory containing the tokenizer files.
34 | text_file_path (str): Path to the text file to tokenize.
35 | output_dir (str): Directory where the tokenized dataset will be saved
36 | """
37 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
38 | tokenizer.pad_token = tokenizer.eos_token
39 |
40 | train_dataset = prepare_dataset(tokenizer, text_file_path)
41 | train_dataset.save_to_disk(output_dir)
42 |
43 |
44 | if __name__ == "__main__":
45 | # Example usage:
46 | # python generate_tokenized_dataset.py --tokenizer_path /shared/public/models/Mistral-7B --text_file_path ./../../resources/tiny_shakespeare.txt --output_dir ./../../resources/tiny_shakespeare_tokenized
47 | parser = argparse.ArgumentParser(
48 | description="Generate tokenized dataset from a text file."
49 | )
50 |
51 | # Add arguments
52 | parser.add_argument(
53 | "--tokenizer_path",
54 | type=str,
55 | required=True,
56 | help="Path to the directory containing the tokenizer files.",
57 | )
58 | parser.add_argument(
59 | "--text_file_path",
60 | type=str,
61 | required=True,
62 | help="Path to the text file to tokenize.",
63 | )
64 | parser.add_argument(
65 | "--output_dir",
66 | type=str,
67 | required=True,
68 | help="Directory where the tokenized dataset will be saved.",
69 | )
70 |
71 | # Parse the arguments
72 | args = parser.parse_args()
73 |
74 | # Call the function with parsed arguments
75 | generate_tokenized_dataset(
76 | tokenizer_path=args.tokenizer_path,
77 | text_file_path=args.text_file_path,
78 | output_dir=args.output_dir,
79 | )
80 |
--------------------------------------------------------------------------------
/test/transformers/test_auto_model.py:
--------------------------------------------------------------------------------
1 | from inspect import signature
2 | from unittest import mock
3 | from unittest.mock import MagicMock, patch
4 |
5 | from transformers import AutoConfig, AutoModelForCausalLM
6 |
7 | from liger_kernel.transformers import AutoLigerKernelForCausalLM
8 | from liger_kernel.transformers.monkey_patch import (
9 | MODEL_TYPE_TO_APPLY_LIGER_FN,
10 | apply_liger_kernel_to_llama,
11 | )
12 |
13 |
14 | def test_auto_liger_kernel_for_causal_lm_from_pretrained():
15 | pretrained_model_name_or_path = "/path/to/llama/model"
16 | model_args = ("model_arg1", "model_arg2")
17 |
18 | original_kwargs = {
19 | "valid_arg_1": "some_value_1",
20 | "valid_arg_2": 10,
21 | }
22 |
23 | # These args should be passed through to apply_liger_kernel_to_llama fn
24 | apply_liger_kernel_kwargs = {
25 | "rope": False,
26 | "swiglu": True,
27 | }
28 |
29 | kwargs = {**original_kwargs, **apply_liger_kernel_kwargs}
30 |
31 | # Mock the model config instance returned from AutoConfig.from_pretrained()
32 | mock_model_config = MagicMock()
33 | mock_model_config.model_type = "llama"
34 | mock_llama = mock.Mock()
35 |
36 | with patch.dict(
37 | MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama}
38 | ), mock.patch.object(
39 | AutoConfig, "from_pretrained", return_value=mock_model_config
40 | ), mock.patch.object(
41 | AutoModelForCausalLM, "from_pretrained", return_value="mock_model"
42 | ) as mock_super_from_pretrained:
43 |
44 | # Mock the function signature of apply_liger_kernel_to_llama
45 | mock_llama.__signature__ = signature(apply_liger_kernel_to_llama)
46 |
47 | model = AutoLigerKernelForCausalLM.from_pretrained(
48 | pretrained_model_name_or_path, *model_args, **kwargs
49 | )
50 |
51 | # Check that the apply_liger_kernel_to_llama mock was called with the correct kwargs
52 | mock_llama.assert_called_once_with(rope=False, swiglu=True)
53 | # Check that the original kwargs are passed to super().from_pretrained
54 | mock_super_from_pretrained.assert_called_once_with(
55 | pretrained_model_name_or_path, *model_args, **original_kwargs
56 | )
57 | assert model == "mock_model"
58 |
--------------------------------------------------------------------------------
/test/transformers/test_embedding.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch.nn import Embedding
4 |
5 | from liger_kernel.transformers.experimental.embedding import LigerEmbedding
6 | from liger_kernel.utils import infer_device
7 |
8 | device = infer_device()
9 |
10 | SLEEP_SECONDS = 0.1
11 |
12 |
13 | @pytest.mark.skip(reason="LigerEmbedding is under experimentation")
14 | @pytest.mark.parametrize(
15 | "num_embeddings, embedding_dim, padding_idx",
16 | [
17 | (100, 64, None),
18 | (100, 64, None),
19 | (1000, 128, None),
20 | (100, 60, None),
21 | (100, 60, None),
22 | (1000, 120, None),
23 | (1000, 500, None),
24 | (30522, 768, None),
25 | (100, 64, 0),
26 | (1000, 128, 50),
27 | (30522, 768, 1),
28 | ],
29 | )
30 | @pytest.mark.parametrize(
31 | "dtype, atol, rtol, device",
32 | [
33 | (torch.float32, 1e-6, 1e-5, device),
34 | ],
35 | )
36 | def test_embedding_correctness(
37 | num_embeddings, embedding_dim, padding_idx, dtype, atol, rtol, device
38 | ):
39 | print(
40 | f"\nTesting embedding with size: ({num_embeddings}, {embedding_dim}), padding_idx: {padding_idx}"
41 | )
42 | torch.manual_seed(42)
43 |
44 | torch_embedding = (
45 | Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
46 | .to(dtype)
47 | .to(device)
48 | )
49 | liger_embedding = (
50 | LigerEmbedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
51 | .to(dtype)
52 | .to(device)
53 | )
54 | liger_embedding.weight.data.copy_(torch_embedding.weight.data)
55 |
56 | if padding_idx is not None:
57 | input_ids = torch.randint(0, num_embeddings, (32 * 10,), device=device)
58 | input_ids[torch.randint(0, 32 * 10, (32 * 10 // 10,))] = padding_idx
59 | else:
60 | input_ids = torch.randint(0, num_embeddings, (32 * 10,), device=device)
61 |
62 | torch_output = torch_embedding(input_ids).view(32, 10, -1)
63 | liger_output = liger_embedding(input_ids).view(32, 10, -1)
64 |
65 | assert torch.allclose(torch_output, liger_output, atol=atol, rtol=rtol)
66 |
67 | grad_output = torch.randn_like(torch_output)
68 |
69 | torch_output.backward(grad_output)
70 | liger_output.backward(grad_output)
71 |
72 | assert torch.allclose(
73 | torch_embedding.weight.grad, liger_embedding.weight.grad, atol=atol, rtol=rtol
74 | )
75 |
--------------------------------------------------------------------------------
/test/transformers/test_geglu.py:
--------------------------------------------------------------------------------
1 | from test.utils import supports_bfloat16
2 |
3 | import pytest
4 | import torch
5 | from transformers.models.llama.configuration_llama import LlamaConfig
6 | from transformers.models.llama.modeling_llama import LlamaMLP
7 |
8 | from liger_kernel.ops.geglu import LigerGELUMulFunction
9 | from liger_kernel.transformers.functional import liger_geglu
10 | from liger_kernel.transformers.geglu import LigerGEGLUMLP
11 | from liger_kernel.utils import infer_device
12 |
13 | device = infer_device()
14 |
15 | LLAMA_CONFIG = LlamaConfig(
16 | hidden_size=4096,
17 | intermediate_size=11008,
18 | hidden_act="gelu_pytorch_tanh",
19 | )
20 | SLEEP_SECONDS = 0.1
21 |
22 |
23 | @pytest.mark.parametrize(
24 | "bsz, seq_len, hidden_size, intermediate_size",
25 | [
26 | (2, 2048, 2048, 4096),
27 | # weird shapes
28 | (9, 41, 341, 4231),
29 | ],
30 | )
31 | @pytest.mark.parametrize(
32 | "dtype, atol, rtol",
33 | [
34 | # atol is for small values: they have more difference, so set atol higher
35 | # rtol is for larger values: they are very close, so set rtol lower
36 | (torch.float32, 1e-0, 2e-6),
37 | pytest.param(
38 | torch.bfloat16,
39 | 1e4,
40 | 6e-3,
41 | marks=pytest.mark.skipif(
42 | not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
43 | ),
44 | ),
45 | ],
46 | )
47 | def test_correctness(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol):
48 | _input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype)
49 |
50 | x1 = _input.clone().requires_grad_(True)
51 | x2 = _input.clone().requires_grad_(True)
52 |
53 | # initialize weights
54 | G = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype)
55 | U = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype)
56 | D = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype)
57 |
58 | llama_mlp = LlamaMLP(config=LLAMA_CONFIG).to(device).to(dtype)
59 | llama_mlp.gate_proj.weight.data = G.T
60 | llama_mlp.up_proj.weight.data = U.T
61 | llama_mlp.down_proj.weight.data = D.T
62 |
63 | liger_mlp = LigerGEGLUMLP(config=LLAMA_CONFIG).to(device).to(dtype)
64 | liger_mlp.gate_proj.weight.data = G.T
65 | liger_mlp.up_proj.weight.data = U.T
66 | liger_mlp.down_proj.weight.data = D.T
67 |
68 | y1 = llama_mlp(x1)
69 | y2 = liger_mlp(x2)
70 |
71 | assert torch.allclose(y1, y2, atol=atol, rtol=rtol) is True
72 |
73 | dy = torch.randn_like(y1)
74 |
75 | y1.backward(dy.clone(), retain_graph=True)
76 | y2.backward(dy.clone(), retain_graph=True)
77 |
78 | assert (
79 | torch.allclose(
80 | llama_mlp.gate_proj.weight.grad,
81 | liger_mlp.gate_proj.weight.grad,
82 | atol=atol,
83 | rtol=rtol,
84 | )
85 | is True
86 | )
87 | assert (
88 | torch.allclose(
89 | llama_mlp.up_proj.weight.grad,
90 | liger_mlp.up_proj.weight.grad,
91 | atol=atol,
92 | rtol=rtol,
93 | )
94 | is True
95 | )
96 | assert (
97 | torch.allclose(
98 | llama_mlp.down_proj.weight.grad,
99 | liger_mlp.down_proj.weight.grad,
100 | atol=atol,
101 | rtol=rtol,
102 | )
103 | is True
104 | )
105 |
106 | assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) is True
107 |
108 |
109 | @pytest.mark.parametrize(
110 | "bsz, seq_len, size",
111 | [
112 | (2, 2, 8),
113 | # weird shapes
114 | (9, 7, 41),
115 | ],
116 | )
117 | @pytest.mark.parametrize(
118 | "dtype, atol, rtol",
119 | [
120 | # atol is for small values: they have more difference, so set atol higher
121 | # rtol is for larger values: they are very close, so set rtol lower
122 | (torch.float32, 1e-0, 2e-6),
123 | (torch.bfloat16, 1e4, 6e-3),
124 | ],
125 | )
126 | def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol):
127 | _input = torch.randn(bsz, seq_len, size, device=device, dtype=dtype)
128 | _b = torch.randn(bsz, seq_len, size, device=device, dtype=dtype)
129 |
130 | x1 = _input.clone().requires_grad_(True)
131 | x2 = _input.clone().requires_grad_(True)
132 |
133 | b1 = _b.clone().requires_grad_(True)
134 | b2 = _b.clone().requires_grad_(True)
135 |
136 | y1 = liger_geglu(a=x1, b=b1)
137 | y2 = LigerGELUMulFunction.apply(x2, b2)
138 |
139 | assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
140 |
141 | grad_output = torch.randn_like(y1)
142 |
143 | y1.backward(grad_output)
144 | y2.backward(grad_output)
145 |
146 | assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
147 | assert torch.allclose(b1.grad, b2.grad, atol=atol, rtol=rtol)
148 |
--------------------------------------------------------------------------------
/test/transformers/test_group_norm.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import pytest
4 | import torch
5 |
6 | from liger_kernel.transformers.group_norm import LigerGroupNorm
7 | from liger_kernel.utils import infer_device
8 |
9 | device = infer_device()
10 |
11 | random_batch_size = random.randint(1, 16)
12 | random_num_groups = random.randint(1, 32)
13 | random_num_channels = random_num_groups * random.randint(1, 16)
14 | random_hidden_size = random.randint(1, 8192)
15 |
16 |
17 | @pytest.mark.parametrize(
18 | "batch_size, num_channels, num_groups, hidden_size",
19 | [
20 | (1, 1, 1, 3),
21 | (1, 4, 2, 4),
22 | (16, 12, 3, 4096),
23 | (random_batch_size, random_num_channels, random_num_groups, random_hidden_size),
24 | ],
25 | )
26 | @pytest.mark.parametrize(
27 | "dtype, atol, rtol",
28 | [
29 | (torch.float32, 1e-4, 1e-4),
30 | ],
31 | )
32 | def test_liger_group_norm(
33 | batch_size, num_channels, num_groups, hidden_size, dtype, atol, rtol
34 | ):
35 | torch.manual_seed(0)
36 |
37 | _tensor = torch.randn(
38 | batch_size, num_channels, hidden_size, dtype=dtype, device=device
39 | )
40 |
41 | liger_x = _tensor.clone().detach().requires_grad_(True)
42 | torch_x = _tensor.clone().detach().requires_grad_(True)
43 |
44 | liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).to(device)
45 | torch_ln = (
46 | torch.nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, eps=1e-6)
47 | .to(dtype)
48 | .to(device)
49 | )
50 |
51 | with torch.no_grad():
52 | torch_ln.weight.copy_(liger_ln.weight)
53 | torch_ln.bias.copy_(liger_ln.bias)
54 |
55 | liger_output = liger_ln(
56 | liger_x,
57 | )
58 | torch_output = torch_ln(torch_x)
59 |
60 | assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol)
61 | grad_output = torch.randn_like(torch_x)
62 | liger_output.backward(grad_output, retain_graph=True)
63 | torch_output.backward(grad_output, retain_graph=True)
64 | assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol)
65 | assert torch.allclose(
66 | liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol
67 | ), "Bias grads different"
68 | assert torch.allclose(
69 | liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol
70 | ), "Weight grads different"
71 |
--------------------------------------------------------------------------------
/test/transformers/test_kl_div.py:
--------------------------------------------------------------------------------
1 | from test.utils import supports_bfloat16
2 |
3 | import pytest
4 | import torch
5 | from torch.nn import KLDivLoss
6 |
7 | from liger_kernel.transformers.kl_div import LigerKLDIVLoss
8 | from liger_kernel.utils import infer_device
9 |
10 | device = infer_device()
11 |
12 | _SHAPE_PARAMS = (
13 | "B, T, V",
14 | [
15 | (1, 4096, 32000),
16 | # weird shape
17 | (41, 401, 1271),
18 | ],
19 | )
20 |
21 | _DTYPE_PARAMS = (
22 | "dtype, atol, rtol",
23 | [
24 | pytest.param(
25 | torch.bfloat16,
26 | 1e-8,
27 | 5e-2,
28 | marks=pytest.mark.skipif(
29 | not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
30 | ),
31 | ),
32 | (torch.float32, 1e-8, 1e-6),
33 | (torch.float16, 1e-3, 1e-3),
34 | ],
35 | )
36 |
37 |
38 | def _test_correctness_once(
39 | target_kldiv,
40 | B,
41 | T,
42 | V,
43 | dtype,
44 | atol,
45 | rtol,
46 | reduction,
47 | log_target,
48 | is_last_layer=True,
49 | device=device,
50 | ):
51 | torch.manual_seed(0)
52 | torch_kldiv = KLDivLoss(reduction=reduction, log_target=log_target)
53 |
54 | input = torch.randn(
55 | B * T, V, device=device, dtype=dtype, requires_grad=True
56 | ).log_softmax(dim=-1)
57 |
58 | x1 = input.detach().clone().requires_grad_(True)
59 | x2 = input.detach().clone().requires_grad_(True)
60 |
61 | with torch.no_grad():
62 | target = torch.randn(B * T, V, device=device).softmax(dim=-1)
63 |
64 | output = torch_kldiv(x1, target)
65 | output2 = target_kldiv(x2, target)
66 | assert torch.allclose(output, output2, atol=atol, rtol=rtol)
67 |
68 | if (
69 | not is_last_layer
70 | ): # if the loss is the last layer, grad_output is 1.0 and mul op is skipped, testing for that reason
71 | output = output * 2.0
72 | output2 = output2 * 2.0
73 |
74 | if reduction == "none":
75 | return
76 |
77 | output.backward()
78 | output2.backward()
79 | assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
80 |
81 |
82 | @pytest.mark.parametrize(*_SHAPE_PARAMS)
83 | @pytest.mark.parametrize("log_target", [True, False])
84 | @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"])
85 | @pytest.mark.parametrize(*_DTYPE_PARAMS)
86 | def test_correctness(B, T, V, log_target, reduction, dtype, atol, rtol):
87 | liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target)
88 | _test_correctness_once(
89 | liger_kldiv, B, T, V, dtype, atol, rtol, reduction, log_target
90 | )
91 |
92 |
93 | @pytest.mark.parametrize(*_SHAPE_PARAMS)
94 | @pytest.mark.parametrize("log_target", [True, False])
95 | @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"])
96 | @pytest.mark.parametrize(*_DTYPE_PARAMS)
97 | def test_correctness_not_last(B, T, V, log_target, reduction, dtype, atol, rtol):
98 | liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target)
99 | _test_correctness_once(
100 | liger_kldiv,
101 | B,
102 | T,
103 | V,
104 | dtype,
105 | atol,
106 | rtol,
107 | reduction,
108 | log_target,
109 | is_last_layer=False,
110 | )
111 |
--------------------------------------------------------------------------------
/test/transformers/test_layer_norm.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from liger_kernel.ops.layer_norm import LigerLayerNormFunction
5 | from liger_kernel.transformers.functional import liger_layer_norm
6 | from liger_kernel.transformers.layer_norm import LigerLayerNorm
7 | from liger_kernel.utils import infer_device
8 |
9 | device = infer_device()
10 |
11 |
12 | @pytest.mark.parametrize(
13 | "batch_size, seq_len, hidden_size",
14 | [
15 | (2, 8, 64),
16 | (4, 16, 128),
17 | ],
18 | )
19 | @pytest.mark.parametrize(
20 | "dtype, atol, rtol",
21 | [
22 | (torch.float32, 1e-5, 1e-5),
23 | ],
24 | )
25 | def test_liger_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol):
26 | torch.manual_seed(0)
27 |
28 | x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device)
29 |
30 | liger_x = x.clone().requires_grad_(True)
31 | torch_x = x.clone().requires_grad_(True)
32 |
33 | liger_ln = LigerLayerNorm(hidden_size, eps=1e-6).to(dtype).to(device)
34 | torch_ln = torch.nn.LayerNorm(hidden_size, eps=1e-6).to(dtype).to(device)
35 |
36 | with torch.no_grad():
37 | torch_ln.weight.copy_(liger_ln.weight)
38 | torch_ln.bias.copy_(liger_ln.bias)
39 |
40 | liger_output = liger_ln(liger_x)
41 | torch_output = torch_ln(torch_x)
42 |
43 | assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol)
44 |
45 | grad_output = torch.randn_like(x)
46 | liger_output.backward(grad_output, retain_graph=True)
47 | torch_output.backward(grad_output, retain_graph=True)
48 |
49 | assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol)
50 | assert torch.allclose(
51 | liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol
52 | )
53 | assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol)
54 |
55 |
56 | @pytest.mark.parametrize(
57 | "batch_size, seq_len, hidden_size",
58 | [
59 | (2, 8, 64),
60 | (4, 16, 128),
61 | ],
62 | )
63 | @pytest.mark.parametrize(
64 | "dtype, atol, rtol",
65 | [
66 | (torch.float32, 1e-5, 1e-5),
67 | ],
68 | )
69 | def test_liger_layer_norm_functional(
70 | hidden_size, batch_size, seq_len, dtype, atol, rtol
71 | ):
72 | torch.manual_seed(0)
73 |
74 | input = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device)
75 |
76 | x1 = input.clone().requires_grad_(True)
77 | x2 = input.clone().requires_grad_(True)
78 |
79 | w = torch.randn(hidden_size, device=device, dtype=dtype)
80 |
81 | w1 = w.clone().requires_grad_(True)
82 | w2 = w.clone().requires_grad_(True)
83 |
84 | b = torch.randn(hidden_size, device=device, dtype=dtype)
85 |
86 | b1 = b.clone().requires_grad_(True)
87 | b2 = b.clone().requires_grad_(True)
88 |
89 | y1 = liger_layer_norm(X=x1, W=w1, B=b1, eps=1e-6)
90 | y2 = LigerLayerNormFunction.apply(x2, w2, b2, 1e-6)
91 |
92 | assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
93 |
94 | grad_output = torch.randn_like(y2)
95 |
96 | y1.backward(grad_output, retain_graph=True)
97 | y2.backward(grad_output, retain_graph=True)
98 |
99 | assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
100 | assert torch.allclose(w1.grad, w2.grad, atol=atol, rtol=rtol)
101 | assert torch.allclose(b1.grad, b2.grad, atol=atol, rtol=rtol)
102 |
--------------------------------------------------------------------------------
/test/transformers/test_mm_int8int2.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from liger_kernel.ops.experimental.mm_int8int2 import (
5 | matmul,
6 | pack_weights,
7 | unpack_weights,
8 | )
9 | from liger_kernel.utils import infer_device
10 |
11 | device = infer_device()
12 |
13 |
14 | # input_features = size*4 when the weight matrix is unpacked
15 | @pytest.mark.skip(reason="mm_int8int2 is under experimentation")
16 | @pytest.mark.parametrize(
17 | "size",
18 | [
19 | 2048,
20 | 1024,
21 | 512,
22 | ],
23 | )
24 | @pytest.mark.parametrize(
25 | "batch_size",
26 | [1, 2, 3, 8],
27 | )
28 | @pytest.mark.parametrize(
29 | "seq_len",
30 | [1, 7, 16, 2048],
31 | )
32 | @pytest.mark.parametrize(
33 | "out_features",
34 | [
35 | 1024,
36 | 2048,
37 | 4096,
38 | 10000,
39 | ],
40 | )
41 | @pytest.mark.parametrize(
42 | "atol, rtol, device",
43 | [
44 | (1e-2, 1e-2, device),
45 | ],
46 | )
47 | def test_kernel_correctness(
48 | batch_size, seq_len, out_features, size, atol, rtol, device
49 | ):
50 | print(f"\nTesting kernel with size: {size}, atol: {atol}, rtol: {rtol}")
51 |
52 | # Generate the random tensors
53 | ht = torch.randint(
54 | -127, 127, (batch_size, seq_len, size * 4), device=device, dtype=torch.int8
55 | )
56 | u = torch.randint(0, 255, (out_features, size), device=device, dtype=torch.uint8)
57 |
58 | # Calculate dimensions
59 | B, M, N = ht.size()
60 |
61 | # Compute triton output
62 | triton_output = matmul(ht.view(B * M, N), u.T.contiguous()).view(B, M, -1)
63 |
64 | # Unpack weights and compute torch output
65 | unpacked = unpack_weights(u.T, bits=2).T
66 | torch_output = torch.matmul(
67 | ht.to(torch.float32), unpacked.T.contiguous().to(torch.float32)
68 | )
69 |
70 | # Print the results (optional, can be commented out)
71 | print("triton_output =", triton_output)
72 | print("torch_output =", torch_output)
73 |
74 | # Check if outputs are close within the given tolerances
75 | assert torch.allclose(
76 | triton_output, torch_output.to(torch.int32), atol=atol, rtol=rtol
77 | ), "Results differ"
78 |
79 |
80 | @pytest.mark.skip(reason="mm_int8int2 is under experimentation")
81 | @pytest.mark.parametrize(
82 | "size",
83 | [
84 | 2048,
85 | 1024,
86 | 512,
87 | ],
88 | )
89 | @pytest.mark.parametrize(
90 | "out_features",
91 | [
92 | 1024,
93 | 2048,
94 | 4096,
95 | 10000,
96 | ],
97 | )
98 | @pytest.mark.parametrize(
99 | "device",
100 | [
101 | device,
102 | ],
103 | )
104 | def test_unpack_pack_correctness(out_features, size, device):
105 | u = torch.randint(0, 255, (out_features, size), device=device, dtype=torch.uint8)
106 |
107 | assert (
108 | pack_weights(unpack_weights(u.T), 2) == u.T
109 | ).all(), "Packed weights do not match original weights."
110 |
--------------------------------------------------------------------------------
/test/transformers/test_qwen2vl_mrope.py:
--------------------------------------------------------------------------------
1 | from test.utils import supports_bfloat16
2 |
3 | import pytest
4 | import torch
5 |
6 | try:
7 | from transformers.models.qwen2_vl.modeling_qwen2_vl import (
8 | Qwen2VLRotaryEmbedding,
9 | apply_multimodal_rotary_pos_emb,
10 | )
11 |
12 | IS_QWEN_AVAILABLE = True
13 | except Exception:
14 | IS_QWEN_AVAILABLE = False
15 |
16 | from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
17 | from liger_kernel.transformers.functional import liger_qwen2vl_mrope
18 | from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
19 | from liger_kernel.utils import infer_device
20 |
21 | device = infer_device()
22 |
23 |
24 | @pytest.mark.skipif(
25 | not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers."
26 | )
27 | @pytest.mark.parametrize("bsz", [1, 2])
28 | @pytest.mark.parametrize("seq_len", [128, 131])
29 | @pytest.mark.parametrize("num_q_heads, num_kv_heads", [(64, 8), (28, 4), (12, 2)])
30 | @pytest.mark.parametrize(
31 | "head_dim, mrope_section",
32 | [
33 | (128, [16, 24, 24]),
34 | (96, [16, 16, 16]),
35 | (64, [8, 12, 12]),
36 | ],
37 | )
38 | @pytest.mark.parametrize(
39 | "dtype, atol, rtol",
40 | [
41 | (torch.float32, 1e-5, 1e-5),
42 | pytest.param(
43 | torch.bfloat16,
44 | 1e-1,
45 | 1e-5,
46 | marks=pytest.mark.skipif(
47 | not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
48 | ),
49 | ),
50 | ],
51 | )
52 | def test_correctness(
53 | bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section, dtype, atol, rtol
54 | ):
55 | rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device)
56 |
57 | _tensor_q = (
58 | torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device)
59 | .transpose(1, 2)
60 | .to(dtype)
61 | )
62 |
63 | _tensor_k = (
64 | torch.randn((bsz, seq_len, num_kv_heads, head_dim), device=device)
65 | .transpose(1, 2)
66 | .to(dtype)
67 | )
68 |
69 | q1 = _tensor_q.clone().requires_grad_(True)
70 | k1 = _tensor_k.clone().requires_grad_(True)
71 |
72 | q2 = _tensor_q.clone().requires_grad_(True)
73 | k2 = _tensor_k.clone().requires_grad_(True)
74 |
75 | # NOTE: this position ids distribution is different from the real one, just to test op correctness
76 | pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view(
77 | 3, bsz, seq_len
78 | )
79 | cos, sin = rotary_emb(k1, pos_ids)
80 |
81 | # validate forward pass
82 | hf_q, hf_k = apply_multimodal_rotary_pos_emb(q1, k1, cos, sin, mrope_section)
83 | tt_q, tt_k = liger_multimodal_rotary_pos_emb(q2, k2, cos, sin, mrope_section)
84 | torch.testing.assert_close(hf_q, tt_q, atol=atol, rtol=rtol)
85 | torch.testing.assert_close(hf_k, tt_k, atol=atol, rtol=rtol)
86 |
87 | # validate backward pass
88 | dq, dk = (
89 | torch.randn_like(hf_q, device=device),
90 | torch.randn_like(hf_k, device=device).to(dtype),
91 | )
92 |
93 | q1_grad, k1_grad = torch.autograd.grad(
94 | (hf_q, hf_k), (q1, k1), (dq, dk), allow_unused=True
95 | )
96 | q2_grad, k2_grad = torch.autograd.grad(
97 | (tt_q, tt_k), (q2, k2), (dq.clone(), dk.clone()), allow_unused=True
98 | )
99 |
100 | torch.testing.assert_close(q1_grad, q2_grad, atol=atol, rtol=rtol)
101 | torch.testing.assert_close(k1_grad, k2_grad, atol=atol, rtol=rtol)
102 |
103 |
104 | @pytest.mark.skipif(
105 | not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers."
106 | )
107 | @pytest.mark.parametrize(
108 | "bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section",
109 | [
110 | (1, 2, 2, 2, 8, [2, 1, 1]),
111 | (1, 2, 1, 2, 8, [2, 1, 1]),
112 | ],
113 | )
114 | @pytest.mark.parametrize(
115 | "dtype, atol, rtol",
116 | [
117 | (torch.float32, 1e-5, 1e-5),
118 | (torch.bfloat16, 1e-1, 1e-5),
119 | ],
120 | )
121 | def test_functional_correctness(
122 | bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section, dtype, atol, rtol
123 | ):
124 | _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype)
125 | _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype)
126 |
127 | q1 = _q.clone().requires_grad_(True)
128 | q2 = _q.clone().requires_grad_(True)
129 |
130 | k1 = _k.clone().requires_grad_(True)
131 | k2 = _k.clone().requires_grad_(True)
132 |
133 | rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device)
134 |
135 | pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view(
136 | 3, bsz, seq_len
137 | )
138 | cos, sin = rotary_emb(k1, pos_ids)
139 |
140 | functional_q, functional_k = liger_qwen2vl_mrope(q1, k1, cos, sin, mrope_section)
141 | class_q, class_k = LigerQwen2VLMRopeFunction.apply(q2, k2, cos, sin, mrope_section)
142 |
143 | torch.testing.assert_close(functional_q, class_q, atol=atol, rtol=rtol)
144 | torch.testing.assert_close(functional_k, class_k, atol=atol, rtol=rtol)
145 |
146 | dq, dk = torch.randn_like(functional_q), torch.randn_like(functional_k)
147 |
148 | dq1, dk1 = dq.clone(), dk.clone()
149 | dq2, dk2 = dq.clone(), dk.clone()
150 |
151 | q1_grad, k1_grad = torch.autograd.grad(
152 | (functional_q, functional_k),
153 | (q1, k1),
154 | (dq1, dk1),
155 | allow_unused=True,
156 | )
157 |
158 | q2_grad, k2_grad = torch.autograd.grad(
159 | (class_q, class_k),
160 | (q2, k2),
161 | (dq2, dk2),
162 | allow_unused=True,
163 | )
164 |
165 | torch.testing.assert_close(q1_grad, q2_grad, atol=atol, rtol=rtol)
166 | torch.testing.assert_close(k1_grad, k2_grad, atol=atol, rtol=rtol)
167 |
--------------------------------------------------------------------------------
/test/transformers/test_rms_norm.py:
--------------------------------------------------------------------------------
1 | import os
2 | from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16
3 |
4 | import pytest
5 | import torch
6 | import torch.nn as nn
7 |
8 | from liger_kernel.ops.rms_norm import LigerRMSNormFunction
9 | from liger_kernel.transformers.functional import liger_rms_norm
10 | from liger_kernel.transformers.rms_norm import LigerRMSNorm
11 | from liger_kernel.utils import infer_device
12 |
13 | device = infer_device()
14 |
15 | set_seed(42)
16 | torch.use_deterministic_algorithms(True)
17 |
18 | # Only setting torch.use_deterministic_algorithms(True) might throw the following error:
19 | # RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`,
20 | # but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an
21 | # environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information,
22 | # go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
23 |
24 | if device == "cuda":
25 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
26 |
27 | SLEEP_SECONDS = 0.1
28 |
29 |
30 | class BaseRMSNorm(nn.Module):
31 | def __init__(self, hidden_size, eps=1e-6):
32 | super().__init__()
33 | self.weight = nn.Parameter(torch.ones(hidden_size))
34 | self.variance_epsilon = eps
35 |
36 | def forward(self, hidden_states):
37 | input_dtype = hidden_states.dtype
38 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
39 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
40 | return self.weight * hidden_states.to(input_dtype)
41 |
42 |
43 | # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L112
44 | class LlamaRMSNorm(nn.Module):
45 | def __init__(self, hidden_size, eps=1e-6):
46 | """
47 | LlamaRMSNorm is equivalent to T5LayerNorm
48 | """
49 | super().__init__()
50 | self.weight = nn.Parameter(torch.ones(hidden_size))
51 | self.variance_epsilon = eps
52 |
53 | def forward(self, hidden_states):
54 | input_dtype = hidden_states.dtype
55 | hidden_states = hidden_states.to(torch.float32)
56 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
57 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
58 | return self.weight * hidden_states.to(input_dtype)
59 |
60 |
61 | # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L122
62 | class GemmaRMSNorm(nn.Module):
63 | def __init__(self, hidden_size: int, eps: float = 1e-6):
64 | super().__init__()
65 | self.eps = eps
66 | self.weight = nn.Parameter(torch.ones(hidden_size))
67 |
68 | def _norm(self, x):
69 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
70 |
71 | def forward(self, x):
72 | output = self._norm(x.float())
73 | output = output * (1.0 + self.weight.float())
74 | return output.type_as(x)
75 |
76 |
77 | @pytest.mark.flaky(reruns=3, reruns_delay=2)
78 | @pytest.mark.parametrize(
79 | "bs, sl, hd",
80 | [
81 | (2, 128, 512),
82 | # weird shapes
83 | (5, 123, 123),
84 | ],
85 | )
86 | @pytest.mark.parametrize(
87 | "dtype, atol, rtol",
88 | [
89 | (torch.float32, 1e-4, 1e-6),
90 | pytest.param(
91 | torch.bfloat16,
92 | 2e-1,
93 | 2e-2,
94 | marks=pytest.mark.skipif(
95 | not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
96 | ),
97 | ),
98 | ],
99 | )
100 | @pytest.mark.parametrize(
101 | "reference, offset, casting_mode",
102 | [
103 | (LlamaRMSNorm, 0.0, "llama"),
104 | (GemmaRMSNorm, 1.0, "gemma"),
105 | (BaseRMSNorm, 0.0, "none"),
106 | ],
107 | )
108 | @pytest.mark.parametrize(
109 | "in_place",
110 | [
111 | True,
112 | False,
113 | ],
114 | )
115 | def test_correctness(
116 | bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode, in_place
117 | ):
118 | _tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype)
119 |
120 | h1 = _tensor.clone().requires_grad_(True)
121 | h2 = _tensor.clone().requires_grad_(True)
122 |
123 | # do
124 | do = torch.randn(bs, sl, hd, device=device, dtype=dtype)
125 |
126 | # reference (llama or gemma)
127 | ref_rms = reference(hidden_size=hd).to(device).to(dtype)
128 | ref_o = ref_rms(h1)
129 | ref_o.backward(do, retain_graph=True)
130 |
131 | # triton
132 | triton_rms = (
133 | LigerRMSNorm(
134 | hidden_size=hd, offset=offset, casting_mode=casting_mode, in_place=in_place
135 | )
136 | .to(device)
137 | .to(dtype)
138 | )
139 | triton_o = triton_rms(h2)
140 | triton_o.backward(do, retain_graph=True)
141 |
142 | assert_verbose_allclose(ref_o, triton_o, atol=atol, rtol=rtol)
143 | assert_verbose_allclose(
144 | ref_rms.weight.grad, triton_rms.weight.grad, atol=atol, rtol=rtol
145 | )
146 | print(f"{h1.grad=}")
147 | print(f"{h2.grad=}")
148 | assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol, max_print=20)
149 |
150 |
151 | @pytest.mark.parametrize(
152 | "bs, sl, hd",
153 | [
154 | (2, 2, 8),
155 | # weird shapes
156 | (9, 7, 41),
157 | ],
158 | )
159 | @pytest.mark.parametrize(
160 | "dtype, atol, rtol",
161 | [
162 | (torch.float32, 1e-4, 1e-6),
163 | (torch.bfloat16, 2e-1, 2e-2),
164 | ],
165 | )
166 | @pytest.mark.parametrize(
167 | "reference, offset, casting_mode",
168 | [
169 | (LlamaRMSNorm, 0.0, "llama"),
170 | (GemmaRMSNorm, 1.0, "gemma"),
171 | ],
172 | )
173 | def test_correctness_functional(
174 | bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode
175 | ):
176 | # h
177 | _tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype)
178 |
179 | h1 = _tensor.clone().requires_grad_(True)
180 | h2 = _tensor.clone().requires_grad_(True)
181 |
182 | w = torch.randn(hd, device=device, dtype=dtype)
183 |
184 | y1 = liger_rms_norm(X=h1, W=w, eps=1e-6, offset=offset, casting_mode=casting_mode)
185 | y2 = LigerRMSNormFunction.apply(h2, w, 1e-6, offset, casting_mode)
186 |
187 | assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
188 |
189 | grad = torch.randn_like(y2)
190 |
191 | y1.backward(grad)
192 | y2.backward(grad)
193 |
194 | assert torch.allclose(h1.grad, h2.grad, atol=atol, rtol=rtol)
195 |
--------------------------------------------------------------------------------
/test/transformers/test_rope.py:
--------------------------------------------------------------------------------
1 | from test.utils import supports_bfloat16
2 |
3 | import pytest
4 | import torch
5 | from transformers.models.llama.modeling_llama import (
6 | LlamaRotaryEmbedding,
7 | apply_rotary_pos_emb,
8 | )
9 |
10 | from liger_kernel.ops.rope import LigerRopeFunction
11 | from liger_kernel.transformers.functional import liger_rope
12 | from liger_kernel.transformers.rope import liger_rotary_pos_emb
13 | from liger_kernel.utils import infer_device
14 |
15 | device = infer_device()
16 |
17 | SLEEP_SECONDS = 0.1
18 |
19 |
20 | @pytest.mark.parametrize(
21 | "bsz, seq_len, num_q_heads, num_kv_heads, head_dim",
22 | [
23 | (1, 128, 32, 32, 64),
24 | (2, 128, 32, 32, 64),
25 | # different q/k heads
26 | (1, 128, 32, 8, 64),
27 | (2, 128, 32, 8, 64),
28 | # weird shapes
29 | # HuggingFace llama/mistral source code doesn't support odd head dimension
30 | # so we don't test it here
31 | (3, 423, 73, 213, 92),
32 | (3, 423, 73, 155, 92),
33 | ],
34 | )
35 | @pytest.mark.parametrize(
36 | "dtype, atol, rtol",
37 | [
38 | (torch.float32, 1e-5, 1e-5),
39 | pytest.param(
40 | torch.bfloat16,
41 | 1e-1,
42 | 1e-5,
43 | marks=pytest.mark.skipif(
44 | not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
45 | ),
46 | ),
47 | ],
48 | )
49 | def test_correctness(
50 | bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol
51 | ):
52 | rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
53 |
54 | _tensor_q = (
55 | torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device)
56 | .transpose(1, 2)
57 | .to(dtype)
58 | )
59 |
60 | _tensor_k = (
61 | torch.randn((bsz, seq_len, num_kv_heads, head_dim), device=device)
62 | .transpose(1, 2)
63 | .to(dtype)
64 | )
65 |
66 | q1 = _tensor_q.clone().requires_grad_(True)
67 | k1 = _tensor_k.clone().requires_grad_(True)
68 |
69 | q2 = _tensor_q.clone().requires_grad_(True)
70 | k2 = _tensor_k.clone().requires_grad_(True)
71 |
72 | pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
73 | cos, sin = rotary_emb(k1, pos_ids)
74 |
75 | # validate forward pass
76 | hf_q, hf_k = apply_rotary_pos_emb(q1, k1, cos, sin, pos_ids)
77 | tt_q, tt_k = liger_rotary_pos_emb(q2, k2, cos, sin)
78 | assert torch.allclose(hf_q, tt_q, atol=atol, rtol=rtol)
79 | assert torch.allclose(hf_k, tt_k, atol=atol, rtol=rtol)
80 |
81 | # validate backward pass
82 | dq, dk = (
83 | torch.randn_like(hf_q, device=device),
84 | torch.randn_like(hf_k, device=device).to(dtype),
85 | )
86 |
87 | q1_grad, k1_grad = torch.autograd.grad(
88 | (hf_q, hf_k), (q1, k1), (dq, dk), allow_unused=True
89 | )
90 | q2_grad, k2_grad = torch.autograd.grad(
91 | (tt_q, tt_k), (q2, k2), (dq.clone(), dk.clone()), allow_unused=True
92 | )
93 |
94 | assert torch.allclose(q1_grad, q2_grad, atol=atol, rtol=rtol)
95 | assert torch.allclose(k1_grad, k2_grad, atol=atol, rtol=rtol)
96 |
97 |
98 | @pytest.mark.parametrize(
99 | "bsz, seq_len, num_q_heads, num_kv_heads, head_dim",
100 | [
101 | (1, 2, 2, 2, 8),
102 | (1, 2, 1, 2, 8),
103 | # weird shapes
104 | (9, 7, 41, 41, 41),
105 | ],
106 | )
107 | @pytest.mark.parametrize(
108 | "dtype, atol, rtol",
109 | [
110 | (torch.float32, 1e-5, 1e-5),
111 | (torch.bfloat16, 1e-1, 1e-5),
112 | ],
113 | )
114 | def test_functional_correctness(
115 | bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol
116 | ):
117 | _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype)
118 | _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype)
119 |
120 | q1 = _q.clone().requires_grad_(True)
121 | q2 = _q.clone().requires_grad_(True)
122 |
123 | k1 = _k.clone().requires_grad_(True)
124 | k2 = _k.clone().requires_grad_(True)
125 |
126 | rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
127 |
128 | pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
129 | cos, sin = rotary_emb(k1, pos_ids)
130 |
131 | functional_q, functional_k = liger_rope(q=q1, k=k1, cos=cos, sin=sin)
132 | class_q, class_k = LigerRopeFunction.apply(q2, k2, cos, sin)
133 |
134 | assert torch.allclose(functional_q, class_q, atol=atol, rtol=rtol)
135 | assert torch.allclose(functional_k, class_k, atol=atol, rtol=rtol)
136 |
137 | dq, dk = torch.randn_like(functional_q), torch.randn_like(functional_k)
138 |
139 | dq1, dk1 = dq.clone(), dk.clone()
140 | dq2, dk2 = dq.clone(), dk.clone()
141 |
142 | q1_grad, k1_grad = torch.autograd.grad(
143 | (functional_q, functional_k),
144 | (q1, k1),
145 | (dq1, dk1),
146 | allow_unused=True,
147 | )
148 |
149 | q2_grad, k2_grad = torch.autograd.grad(
150 | (class_q, class_k),
151 | (q2, k2),
152 | (dq2, dk2),
153 | allow_unused=True,
154 | )
155 |
156 | assert torch.allclose(q1_grad, q2_grad, atol=atol, rtol=rtol)
157 | assert torch.allclose(k1_grad, k2_grad, atol=atol, rtol=rtol)
158 |
--------------------------------------------------------------------------------
/test/transformers/test_swiglu.py:
--------------------------------------------------------------------------------
1 | from test.utils import supports_bfloat16
2 |
3 | import pytest
4 | import torch
5 | from transformers.models.llama.configuration_llama import LlamaConfig
6 | from transformers.models.llama.modeling_llama import LlamaMLP
7 | from transformers.models.phi3.configuration_phi3 import Phi3Config
8 | from transformers.models.phi3.modeling_phi3 import Phi3MLP
9 |
10 | from liger_kernel.ops.swiglu import LigerSiLUMulFunction
11 | from liger_kernel.transformers.functional import liger_swiglu
12 | from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP, LigerSwiGLUMLP
13 | from liger_kernel.utils import infer_device
14 |
15 | device = infer_device()
16 |
17 | LLAMA_CONFIG = LlamaConfig(
18 | hidden_size=4096,
19 | intermediate_size=11008,
20 | hidden_act="silu",
21 | )
22 | PHI3_CONFIG = Phi3Config(
23 | hidden_size=4096,
24 | intermediate_size=11008,
25 | hidden_act="silu",
26 | )
27 | SLEEP_SECONDS = 0.1
28 |
29 |
30 | @pytest.mark.parametrize(
31 | "bsz, seq_len, hidden_size, intermediate_size",
32 | [
33 | (2, 256, 256, 512),
34 | # weird shapes
35 | (6, 42, 123, 431),
36 | ],
37 | )
38 | @pytest.mark.parametrize(
39 | "dtype, atol, rtol",
40 | [
41 | # atol is for small values: they have more difference, so set atol higher
42 | # rtol is for larger values: they are very close, so set rtol lower
43 | (torch.float32, 1e-0, 1e-5),
44 | # TODO: we should find a better way to tune this. 1e4 is too large apparently
45 | pytest.param(
46 | torch.bfloat16,
47 | 1e4,
48 | 1e-2,
49 | marks=pytest.mark.skipif(
50 | not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
51 | ),
52 | ),
53 | ],
54 | )
55 | def test_correctness_llamamlp(
56 | bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol
57 | ):
58 | _input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype)
59 |
60 | x1 = _input.clone().requires_grad_(True)
61 | x2 = _input.clone().requires_grad_(True)
62 |
63 | # initialize weights
64 | G = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype)
65 | U = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype)
66 | D = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype)
67 |
68 | llama_mlp = LlamaMLP(config=LLAMA_CONFIG).to(device).to(dtype)
69 | llama_mlp.gate_proj.weight.data = G.T
70 | llama_mlp.up_proj.weight.data = U.T
71 | llama_mlp.down_proj.weight.data = D.T
72 |
73 | liger_mlp = LigerSwiGLUMLP(config=LLAMA_CONFIG).to(device).to(dtype)
74 | liger_mlp.gate_proj.weight.data = G.T
75 | liger_mlp.up_proj.weight.data = U.T
76 | liger_mlp.down_proj.weight.data = D.T
77 |
78 | y1 = llama_mlp(x1)
79 | y2 = liger_mlp(x2)
80 |
81 | assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
82 |
83 | dy = torch.randn_like(y1)
84 |
85 | y1.backward(dy.clone(), retain_graph=True)
86 | y2.backward(dy.clone(), retain_graph=True)
87 |
88 | assert torch.allclose(
89 | llama_mlp.gate_proj.weight.grad,
90 | liger_mlp.gate_proj.weight.grad,
91 | atol=atol,
92 | rtol=rtol,
93 | )
94 | assert torch.allclose(
95 | llama_mlp.up_proj.weight.grad,
96 | liger_mlp.up_proj.weight.grad,
97 | atol=atol,
98 | rtol=rtol,
99 | )
100 | assert torch.allclose(
101 | llama_mlp.down_proj.weight.grad,
102 | liger_mlp.down_proj.weight.grad,
103 | atol=atol,
104 | rtol=rtol,
105 | )
106 |
107 | assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
108 |
109 |
110 | @pytest.mark.parametrize(
111 | "bsz, seq_len, hidden_size, intermediate_size",
112 | [
113 | (2, 256, 256, 512),
114 | # weird shapes
115 | (6, 42, 123, 431),
116 | ],
117 | )
118 | @pytest.mark.parametrize(
119 | "dtype, atol, rtol",
120 | [
121 | # atol is for small values: they have more difference, so set atol higher
122 | # rtol is for larger values: they are very close, so set rtol lower
123 | (torch.float32, 1e-0, 1e-5),
124 | # TODO: we should find a better way to tune this. 1e4 is too large apparently
125 | pytest.param(
126 | torch.bfloat16,
127 | 1e4,
128 | 1e-2,
129 | marks=pytest.mark.skipif(
130 | not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
131 | ),
132 | ),
133 | ],
134 | )
135 | def test_correctness_phi3mlp(
136 | bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol
137 | ):
138 | _input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype)
139 |
140 | x1 = _input.clone().requires_grad_(True)
141 | x2 = _input.clone().requires_grad_(True)
142 |
143 | # initialize weights
144 | GU = torch.randn(hidden_size, intermediate_size * 2, device=device, dtype=dtype)
145 | D = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype)
146 |
147 | phi3_mlp = Phi3MLP(config=PHI3_CONFIG).to(device).to(dtype)
148 | phi3_mlp.gate_up_proj.weight.data = GU.T
149 | phi3_mlp.down_proj.weight.data = D.T
150 |
151 | liger_mlp = LigerPhi3SwiGLUMLP(config=PHI3_CONFIG).to(device).to(dtype)
152 | liger_mlp.gate_up_proj.weight.data = GU.T
153 | liger_mlp.down_proj.weight.data = D.T
154 |
155 | y1 = phi3_mlp(x1)
156 | y2 = liger_mlp(x2)
157 |
158 | assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
159 |
160 | dy = torch.randn_like(y1)
161 |
162 | y1.backward(dy.clone(), retain_graph=True)
163 | y2.backward(dy.clone(), retain_graph=True)
164 |
165 | assert torch.allclose(
166 | phi3_mlp.gate_up_proj.weight.grad,
167 | liger_mlp.gate_up_proj.weight.grad,
168 | atol=atol,
169 | rtol=rtol,
170 | )
171 | assert torch.allclose(
172 | phi3_mlp.down_proj.weight.grad,
173 | liger_mlp.down_proj.weight.grad,
174 | atol=atol,
175 | rtol=rtol,
176 | )
177 |
178 | assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
179 |
180 |
181 | @pytest.mark.parametrize(
182 | "bsz, seq_len, size",
183 | [
184 | (2, 8, 8),
185 | (9, 7, 41),
186 | ],
187 | )
188 | @pytest.mark.parametrize(
189 | "dtype, atol, rtol",
190 | [
191 | # atol is for small values: they have more difference, so set atol higher
192 | # rtol is for larger values: they are very close, so set rtol lower
193 | (torch.float32, 1e-0, 1e-5),
194 | # TODO: we should find a better way to tune this. 1e4 is too large apparently
195 | (torch.bfloat16, 1e4, 1e-2),
196 | ],
197 | )
198 | def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol):
199 | _input = torch.randn(bsz, seq_len, size, device=device, dtype=dtype)
200 | _b = torch.randn(bsz, seq_len, size, device=device, dtype=dtype)
201 |
202 | x1 = _input.clone().requires_grad_(True)
203 | x2 = _input.clone().requires_grad_(True)
204 |
205 | b1 = _b.clone().requires_grad_(True)
206 | b2 = _b.clone().requires_grad_(True)
207 |
208 | y1 = liger_swiglu(a=x1, b=b1)
209 | y2 = LigerSiLUMulFunction.apply(x2, b2)
210 |
211 | assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
212 |
213 | # Test backward pass
214 | grad_output = torch.randn_like(y1)
215 |
216 | y1.backward(grad_output)
217 | y2.backward(grad_output)
218 |
219 | # Check if gradients are close for x
220 | assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
221 | assert torch.allclose(b1.grad, b2.grad, atol=atol, rtol=rtol)
222 |
--------------------------------------------------------------------------------
/test/transformers/test_trainer_integration.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | def test_import():
5 | try:
6 | from liger_kernel.transformers.trainer_integration import ( # noqa: F401
7 | _apply_liger_kernel,
8 | )
9 | except Exception:
10 | pytest.fail("Import _apply_liger_kernel fails")
11 |
--------------------------------------------------------------------------------
/test/transformers/test_transformers.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | def test_import_from_root():
5 | try:
6 | from liger_kernel.transformers import ( # noqa: F401
7 | LigerBlockSparseTop2MLP,
8 | LigerCrossEntropyLoss,
9 | LigerFusedLinearCrossEntropyLoss,
10 | LigerGEGLUMLP,
11 | LigerLayerNorm,
12 | LigerPhi3SwiGLUMLP,
13 | LigerRMSNorm,
14 | LigerSwiGLUMLP,
15 | liger_rotary_pos_emb,
16 | )
17 | except Exception:
18 | pytest.fail("Import kernels from root fails")
19 |
--------------------------------------------------------------------------------
/test/triton/test_triton_monkey_patch.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | def test_import_from_root():
5 | try:
6 | from liger_kernel.triton import apply_liger_triton_cache_manager # noqa: F401
7 | except Exception:
8 | pytest.fail("Import kernel patch from root fails")
9 |
10 |
11 | def test_import_custom_cache_manager():
12 | from triton.runtime.cache import get_cache_manager
13 |
14 | from liger_kernel.triton import apply_liger_triton_cache_manager
15 |
16 | apply_liger_triton_cache_manager()
17 | cache_manager = get_cache_manager(key="test_hash")
18 | from liger_kernel.triton.monkey_patch import LigerTritonFileCacheManager
19 |
20 | assert isinstance(
21 | cache_manager, LigerTritonFileCacheManager
22 | ), "Cache manager should have been LigerTritonFileCacheManager"
23 |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment