Skip to content

Instantly share code, notes, and snippets.

@wangkuiyi
Created February 17, 2023 22:26
Show Gist options
  • Save wangkuiyi/a0dfb8026a567f9942c75c58167f2f63 to your computer and use it in GitHub Desktop.
Save wangkuiyi/a0dfb8026a567f9942c75c58167f2f63 to your computer and use it in GitHub Desktop.
%1140 = stablehlo.broadcast_in_dim %1139, dims = [] : (tensor<i32>) -> tensor<500x256xi32>
%1141 = stablehlo.multiply %1136, %1140 : tensor<500x256xi32>
%1142 = stablehlo.constant dense<999993> : tensor<i64>
%1143 = call @jit__train_step_kernel$remainder_4(%1141, %1142) : (tensor<500x256xi32>, tensor<i64>) -> tensor<500x256xi32>
%1144 = stablehlo.broadcast_in_dim %arg9, dims = [1] : (tensor<256xi32>) -> tensor<1x256xi32>
%1145 = stablehlo.broadcast_in_dim %1144, dims = [0, 1] : (tensor<1x256xi32>) -> tensor<500x256xi32>
%1146 = stablehlo.add %1143, %1145 : tensor<500x256xi32>
%1147 = stablehlo.constant dense<0> : tensor<i32>
%1148 = stablehlo.broadcast_in_dim %1147, dims = [] : (tensor<i32>) -> tensor<500x256xi32>
%1149 = stablehlo.compare LT, %1146, %1148, SIGNED : (tensor<500x256xi32>, tensor<500x256xi32>) -> tensor<500x256xi1>
%1150 = stablehlo.constant dense<1000000> : tensor<i32>
%1151 = stablehlo.broadcast_in_dim %1150, dims = [] : (tensor<i32>) -> tensor<500x256xi32>
%1152 = stablehlo.add %1146, %1151 : tensor<500x256xi32>
%1153 = stablehlo.select %1149, %1152, %1146 : tensor<500x256xi1>, tensor<500x256xi32>
%1154 = stablehlo.broadcast_in_dim %1153, dims = [0, 1] : (tensor<500x256xi32>) -> tensor<500x256x1xi32>
%1155 = "stablehlo.gather"(%arg2, %1154) {dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, in
%1156 = stablehlo.transpose %1155, dims = [1, 0] : (tensor<500x256xf32>) -> tensor<256x500xf32>
%1157 = "stablehlo.dot_general"(%54, %1156) {dot_dimension_numbers = #stablehlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, pre
%1158 = stablehlo.reshape %1107 : (tensor<512xf32>) -> tensor<512x1xf32>
%1159 = stablehlo.concatenate %1158, %1157, dim = 1 : (tensor<512x1xf32>, tensor<512x500xf32>) -> tensor<512x501xf32>
%1160:3 = call @jit__train_step_kernel$log_softmax(%1159) : (tensor<512x501xf32>) -> (tensor<512x501xf32>, tensor<512x1xf32>, tensor<512x501xf32>)
%1161 = stablehlo.constant dense<0> : tensor<i32>
%1162 = stablehlo.broadcast_in_dim %1161, dims = [] : (tensor<i32>) -> tensor<1xi32>
%1163 = "stablehlo.gather"(%1160#0, %1162) {dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1]>, ind
%1164 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%1165 = stablehlo.reduce(%1163 init: %1164) across dimensions = [0] : (tensor<512xf32>, tensor<f32>) -> tensor<f32>
reducer(%arg10: tensor<f32>, %arg11: tensor<f32>) {
%1480 = stablehlo.add %arg10, %arg11 : tensor<f32>
stablehlo.return %1480 : tensor<f32>
}
%1166 = stablehlo.constant dense<5.120000e+02> : tensor<f32>
%1167 = stablehlo.divide %1165, %1166 : tensor<f32>
%1168 = stablehlo.negate %1167 : tensor<f32>
%1169 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%1170 = stablehlo.negate %1169 : tensor<f32>
%1171 = stablehlo.constant dense<5.120000e+02> : tensor<f32>
%1172 = stablehlo.divide %1170, %1171 : tensor<f32>
%1173 = stablehlo.broadcast_in_dim %1172, dims = [] : (tensor<f32>) -> tensor<512xf32>
%1174 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%1175 = stablehlo.reduce(%1173 init: %1174) across dimensions = [] : (tensor<512xf32>, tensor<f32>) -> tensor<512xf32>
reducer(%arg10: tensor<f32>, %arg11: tensor<f32>) {
%1480 = stablehlo.add %arg10, %arg11 : tensor<f32>
stablehlo.return %1480 : tensor<f32>
}
%1176 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%1177 = stablehlo.broadcast_in_dim %1176, dims = [] : (tensor<f32>) -> tensor<512x501xf32>
%1178 = "stablehlo.scatter"(%1177, %1162, %1175) ({
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>):
%1480 = stablehlo.add %arg10, %arg11 : tensor<f32>
stablehlo.return %1480 : tensor<f32>
}) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_
%1179 = call @jit__train_step_kernel$log_softmax_6(%1160#1, %1160#2, %1178) : (tensor<512x1xf32>, tensor<512x501xf32>, tensor<512x501xf32>) -> tensor<512x50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment