Last active
July 31, 2023 22:52
-
-
Save AmosLewis/1ef4e7c549d5a3d3906e6b6f128d4f36 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// RUN: stablehlo-translate --interpret -split-input-file %s | |
func.func @scatter_op_test() { | |
%inputs = stablehlo.constant dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> | |
%scatter_indices = stablehlo.constant dense<[[0, 1], [0, 2], [0, 3]]> : tensor<3x2xi64> | |
%updates = stablehlo.constant dense<[[4], [5], [6]]> : tensor<3x1xi64> | |
%result = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) ({ | |
^bb0(%arg0: tensor<i64>): | |
stablehlo.return %arg0 : tensor<i64> | |
}) { | |
scatter_dimension_numbers = #stablehlo.scatter< | |
update_window_dims = [1], | |
inserted_window_dims = [0], | |
scatter_dims_to_operand_dims = [0, 1], | |
index_vector_dim = 1>, | |
indices_are_sorted = false, | |
unique_indices = false | |
} : (tensor<1x4xi64>, tensor<3x2xi64>, tensor<3x1xi64>) -> tensor<1x4xi64> | |
check.expect_eq_const %result, dense<[[0, 4, 5, 6]]> : tensor<1x4xi64> | |
func.return | |
} | |
size(inputs) = size(updates) = size(results) = 1 = N | |
input = inputs[0] | |
update = updates[0] | |
result = results[0] | |
// %input: | |
// [ | |
// [0, 1, 2, 3] | |
// ] | |
// [1, 4] | |
// %update: | |
// [ | |
// [1], [2], [3] | |
// ] | |
// shape of update=updates[0] is [3,1] | |
// %scatter_indices: | |
// [ | |
// [0, 1], | |
// [0, 2], | |
// [0, 3] | |
// ] | |
// [3, 2] | |
index space of updates[0] has 3 (=3x1) entries. | |
For example, (0,0), (1,0), (2,0) | |
Index in update: U | |
U1,U2....,U3 | |
Index in result: I | |
I1,I2...,I3 | |
Then the updated value of results[0][I] is | |
update_computation(results[0][I] /old value/, updates[0][U]/new value/). | |
// %scatter_indices: [[0, 1], [0, 2], [0, 3]] | |
// [3, 2] | |
U -> I | |
update_index = U2 = (1,0) | |
update_scatter_dims = [0] | |
update_scatter_index = [1] | |
start_index = [0,2] | |
full_start_index = [0, 2] | |
update_window_index = [0] | |
full_window_index = [0,2]. | |
result_index = [0, 2] + [0, 0] = [0, 2] | |
results[0][0, 2] = update_compulation(results[0][0,2] /*old value*/, updates[0][1,0]) = replace(3, 5) = 20 | |
STEP BY STEP DERIVED: | |
More formally, for all update_index(U) in index_space(updates[0]): | |
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]. | |
GIVEN: | |
// %update: [ | |
// [1], [2], [3] | |
// ] | |
// shape of update=updates[0] is [3,1] | |
-> | |
d = axes(updates[0]) = axes(update) :[0,1] | |
GIVEN update_window_dims: [1] | |
-> | |
update_scatter_dims = [0] | |
NOTE: Why do we need the update_window_dims? | |
The update_window_dims means in the %update, only the dim1 has the actual value that will be used. The other dim will only be used to index it. The update could be rearranged in different dimention. But the dim1 value should not be changed. For example, the shape of %update is [3,1], but if we reshape it to [1,3,1], theoratically, it should also work. The update_window_dims is still: [2], the update_scatter_dims is [0,1] | |
update_scatter_index = update_index[update_scatter_dims...]. | |
If choose update_index = U2 | |
= (1,0) | |
-> | |
update_scatter_index = update_index[update_scatter_dims] | |
= update_index[not_update_window_dims] | |
= U2[[0]] | |
= [1] | |
IF WE SET GIVEN index_vector_dim = 1 | |
start_index is defined as: | |
scatter_indices[si0, ..., :, ..., siN] where si are individual elements in update_scatter_index and : is inserted at the index_vector_dim index, if index_vector_dim < rank(scatter_indices). | |
GIVEN | |
update_scatter_index = U2[[0]] = [1] | |
-> | |
si0 = 1 | |
GIVEN | |
scatter_indices: [ | |
[0, 1], [0, 2], [0, 3] | |
] // [3, 2] | |
-> | |
rank(scatter_indices) = 2 | |
GIVEN index_vector_dim = 1 ??? | |
-> | |
index_vector_dim < rank(scatter_indices) | |
1 < 2 | |
-> | |
start_index = scatter_indices[si0,:] | |
= scatter_indices(update_scatter_index...,:) | |
= scatter_indices[1,:] | |
= [0, 2] | |
: means take everything in dim==1 in scatter_indices[1] | |
[scatter_indices[update_scatter_index]] otherwise. | |
-> SKIP | |
NOTE: index_vector_dim = 1 means the index value in scatter_indices like [0,2] is started from dim=1 of scatter_indices.??? | |
IF WE SET GIVEN index_vector_dim = 2 | |
start_index is defined as: | |
scatter_indices[si0, ..., :, ..., siN] where si are individual elements in update_scatter_index and : is inserted at the index_vector_dim index, if index_vector_dim < rank(scatter_indices). | |
GIVEN | |
update_scatter_index = U2[[0]] = [1] | |
-> | |
si0 = 1 | |
GIVEN | |
scatter_indices: [ | |
[0, 1], [0, 2], [0, 3] | |
] // [3, 2] | |
-> | |
rank(scatter_indices) = 2 | |
GIVEN index_vector_dim = 2 | |
-> | |
index_vector_dim < rank(scatter_indices) | |
2 !< 2 | |
[scatter_indices[update_scatter_index]] otherwise. | |
-> start_index = scatter_indices[[1]] = [0,2] | |
// %input: [ | |
// [1, 2, 3, 4] | |
// ] | |
// [1, 4] | |
For d_input in axes(inputs[0]): | |
GIVEN input = inputs[0], | |
-> | |
input.shape = [1,4] | |
-> | |
d_input = 0, 1 | |
full_start_index[d_input] = start_index[d_start] if d_input = scatter_dims_to_operand_dims[d_start]. | |
full_start_index[d_input] = 0 otherwise. | |
GIVEN | |
scatter_dims_to_operand_dims = [0, 1] | |
start_index = [0,2] | |
-> | |
If d_input = 0 | |
full_start_index[d_input] = full_start_index[0] = start_index[d_start] | |
If d_start = 0, SUCCESS | |
-> scatter_dims_to_operand_dims[d_start] | |
= scatter_dims_to_operand_dims[0] | |
= 0 | |
-> d_input=0 == scatter_dims_to_operand_dims[d_start]=0 SUCCESS | |
-> full_start_index[d_input] = start_index[d_start] | |
-> full_start_index[0] = start_index[0] = 0 | |
If d_start = 1, SKIP | |
-> scatter_dims_to_operand_dims[d_start] | |
= scatter_dims_to_operand_dims[1] | |
= 1, | |
-> d_input=0 != scatter_dims_to_operand_dims[d_start]=1 SKIP | |
If d_input = 1 | |
full_start_index[d_input] = full_start_index[0] = start_index[d_start] | |
If d_start = 0, SKIP | |
-> scatter_dims_to_operand_dims[d_start] = scatter_dims_to_operand_dims[0]= 0 | |
-> d_input1 != scatter_dims_to_operand_dims[d_start]=0 SKIP | |
If d_start = 1, SUCCESS | |
-> scatter_dims_to_operand_dims[d_start] = scatter_dims_to_operand_dims[1]= 1, | |
-> d_input=1 == scatter_dims_to_operand_dims[d_start]=1 | |
-> full_start_index[d_input] = start_index[d_start] | |
-> full_start_index[1] = start_index[1] = 2 | |
SO full_start_index = [0, 2] | |
NOTE: scatter_dims_to_operand_dims, is because the value in scatter_indices might represent different different dim in operand, say the value [0,2] in scatter_indices, since scatter_dims_to_operand_dims=[0,1], so the indices 2 represent the indice that need to be selected on the input's dim 1. | |
If we change the scatter_dims_to_operand_dims=[1,0], so the indices 2 represent the indice that need to be selected on the input's dim 0. The result full_start_index = [2,0]. | |
The scatter_dims_to_operand_dims=[0,1] is more intuitive in most cases. | |
update_window_index = update_index[update_window_dims...]. | |
GIVEN | |
update_window_dims: [1] | |
update_index = U2 = (1,0) | |
-> | |
update_window_index = [0] | |
full_window_index = [wi0, ..., 0, ..., wiN] where wi are individual elements in update_window_index, and 0 is inserted at indices from inserted_window_dims. | |
GIVEN | |
update_window_index = [0] | |
-> | |
wi0 = 0 | |
GIVEN | |
inserted_window_dims = [0] | |
-> | |
full_window_index = [0, 0] | |
result_index = full_start_index + full_window_index. | |
GIVEN | |
full_start_index = [0, 2] | |
full_window_index = [0, 0] | |
-> | |
result_index = [0, 2] + [0, 0] = [0, 2] | |
Given that, results = exec(schedule, inputs), where: | |
schedule is an implementation-defined permutation of index_space(updates[0]). | |
exec([update_index, ...], results) = exec([...], updated_results) where: | |
If result_index is in bounds for shape(results...) | |
updated_values = update_computation(results...[result_index], updates...[update_index]). | |
updated_results is a copy of results with results...[result_index] set to updated_values.... | |
Otherwise | |
updated_results = results. | |
exec([], results) = results. | |
results[0][0, 2] | |
= updated_value | |
= update_computation(results...[result_index], updates...[update_index]) | |
= update_computation(results[0][result_index], updates[0][U2]) | |
= update_compulation(results[0][0,2] /*old value*/, updates[0][1,0]) = replace(2, 5) = 5 | |
Old: [[0, 1, 2, 3]] | |
updates:[[4], [5], [6]] | |
Result, [[0, 4, 5, 6]] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment