Created
May 20, 2017 07:02
-
-
Save Varriount/352df3399e13f889d989f620d16fa762 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2017 Mamy André-Ratsimbazafy | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
## This file deals with slicing. | |
## Foo being: | |
# Tensor of shape 5x5 of type "int" on backend "Cpu" | |
# |1 1 1 1 1| | |
# |2 4 8 16 32| | |
# |3 9 27 81 243| | |
# |4 16 64 256 1024| | |
# |5 25 125 625 3125| | |
## Target supported syntax is: | |
# | |
## Full Integer | |
# foo[2, 3] | |
# foo[1+1, 2*4*1] | |
# | |
## Slice | |
# echo foo[1+1..4,3] | |
# echo foo[1..2, 3] | |
# | |
## Span | |
# echo foo[_, 3] | |
# echo foo[_.._|2, 3] | |
# echo foo[1.._, 3] | |
# echo foo[1.._|1, 3] | |
# echo foo[1..|1, 3] | |
# echo foo[_..3, 3] | |
# echo foo[_..3|1, 3] | |
# echo foo[..3|1, 3] | |
# echo foo[_..^2, 3] | |
# echo foo[_..<4|2, 3] | |
# | |
## Combinations | |
# echo foo[1..2+1|2,3] | |
# | |
## End step from last | |
# echo foo[0..^4|1, 3] | |
# echo foo[1..^1|2, 3] | |
# | |
## Prefix | |
# echo foo[..^2|2, 3] | |
# echo foo[..4, 3] | |
# echo foo[..<4, 3] | |
# echo foo[..^10|2, 3] | |
# | |
## Start from last | |
# echo foo[^1..0|-1, 3] | |
# echo foo[^4..2*2, 3] | |
# echo foo[^1..2|-1, 3] | |
# echo foo[^1..2*2|-1, 3] | |
# | |
## NOT working yet | |
# echo foo[^3..^2, 3] | |
# echo foo[^(1..^2*3), 3] | |
###################################### | |
type SteppedSlice* = object | |
## A slice with step information and start is from beginning or end of range | |
a, b: int | |
step: int | |
a_from_end: bool | |
b_from_end: bool | |
## Necessary to avoid parenthesis due to operator precedence of | over .. | |
## [0..10|1] is intepreted as [0..(10|1)] | |
type Step* = object | |
## Holds the end of rang eand step | |
b: int | |
step: int | |
template check_steps(a,b, step:int) = | |
## Though it might be convenient to automatically step in the correct direction like in Python | |
## I choose not to do it as this might introduce the typical silent bugs typechecking/Nim is helping avoid. | |
if ((b-a) * step < 0): | |
raise newException(IndexError, "Your slice start: " & | |
$(a) & | |
", and stop: " & | |
$(b) & | |
", or your step: " & | |
$(step) & | |
""", are not correct. If your step is positive | |
start must be inferior to stop and inversely if your step is negative | |
start must be superior to stop.""") | |
## Procs to manage all integer, slice, SteppedSlice | |
## TODO: change the macro so we don't need to export the following symbols? | |
proc `|+`*(s: Slice[int], step: int): SteppedSlice {.noSideEffect, inline.}= | |
return SteppedSlice(a: s.a, b: s.b, step: step) | |
proc `|+`*(b, step: int): Step {.noSideEffect, inline.}= | |
return Step(b: b, step: step) | |
proc `|+`*(ss: SteppedSlice, step: int): SteppedSlice {.noSideEffect, inline.}= | |
result = ss | |
result.step = step | |
proc `|`*(s: Slice[int], step: int): SteppedSlice {.noSideEffect, inline.}= | |
return SteppedSlice(a: s.a, b: s.b, step: step) | |
proc `|`*(b, step: int): Step {.noSideEffect, inline.}= | |
return Step(b: b, step: step) | |
proc `|`*(ss: SteppedSlice, step: int): SteppedSlice {.noSideEffect, inline.}= | |
result = ss | |
result.step = step | |
proc `|-`*(s: Slice[int], step: int): SteppedSlice {.noSideEffect, inline.}= | |
return SteppedSlice(a: s.a, b: s.b, step: -step) | |
proc `|-`*(b, step: int): Step {.noSideEffect, inline.}= | |
return Step(b: b, step: -step) | |
proc `|-`*(ss: SteppedSlice, step: int): SteppedSlice {.noSideEffect, inline.}= | |
result = ss | |
result.step = -step | |
proc `..`*(a: int, s: Step): SteppedSlice {.noSideEffect, inline.} = | |
return SteppedSlice(a: a, b: s.b, step: s.step) | |
proc `..|`*(s: int): SteppedSlice {.noSideEffect, inline.} = | |
return SteppedSlice(a: 0, b: 1, step: s, b_from_end: true) | |
proc `..|`*(a,s: int): SteppedSlice {.noSideEffect, inline.} = | |
return SteppedSlice(a: a, b: 1, step: s, b_from_end: true) | |
proc `..<`*(a: int, s: Step): SteppedSlice {.noSideEffect, inline.} = | |
return SteppedSlice(a: a, b: <s.b, step: s.step) | |
proc `..^`*(a: int, s: Step): SteppedSlice {.noSideEffect, inline.} = | |
return SteppedSlice(a: a, b: s.b, step: s.step, b_from_end: true) | |
proc `^`*(s: SteppedSlice): SteppedSlice {.noSideEffect, inline.} = | |
## Note: This does not automatically inverse stepping, what if we want ^5..^1 | |
result = s | |
result.a_from_end = not result.a_from_end | |
proc `^`*(s: Slice): SteppedSlice {.noSideEffect, inline.} = | |
## Note: This does not automatically inverse stepping, what if we want ^5..^1 | |
return SteppedSlice(a: s.a, b: s.b, step: 1, a_from_end: true) | |
## span is equivalent to `:` in Python. It returns the whole axis range. | |
## Tensor[_, 3] will be replaced by Tensor[span, 3] | |
const span* = SteppedSlice(b: 1, step: 1, b_from_end: true) | |
proc slicer*[B, T](t: Tensor[B, T], slices: varargs[SteppedSlice]): Tensor[B, T] {.noSideEffect.}= | |
## Take a Tensor and SteppedSlices | |
## Returns: | |
## A view of the original Tensor | |
## Data is not changed, only offset and strides are changed to achieve the desired effect. | |
## TODO: Currently, BLAS needs C-contiguous data and does not work with "Universal" strides. | |
## Provide a way to convert from Universal to C-contiguous. (Strided iterator should make that easy) | |
result = t # For t.data, seq semantics should copy only on write. TODO: Test | |
for i, slice in slices: | |
# Check if we start from the end | |
let a = if slice.a_from_end: result.shape[i] - slice.a | |
else: slice.a | |
let b = if slice.b_from_end: result.shape[i] - slice.b | |
else: slice.b | |
# Bounds checking | |
when compileOption("boundChecks"): check_steps(a,b, slice.step) | |
# Compute offset: | |
result.offset += a * result.strides[i] | |
# Now change shape and strides | |
result.strides[i] *= slice.step | |
result.shape[i] = abs((b-a) div slice.step) + 1 | |
macro desugar(args: untyped): typed = | |
## Transform all syntactic sugar in arguments to integer or SteppedSlices | |
## It will then be dispatch to "atIndex" (if specific integers) | |
## or slicer if there are SteppedSlices | |
# echo "\n------------------\nOriginal tree" | |
# echo args.treerepr | |
var r = newNimNode(nnkArglist) | |
for nnk in children(args): | |
###### Traverse tree and one-hot-encode the different conditions | |
let nnk_joker = nnk == ident("_") | |
# Node is of the form "* .. *" | |
let nnk0_inf_dotdot = ( | |
nnk.kind == nnkInfix and | |
nnk[0] == ident("..") | |
) | |
# Node is of the form "* ..< *" or "* ..^ *" | |
let nnk0_inf_dotdot_alt = ( | |
nnk.kind == nnkInfix and ( | |
nnk[0] == ident("..<") or | |
nnk[0] == ident("..^") | |
) | |
) | |
# Node is of the form "* .. *", "* ..< *" or "* ..^ *" | |
let nnk0_inf_dotdot_all = ( | |
nnk0_inf_dotdot or | |
nnk0_inf_dotdot_alt | |
) | |
# Node is of the form "* | *", "* |+ *", "* |- *" | |
let nnk0_inf_bar_all = ( | |
nnk.kind == nnkInfix and ( | |
nnk[0] == ident("|") or | |
nnk[0] == ident("|+") or | |
nnk[0] == ident("|-") | |
) | |
) | |
let nnk0_pre_dotdot_all = ( | |
innk.kind == nnkInfix and ( | |
nnk[0] == ident("..") or | |
nnk[0] == ident("..<") or | |
nnk[0] == ident("..^") | |
) | |
) | |
let nnk1_joker = ( | |
nnk.kind == nnkInfix and | |
nnk[1] == ident("_") | |
) | |
let nnk10_hat = ( | |
nnk.kind == nnkInfix and | |
nnk[1].kind == nnkPrefix | |
and nnk[1][0] == ident("^") | |
) | |
let nnk10_dotdot_pre_alt = ( | |
nnk.kind == nnkInfix and | |
nnk[1].kind == nnkPrefix and ( | |
nnk[1][0] == ident("..^") or | |
nnk[1][0] == ident("..<") | |
) | |
) | |
let nnk2_joker = ( | |
nnk.kind == nnkInfix and | |
nnk[2] == ident("_") | |
) | |
let nnk20_bar_pos = ( | |
nnk.kind == nnkInfix and | |
nnk[2].kind == nnkInfix and ( | |
nnk[2][0] == ident("|") or | |
nnk[2][0] == ident("|+") | |
) | |
) | |
let nnk20_bar_min = ( | |
nnk.kind == nnkInfix and | |
nnk[2].kind == nnkInfix and | |
nnk[2][0] == ident("|-") | |
) | |
let nnk20_bar_all = nnk20_bar_pos or nnk20_bar_min | |
let nnk21_joker = ( | |
nnk.kind == nnkInfix and | |
nnk[2].kind == nnkInfix and | |
nnk[2][1] == ident("_") | |
) | |
###### Core logic | |
if nnk_joker: | |
## [_, 3] into [span, 3] | |
r.add(ident("span")) | |
elif nnk0_inf_dotdot and nnk1_joker and nnk2_joker: | |
## [_.._, 3] into [span, 3] | |
r.add(ident("span")) | |
elif nnk0_inf_dotdot and nnk1_joker and nnk20_bar_pos and nnk21_joker: | |
## [_.._|2, 3] into [..|2, 3] | |
## [_.._|+2, 3] into [..|2, 3] | |
r.add(prefix(nnk[2][2], "..|")) | |
elif nnk0_inf_dotdot and nnk1_joker and nnk20_bar_min and nnk21_joker: | |
## [_.._|-2, 3] into [..|-2, 3] | |
r.add(prefix(nnk[2][2], "..|-")) | |
elif nnk0_inf_dotdot_all and nnk1_joker and nnk20_bar_all: | |
## [_..10|1, 3] into [0..10|1, 3] | |
## [_..^10|1, 3] into [0..^10|1, 3] # ..^ directly creating SteppedSlices may introduce issues in seq[0..^10] | |
# Furthermore ..^10|1, would have ..^ with precedence over | | |
## [_..<10|1, 3] into [0..<10|1, 3] | |
r.add(infix(newIntLitNode(0), $nnk[0], infix(nnk[2][1], $nnk[2][0], nnk[2][2]))) | |
elif nnk0_inf_dotdot_all and nnk1_joker: | |
## [_..10, 3] into [0..10|1, 3] | |
## [_..^10, 3] into [0..^10|1, 3] # ..^ directly creating SteppedSlices from int in 0..^10 may introduce issues in seq[0..^10] | |
## [_..<10, 3] into [0..<10|1, 3] | |
r.add(infix(newIntLitNode(0), $nnk[0], infix(nnk[2], "|", newIntLitNode(1)))) | |
elif nnk0_inf_dotdot and nnk2_joker: | |
## [1.._, 3] into [1..|1, 3] | |
r.add(infix(nnk[1], "..|", newIntLitNode(1))) | |
elif nnk0_inf_dotdot and nnk20_bar_pos and nnk21_joker: | |
## [1.._|1, 3] into [1..|1, 3] | |
## [1.._|+1, 3] into [1..|1, 3] | |
r.add(infix(nnk[1], "..|", nnk[2][2])) | |
elif nnk0_inf_dotdot and nnk20_bar_min and nnk21_joker: | |
## Raise error on [5.._|-1, 3] | |
raise newException(IndexError, "Please use explicit end of range " & | |
"instead of `_` " & | |
"when the steps are negative") | |
elif nnk0_inf_dotdot_all and nnk10_hat and nnk20_bar_all: | |
# We can skip the parenthesis in the AST | |
## [^1..2|-1, 3] into [^(1..2|-1), 3] | |
r.add(prefix(infix(nnk[1][1], $nnk[0], nnk[2]), "^")) | |
elif nnk0_inf_dotdot_all and nnk10_hat: | |
# We can skip the parenthesis in the AST | |
## [^1..2*3, 3] into [^(1..2*3|1), 3] | |
## [^1..0, 3] into [^(1..0|1), 3] | |
## [^1..<10, 3] into [^(1..<10|1), 3] | |
## [^10..^1, 3] into [^(10..^1|1), 3] | |
r.add(prefix(infix(nnk[1][1], $nnk[0], infix(nnk[2],"|",newIntLitNode(1))), "^")) | |
elif nnk0_inf_dotdot_all and nnk20_bar_all: | |
## [1..10|1] as is | |
## [1..^10|1] as is | |
r.add(nnk) | |
elif nnk0_inf_dotdot_all: | |
## [1..10, 3] to [1..10|1, 3] | |
## [1..^10, 3] to [1..^10|1, 3] | |
## [1..<10, 3] to [1..<10|1, 3] | |
r.add(infix(nnk[1], $nnk[0], infix(nnk[2], "|", newIntLitNode(1)))) | |
elif nnk0_inf_bar_all and nnk10_dotdot_pre_alt: | |
## [..^10|2, 3] into [0..^10|2, 3] | |
# Not needed for `..` it already creates a slice without precedence/Step issues | |
r.add(infix(newIntLitNode(0), $nnk[1][0], infix(nnk[1][1], $nnk[0], nnk[2]))) | |
elif nnk0_pre_dotdot_all: | |
## [..10, 3] to [0..10|1, 3] | |
r.add(infix(newIntLitNode(0), $nnk[0], infix(nnk[1], "|", newIntLitNode(1)))) | |
else: | |
r.add(nnk) | |
# echo "\nAfter modif" | |
# echo r.treerepr | |
return r | |
proc hasType(x: NimNode, t: static[string]): bool {. compileTime .} = | |
## Compile-time type checking | |
sameType(x, bindSym(t)) | |
proc isInt(x: NimNode): bool {. compileTime .} = | |
## Compile-time type checking | |
hasType(x, "int") | |
proc isAllInt(slice_args: NimNode): bool {. compileTime .} = | |
## Compile-time type checking | |
result = true | |
for child in slice_args: | |
result = result and isInt(child) | |
macro inner_typed_dispatch(t: typed, args: varargs[typed]): untyped = | |
## Typed macro so that isAllInt has typed context and we can dispatch. | |
## If args are all int, we dispatch to atIndex and return T | |
## Else, all ints are converted to SteppedSlices and we return a Tensor. | |
## Note, normal slices and `_` were already converted in the `[]` macro | |
## TODO in total we do 3 passes over the list of arguments :/. It is done only at compile time though | |
if isAllInt(args): | |
result = newCall("atIndex", t) | |
for slice in args: | |
result.add(slice) | |
else: | |
result = newCall("slicer", t) | |
for slice in args: | |
if isInt(slice): | |
## Convert [10, 1..10|1] to [10..10|1, 1..10|1] | |
result.add(infix(slice, "..", infix(slice, "|", newIntLitNode(1)))) | |
else: | |
result.add(slice) | |
macro `[]`*[B, T](t: Tensor[B,T], args: varargs[untyped]): untyped = | |
let new_args = getAST(desugar(args)) | |
result = quote do: | |
inner_typed_dispatch(`t`, `new_args`) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment