Last active
November 18, 2022 01:18
-
-
Save antiagainst/146af8b1a25960a4ac788bacff267f9e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Float16, CooperativeMatrixNV], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]> { | |
spirv.GlobalVariable @__builtin_var_WorkgroupId__ built_in("WorkgroupId") : !spirv.ptr<vector<3xi32>, Input> | |
spirv.GlobalVariable @__workgroup_mem__5 : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup> | |
spirv.GlobalVariable @__workgroup_mem__4 : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup> | |
spirv.GlobalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spirv.ptr<vector<3xi32>, Input> | |
spirv.GlobalVariable @__resource_var_0_0_ bind(0, 0) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer> | |
spirv.GlobalVariable @__resource_var_0_1_ bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer> | |
spirv.GlobalVariable @__resource_var_0_2_ bind(0, 2) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer> | |
spirv.func @forward_dispatch_0_matmul_512x1280x1280() "None" { | |
%false = spirv.Constant false | |
%cst2562_i32 = spirv.Constant 2562 : i32 | |
%cst130_i32 = spirv.Constant 130 : i32 | |
%cst66_i32 = spirv.Constant 66 : i32 | |
%cst64_i32 = spirv.Constant 64 : i32 | |
%cst2_i32 = spirv.Constant 2 : i32 | |
%cst152_i32 = spirv.Constant 152 : i32 | |
%cst2560_i32 = spirv.Constant 2560 : i32 | |
%cst160_i32 = spirv.Constant 160 : i32 | |
%cst256_i32 = spirv.Constant 256 : i32 | |
%cst128_i32 = spirv.Constant 128 : i32 | |
%cst156_i32 = spirv.Constant 156 : i32 | |
%cst4_i32 = spirv.Constant 4 : i32 | |
%cst-1_i32 = spirv.Constant -1 : i32 | |
%cst8_i32 = spirv.Constant 8 : i32 | |
%cst10240_i32 = spirv.Constant 10240 : i32 | |
%cst5120_i32 = spirv.Constant 5120 : i32 | |
%cst1280_i32 = spirv.Constant 1280 : i32 | |
%cst32_i32 = spirv.Constant 32 : i32 | |
%cst0_i32 = spirv.Constant 0 : i32 | |
%cst_f16 = spirv.Constant 0.000000e+00 : f16 | |
%0 = spirv.CompositeConstruct %cst_f16 : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> | |
%__builtin_var_LocalInvocationId___addr = spirv.mlir.addressof @__builtin_var_LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input> | |
%1 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi32> | |
%2 = spirv.CompositeExtract %1[0 : i32] : vector<3xi32> | |
%3 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi32> | |
%4 = spirv.CompositeExtract %3[1 : i32] : vector<3xi32> | |
%5 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi32> | |
%6 = spirv.CompositeExtract %5[2 : i32] : vector<3xi32> | |
%__workgroup_mem__4_addr = spirv.mlir.addressof @__workgroup_mem__4 : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup> | |
%__workgroup_mem__5_addr = spirv.mlir.addressof @__workgroup_mem__5 : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup> | |
%__resource_var_0_0__addr = spirv.mlir.addressof @__resource_var_0_0_ : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer> | |
%__resource_var_0_1__addr = spirv.mlir.addressof @__resource_var_0_1_ : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer> | |
%__resource_var_0_2__addr = spirv.mlir.addressof @__resource_var_0_2_ : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer> | |
%__builtin_var_WorkgroupId___addr = spirv.mlir.addressof @__builtin_var_WorkgroupId__ : !spirv.ptr<vector<3xi32>, Input> | |
%7 = spirv.Load "Input" %__builtin_var_WorkgroupId___addr : vector<3xi32> | |
%8 = spirv.CompositeExtract %7[0 : i32] : vector<3xi32> | |
%9 = spirv.Load "Input" %__builtin_var_WorkgroupId___addr : vector<3xi32> | |
%10 = spirv.CompositeExtract %9[1 : i32] : vector<3xi32> | |
%11 = spirv.Variable : !spirv.ptr<!spirv.coopmatrix<16x16xf16, Subgroup>, Function> | |
%12 = spirv.Variable : !spirv.ptr<!spirv.coopmatrix<16x16xf16, Subgroup>, Function> | |
%13 = spirv.Variable : !spirv.ptr<!spirv.coopmatrix<16x16xf16, Subgroup>, Function> | |
%14 = spirv.Variable : !spirv.ptr<!spirv.coopmatrix<16x16xf16, Subgroup>, Function> | |
spirv.mlir.loop { | |
spirv.Branch ^bb1(%cst0_i32, %0, %0, %0, %0 : i32, !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup>) | |
^bb1(%39: i32, %40: !spirv.coopmatrix<16x16xf16, Subgroup>, %41: !spirv.coopmatrix<16x16xf16, Subgroup>, %42: !spirv.coopmatrix<16x16xf16, Subgroup>, %43: !spirv.coopmatrix<16x16xf16, Subgroup>): // 2 preds: ^bb0, ^bb2 | |
%44 = spirv.SLessThan %39, %cst1280_i32 : i32 | |
spirv.BranchConditional %44, ^bb2, ^bb3 | |
^bb2: // pred: ^bb1 | |
spirv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory> | |
%45 = spirv.IMul %4, %cst5120_i32 : i32 | |
%46 = spirv.IAdd %2, %45 : i32 | |
%47 = spirv.IMul %6, %cst10240_i32 : i32 | |
%48 = spirv.IAdd %46, %47 : i32 | |
%49 = spirv.IMul %10, %cst10240_i32 : i32 | |
%50 = spirv.IAdd %48, %49 : i32 | |
%51 = spirv.SLessThan %39, %cst0_i32 : i32 | |
%52 = spirv.ISub %cst-1_i32, %39 : i32 | |
%53 = spirv.Select %51, %52, %39 : i1, i32 | |
%54 = spirv.SDiv %53, %cst8_i32 : i32 | |
%55 = spirv.ISub %cst-1_i32, %54 : i32 | |
%56 = spirv.Select %51, %55, %54 : i1, i32 | |
%57 = spirv.IAdd %50, %56 : i32 | |
%58 = spirv.SLessThan %2, %cst0_i32 : i32 | |
%59 = spirv.ISub %cst-1_i32, %2 : i32 | |
%60 = spirv.Select %58, %59, %2 : i1, i32 | |
%61 = spirv.SDiv %60, %cst4_i32 : i32 | |
%62 = spirv.ISub %cst-1_i32, %61 : i32 | |
%63 = spirv.Select %58, %62, %61 : i1, i32 | |
%64 = spirv.IMul %63, %cst156_i32 : i32 | |
%65 = spirv.IAdd %57, %64 : i32 | |
%66 = spirv.AccessChain %__resource_var_0_0__addr[%cst0_i32, %65] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32 | |
%67 = spirv.Load "StorageBuffer" %66 : vector<4xf32> | |
%68 = spirv.IMul %4, %cst128_i32 : i32 | |
%69 = spirv.IAdd %2, %68 : i32 | |
%70 = spirv.IMul %6, %cst256_i32 : i32 | |
%71 = spirv.IAdd %69, %70 : i32 | |
%72 = spirv.AccessChain %__workgroup_mem__4_addr[%cst0_i32, %71] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32 | |
spirv.Store "Workgroup" %72, %67 : vector<4xf32> | |
%73 = spirv.IMul %39, %cst160_i32 : i32 | |
%74 = spirv.IAdd %73, %2 : i32 | |
%75 = spirv.IMul %4, %cst2560_i32 : i32 | |
%76 = spirv.IAdd %74, %75 : i32 | |
%77 = spirv.IMul %6, %cst5120_i32 : i32 | |
%78 = spirv.IAdd %76, %77 : i32 | |
%79 = spirv.IMul %8, %cst8_i32 : i32 | |
%80 = spirv.IAdd %78, %79 : i32 | |
%81 = spirv.SDiv %60, %cst8_i32 : i32 | |
%82 = spirv.ISub %cst-1_i32, %81 : i32 | |
%83 = spirv.Select %58, %82, %81 : i1, i32 | |
%84 = spirv.IMul %83, %cst152_i32 : i32 | |
%85 = spirv.IAdd %80, %84 : i32 | |
%86 = spirv.AccessChain %__resource_var_0_1__addr[%cst0_i32, %85] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32 | |
%87 = spirv.Load "StorageBuffer" %86 : vector<4xf32> | |
%88 = spirv.AccessChain %__workgroup_mem__5_addr[%cst0_i32, %71] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32 | |
spirv.Store "Workgroup" %88, %87 : vector<4xf32> | |
spirv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory> | |
%89 = spirv.AccessChain %__workgroup_mem__4_addr[%cst0_i32, %68] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32 | |
%90 = spirv.NV.CooperativeMatrixLoad %89, %cst4_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup> | |
%91 = spirv.IAdd %68, %cst2_i32 : i32 | |
%92 = spirv.AccessChain %__workgroup_mem__4_addr[%cst0_i32, %91] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32 | |
%93 = spirv.NV.CooperativeMatrixLoad %92, %cst4_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup> | |
%94 = spirv.IAdd %68, %cst64_i32 : i32 | |
%95 = spirv.AccessChain %__workgroup_mem__4_addr[%cst0_i32, %94] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32 | |
%96 = spirv.NV.CooperativeMatrixLoad %95, %cst4_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup> | |
%97 = spirv.IAdd %68, %cst66_i32 : i32 | |
%98 = spirv.AccessChain %__workgroup_mem__4_addr[%cst0_i32, %97] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32 | |
%99 = spirv.NV.CooperativeMatrixLoad %98, %cst4_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup> | |
%100 = spirv.SDiv %60, %cst64_i32 : i32 | |
%101 = spirv.ISub %cst-1_i32, %100 : i32 | |
%102 = spirv.Select %58, %101, %100 : i1, i32 | |
%103 = spirv.IMul %102, %cst4_i32 : i32 | |
%104 = spirv.AccessChain %__workgroup_mem__5_addr[%cst0_i32, %103] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32 | |
%105 = spirv.NV.CooperativeMatrixLoad %104, %cst8_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup> | |
%106 = spirv.IAdd %103, %cst2_i32 : i32 | |
%107 = spirv.AccessChain %__workgroup_mem__5_addr[%cst0_i32, %106] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32 | |
%108 = spirv.NV.CooperativeMatrixLoad %107, %cst8_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup> | |
%109 = spirv.IAdd %103, %cst128_i32 : i32 | |
%110 = spirv.AccessChain %__workgroup_mem__5_addr[%cst0_i32, %109] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32 | |
%111 = spirv.NV.CooperativeMatrixLoad %110, %cst8_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup> | |
%112 = spirv.IAdd %103, %cst130_i32 : i32 | |
%113 = spirv.AccessChain %__workgroup_mem__5_addr[%cst0_i32, %112] : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>, i32, i32 | |
%114 = spirv.NV.CooperativeMatrixLoad %113, %cst8_i32, %false : !spirv.ptr<vector<4xf32>, Workgroup> as !spirv.coopmatrix<16x16xf16, Subgroup> | |
%115 = spirv.NV.CooperativeMatrixMulAdd %90, %105, %40 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup> | |
%116 = spirv.NV.CooperativeMatrixMulAdd %93, %111, %115 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup> | |
%117 = spirv.NV.CooperativeMatrixMulAdd %90, %108, %41 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup> | |
%118 = spirv.NV.CooperativeMatrixMulAdd %93, %114, %117 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup> | |
%119 = spirv.NV.CooperativeMatrixMulAdd %96, %105, %42 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup> | |
%120 = spirv.NV.CooperativeMatrixMulAdd %99, %111, %119 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup> | |
%121 = spirv.NV.CooperativeMatrixMulAdd %96, %108, %43 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup> | |
%122 = spirv.NV.CooperativeMatrixMulAdd %99, %114, %121 : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup> | |
spirv.Store "Function" %11, %116 : !spirv.coopmatrix<16x16xf16, Subgroup> | |
spirv.Store "Function" %12, %118 : !spirv.coopmatrix<16x16xf16, Subgroup> | |
spirv.Store "Function" %13, %120 : !spirv.coopmatrix<16x16xf16, Subgroup> | |
spirv.Store "Function" %14, %122 : !spirv.coopmatrix<16x16xf16, Subgroup> | |
%123 = spirv.IAdd %39, %cst32_i32 : i32 | |
spirv.Branch ^bb1(%123, %116, %118, %120, %122 : i32, !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup>) | |
^bb3: // pred: ^bb1 | |
spirv.mlir.merge | |
} | |
%15 = spirv.Load "Function" %14 : !spirv.coopmatrix<16x16xf16, Subgroup> | |
%16 = spirv.Load "Function" %13 : !spirv.coopmatrix<16x16xf16, Subgroup> | |
%17 = spirv.Load "Function" %12 : !spirv.coopmatrix<16x16xf16, Subgroup> | |
%18 = spirv.Load "Function" %11 : !spirv.coopmatrix<16x16xf16, Subgroup> | |
%19 = spirv.IMul %4, %cst5120_i32 : i32 | |
%20 = spirv.IMul %10, %cst10240_i32 : i32 | |
%21 = spirv.IAdd %19, %20 : i32 | |
%22 = spirv.IMul %8, %cst8_i32 : i32 | |
%23 = spirv.IAdd %21, %22 : i32 | |
%24 = spirv.SLessThan %2, %cst0_i32 : i32 | |
%25 = spirv.ISub %cst-1_i32, %2 : i32 | |
%26 = spirv.Select %24, %25, %2 : i1, i32 | |
%27 = spirv.SDiv %26, %cst64_i32 : i32 | |
%28 = spirv.ISub %cst-1_i32, %27 : i32 | |
%29 = spirv.Select %24, %28, %27 : i1, i32 | |
%30 = spirv.IMul %29, %cst4_i32 : i32 | |
%31 = spirv.IAdd %23, %30 : i32 | |
%32 = spirv.IAdd %31, %cst2562_i32 : i32 | |
%33 = spirv.AccessChain %__resource_var_0_2__addr[%cst0_i32, %32] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32 | |
spirv.NV.CooperativeMatrixStore %33, %15, %cst160_i32, %false : !spirv.ptr<vector<4xf32>, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup> | |
%34 = spirv.IAdd %31, %cst2560_i32 : i32 | |
%35 = spirv.AccessChain %__resource_var_0_2__addr[%cst0_i32, %34] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32 | |
spirv.NV.CooperativeMatrixStore %35, %16, %cst160_i32, %false : !spirv.ptr<vector<4xf32>, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup> | |
%36 = spirv.IAdd %31, %cst2_i32 : i32 | |
%37 = spirv.AccessChain %__resource_var_0_2__addr[%cst0_i32, %36] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32 | |
spirv.NV.CooperativeMatrixStore %37, %17, %cst160_i32, %false : !spirv.ptr<vector<4xf32>, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup> | |
%38 = spirv.AccessChain %__resource_var_0_2__addr[%cst0_i32, %31] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32 | |
spirv.NV.CooperativeMatrixStore %38, %18, %cst160_i32, %false : !spirv.ptr<vector<4xf32>, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup> | |
spirv.Return | |
} | |
spirv.EntryPoint "GLCompute" @forward_dispatch_0_matmul_512x1280x1280, @__builtin_var_LocalInvocationId__, @__builtin_var_WorkgroupId__ | |
spirv.ExecutionMode @forward_dispatch_0_matmul_512x1280x1280 "LocalSize", 128, 2, 1 | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func.func @forward(%arg0: tensor<512x1280xf16>, %arg1: tensor<1280x1280xf16>) -> tensor<512x1280xf16> { | |
%cst = arith.constant 0.000000e+00 : f16 | |
%2 = tensor.empty() : tensor<512x1280xf16> | |
%3 = linalg.fill ins(%cst : f16) outs(%2 : tensor<512x1280xf16>) -> tensor<512x1280xf16> | |
%4 = linalg.matmul ins(%arg0, %arg1 : tensor<512x1280xf16>, tensor<1280x1280xf16>) outs(%3 : tensor<512x1280xf16>) -> tensor<512x1280xf16> | |
return %4 : tensor<512x1280xf16> | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment