Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save AmosLewis/1ef4e7c549d5a3d3906e6b6f128d4f36 to your computer and use it in GitHub Desktop.
Save AmosLewis/1ef4e7c549d5a3d3906e6b6f128d4f36 to your computer and use it in GitHub Desktop.
// RUN: stablehlo-translate --interpret -split-input-file %s
func.func @scatter_op_test() {
%inputs = stablehlo.constant dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>
%scatter_indices = stablehlo.constant dense<[[0, 1], [0, 2], [0, 3]]> : tensor<3x2xi64>
%updates = stablehlo.constant dense<[[4], [5], [6]]> : tensor<3x1xi64>
%result = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) ({
^bb0(%arg0: tensor<i64>):
stablehlo.return %arg0 : tensor<i64>
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [1],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [0, 1],
index_vector_dim = 1>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<1x4xi64>, tensor<3x2xi64>, tensor<3x1xi64>) -> tensor<1x4xi64>
check.expect_eq_const %result, dense<[[0, 4, 5, 6]]> : tensor<1x4xi64>
func.return
}
size(inputs) = size(updates) = size(results) = 1 = N
input = inputs[0]
update = updates[0]
result = results[0]
// %input:
// [
// [0, 1, 2, 3]
// ]
// [1, 4]
// %update:
// [
// [1], [2], [3]
// ]
// shape of update=updates[0] is [3,1]
// %scatter_indices:
// [
// [0, 1],
// [0, 2],
// [0, 3]
// ]
// [3, 2]
index space of updates[0] has 3 (=3x1) entries.
For example, (0,0), (1,0), (2,0)
Index in update: U
U1,U2....,U3
Index in result: I
I1,I2...,I3
Then the updated value of results[0][I] is
update_computation(results[0][I] /old value/, updates[0][U]/new value/).
// %scatter_indices: [[0, 1], [0, 2], [0, 3]]
// [3, 2]
U -> I
update_index = U2 = (1,0)
update_scatter_dims = [0]
update_scatter_index = [1]
start_index = [0,2]
full_start_index = [0, 2]
update_window_index = [0]
full_window_index = [0,2].
result_index = [0, 2] + [0, 0] = [0, 2]
results[0][0, 2] = update_compulation(results[0][0,2] /*old value*/, updates[0][1,0]) = replace(3, 5) = 20
STEP BY STEP DERIVED:
More formally, for all update_index(U) in index_space(updates[0]):
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
GIVEN:
// %update: [
// [1], [2], [3]
// ]
// shape of update=updates[0] is [3,1]
->
d = axes(updates[0]) = axes(update) :[0,1]
GIVEN update_window_dims: [1]
->
update_scatter_dims = [0]
NOTE: Why do we need the update_window_dims?
The update_window_dims means in the %update, only the dim1 has the actual value that will be used. The other dim will only be used to index it. The update could be rearranged in different dimention. But the dim1 value should not be changed. For example, the shape of %update is [3,1], but if we reshape it to [1,3,1], theoratically, it should also work. The update_window_dims is still: [2], the update_scatter_dims is [0,1]
update_scatter_index = update_index[update_scatter_dims...].
If choose update_index = U2
= (1,0)
->
update_scatter_index = update_index[update_scatter_dims]
= update_index[not_update_window_dims]
= U2[[0]]
= [1]
IF WE SET GIVEN index_vector_dim = 1
start_index is defined as:
scatter_indices[si0, ..., :, ..., siN] where si are individual elements in update_scatter_index and : is inserted at the index_vector_dim index, if index_vector_dim < rank(scatter_indices).
GIVEN
update_scatter_index = U2[[0]] = [1]
->
si0 = 1
GIVEN
scatter_indices: [
[0, 1], [0, 2], [0, 3]
] // [3, 2]
->
rank(scatter_indices) = 2
GIVEN index_vector_dim = 1 ???
->
index_vector_dim < rank(scatter_indices)
1 < 2
->
start_index = scatter_indices[si0,:]
= scatter_indices(update_scatter_index...,:)
= scatter_indices[1,:]
= [0, 2]
: means take everything in dim==1 in scatter_indices[1]
[scatter_indices[update_scatter_index]] otherwise.
-> SKIP
NOTE: index_vector_dim = 1 means the index value in scatter_indices like [0,2] is started from dim=1 of scatter_indices.???
IF WE SET GIVEN index_vector_dim = 2
start_index is defined as:
scatter_indices[si0, ..., :, ..., siN] where si are individual elements in update_scatter_index and : is inserted at the index_vector_dim index, if index_vector_dim < rank(scatter_indices).
GIVEN
update_scatter_index = U2[[0]] = [1]
->
si0 = 1
GIVEN
scatter_indices: [
[0, 1], [0, 2], [0, 3]
] // [3, 2]
->
rank(scatter_indices) = 2
GIVEN index_vector_dim = 2
->
index_vector_dim < rank(scatter_indices)
2 !< 2
[scatter_indices[update_scatter_index]] otherwise.
-> start_index = scatter_indices[[1]] = [0,2]
// %input: [
// [1, 2, 3, 4]
// ]
// [1, 4]
For d_input in axes(inputs[0]):
GIVEN input = inputs[0],
->
input.shape = [1,4]
->
d_input = 0, 1
full_start_index[d_input] = start_index[d_start] if d_input = scatter_dims_to_operand_dims[d_start].
full_start_index[d_input] = 0 otherwise.
GIVEN
scatter_dims_to_operand_dims = [0, 1]
start_index = [0,2]
->
If d_input = 0
full_start_index[d_input] = full_start_index[0] = start_index[d_start]
If d_start = 0, SUCCESS
-> scatter_dims_to_operand_dims[d_start]
= scatter_dims_to_operand_dims[0]
= 0
-> d_input=0 == scatter_dims_to_operand_dims[d_start]=0 SUCCESS
-> full_start_index[d_input] = start_index[d_start]
-> full_start_index[0] = start_index[0] = 0
If d_start = 1, SKIP
-> scatter_dims_to_operand_dims[d_start]
= scatter_dims_to_operand_dims[1]
= 1,
-> d_input=0 != scatter_dims_to_operand_dims[d_start]=1 SKIP
If d_input = 1
full_start_index[d_input] = full_start_index[0] = start_index[d_start]
If d_start = 0, SKIP
-> scatter_dims_to_operand_dims[d_start] = scatter_dims_to_operand_dims[0]= 0
-> d_input1 != scatter_dims_to_operand_dims[d_start]=0 SKIP
If d_start = 1, SUCCESS
-> scatter_dims_to_operand_dims[d_start] = scatter_dims_to_operand_dims[1]= 1,
-> d_input=1 == scatter_dims_to_operand_dims[d_start]=1
-> full_start_index[d_input] = start_index[d_start]
-> full_start_index[1] = start_index[1] = 2
SO full_start_index = [0, 2]
NOTE: scatter_dims_to_operand_dims, is because the value in scatter_indices might represent different different dim in operand, say the value [0,2] in scatter_indices, since scatter_dims_to_operand_dims=[0,1], so the indices 2 represent the indice that need to be selected on the input's dim 1.
If we change the scatter_dims_to_operand_dims=[1,0], so the indices 2 represent the indice that need to be selected on the input's dim 0. The result full_start_index = [2,0].
The scatter_dims_to_operand_dims=[0,1] is more intuitive in most cases.
update_window_index = update_index[update_window_dims...].
GIVEN
update_window_dims: [1]
update_index = U2 = (1,0)
->
update_window_index = [0]
full_window_index = [wi0, ..., 0, ..., wiN] where wi are individual elements in update_window_index, and 0 is inserted at indices from inserted_window_dims.
GIVEN
update_window_index = [0]
->
wi0 = 0
GIVEN
inserted_window_dims = [0]
->
full_window_index = [0, 0]
result_index = full_start_index + full_window_index.
GIVEN
full_start_index = [0, 2]
full_window_index = [0, 0]
->
result_index = [0, 2] + [0, 0] = [0, 2]
Given that, results = exec(schedule, inputs), where:
schedule is an implementation-defined permutation of index_space(updates[0]).
exec([update_index, ...], results) = exec([...], updated_results) where:
If result_index is in bounds for shape(results...)
updated_values = update_computation(results...[result_index], updates...[update_index]).
updated_results is a copy of results with results...[result_index] set to updated_values....
Otherwise
updated_results = results.
exec([], results) = results.
results[0][0, 2]
= updated_value
= update_computation(results...[result_index], updates...[update_index])
= update_computation(results[0][result_index], updates[0][U2])
= update_compulation(results[0][0,2] /*old value*/, updates[0][1,0]) = replace(2, 5) = 5
Old: [[0, 1, 2, 3]]
updates:[[4], [5], [6]]
Result, [[0, 4, 5, 6]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment