Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created November 14, 2022 17:17
Show Gist options
  • Save AmosLewis/1517a0ecbe4baf153222a0631a39e119 to your computer and use it in GitHub Desktop.
Save AmosLewis/1517a0ecbe4baf153222a0631a39e119 to your computer and use it in GitHub Desktop.
func.func @forward(%arg0: !torch.vtensor<[1,128],si64>) -> !torch.vtensor<[1,128,1],f32>{
%int1 = torch.constant.int 1
%int32 = torch.constant.int 32
%int128 = torch.constant.int 128
%float1.000000e00 = torch.constant.float 1.000000e+00
%0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<2xf32>) : !torch.vtensor<[2],f32>
%1 = torch.vtensor.literal(dense_resource<__elided__> : tensor<2x32xf32>) : !torch.vtensor<[2,32],f32>
%2 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%3 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
%4 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
%5 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%6 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
%7 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
%8 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
%9 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%10 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
%11 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
%12 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
%13 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%14 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
%15 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
%16 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
%17 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%18 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
%19 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x37xf32>) : !torch.vtensor<[32,37],f32>
%20 = torch.vtensor.literal(dense<0.000000e+00> : tensor<37xf32>) : !torch.vtensor<[37],f32>
%21 = torch.vtensor.literal(dense_resource<__elided__> : tensor<37x32xf32>) : !torch.vtensor<[37,32],f32>
%22 = torch.vtensor.literal(dense_resource<__elided__> : tensor<32x32xf32>) : !torch.vtensor<[32,32],f32>
%23 = torch.vtensor.literal(dense<-3.40282347E+38> : tensor<f32>) : !torch.vtensor<[],f32>
%24 = torch.vtensor.literal(dense_resource<__elided__> : tensor<96x32xf32>) : !torch.vtensor<[96,32],f32>
%25 = torch.vtensor.literal(dense<0.000000e+00> : tensor<32xf32>) : !torch.vtensor<[32],f32>
%26 = torch.vtensor.literal(dense<1.000000e+00> : tensor<32xf32>) : !torch.vtensor<[32],f32>
%27 = torch.vtensor.literal(dense_resource<__elided__> : tensor<16x32xf32>) : !torch.vtensor<[16,32],f32>
%28 = torch.vtensor.literal(dense_resource<__elided__> : tensor<512x32xf32>) : !torch.vtensor<[512,32],f32>
%29 = torch.vtensor.literal(dense_resource<__elided__> : tensor<1000x32xf32>) : !torch.vtensor<[1000,32],f32>
%30 = torch.vtensor.literal(dense_resource<__elided__> : tensor<1x512xsi64>) : !torch.vtensor<[1,512],si64>
%false = torch.constant.bool false
%none = torch.constant.none
%int-1 = torch.constant.int -1
%true = torch.constant.bool true
%int-2 = torch.constant.int -2
%int11 = torch.constant.int 11
%str = torch.constant.str "none"
%int0 = torch.constant.int 0
%int9223372036854775807 = torch.constant.int 9223372036854775807
%int2 = torch.constant.int 2
%float9.999990e-08 = torch.constant.float 9.9999999999999995E-8
%int96 = torch.constant.int 96
%int4 = torch.constant.int 4
%int3 = torch.constant.int 3
%int8 = torch.constant.int 8
%int16 = torch.constant.int 16
%int24 = torch.constant.int 24
%float4.000000e00 = torch.constant.float 4.000000e+00
%int37 = torch.constant.int 37
%cpu = torch.constant.device "cpu"
%31 = torch.prim.ListConstruct %int1, %int128 : (!torch.int, !torch.int) -> !torch.list<int>
%32 = torch.aten.ones %31, %none, %none, %cpu, %false : !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1,128],f32>
%33 = torch.aten.zeros %31, %int4, %none, %cpu, %false : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1,128],si64>
%34 = torch.aten.slice.Tensor %30, %int0, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,512],si64>
%35 = torch.aten.slice.Tensor %34, %int1, %int0, %int128, %int1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128],si64>
%36 = torch.aten.embedding %29, %arg0, %int0, %false, %false : !torch.vtensor<[1000,32],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,32],f32>
%37 = torch.aten.embedding %28, %35, %int-1, %false, %false : !torch.vtensor<[512,32],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,32],f32>
%38 = torch.aten.add.Tensor %36, %37, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%39 = torch.aten.embedding %27, %33, %int-1, %false, %false : !torch.vtensor<[16,32],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,32],f32>
%40 = torch.aten.add.Tensor %38, %39, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%41 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%42 = torch.aten.sum.dim_IntList %40, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%43 = torch.aten.div.Scalar %42, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%44 = torch.aten.sub.Tensor %40, %43, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%45 = torch.aten.pow.Tensor_Scalar %44, %int2 : !torch.vtensor<[1,128,32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%46 = torch.aten.sum.dim_IntList %45, %41, %true, %none : !torch.vtensor<[1,128,32],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%47 = torch.aten.div.Scalar %46, %int32 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%48 = torch.aten.sub.Tensor %40, %43, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%49 = torch.aten.add.Scalar %47, %float9.999990e-08, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%50 = torch.aten.sqrt %49 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%51 = torch.aten.div.Tensor %48, %50 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%52 = torch.aten.mul.Tensor %26, %51 : !torch.vtensor<[32],f32>, !torch.vtensor<[1,128,32],f32> -> !torch.vtensor<[1,128,32],f32>
%53 = torch.aten.add.Tensor %52, %25, %int1 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[1,128,32],f32>
%54 = torch.aten.unsqueeze %32, %int2 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%55 = torch.aten.mul.Tensor %53, %54 : !torch.vtensor<[1,128,32],f32>, !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,32],f32>
%56 = torch.aten.unsqueeze %32, %int1 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,1,128],f32>
%57 = torch.aten.unsqueeze %56, %int2 : !torch.vtensor<[1,1,128],f32>, !torch.int -> !torch.vtensor<[1,1,1,128],f32>
%58 = torch.aten.squeeze.dim %57, %int-2 : !torch.vtensor<[1,1,1,128],f32>, !torch.int -> !torch.vtensor<[1,1,128],f32>
%59 = torch.aten.unsqueeze %58, %int-1 : !torch.vtensor<[1,1,128],f32>, !torch.int -> !torch.vtensor<[1,1,128,1],f32>
%60 = torch.aten.mul.Tensor %57, %59 : !torch.vtensor<[1,1,1,128],f32>, !torch.vtensor<[1,1,128,1],f32> -> !torch.vtensor<[1,1,128,128],f32>
%61 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%62 = torch.aten.to.dtype %61, %int1, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],si8>
return %54 : !torch.vtensor<[1,128,1],f32>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment