Skip to content

Instantly share code, notes, and snippets.

@antiagainst
Last active November 18, 2022 01:18
Show Gist options
  • Save antiagainst/146af8b1a25960a4ac788bacff267f9e to your computer and use it in GitHub Desktop.
Save antiagainst/146af8b1a25960a4ac788bacff267f9e to your computer and use it in GitHub Desktop.
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Float16, CooperativeMatrixNV], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]> {
spirv.GlobalVariable @__builtin_var_WorkgroupId__ built_in("WorkgroupId") : !spirv.ptr<vector<3xi32>, Input>
spirv.GlobalVariable @__workgroup_mem__5 : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>
spirv.GlobalVariable @__workgroup_mem__4 : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>
spirv.GlobalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
spirv.GlobalVariable @__resource_var_0_0_ bind(0, 0) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
spirv.GlobalVariable @__resource_var_0_1_ bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
spirv.GlobalVariable @__resource_var_0_2_ bind(0, 2) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
spirv.func @forward_dispatch_0_matmul_512x1280x1280() "None" {
%false = spirv.Constant false
%cst2562_i32 = spirv.Constant 2562 : i32
%cst130_i32 = spirv.Constant 130 : i32
%cst66_i32 = spirv.Constant 66 : i32
%cst64_i32 = spirv.Constant 64 : i32
%cst2_i32 = spirv.Constant 2 : i32
%cst152_i32 = spirv.Constant 152 : i32
%cst2560_i32 = spirv.Constant 2560 : i32
%cst160_i32 = spirv.Constant 160 : i32
%cst256_i32 = spirv.Constant 256 : i32
%cst128_i32 = spirv.Constant 128 : i32
%cst156_i32 = spirv.Constant 156 : i32
%cst4_i32 = spirv.Constant 4 : i32
%cst-1_i32 = spirv.Constant -1 : i32
%cst8_i32 = spirv.Constant 8 : i32
%cst10240_i32 = spirv.Constant 10240 : i32
%cst5120_i32 = spirv.Constant 5120 : i32
%cst1280_i32 = spirv.Constant 1280 : i32
%cst32_i32 = spirv.Constant 32 : i32
%cst0_i32 = spirv.Constant 0 : i32
%cst_f16 = spirv.Constant 0.000000e+00 : f16
%0 = spirv.CompositeConstruct %cst_f16 : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup>
%__builtin_var_LocalInvocationId___addr = spirv.mlir.addressof @__builtin_var_LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
%1 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi32>
%2 = spirv.CompositeExtract %1[0 : i32] : vector<3xi32>
%3 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi32>
%4 = spirv.CompositeExtract %3[1 : i32] : vector<3xi32>
%5 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi32>
%6 = spirv.CompositeExtract %5[2 : i32] : vector<3xi32>
%__workgroup_mem__4_addr = spirv.mlir.addressof @__workgroup_mem__4 : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>
%__workgroup_mem__5_addr = spirv.mlir.addressof @__workgroup_mem__5 : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>
%__resource_var_0_0__addr = spirv.mlir.addressof @__resource_var_0_0_ : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
%__resource_var_0_1__addr = spirv.mlir.addressof @__resource_var_0_1_ : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
%__resource_var_0_2__addr = spirv.mlir.addressof @__resource_var_0_2_ : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
%__builtin_var_WorkgroupId___addr = spirv.mlir.addressof @__builtin_var_WorkgroupId__ : !spirv.ptr<vector<3xi32>, Input>
%7 = spirv.Load "Input" %__builtin_var_WorkgroupId___addr : vector<3xi32>
%8 = spirv.CompositeExtract %7[0 : i32] : vector<3xi32>
%9 = spirv.Load "Input" %__builtin_var_WorkgroupId___addr : vector<3xi32>
%10 = spirv.CompositeExtract %9[1 : i32] : vector<3xi32>
%11 = spirv.Variable : !spirv.ptr<!spirv.coopmatrix<16x16xf16, Subgroup>, Function>
%12 = spirv.Variable : !spirv.ptr<!spirv.coopmatrix<16x16xf16, Subgroup>, Function>
%13 = spirv.Variable : !spirv.ptr<!spirv.coopmatrix<16x16xf16, Subgroup>, Function>
%14 = spirv.Variable : !spirv.ptr<!spirv.coopmatrix<16x16xf16, Subgroup>, Function>
spirv.mlir.loop {
spirv.Branch ^bb1(%cst0_i32, %0, %0, %0, %0 : i32, !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup>)
^bb1(%39: i32, %40: !spirv.coopmatrix<16x16xf16, Subgroup>, %41: !spirv.coopmatrix<16x16xf16, Subgroup>, %42: !spirv.coopmatrix<16x16xf16, Subgroup>, %43: !spirv.coopmatrix<16x16xf16, Subgroup>): // 2 preds: ^bb0, ^bb2
%44 = spirv.SLessThan %39, %cst1280_i32 : i32
spirv.BranchConditional %44, ^bb2, ^bb3
^bb2: // pred: ^bb1
spirv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
%45 = spirv.IMul %4, %cst5120_i32 : i32
%46 = spirv.IAdd %2, %45 : i32
%47 = spirv.IMul %6, %cst10240_i32 : i32
%48 = spirv.IAdd %46, %47 : i32
%49 = spirv.IMul %10, %cst10240_i32 : i32
%50 = spirv.IAdd %48, %49 : i32
%51 = spirv.SLessThan %39, %cst0_i32 : i32
%52 = spirv.ISub %cst-1_i32, %39 : i32
%53 = spirv.Select %51, %52, %39 : i1, i32
%54 = spirv.SDiv %53, %cst8_i32 : i32
%55 = spirv.ISub %cst-1_i32, %54 : i32
%56 = spirv.Select %51, %55, %54 : i1, i32
%57 = spirv.IAdd %50, %56 : i32
%58 = spirv.SLessThan %2, %cst0_i32 : i32
%59 = spirv.ISub %cst-1_i32, %2 : i32
%60 = spirv.Select %58, %59, %2 : i1, i32
%61 = spirv.SDiv %60, %cst4_i32 : i32
%62 = spirv.ISub %cst-1_i32, %61 : i32
%63 = spirv.Select %58, %62, %61 : i1, i32
%64 = spirv.IMul %63, %cst156_i32 : i32
%65 = spirv.IAdd %57, %64 : i32
%66 = spirv.AccessChain %__resource_var_0_0__addr[%cst0_i32, %65] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32
%67 = spirv.Load "StorageBuffer" %66 : vector<4xf32>
%68 = spirv.IMul %4, %cst128_i32 : i32
%69 = spirv.IAdd %2, %68 : i32
%70 = spirv.IMul %6, %cst256_i32 : i32
%71 = spirv.IAdd %69, %70 : i32
%72 = spirv.AccessChain %__workgroup_mem__4_addr[%cst0_i32, %71] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32
spirv.Store "Workgroup" %72, %67 : vector<4xf32>
%73 = spirv.IMul %39, %cst160_i32 : i32
%74 = spirv.IAdd %73, %2 : i32
%75 = spirv.IMul %4, %cst2560_i32 : i32
%76 = spirv.IAdd %74, %75 : i32
%77 = spirv.IMul %6, %cst5120_i32 : i32
%78 = spirv.IAdd %76, %77 : i32
%79 = spirv.IMul %8, %cst8_i32 : i32
%80 = spirv.IAdd %78, %79 : i32
%81 = spirv.SDiv %60, %cst8_i32 : i32
%82 = spirv.ISub %cst-1_i32, %81 : i32
%83 = spirv.Select %58, %82, %81 : i1, i32
%84 = spirv.IMul %83, %cst152_i32 : i32
%85 = spirv.IAdd %80, %84 : i32
%86 = spirv.AccessChain %__resource_var_0_1__addr[%cst0_i32, %85] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32
%87 = spirv.Load "StorageBuffer" %86 : vector<4xf32>
%88 = spirv.AccessChain %__workgroup_mem__5_addr[%cst0_i32, %71] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32
spirv.Store "Workgroup" %88, %87 : vector<4xf32>
spirv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
%89 = spirv.AccessChain %__workgroup_mem__4_addr[%cst0_i32, %68] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32
%90 = spirv.NV.CooperativeMatrixLoad %89, %cst4_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup>
%91 = spirv.IAdd %68, %cst2_i32 : i32
%92 = spirv.AccessChain %__workgroup_mem__4_addr[%cst0_i32, %91] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32
%93 = spirv.NV.CooperativeMatrixLoad %92, %cst4_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup>
%94 = spirv.IAdd %68, %cst64_i32 : i32
%95 = spirv.AccessChain %__workgroup_mem__4_addr[%cst0_i32, %94] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32
%96 = spirv.NV.CooperativeMatrixLoad %95, %cst4_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup>
%97 = spirv.IAdd %68, %cst66_i32 : i32
%98 = spirv.AccessChain %__workgroup_mem__4_addr[%cst0_i32, %97] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32
%99 = spirv.NV.CooperativeMatrixLoad %98, %cst4_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup>
%100 = spirv.SDiv %60, %cst64_i32 : i32
%101 = spirv.ISub %cst-1_i32, %100 : i32
%102 = spirv.Select %58, %101, %100 : i1, i32
%103 = spirv.IMul %102, %cst4_i32 : i32
%104 = spirv.AccessChain %__workgroup_mem__5_addr[%cst0_i32, %103] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32
%105 = spirv.NV.CooperativeMatrixLoad %104, %cst8_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup>
%106 = spirv.IAdd %103, %cst2_i32 : i32
%107 = spirv.AccessChain %__workgroup_mem__5_addr[%cst0_i32, %106] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32
%108 = spirv.NV.CooperativeMatrixLoad %107, %cst8_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup>
%109 = spirv.IAdd %103, %cst128_i32 : i32
%110 = spirv.AccessChain %__workgroup_mem__5_addr[%cst0_i32, %109] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32
%111 = spirv.NV.CooperativeMatrixLoad %110, %cst8_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup>
%112 = spirv.IAdd %103, %cst130_i32 : i32
%113 = spirv.AccessChain %__workgroup_mem__5_addr[%cst0_i32, %112] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32
%114 = spirv.NV.CooperativeMatrixLoad %113, %cst8_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup>
%115 = spirv.NV.CooperativeMatrixMulAdd %90, %105, %40 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
%116 = spirv.NV.CooperativeMatrixMulAdd %93, %111, %115 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
%117 = spirv.NV.CooperativeMatrixMulAdd %90, %108, %41 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
%118 = spirv.NV.CooperativeMatrixMulAdd %93, %114, %117 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
%119 = spirv.NV.CooperativeMatrixMulAdd %96, %105, %42 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
%120 = spirv.NV.CooperativeMatrixMulAdd %99, %111, %119 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
%121 = spirv.NV.CooperativeMatrixMulAdd %96, %108, %43 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
%122 = spirv.NV.CooperativeMatrixMulAdd %99, %114, %121 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
spirv.Store "Function" %11, %116 : !spirv.coopmatrix<16x16xf16, Subgroup>
spirv.Store "Function" %12, %118 : !spirv.coopmatrix<16x16xf16, Subgroup>
spirv.Store "Function" %13, %120 : !spirv.coopmatrix<16x16xf16, Subgroup>
spirv.Store "Function" %14, %122 : !spirv.coopmatrix<16x16xf16, Subgroup>
%123 = spirv.IAdd %39, %cst32_i32 : i32
spirv.Branch ^bb1(%123, %116, %118, %120, %122 : i32, !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup>)
^bb3: // pred: ^bb1
spirv.mlir.merge
}
%15 = spirv.Load "Function" %14 : !spirv.coopmatrix<16x16xf16, Subgroup>
%16 = spirv.Load "Function" %13 : !spirv.coopmatrix<16x16xf16, Subgroup>
%17 = spirv.Load "Function" %12 : !spirv.coopmatrix<16x16xf16, Subgroup>
%18 = spirv.Load "Function" %11 : !spirv.coopmatrix<16x16xf16, Subgroup>
%19 = spirv.IMul %4, %cst5120_i32 : i32
%20 = spirv.IMul %10, %cst10240_i32 : i32
%21 = spirv.IAdd %19, %20 : i32
%22 = spirv.IMul %8, %cst8_i32 : i32
%23 = spirv.IAdd %21, %22 : i32
%24 = spirv.SLessThan %2, %cst0_i32 : i32
%25 = spirv.ISub %cst-1_i32, %2 : i32
%26 = spirv.Select %24, %25, %2 : i1, i32
%27 = spirv.SDiv %26, %cst64_i32 : i32
%28 = spirv.ISub %cst-1_i32, %27 : i32
%29 = spirv.Select %24, %28, %27 : i1, i32
%30 = spirv.IMul %29, %cst4_i32 : i32
%31 = spirv.IAdd %23, %30 : i32
%32 = spirv.IAdd %31, %cst2562_i32 : i32
%33 = spirv.AccessChain %__resource_var_0_2__addr[%cst0_i32, %32] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32
spirv.NV.CooperativeMatrixStore %33, %15, %cst160_i32, %false : !spirv.ptr<vector<4xf32>, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup>
%34 = spirv.IAdd %31, %cst2560_i32 : i32
%35 = spirv.AccessChain %__resource_var_0_2__addr[%cst0_i32, %34] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32
spirv.NV.CooperativeMatrixStore %35, %16, %cst160_i32, %false : !spirv.ptr<vector<4xf32>, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup>
%36 = spirv.IAdd %31, %cst2_i32 : i32
%37 = spirv.AccessChain %__resource_var_0_2__addr[%cst0_i32, %36] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32
spirv.NV.CooperativeMatrixStore %37, %17, %cst160_i32, %false : !spirv.ptr<vector<4xf32>, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup>
%38 = spirv.AccessChain %__resource_var_0_2__addr[%cst0_i32, %31] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32
spirv.NV.CooperativeMatrixStore %38, %18, %cst160_i32, %false : !spirv.ptr<vector<4xf32>, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup>
spirv.Return
}
spirv.EntryPoint "GLCompute" @forward_dispatch_0_matmul_512x1280x1280, @__builtin_var_LocalInvocationId__, @__builtin_var_WorkgroupId__
spirv.ExecutionMode @forward_dispatch_0_matmul_512x1280x1280 "LocalSize", 128, 2, 1
}
func.func @forward(%arg0: tensor<512x1280xf16>, %arg1: tensor<1280x1280xf16>) -> tensor<512x1280xf16> {
%cst = arith.constant 0.000000e+00 : f16
%2 = tensor.empty() : tensor<512x1280xf16>
%3 = linalg.fill ins(%cst : f16) outs(%2 : tensor<512x1280xf16>) -> tensor<512x1280xf16>
%4 = linalg.matmul ins(%arg0, %arg1 : tensor<512x1280xf16>, tensor<1280x1280xf16>) outs(%3 : tensor<512x1280xf16>) -> tensor<512x1280xf16>
return %4 : tensor<512x1280xf16>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment