Skip to content

Instantly share code, notes, and snippets.

@mratsim
Last active May 10, 2018 17:20
Show Gist options
  • Save mratsim/e8867753ee5bbe0193ea4015cacecd02 to your computer and use it in GitHub Desktop.
Save mratsim/e8867753ee5bbe0193ea4015cacecd02 to your computer and use it in GitHub Desktop.
Tensor splitting experiment
import ../src/arraymancer
import macros
# dumpASTGen:
# let b = (a[_, 0..<2], a[_, 2..<4], a[_, 4..<6])
macro split[T](x: Tensor[T], n_chunks: static[int], axis = 0): untyped =
## Splits a Tensor into n chunks.
## For efficiency reason, n is required at compile-time
## Split is done without copy, orginal and each chunk share data.
##
## In case a tensor cannot be split evenly,
## with la == length_axis, n = n_chunks
## it returns la mod n subtensors of size `(la div n) + 1`
## the rest of size `la div n`, consistent with numpy array_split
var splitted = nnkPar.newTree
let chunk_size = genSym(nskLet, "split_chunk_size_")
let reminder = genSym(nskLet, "split_reminder_")
result = newStmtList()
result.add quote do:
let `chunk_size` = `x`.shape[`axis`] div `n_chunks`
let `reminder` = `x`.shape[`axis`] mod `n_chunks`
for i in 0 ..< n_chunks:
splitted.add quote do:
if `i` < `reminder`:
`x`.atAxisIndex(
`axis`,
`i`*`chunk_size` + `i`,
`chunk_size` + 1
)
else:
`x`.atAxisIndex(
`axis`,
`i`*`chunk_size` + `reminder`, # Simplification of reminder*(chunk_size+1) + (i-reminder)*chunk_size
`chunk_size`)
result.add splitted
let a = [[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12],
[ 13, 14, 15, 16, 17, 18]].toTensor
echo a.split(4, 1)
import ../src/arraymancer
import macros
# dumpASTGen:
# let b = (a[_, 0..<2], a[_, 2..<4], a[_, 4..<6])
macro split[T](x: Tensor[T], n_chunks: static[int], axis = 0): untyped =
## Splits a Tensor into n chunks.
## For efficiency reason, n is required at compile-time
## Split is done without copy, orginal and each chunk share data.
##
## In case a tensor cannot be split evenly,
## with la == length_axis, n = n_chunks
## it returns la mod n subtensors of size `(la div n) + 1`
## the rest of size `la div n`, consistent with numpy array_split
var splitted = nnkPar.newTree
let chunk_size = genSym(nskLet, "split_chunk_size_")
let reminder = genSym(nskLet, "split_reminder_")
let pos = genSym(nskVar, "split_position_")
let length = genSym(nskVar, "split_length_")
result = newStmtList()
result.add quote do:
let `chunk_size` = `x`.shape[`axis`] div `n_chunks`
let `reminder` = `x`.shape[`axis`] mod `n_chunks`
var `pos`, `length`: Natural
for i in 0 ..< n_chunks:
splitted.add quote do:
`length` = `chunk_size` + int(`i` < `reminder`)
`x`.atAxisIndex(`axis`, `i`*`length` + `pos`, `length`)
result.add splitted
let a = [[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12],
[ 13, 14, 15, 16, 17, 18]].toTensor
echo a.split(4, 1)
import ../src/arraymancer
import macros
# dumpASTGen:
# let b = (a[_, 0..<2], a[_, 2..<4], a[_, 4..<6])
macro split[T](x: Tensor[T], n_chunks: static[int], axis = 0): untyped =
## Splits a Tensor into n chunks.
## For efficiency reason, n is required at compile-time
## Split is done without copy, orginal and each chunk share data.
##
## In case a tensor cannot be split evenly,
## with la == length_axis, n = n_chunks
## it returns la mod n subtensors of size `(la div n) + 1`
## the rest of size `la div n`, consistent with numpy array_split
var splitted = nnkPar.newTree
let chunk_size = genSym(nskLet, "split_chunk_size_")
let reminder = genSym(nskLet, "split_reminder_")
result = newStmtList()
result.add quote do:
let `chunk_size` = `x`.shape[`axis`] div `n_chunks`
let `reminder` = `x`.shape[`axis`] mod `n_chunks`
for i in 0 ..< n_chunks:
splitted.add quote do:
if `i` < `reminder`:
`x`.atAxisIndex(`axis`, `i`, `chunk_size` + 1)
else:
`x`.atAxisIndex(`axis`, `i`, `chunk_size`)
result.add splitted
let a = [[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12],
[ 13, 14, 15, 16, 17, 18]].toTensor
echo a.split(3, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment