Created
April 11, 2023 13:16
-
-
Save sergei-mironov/a40d20edf2bf8f7debc046ac41a08cd9 to your computer and use it in GitHub Desktop.
Reference implementation if MHLO::Scatter in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from frozendict import frozendict | |
from typing import List, Any, Callable | |
from copy import deepcopy | |
Dim = int # Dimention "names" | |
Index = Dict[Dim, int] # `Index :: Dimention -> Coordinate` | |
# (To-be implemented as Frozendict to make Python data hashable) | |
Value = complex # A value | |
Tensor = Dict[Index, Value] # `Tensor :: Index -> Value` | |
def make_tensor(shape:List[int], lst:list) -> Tensor: | |
""" Constructs a tensor """ | |
pass | |
def tensor_slice(t:Tensor, i:Index) -> List[Value]: | |
""" Aka `t[i[0], ... , : , ... , i[N-1]]` """ | |
pass | |
def tensor_update(t:Tensor, i:Index, v:List[Values]) -> Tensor: | |
""" Aka `t2=copy(t); t2[i[0], ... , : , ... , i[N-1]] = v[:]` """ | |
pass | |
def sctter(inputs:Tensor, | |
scatter_indices:Tensor, | |
updates:Tensor, | |
update_computation:Callable[[List[Value],List[Value]],List[Value]], | |
attrs) -> Tensor: | |
""" A reference `MHLO::Scatter` implementation in simple Python. | |
Ref. https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter | |
""" | |
def _result_index(update_index: Index) -> Index: | |
# Computes the `result_index` based on `updated_index`, `scatter_indices` and `attrs`. | |
pass | |
def _exec(update_indices:List[Index], results:Tensor) -> Tensor: | |
if len(update_indices)>0: | |
update_index:Index = update_indices[0] | |
result_index:Index = _result_index(update_index) | |
updated_values = update_computation(tensor_slice(results, result_index), | |
tensor_slice(updates, update_index)) | |
updated_results = tensor_update(results, result_index, updated_values) | |
return _exec(update_indices[1:], updated_results) | |
else: | |
return results | |
schedule = list(sorted(inputs.keys(), key=lambda i: sorted(i.items()))) # ??? | |
return _exec(schedule, inputs) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment