Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active August 1, 2023 01:46
Show Gist options
  • Save AmosLewis/197d7b0e512db98d44715f269ad30069 to your computer and use it in GitHub Desktop.
Save AmosLewis/197d7b0e512db98d44715f269ad30069 to your computer and use it in GitHub Desktop.
spec_scatter.txt
size(inputs) = size(updates) = size(results) = 1 = N
input = inputs[0]
update = updates[0]
result = results[0]
// %input:
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10], [11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// [3, 4, 2]
// %update:
// [
// [ [[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]] ],
// [ [[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]] ]
// ]
// shape of update=updates[0] is [2,3,2,2]
// %scatter_indices:
// [
// [ [0, 2], [1, 0], [2, 1] ],
// [ [0, 1], [1, 0], [2, 0] ]
// ]
// [2, 3, 2]
index space of updates[0] has 24 (=2x3x2x2) entries.
For example, (0,0,0,0), (0,0,0,1), (0,0,1,0), (0,0,1,1), (0,1,0,0), (0,1,0,1), (0,1,1,0), (0,1,1,1), (0,2,0,0), ..., (1,2,1,0), (1,2,1,1)
Index in update: U
U1,U2....,U23,U24
Index in result: I
I1,I2...,I23,I24
Then the updated value of results[0][I] is
update_computation(results[0][I] /old value/, updates[0][U]/new value/).
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [2, 0]]]
// [2, 3, 2]
U -> I
update_index = U23 = (1,2,1,0)
update_scatter_dims = [0,1]
update_scatter_index = [1,2]
start_index = [0,2]
full_start_index = [2, 0, 0]
update_window_index = [1,0]
full_window_index = [0,1,0].
result_index = [2, 0, 0] + [0, 1, 0] = [2, 1, 0]
results[0][2,1,0] = update_compulation(results[0][2,1,0] /*old value*/, updates[0][1,2,1,0]) = add(19, 1) = 20
STEP BY STEP DERIVED:
More formally, for all update_index(U) in index_space(updates[0]):
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
GIVEN:
// %update: [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
// shape of update=updates[0] is [2,3,2,2]
->
d = axes(updates[0]) = axes(update) :[0,1,2,3]
GIVEN update_window_dims: [2,3]
->
update_scatter_dims = [0,1]
NOTE: Why do we need the update_window_dims?
The update_window_dims means in the %update, only the 2 and 3 dim has the actual value that will be used. The other dim will only be used to index it. The update could be rearranged in different dimention. But the 2 and 3 dim value should not be changed. For example, the shape of %update is [2,3,2,2], but if we reshape it to [1,6,2,2], theoratically, it should also work. The update_window_dims is still: [2,3], the update_scatter_dims is still [0,1]
update_scatter_index = update_index[update_scatter_dims...].
If choose update_index = U23
= (1,2,1,0)
->
update_scatter_index = update_index[update_scatter_dims]
= update_index[not_update_window_dims]
= U23[[0,1]]
= [1,2]
start_index is defined as:
scatter_indices[si0, ..., :, ..., siN] where si are individual elements in update_scatter_index and : is inserted at the index_vector_dim index, if index_vector_dim < rank(scatter_indices).
GIVEN
update_scatter_index = U23[[0,1]] = [1,2]
->
si0 = 1, si1 = 2
GIVEN
scatter_indices: [
[[0, 2], [1, 0], [2, 1]],
[[0, 1], [1, 0], [0, 2]]
] // [2, 3, 2]
->
rank(scatter_indices) = 3
GIVEN index_vector_dim = 2
->
index_vector_dim < rank(scatter_indices)
2 < 3
->
start_index = scatter_indices[si0,si1,:]
= scatter_indices(update_scatter_index...,:)
= scatter_indices[1,2,:]
= [0, 2]
: means take everything in dim==2 in scatter_indices[1,2]
[scatter_indices[update_scatter_index]] otherwise.
-> SKIP
NOTE: index_vector_dim = 2 means the index value in scatter_indices like [0,2] is started from dim=2 of scatter_indices.
// %input: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10], [11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// [3, 4, 2]
For d_input in axes(inputs[0]):
GIVEN input = inputs[0],
->
input.shape = [3,4,2]
->
d_input = 0, 1, 2
full_start_index[d_input] = start_index[d_start] if d_input = scatter_dims_to_operand_dims[d_start].
full_start_index[d_input] = 0 otherwise.
GIVEN
scatter_dims_to_operand_dims = [1, 0]
start_index = [0,2]
->
If d_input = 0
full_start_index[d_input] = full_start_index[0] = start_index[d_start]
If d_start = 0, SKIP
-> scatter_dims_to_operand_dims[d_start]
= scatter_dims_to_operand_dims[0]
= 1
-> d_input != scatter_dims_to_operand_dims[d_start] SKIP
If d_start = 1, SUCCESS
-> scatter_dims_to_operand_dims[d_start]
= scatter_dims_to_operand_dims[1]
= 0,
-> d_input == scatter_dims_to_operand_dims[d_start] = 0
-> start_index[d_start]
= start_index[1]
= 2
-> full_start_index[d_input] = start_index[d_start]
-> full_start_index[0] = start_index[1] = 2
If d_input = 1
full_start_index[d_input] = full_start_index[0] = start_index[d_start]
If d_start = 0, SUCCESS
-> scatter_dims_to_operand_dims[d_start] = scatter_dims_to_operand_dims[0]= 1
-> d_input == scatter_dims_to_operand_dims[d_start] = 1
-> full_start_index[d_input] = start_index[d_start]
-> full_start_index[1] = start_index[0] = 0
If d_start = 1, SKIP
-> scatter_dims_to_operand_dims[d_start] = scatter_dims_to_operand_dims[1]= 0,
-> d_input != scatter_dims_to_operand_dims[d_start] SKIP
If d_input = 2
full_start_index[d_input] = full_start_index[0] = start_index[d_start]
If d_start = 0, SKIP
-> scatter_dims_to_operand_dims[d_start] = scatter_dims_to_operand_dims[0]= 1
-> d_input == scatter_dims_to_operand_dims[d_start] = 1
If d_start = 1, SKIP
-> scatter_dims_to_operand_dims[d_start] = scatter_dims_to_operand_dims[1]= 0,
-> d_input != scatter_dims_to_operand_dims[d_start] SKIP
full_start_index[d_input] = full_start_index[2] = 0.
SO full_start_index = [2, 0, 0]
NOTE: scatter_dims_to_operand_dims, is because the value in scatter_indices might represent different different dim in operand, say the value [0,2] in scatter_indices, since scatter_dims_to_operand_dims=[1,0], so the indices 2 represent the indice that need to be selected on the input's dim 0.
If we change the scatter_dims_to_operand_dims=[0,1], so the indices 2 represent the indice that need to be selected on the input's dim 2. The result full_start_index = [0,2,0].
The scatter_dims_to_operand_dims=[0,1] is more intuitive in most cases. The example is just intentionaly flip the setting to show that the order of indices setting is flexible.
update_window_index = update_index[update_window_dims...].
GIVEN
update_window_dims: [2,3]
update_index = U23 = (1,2,1,0)
->
update_window_index = [1,0]
full_window_index = [wi0, ..., 0, ..., wiN] where wi are individual elements in update_window_index, and 0 is inserted at indices from inserted_window_dims.
GIVEN
update_window_index = [1,0]
->
wi0 = 1, wi1 = 0
GIVEN
inserted_window_dims = [0]
->
full_window_index = [0, 1, 0]
result_index = full_start_index + full_window_index.
GIVEN
full_start_index = [2, 0, 0]
full_window_index = [0, 1, 0]
->
result_index = [2, 0, 0] + [0, 1, 0] = [2, 1, 0]
Given that, results = exec(schedule, inputs), where:
schedule is an implementation-defined permutation of index_space(updates[0]).
exec([update_index, ...], results) = exec([...], updated_results) where:
If result_index is in bounds for shape(results...)
updated_values = update_computation(results...[result_index], updates...[update_index]).
updated_results is a copy of results with results...[result_index] set to updated_values....
Otherwise
updated_results = results.
exec([], results) = results.
results[0][2,1,0]
= updated_value
= update_computation(results...[result_index], updates...[update_index])
= update_computation(results[0][result_index], updates[0][U23])
= update_compulation(results[0][2,1,0] /*old value*/, updates[0][1,2,1,0]) = add(19, 1) = 20
@AmosLewis
Copy link
Author

AmosLewis commented Jul 21, 2023

// RUN: stablehlo-translate --interpret -split-input-file %s

func.func @scatter_op_test() {
%inputs = stablehlo.constant dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>
%scatter_indices = stablehlo.constant dense<[[0, 1], [0, 2], [0, 3]]> : tensor<3x2xi64>
%updates = stablehlo.constant dense<[[4], [5], [6]]> : tensor<3x1xi64>
%result = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) ({
^bb0(%arg0: tensor):
stablehlo.return %arg0 : tensor
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [1],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [0, 1],
index_vector_dim = 1>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<1x4xi64>, tensor<3x2xi64>, tensor<3x1xi64>) -> tensor<1x4xi64>
check.expect_eq_const %result, dense<[[0, 4, 5, 6]]> : tensor<1x4xi64>
func.return
}
Old: [[0, 1, 2, 3]]
updates:[[4], [5], [6]]
Result, [[0, 4, 5, 6]]

detail explanations

@AmosLewis
Copy link
Author

AmosLewis commented Aug 1, 2023

func.func @torch.aten._index_put_impl(%input: !torch.vtensor<[1,4],si64>, %index: !torch.vtensor<[3],si64>, %fillValues: !torch.vtensor<[3,1],si64>) -> !torch.vtensor<[1,4],si64>{
  %false = torch.constant.bool false
  %none = torch.constant.none
  %indices = torch.prim.ListConstruct %none, %index : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
  %out = torch.aten._index_put_impl %input, %indices, %fillValues, %false, %false : !torch.vtensor<[1,4],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[3,1],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[1,4],si64>
  return %out : !torch.vtensor<[1,4],si64>
}

(mlir_venv) ➜ torch-mlir git:(stablehlo_indexput) ✗ torch-mlir-opt --convert-torch-to-stablehlo /nodclouddata/chi/src/torch-mlir/test_indexput.mlir
scatter_indices:
%8 = stablehlo.concatenate %6, %7, dim = 1 : (tensor<3x1xi64>, tensor<3x1xi64>) -> tensor<3x2xi64>
updates:
%0 = unrealized_conversion_cast %arg2 : !torch.vtensor<[3,1],si64> to tensor<3x1xi64>
input:
%1 = unrealized_conversion_cast %arg0 : !torch.vtensor<[1,4],si64> to tensor<1x4xi64>
indexVectorDim: 1
scatterDimOperandDimMap: 0 1
insertedWindowDims: 0
updateWindowDims: 1
rank-of('scatter_indices'): 2
size-of('update_window_dims'): 1

module {
  func.func @torch.aten._index_put_impl(%arg0: !torch.vtensor<[1,4],si64>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3,1],si64>) -> !torch.vtensor<[1,4],si64> {
    %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,4],si64> -> tensor<1x4xi64>
    %1 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[3,1],si64> -> tensor<3x1xi64>
    %false = torch.constant.bool false
    %none = torch.constant.none
    %2 = torch.prim.ListConstruct %none, %arg1 : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
    %3 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[3],si64> -> tensor<3xi64>
    %4 = stablehlo.constant dense<0> : tensor<3xi64>
    %5 = stablehlo.reshape %4 : (tensor<3xi64>) -> tensor<3x1xi64>
    %6 = stablehlo.reshape %3 : (tensor<3xi64>) -> tensor<3x1xi64>
    %7 = stablehlo.concatenate %5, %6, dim = 1 : (tensor<3x1xi64>, tensor<3x1xi64>) -> tensor<3x2xi64>
    %8 = "stablehlo.scatter"(%0, %7, %1) ({
    ^bb0(%arg3: tensor<i64>, %arg4: tensor<i64>):
      stablehlo.return %arg4 : tensor<i64>
    }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = false} : (tensor<1x4xi64>, tensor<3x2xi64>, tensor<3x1xi64>) -> tensor<1x4xi64>
    %9 = torch_c.from_builtin_tensor %8 : tensor<1x4xi64> -> !torch.vtensor<[1,4],si64>
    return %9 : !torch.vtensor<[1,4],si64>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment