Last active
August 1, 2023 01:46
-
-
Save AmosLewis/197d7b0e512db98d44715f269ad30069 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
spec_scatter.txt | |
size(inputs) = size(updates) = size(results) = 1 = N | |
input = inputs[0] | |
update = updates[0] | |
result = results[0] | |
// %input: | |
// [ | |
// [[1, 2], [3, 4], [5, 6], [7, 8]], | |
// [[9, 10], [11, 12], [13, 14], [15, 16]], | |
// [[17, 18], [19, 20], [21, 22], [23, 24]] | |
// ] | |
// [3, 4, 2] | |
// %update: | |
// [ | |
// [ [[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]] ], | |
// [ [[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]] ] | |
// ] | |
// shape of update=updates[0] is [2,3,2,2] | |
// %scatter_indices: | |
// [ | |
// [ [0, 2], [1, 0], [2, 1] ], | |
// [ [0, 1], [1, 0], [2, 0] ] | |
// ] | |
// [2, 3, 2] | |
index space of updates[0] has 24 (=2x3x2x2) entries. | |
For example, (0,0,0,0), (0,0,0,1), (0,0,1,0), (0,0,1,1), (0,1,0,0), (0,1,0,1), (0,1,1,0), (0,1,1,1), (0,2,0,0), ..., (1,2,1,0), (1,2,1,1) | |
Index in update: U | |
U1,U2....,U23,U24 | |
Index in result: I | |
I1,I2...,I23,I24 | |
Then the updated value of results[0][I] is | |
update_computation(results[0][I] /old value/, updates[0][U]/new value/). | |
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [2, 0]]] | |
// [2, 3, 2] | |
U -> I | |
update_index = U23 = (1,2,1,0) | |
update_scatter_dims = [0,1] | |
update_scatter_index = [1,2] | |
start_index = [0,2] | |
full_start_index = [2, 0, 0] | |
update_window_index = [1,0] | |
full_window_index = [0,1,0]. | |
result_index = [2, 0, 0] + [0, 1, 0] = [2, 1, 0] | |
results[0][2,1,0] = update_compulation(results[0][2,1,0] /*old value*/, updates[0][1,2,1,0]) = add(19, 1) = 20 | |
STEP BY STEP DERIVED: | |
More formally, for all update_index(U) in index_space(updates[0]): | |
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]. | |
GIVEN: | |
// %update: [ | |
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]], | |
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]] | |
// ] | |
// shape of update=updates[0] is [2,3,2,2] | |
-> | |
d = axes(updates[0]) = axes(update) :[0,1,2,3] | |
GIVEN update_window_dims: [2,3] | |
-> | |
update_scatter_dims = [0,1] | |
NOTE: Why do we need the update_window_dims? | |
The update_window_dims means in the %update, only the 2 and 3 dim has the actual value that will be used. The other dim will only be used to index it. The update could be rearranged in different dimention. But the 2 and 3 dim value should not be changed. For example, the shape of %update is [2,3,2,2], but if we reshape it to [1,6,2,2], theoratically, it should also work. The update_window_dims is still: [2,3], the update_scatter_dims is still [0,1] | |
update_scatter_index = update_index[update_scatter_dims...]. | |
If choose update_index = U23 | |
= (1,2,1,0) | |
-> | |
update_scatter_index = update_index[update_scatter_dims] | |
= update_index[not_update_window_dims] | |
= U23[[0,1]] | |
= [1,2] | |
start_index is defined as: | |
scatter_indices[si0, ..., :, ..., siN] where si are individual elements in update_scatter_index and : is inserted at the index_vector_dim index, if index_vector_dim < rank(scatter_indices). | |
GIVEN | |
update_scatter_index = U23[[0,1]] = [1,2] | |
-> | |
si0 = 1, si1 = 2 | |
GIVEN | |
scatter_indices: [ | |
[[0, 2], [1, 0], [2, 1]], | |
[[0, 1], [1, 0], [0, 2]] | |
] // [2, 3, 2] | |
-> | |
rank(scatter_indices) = 3 | |
GIVEN index_vector_dim = 2 | |
-> | |
index_vector_dim < rank(scatter_indices) | |
2 < 3 | |
-> | |
start_index = scatter_indices[si0,si1,:] | |
= scatter_indices(update_scatter_index...,:) | |
= scatter_indices[1,2,:] | |
= [0, 2] | |
: means take everything in dim==2 in scatter_indices[1,2] | |
[scatter_indices[update_scatter_index]] otherwise. | |
-> SKIP | |
NOTE: index_vector_dim = 2 means the index value in scatter_indices like [0,2] is started from dim=2 of scatter_indices. | |
// %input: [ | |
// [[1, 2], [3, 4], [5, 6], [7, 8]], | |
// [[9, 10], [11, 12], [13, 14], [15, 16]], | |
// [[17, 18], [19, 20], [21, 22], [23, 24]] | |
// ] | |
// [3, 4, 2] | |
For d_input in axes(inputs[0]): | |
GIVEN input = inputs[0], | |
-> | |
input.shape = [3,4,2] | |
-> | |
d_input = 0, 1, 2 | |
full_start_index[d_input] = start_index[d_start] if d_input = scatter_dims_to_operand_dims[d_start]. | |
full_start_index[d_input] = 0 otherwise. | |
GIVEN | |
scatter_dims_to_operand_dims = [1, 0] | |
start_index = [0,2] | |
-> | |
If d_input = 0 | |
full_start_index[d_input] = full_start_index[0] = start_index[d_start] | |
If d_start = 0, SKIP | |
-> scatter_dims_to_operand_dims[d_start] | |
= scatter_dims_to_operand_dims[0] | |
= 1 | |
-> d_input != scatter_dims_to_operand_dims[d_start] SKIP | |
If d_start = 1, SUCCESS | |
-> scatter_dims_to_operand_dims[d_start] | |
= scatter_dims_to_operand_dims[1] | |
= 0, | |
-> d_input == scatter_dims_to_operand_dims[d_start] = 0 | |
-> start_index[d_start] | |
= start_index[1] | |
= 2 | |
-> full_start_index[d_input] = start_index[d_start] | |
-> full_start_index[0] = start_index[1] = 2 | |
If d_input = 1 | |
full_start_index[d_input] = full_start_index[0] = start_index[d_start] | |
If d_start = 0, SUCCESS | |
-> scatter_dims_to_operand_dims[d_start] = scatter_dims_to_operand_dims[0]= 1 | |
-> d_input == scatter_dims_to_operand_dims[d_start] = 1 | |
-> full_start_index[d_input] = start_index[d_start] | |
-> full_start_index[1] = start_index[0] = 0 | |
If d_start = 1, SKIP | |
-> scatter_dims_to_operand_dims[d_start] = scatter_dims_to_operand_dims[1]= 0, | |
-> d_input != scatter_dims_to_operand_dims[d_start] SKIP | |
If d_input = 2 | |
full_start_index[d_input] = full_start_index[0] = start_index[d_start] | |
If d_start = 0, SKIP | |
-> scatter_dims_to_operand_dims[d_start] = scatter_dims_to_operand_dims[0]= 1 | |
-> d_input == scatter_dims_to_operand_dims[d_start] = 1 | |
If d_start = 1, SKIP | |
-> scatter_dims_to_operand_dims[d_start] = scatter_dims_to_operand_dims[1]= 0, | |
-> d_input != scatter_dims_to_operand_dims[d_start] SKIP | |
full_start_index[d_input] = full_start_index[2] = 0. | |
SO full_start_index = [2, 0, 0] | |
NOTE: scatter_dims_to_operand_dims, is because the value in scatter_indices might represent different different dim in operand, say the value [0,2] in scatter_indices, since scatter_dims_to_operand_dims=[1,0], so the indices 2 represent the indice that need to be selected on the input's dim 0. | |
If we change the scatter_dims_to_operand_dims=[0,1], so the indices 2 represent the indice that need to be selected on the input's dim 2. The result full_start_index = [0,2,0]. | |
The scatter_dims_to_operand_dims=[0,1] is more intuitive in most cases. The example is just intentionaly flip the setting to show that the order of indices setting is flexible. | |
update_window_index = update_index[update_window_dims...]. | |
GIVEN | |
update_window_dims: [2,3] | |
update_index = U23 = (1,2,1,0) | |
-> | |
update_window_index = [1,0] | |
full_window_index = [wi0, ..., 0, ..., wiN] where wi are individual elements in update_window_index, and 0 is inserted at indices from inserted_window_dims. | |
GIVEN | |
update_window_index = [1,0] | |
-> | |
wi0 = 1, wi1 = 0 | |
GIVEN | |
inserted_window_dims = [0] | |
-> | |
full_window_index = [0, 1, 0] | |
result_index = full_start_index + full_window_index. | |
GIVEN | |
full_start_index = [2, 0, 0] | |
full_window_index = [0, 1, 0] | |
-> | |
result_index = [2, 0, 0] + [0, 1, 0] = [2, 1, 0] | |
Given that, results = exec(schedule, inputs), where: | |
schedule is an implementation-defined permutation of index_space(updates[0]). | |
exec([update_index, ...], results) = exec([...], updated_results) where: | |
If result_index is in bounds for shape(results...) | |
updated_values = update_computation(results...[result_index], updates...[update_index]). | |
updated_results is a copy of results with results...[result_index] set to updated_values.... | |
Otherwise | |
updated_results = results. | |
exec([], results) = results. | |
results[0][2,1,0] | |
= updated_value | |
= update_computation(results...[result_index], updates...[update_index]) | |
= update_computation(results[0][result_index], updates[0][U23]) | |
= update_compulation(results[0][2,1,0] /*old value*/, updates[0][1,2,1,0]) = add(19, 1) = 20 | |
func.func @torch.aten._index_put_impl(%input: !torch.vtensor<[1,4],si64>, %index: !torch.vtensor<[3],si64>, %fillValues: !torch.vtensor<[3,1],si64>) -> !torch.vtensor<[1,4],si64>{
%false = torch.constant.bool false
%none = torch.constant.none
%indices = torch.prim.ListConstruct %none, %index : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
%out = torch.aten._index_put_impl %input, %indices, %fillValues, %false, %false : !torch.vtensor<[1,4],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[3,1],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[1,4],si64>
return %out : !torch.vtensor<[1,4],si64>
}
(mlir_venv) ➜ torch-mlir git:(stablehlo_indexput) ✗ torch-mlir-opt --convert-torch-to-stablehlo /nodclouddata/chi/src/torch-mlir/test_indexput.mlir
scatter_indices:
%8 = stablehlo.concatenate %6, %7, dim = 1 : (tensor<3x1xi64>, tensor<3x1xi64>) -> tensor<3x2xi64>
updates:
%0 = unrealized_conversion_cast %arg2 : !torch.vtensor<[3,1],si64> to tensor<3x1xi64>
input:
%1 = unrealized_conversion_cast %arg0 : !torch.vtensor<[1,4],si64> to tensor<1x4xi64>
indexVectorDim: 1
scatterDimOperandDimMap: 0 1
insertedWindowDims: 0
updateWindowDims: 1
rank-of('scatter_indices'): 2
size-of('update_window_dims'): 1
module {
func.func @torch.aten._index_put_impl(%arg0: !torch.vtensor<[1,4],si64>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3,1],si64>) -> !torch.vtensor<[1,4],si64> {
%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,4],si64> -> tensor<1x4xi64>
%1 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[3,1],si64> -> tensor<3x1xi64>
%false = torch.constant.bool false
%none = torch.constant.none
%2 = torch.prim.ListConstruct %none, %arg1 : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
%3 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[3],si64> -> tensor<3xi64>
%4 = stablehlo.constant dense<0> : tensor<3xi64>
%5 = stablehlo.reshape %4 : (tensor<3xi64>) -> tensor<3x1xi64>
%6 = stablehlo.reshape %3 : (tensor<3xi64>) -> tensor<3x1xi64>
%7 = stablehlo.concatenate %5, %6, dim = 1 : (tensor<3x1xi64>, tensor<3x1xi64>) -> tensor<3x2xi64>
%8 = "stablehlo.scatter"(%0, %7, %1) ({
^bb0(%arg3: tensor<i64>, %arg4: tensor<i64>):
stablehlo.return %arg4 : tensor<i64>
}) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = false} : (tensor<1x4xi64>, tensor<3x2xi64>, tensor<3x1xi64>) -> tensor<1x4xi64>
%9 = torch_c.from_builtin_tensor %8 : tensor<1x4xi64> -> !torch.vtensor<[1,4],si64>
return %9 : !torch.vtensor<[1,4],si64>
}
}
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
// RUN: stablehlo-translate --interpret -split-input-file %s
func.func @scatter_op_test() {
%inputs = stablehlo.constant dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>
%scatter_indices = stablehlo.constant dense<[[0, 1], [0, 2], [0, 3]]> : tensor<3x2xi64>
%updates = stablehlo.constant dense<[[4], [5], [6]]> : tensor<3x1xi64>
%result = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) ({
^bb0(%arg0: tensor):
stablehlo.return %arg0 : tensor
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [1],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [0, 1],
index_vector_dim = 1>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<1x4xi64>, tensor<3x2xi64>, tensor<3x1xi64>) -> tensor<1x4xi64>
check.expect_eq_const %result, dense<[[0, 4, 5, 6]]> : tensor<1x4xi64>
func.return
}
Old: [[0, 1, 2, 3]]
updates:[[4], [5], [6]]
Result, [[0, 4, 5, 6]]
detail explanations