Last active
May 10, 2018 17:20
-
-
Save mratsim/e8867753ee5bbe0193ea4015cacecd02 to your computer and use it in GitHub Desktop.
Tensor splitting experiment
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ../src/arraymancer | |
import macros | |
# dumpASTGen: | |
# let b = (a[_, 0..<2], a[_, 2..<4], a[_, 4..<6]) | |
macro split[T](x: Tensor[T], n_chunks: static[int], axis = 0): untyped = | |
## Splits a Tensor into n chunks. | |
## For efficiency reason, n is required at compile-time | |
## Split is done without copy, orginal and each chunk share data. | |
## | |
## In case a tensor cannot be split evenly, | |
## with la == length_axis, n = n_chunks | |
## it returns la mod n subtensors of size `(la div n) + 1` | |
## the rest of size `la div n`, consistent with numpy array_split | |
var splitted = nnkPar.newTree | |
let chunk_size = genSym(nskLet, "split_chunk_size_") | |
let reminder = genSym(nskLet, "split_reminder_") | |
result = newStmtList() | |
result.add quote do: | |
let `chunk_size` = `x`.shape[`axis`] div `n_chunks` | |
let `reminder` = `x`.shape[`axis`] mod `n_chunks` | |
for i in 0 ..< n_chunks: | |
splitted.add quote do: | |
if `i` < `reminder`: | |
`x`.atAxisIndex( | |
`axis`, | |
`i`*`chunk_size` + `i`, | |
`chunk_size` + 1 | |
) | |
else: | |
`x`.atAxisIndex( | |
`axis`, | |
`i`*`chunk_size` + `reminder`, # Simplification of reminder*(chunk_size+1) + (i-reminder)*chunk_size | |
`chunk_size`) | |
result.add splitted | |
let a = [[ 1, 2, 3, 4, 5, 6], | |
[ 7, 8, 9, 10, 11, 12], | |
[ 13, 14, 15, 16, 17, 18]].toTensor | |
echo a.split(4, 1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ../src/arraymancer | |
import macros | |
# dumpASTGen: | |
# let b = (a[_, 0..<2], a[_, 2..<4], a[_, 4..<6]) | |
macro split[T](x: Tensor[T], n_chunks: static[int], axis = 0): untyped = | |
## Splits a Tensor into n chunks. | |
## For efficiency reason, n is required at compile-time | |
## Split is done without copy, orginal and each chunk share data. | |
## | |
## In case a tensor cannot be split evenly, | |
## with la == length_axis, n = n_chunks | |
## it returns la mod n subtensors of size `(la div n) + 1` | |
## the rest of size `la div n`, consistent with numpy array_split | |
var splitted = nnkPar.newTree | |
let chunk_size = genSym(nskLet, "split_chunk_size_") | |
let reminder = genSym(nskLet, "split_reminder_") | |
let pos = genSym(nskVar, "split_position_") | |
let length = genSym(nskVar, "split_length_") | |
result = newStmtList() | |
result.add quote do: | |
let `chunk_size` = `x`.shape[`axis`] div `n_chunks` | |
let `reminder` = `x`.shape[`axis`] mod `n_chunks` | |
var `pos`, `length`: Natural | |
for i in 0 ..< n_chunks: | |
splitted.add quote do: | |
`length` = `chunk_size` + int(`i` < `reminder`) | |
`x`.atAxisIndex(`axis`, `i`*`length` + `pos`, `length`) | |
result.add splitted | |
let a = [[ 1, 2, 3, 4, 5, 6], | |
[ 7, 8, 9, 10, 11, 12], | |
[ 13, 14, 15, 16, 17, 18]].toTensor | |
echo a.split(4, 1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ../src/arraymancer | |
import macros | |
# dumpASTGen: | |
# let b = (a[_, 0..<2], a[_, 2..<4], a[_, 4..<6]) | |
macro split[T](x: Tensor[T], n_chunks: static[int], axis = 0): untyped = | |
## Splits a Tensor into n chunks. | |
## For efficiency reason, n is required at compile-time | |
## Split is done without copy, orginal and each chunk share data. | |
## | |
## In case a tensor cannot be split evenly, | |
## with la == length_axis, n = n_chunks | |
## it returns la mod n subtensors of size `(la div n) + 1` | |
## the rest of size `la div n`, consistent with numpy array_split | |
var splitted = nnkPar.newTree | |
let chunk_size = genSym(nskLet, "split_chunk_size_") | |
let reminder = genSym(nskLet, "split_reminder_") | |
result = newStmtList() | |
result.add quote do: | |
let `chunk_size` = `x`.shape[`axis`] div `n_chunks` | |
let `reminder` = `x`.shape[`axis`] mod `n_chunks` | |
for i in 0 ..< n_chunks: | |
splitted.add quote do: | |
if `i` < `reminder`: | |
`x`.atAxisIndex(`axis`, `i`, `chunk_size` + 1) | |
else: | |
`x`.atAxisIndex(`axis`, `i`, `chunk_size`) | |
result.add splitted | |
let a = [[ 1, 2, 3, 4, 5, 6], | |
[ 7, 8, 9, 10, 11, 12], | |
[ 13, 14, 15, 16, 17, 18]].toTensor | |
echo a.split(3, 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment