Last active
March 19, 2025 22:41
-
-
Save HDCharles/15813ce9105b24b75d8fcfc9fe595d59 to your computer and use it in GitHub Desktop.
this shows a place where moe doesn't work with compile
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from torch.nn import functional as F | |
from dataclasses import dataclass | |
torch.manual_seed(0) | |
# T tokens | |
# E experts | |
# D dim | |
# I intermediate dim | |
# A activated experts | |
# T'(e) tokens for expert e | |
class MOEFeedForward(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.gate = nn.Linear(4, 8, bias=False) | |
self.cond_ffn = ConditionalFeedForward() | |
self.dim = 4 | |
self.num_activated_experts = 2 | |
def forward(self, x: Tensor) -> Tensor: | |
batch_size = x.shape[0] | |
x = x.view(-1, self.dim) # x: [T, D] | |
scores = self.gate(x) # [T, E] | |
expert_weights = F.softmax(scores, dim=-1) | |
expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] | |
expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] | |
out = self.cond_ffn(x, expert_indices, expert_weights, self.num_activated_experts) | |
return out.reshape(batch_size, -1, self.dim) | |
class ConditionalFeedForward(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.w1 = nn.Parameter(torch.randn((8, 16, 4), device="cuda", dtype=torch.bfloat16)) # E, I, D | |
self.w2 = nn.Parameter(torch.randn((8, 4, 16), device="cuda", dtype=torch.bfloat16)) # E, D, I | |
self.w3 = nn.Parameter(torch.randn((8, 16, 4), device="cuda", dtype=torch.bfloat16)) # E, I, D | |
self.num_experts = 8 | |
algorithm = "forloop" | |
def forward( | |
self, x: Tensor, # T, D | |
expert_indices: Tensor, # T, A | |
expert_weights: Tensor, # T, A | |
num_activated_experts: int, | |
) -> Tensor: | |
if x.shape[0] == 1: | |
outs = [] | |
expert_indices=expert_indices.squeeze() | |
w1 = self.w1[expert_indices] | |
w2 = self.w2[expert_indices] | |
w3 = self.w3[expert_indices] | |
for index in range(num_activated_experts): | |
cur_out = F.linear( F.silu(F.linear(x, w1[index])) * F.linear(x, w3[index]), w2[index]) | |
outs.append(cur_out) | |
mixed_outs = torch.cat(outs, dim=0) | |
final_out = (mixed_outs * expert_weights.view(-1,1)).sum(dim=0).unsqueeze(-1) | |
return final_out | |
else: | |
raise Exception("not implemented") | |
moe = MOEFeedForward().to("cuda").to(torch.bfloat16) | |
input1 = torch.randn(1, 1, 4).to("cuda").to(torch.bfloat16) | |
input2 = torch.randn(1, 1, 4).to("cuda").to(torch.bfloat16) | |
torch.set_float32_matmul_precision("highest") | |
with torch.no_grad(): | |
out1 = moe(input1) | |
print(out1.sum()) | |
out2 = moe(input2) | |
print(out2.sum()) | |
# moe_c = torch.compile(moe, mode="reduce-overhead") # working | |
moe_c = torch.compile(moe, mode="reduce-overhead", fullgraph=True) #this fails on token shuffle part | |
moe_c(input1) | |
moe_c(input2) | |
out1c = moe_c(input1) | |
print(out1c.sum()) | |
out2c = moe_c(input2) | |
print(out2c.sum()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
tensor(2.2656, device='cuda:0', dtype=torch.bfloat16)
tensor(22.5000, device='cuda:0', dtype=torch.bfloat16)
tensor(2.2500, device='cuda:0', dtype=torch.bfloat16)
tensor(22.3750, device='cuda:0', dtype=torch.bfloat16)
(output should match)