Skip to content

Instantly share code, notes, and snippets.

@HDCharles
Last active March 19, 2025 22:41
Show Gist options
  • Save HDCharles/15813ce9105b24b75d8fcfc9fe595d59 to your computer and use it in GitHub Desktop.
Save HDCharles/15813ce9105b24b75d8fcfc9fe595d59 to your computer and use it in GitHub Desktop.
this shows a place where moe doesn't work with compile
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from dataclasses import dataclass
torch.manual_seed(0)
# T tokens
# E experts
# D dim
# I intermediate dim
# A activated experts
# T'(e) tokens for expert e
class MOEFeedForward(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gate = nn.Linear(4, 8, bias=False)
self.cond_ffn = ConditionalFeedForward()
self.dim = 4
self.num_activated_experts = 2
def forward(self, x: Tensor) -> Tensor:
batch_size = x.shape[0]
x = x.view(-1, self.dim) # x: [T, D]
scores = self.gate(x) # [T, E]
expert_weights = F.softmax(scores, dim=-1)
expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A]
expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
out = self.cond_ffn(x, expert_indices, expert_weights, self.num_activated_experts)
return out.reshape(batch_size, -1, self.dim)
class ConditionalFeedForward(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.randn((8, 16, 4), device="cuda", dtype=torch.bfloat16)) # E, I, D
self.w2 = nn.Parameter(torch.randn((8, 4, 16), device="cuda", dtype=torch.bfloat16)) # E, D, I
self.w3 = nn.Parameter(torch.randn((8, 16, 4), device="cuda", dtype=torch.bfloat16)) # E, I, D
self.num_experts = 8
algorithm = "forloop"
def forward(
self, x: Tensor, # T, D
expert_indices: Tensor, # T, A
expert_weights: Tensor, # T, A
num_activated_experts: int,
) -> Tensor:
if x.shape[0] == 1:
outs = []
expert_indices=expert_indices.squeeze()
w1 = self.w1[expert_indices]
w2 = self.w2[expert_indices]
w3 = self.w3[expert_indices]
for index in range(num_activated_experts):
cur_out = F.linear( F.silu(F.linear(x, w1[index])) * F.linear(x, w3[index]), w2[index])
outs.append(cur_out)
mixed_outs = torch.cat(outs, dim=0)
final_out = (mixed_outs * expert_weights.view(-1,1)).sum(dim=0).unsqueeze(-1)
return final_out
else:
raise Exception("not implemented")
moe = MOEFeedForward().to("cuda").to(torch.bfloat16)
input1 = torch.randn(1, 1, 4).to("cuda").to(torch.bfloat16)
input2 = torch.randn(1, 1, 4).to("cuda").to(torch.bfloat16)
torch.set_float32_matmul_precision("highest")
with torch.no_grad():
out1 = moe(input1)
print(out1.sum())
out2 = moe(input2)
print(out2.sum())
# moe_c = torch.compile(moe, mode="reduce-overhead") # working
moe_c = torch.compile(moe, mode="reduce-overhead", fullgraph=True) #this fails on token shuffle part
moe_c(input1)
moe_c(input2)
out1c = moe_c(input1)
print(out1c.sum())
out2c = moe_c(input2)
print(out2c.sum())
@HDCharles
Copy link
Author

tensor(2.2656, device='cuda:0', dtype=torch.bfloat16)
tensor(22.5000, device='cuda:0', dtype=torch.bfloat16)
tensor(2.2500, device='cuda:0', dtype=torch.bfloat16)
tensor(22.3750, device='cuda:0', dtype=torch.bfloat16)

(output should match)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment