Skip to content

Instantly share code, notes, and snippets.

@merrymercy
Created September 14, 2024 21:22
Show Gist options
  • Save merrymercy/9aa9e6a1c5e39246605e1df1604bbc10 to your computer and use it in GitHub Desktop.
Save merrymercy/9aa9e6a1c5e39246605e1df1604bbc10 to your computer and use it in GitHub Desktop.
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Fused operators for activation layers."""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hip
logger = logging.getLogger(__name__)
try:
if is_hip():
raise ImportError("FlashInfer is not supported on AMD GPUs.")
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
except ImportError as e:
logger.warning("FlashInfer is not available. Fallback to other kernel libraries. Message: {e}")
from vllm._custom_ops import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
silu_and_mul(x, out)
return out
class GeluAndMul(CustomOp):
def __init__(self, approximate="tanh"):
super().__init__()
self.approximate = approximate
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "tanh":
gelu_tanh_and_mul(x, out)
elif self.approximate == "none":
gelu_and_mul(x, out)
else:
raise RuntimeError("GeluAndMul only support tanh or none")
return out
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
This is used for some quantization methods like AWQ.
"""
def __init__(
self,
act_module: nn.Module,
intermediate_size: int,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.act = act_module
self.input_is_parallel = input_is_parallel
if input_is_parallel:
tp_size = get_tensor_model_parallel_world_size()
intermediate_size_per_partition = divide(intermediate_size, tp_size)
else:
intermediate_size_per_partition = intermediate_size
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.scales = nn.Parameter(
torch.empty(intermediate_size_per_partition, dtype=params_dtype)
)
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(x) / self.scales
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
if self.input_is_parallel:
tp_rank = get_tensor_model_parallel_rank()
shard_size = param_data.shape[0]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
}
def get_act_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
if quant_config is not None and act_fn_name in quant_config.get_scaled_act_names():
if intermediate_size is None:
raise ValueError(
"intermediate_size must be specified for scaled "
"activation functions."
)
return ScaledActivation(
act_fn, intermediate_size, input_is_parallel, params_dtype
)
return act_fn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment