This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import torch | |
import triton | |
import triton.language as tl | |
@triton.jit | |
def scaled_dot_kernel( | |
# Pointers to matrices | |
a_ptr, b_ptr, output_ptr, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton | |
import triton.language as tl | |
@triton.jit | |
def mxfp_matmul( | |
a_ptr, b_ptr, output_ptr, | |
a_scale, b_scale, | |
M, N, K, | |
stride_scale: tl.constexpr, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
os.environ["TORCH_COMPILE_DISABLE"] = "1" | |
os.environ["TRITON_ALWAYS_COMPILE"] = "1" | |
import argparse, torch, triton, triton.language as tl | |
from triton.tools.mxfp import MXFP4Tensor | |
import time | |
def scaleDot_ref(A, B, sA_grouped, sB_grouped, GROUP_K: int): | |
sA = 2 ** (sA_grouped.float() - 127.0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
os.environ["TRITON_CACHE_DIR"] = "./cache" | |
os.environ["TRITON_DUMP_DIR"] = "./cache" | |
import torch, triton, triton.language as tl, os, statistics as stats | |
# --------------------------- | |
# Fused 2-dot kernel | |
# Dot A: (M1,K1)x(K1,N1) -> C1 | |
# Dot B: (M2,K2)x(K2,N2) -> C2 (smaller) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// -----// IR Dump After Canonicalizer (canonicalize) //----- // | |
func.func @test_fusion(%arg0: tensor<32x16x256x256xf32>, %arg1: tensor<32xf32>, %arg2: tensor<32x16xf32>, %arg3: tensor<32x16xf32>) -> tensor<512x258x258xf32> { | |
%cst = arith.constant 1.000000e+00 : f32 | |
%cst_0 = arith.constant 0.000000e+00 : f32 | |
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<32x16x256x256xf32>, tensor<32xf32>, tensor<32x16xf32>, tensor<32x16xf32>) outs(%arg0 : tensor<32x16x256x256xf32>) { | |
^bb0(%in: f32, %in_1: f32, %in_2: f32, %in_3: f32, %out: f32): | |
%1 = arith.addf %in_1, %cst : f32 | |
%2 = math.rsqrt %1 : f32 | |
%3 = arith.mulf %in, %2 : f32 | |
%4 = arith.mulf %3, %in_2 : f32 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
" Don't try to be vi compatible | |
set nocompatible | |
" Helps force plugins to load correctly when it is turned back on below | |
filetype off | |
" TODO: Load plugins here (pathogen or vundle) | |
" Turn on syntax highlighting | |
syntax on |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Train CIFAR10 with PyTorch. | |
based on https://github.com/kuangliu/pytorch-cifar | |
''' | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import torch.backends.cudnn as cudnn |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class GATMHA(nn.Module): | |
def __init__(self, hidden_size, n_heads, lralpha=0.2): | |
super(GATMHA, self).__init__() | |
self.W = nn.Linear(hidden_size, hidden_size, bias=True) | |
self.Q = nn.Linear(hidden_size, hidden_size, bias=True) | |
self.a = nn.Parameter(torch.FloatTensor(n_heads, 2 * (hidden_size // n_heads))) | |
self.lralpha = lralpha | |
self.n_heads = n_heads | |
self.n_hidden_per_head = hidden_size // n_heads | |
nn.init.uniform_(self.W.weight, -1 / np.sqrt(hidden_size), 1 / np.sqrt(hidden_size)) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
NewerOlder