Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save Varriount/352df3399e13f889d989f620d16fa762 to your computer and use it in GitHub Desktop.
Save Varriount/352df3399e13f889d989f620d16fa762 to your computer and use it in GitHub Desktop.
# Copyright 2017 Mamy André-Ratsimbazafy
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
## This file deals with slicing.
## Foo being:
# Tensor of shape 5x5 of type "int" on backend "Cpu"
# |1 1 1 1 1|
# |2 4 8 16 32|
# |3 9 27 81 243|
# |4 16 64 256 1024|
# |5 25 125 625 3125|
## Target supported syntax is:
#
## Full Integer
# foo[2, 3]
# foo[1+1, 2*4*1]
#
## Slice
# echo foo[1+1..4,3]
# echo foo[1..2, 3]
#
## Span
# echo foo[_, 3]
# echo foo[_.._|2, 3]
# echo foo[1.._, 3]
# echo foo[1.._|1, 3]
# echo foo[1..|1, 3]
# echo foo[_..3, 3]
# echo foo[_..3|1, 3]
# echo foo[..3|1, 3]
# echo foo[_..^2, 3]
# echo foo[_..<4|2, 3]
#
## Combinations
# echo foo[1..2+1|2,3]
#
## End step from last
# echo foo[0..^4|1, 3]
# echo foo[1..^1|2, 3]
#
## Prefix
# echo foo[..^2|2, 3]
# echo foo[..4, 3]
# echo foo[..<4, 3]
# echo foo[..^10|2, 3]
#
## Start from last
# echo foo[^1..0|-1, 3]
# echo foo[^4..2*2, 3]
# echo foo[^1..2|-1, 3]
# echo foo[^1..2*2|-1, 3]
#
## NOT working yet
# echo foo[^3..^2, 3]
# echo foo[^(1..^2*3), 3]
######################################
type SteppedSlice* = object
## A slice with step information and start is from beginning or end of range
a, b: int
step: int
a_from_end: bool
b_from_end: bool
## Necessary to avoid parenthesis due to operator precedence of | over ..
## [0..10|1] is intepreted as [0..(10|1)]
type Step* = object
## Holds the end of rang eand step
b: int
step: int
template check_steps(a,b, step:int) =
## Though it might be convenient to automatically step in the correct direction like in Python
## I choose not to do it as this might introduce the typical silent bugs typechecking/Nim is helping avoid.
if ((b-a) * step < 0):
raise newException(IndexError, "Your slice start: " &
$(a) &
", and stop: " &
$(b) &
", or your step: " &
$(step) &
""", are not correct. If your step is positive
start must be inferior to stop and inversely if your step is negative
start must be superior to stop.""")
## Procs to manage all integer, slice, SteppedSlice
## TODO: change the macro so we don't need to export the following symbols?
proc `|+`*(s: Slice[int], step: int): SteppedSlice {.noSideEffect, inline.}=
return SteppedSlice(a: s.a, b: s.b, step: step)
proc `|+`*(b, step: int): Step {.noSideEffect, inline.}=
return Step(b: b, step: step)
proc `|+`*(ss: SteppedSlice, step: int): SteppedSlice {.noSideEffect, inline.}=
result = ss
result.step = step
proc `|`*(s: Slice[int], step: int): SteppedSlice {.noSideEffect, inline.}=
return SteppedSlice(a: s.a, b: s.b, step: step)
proc `|`*(b, step: int): Step {.noSideEffect, inline.}=
return Step(b: b, step: step)
proc `|`*(ss: SteppedSlice, step: int): SteppedSlice {.noSideEffect, inline.}=
result = ss
result.step = step
proc `|-`*(s: Slice[int], step: int): SteppedSlice {.noSideEffect, inline.}=
return SteppedSlice(a: s.a, b: s.b, step: -step)
proc `|-`*(b, step: int): Step {.noSideEffect, inline.}=
return Step(b: b, step: -step)
proc `|-`*(ss: SteppedSlice, step: int): SteppedSlice {.noSideEffect, inline.}=
result = ss
result.step = -step
proc `..`*(a: int, s: Step): SteppedSlice {.noSideEffect, inline.} =
return SteppedSlice(a: a, b: s.b, step: s.step)
proc `..|`*(s: int): SteppedSlice {.noSideEffect, inline.} =
return SteppedSlice(a: 0, b: 1, step: s, b_from_end: true)
proc `..|`*(a,s: int): SteppedSlice {.noSideEffect, inline.} =
return SteppedSlice(a: a, b: 1, step: s, b_from_end: true)
proc `..<`*(a: int, s: Step): SteppedSlice {.noSideEffect, inline.} =
return SteppedSlice(a: a, b: <s.b, step: s.step)
proc `..^`*(a: int, s: Step): SteppedSlice {.noSideEffect, inline.} =
return SteppedSlice(a: a, b: s.b, step: s.step, b_from_end: true)
proc `^`*(s: SteppedSlice): SteppedSlice {.noSideEffect, inline.} =
## Note: This does not automatically inverse stepping, what if we want ^5..^1
result = s
result.a_from_end = not result.a_from_end
proc `^`*(s: Slice): SteppedSlice {.noSideEffect, inline.} =
## Note: This does not automatically inverse stepping, what if we want ^5..^1
return SteppedSlice(a: s.a, b: s.b, step: 1, a_from_end: true)
## span is equivalent to `:` in Python. It returns the whole axis range.
## Tensor[_, 3] will be replaced by Tensor[span, 3]
const span* = SteppedSlice(b: 1, step: 1, b_from_end: true)
proc slicer*[B, T](t: Tensor[B, T], slices: varargs[SteppedSlice]): Tensor[B, T] {.noSideEffect.}=
## Take a Tensor and SteppedSlices
## Returns:
## A view of the original Tensor
## Data is not changed, only offset and strides are changed to achieve the desired effect.
## TODO: Currently, BLAS needs C-contiguous data and does not work with "Universal" strides.
## Provide a way to convert from Universal to C-contiguous. (Strided iterator should make that easy)
result = t # For t.data, seq semantics should copy only on write. TODO: Test
for i, slice in slices:
# Check if we start from the end
let a = if slice.a_from_end: result.shape[i] - slice.a
else: slice.a
let b = if slice.b_from_end: result.shape[i] - slice.b
else: slice.b
# Bounds checking
when compileOption("boundChecks"): check_steps(a,b, slice.step)
# Compute offset:
result.offset += a * result.strides[i]
# Now change shape and strides
result.strides[i] *= slice.step
result.shape[i] = abs((b-a) div slice.step) + 1
macro desugar(args: untyped): typed =
## Transform all syntactic sugar in arguments to integer or SteppedSlices
## It will then be dispatch to "atIndex" (if specific integers)
## or slicer if there are SteppedSlices
# echo "\n------------------\nOriginal tree"
# echo args.treerepr
var r = newNimNode(nnkArglist)
for nnk in children(args):
###### Traverse tree and one-hot-encode the different conditions
let nnk_joker = nnk == ident("_")
# Node is of the form "* .. *"
let nnk0_inf_dotdot = (
nnk.kind == nnkInfix and
nnk[0] == ident("..")
)
# Node is of the form "* ..< *" or "* ..^ *"
let nnk0_inf_dotdot_alt = (
nnk.kind == nnkInfix and (
nnk[0] == ident("..<") or
nnk[0] == ident("..^")
)
)
# Node is of the form "* .. *", "* ..< *" or "* ..^ *"
let nnk0_inf_dotdot_all = (
nnk0_inf_dotdot or
nnk0_inf_dotdot_alt
)
# Node is of the form "* | *", "* |+ *", "* |- *"
let nnk0_inf_bar_all = (
nnk.kind == nnkInfix and (
nnk[0] == ident("|") or
nnk[0] == ident("|+") or
nnk[0] == ident("|-")
)
)
let nnk0_pre_dotdot_all = (
innk.kind == nnkInfix and (
nnk[0] == ident("..") or
nnk[0] == ident("..<") or
nnk[0] == ident("..^")
)
)
let nnk1_joker = (
nnk.kind == nnkInfix and
nnk[1] == ident("_")
)
let nnk10_hat = (
nnk.kind == nnkInfix and
nnk[1].kind == nnkPrefix
and nnk[1][0] == ident("^")
)
let nnk10_dotdot_pre_alt = (
nnk.kind == nnkInfix and
nnk[1].kind == nnkPrefix and (
nnk[1][0] == ident("..^") or
nnk[1][0] == ident("..<")
)
)
let nnk2_joker = (
nnk.kind == nnkInfix and
nnk[2] == ident("_")
)
let nnk20_bar_pos = (
nnk.kind == nnkInfix and
nnk[2].kind == nnkInfix and (
nnk[2][0] == ident("|") or
nnk[2][0] == ident("|+")
)
)
let nnk20_bar_min = (
nnk.kind == nnkInfix and
nnk[2].kind == nnkInfix and
nnk[2][0] == ident("|-")
)
let nnk20_bar_all = nnk20_bar_pos or nnk20_bar_min
let nnk21_joker = (
nnk.kind == nnkInfix and
nnk[2].kind == nnkInfix and
nnk[2][1] == ident("_")
)
###### Core logic
if nnk_joker:
## [_, 3] into [span, 3]
r.add(ident("span"))
elif nnk0_inf_dotdot and nnk1_joker and nnk2_joker:
## [_.._, 3] into [span, 3]
r.add(ident("span"))
elif nnk0_inf_dotdot and nnk1_joker and nnk20_bar_pos and nnk21_joker:
## [_.._|2, 3] into [..|2, 3]
## [_.._|+2, 3] into [..|2, 3]
r.add(prefix(nnk[2][2], "..|"))
elif nnk0_inf_dotdot and nnk1_joker and nnk20_bar_min and nnk21_joker:
## [_.._|-2, 3] into [..|-2, 3]
r.add(prefix(nnk[2][2], "..|-"))
elif nnk0_inf_dotdot_all and nnk1_joker and nnk20_bar_all:
## [_..10|1, 3] into [0..10|1, 3]
## [_..^10|1, 3] into [0..^10|1, 3] # ..^ directly creating SteppedSlices may introduce issues in seq[0..^10]
# Furthermore ..^10|1, would have ..^ with precedence over |
## [_..<10|1, 3] into [0..<10|1, 3]
r.add(infix(newIntLitNode(0), $nnk[0], infix(nnk[2][1], $nnk[2][0], nnk[2][2])))
elif nnk0_inf_dotdot_all and nnk1_joker:
## [_..10, 3] into [0..10|1, 3]
## [_..^10, 3] into [0..^10|1, 3] # ..^ directly creating SteppedSlices from int in 0..^10 may introduce issues in seq[0..^10]
## [_..<10, 3] into [0..<10|1, 3]
r.add(infix(newIntLitNode(0), $nnk[0], infix(nnk[2], "|", newIntLitNode(1))))
elif nnk0_inf_dotdot and nnk2_joker:
## [1.._, 3] into [1..|1, 3]
r.add(infix(nnk[1], "..|", newIntLitNode(1)))
elif nnk0_inf_dotdot and nnk20_bar_pos and nnk21_joker:
## [1.._|1, 3] into [1..|1, 3]
## [1.._|+1, 3] into [1..|1, 3]
r.add(infix(nnk[1], "..|", nnk[2][2]))
elif nnk0_inf_dotdot and nnk20_bar_min and nnk21_joker:
## Raise error on [5.._|-1, 3]
raise newException(IndexError, "Please use explicit end of range " &
"instead of `_` " &
"when the steps are negative")
elif nnk0_inf_dotdot_all and nnk10_hat and nnk20_bar_all:
# We can skip the parenthesis in the AST
## [^1..2|-1, 3] into [^(1..2|-1), 3]
r.add(prefix(infix(nnk[1][1], $nnk[0], nnk[2]), "^"))
elif nnk0_inf_dotdot_all and nnk10_hat:
# We can skip the parenthesis in the AST
## [^1..2*3, 3] into [^(1..2*3|1), 3]
## [^1..0, 3] into [^(1..0|1), 3]
## [^1..<10, 3] into [^(1..<10|1), 3]
## [^10..^1, 3] into [^(10..^1|1), 3]
r.add(prefix(infix(nnk[1][1], $nnk[0], infix(nnk[2],"|",newIntLitNode(1))), "^"))
elif nnk0_inf_dotdot_all and nnk20_bar_all:
## [1..10|1] as is
## [1..^10|1] as is
r.add(nnk)
elif nnk0_inf_dotdot_all:
## [1..10, 3] to [1..10|1, 3]
## [1..^10, 3] to [1..^10|1, 3]
## [1..<10, 3] to [1..<10|1, 3]
r.add(infix(nnk[1], $nnk[0], infix(nnk[2], "|", newIntLitNode(1))))
elif nnk0_inf_bar_all and nnk10_dotdot_pre_alt:
## [..^10|2, 3] into [0..^10|2, 3]
# Not needed for `..` it already creates a slice without precedence/Step issues
r.add(infix(newIntLitNode(0), $nnk[1][0], infix(nnk[1][1], $nnk[0], nnk[2])))
elif nnk0_pre_dotdot_all:
## [..10, 3] to [0..10|1, 3]
r.add(infix(newIntLitNode(0), $nnk[0], infix(nnk[1], "|", newIntLitNode(1))))
else:
r.add(nnk)
# echo "\nAfter modif"
# echo r.treerepr
return r
proc hasType(x: NimNode, t: static[string]): bool {. compileTime .} =
## Compile-time type checking
sameType(x, bindSym(t))
proc isInt(x: NimNode): bool {. compileTime .} =
## Compile-time type checking
hasType(x, "int")
proc isAllInt(slice_args: NimNode): bool {. compileTime .} =
## Compile-time type checking
result = true
for child in slice_args:
result = result and isInt(child)
macro inner_typed_dispatch(t: typed, args: varargs[typed]): untyped =
## Typed macro so that isAllInt has typed context and we can dispatch.
## If args are all int, we dispatch to atIndex and return T
## Else, all ints are converted to SteppedSlices and we return a Tensor.
## Note, normal slices and `_` were already converted in the `[]` macro
## TODO in total we do 3 passes over the list of arguments :/. It is done only at compile time though
if isAllInt(args):
result = newCall("atIndex", t)
for slice in args:
result.add(slice)
else:
result = newCall("slicer", t)
for slice in args:
if isInt(slice):
## Convert [10, 1..10|1] to [10..10|1, 1..10|1]
result.add(infix(slice, "..", infix(slice, "|", newIntLitNode(1))))
else:
result.add(slice)
macro `[]`*[B, T](t: Tensor[B,T], args: varargs[untyped]): untyped =
let new_args = getAST(desugar(args))
result = quote do:
inner_typed_dispatch(`t`, `new_args`)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment