Last active
December 29, 2020 09:58
-
-
Save Vindaar/ac49c9caf063c04caaf74dcae08eca79 to your computer and use it in GitHub Desktop.
Pretty printing of ND tensors supporting N > 3
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sequtils | |
type | |
Tensor[T] = object | |
size: int | |
shape: seq[int] | |
strides: seq[int] | |
data: seq[T] | |
proc parseStrides(shape: seq[int]): seq[int] = | |
var stride = 1 | |
var res: int | |
for axis, idx in shape: | |
if axis == 0: | |
result.add stride | |
stride = shape[axis] | |
else: | |
result.add stride | |
stride *= shape[axis] | |
proc initTensor[T](shape: varargs[int]): Tensor[T] = | |
result.shape = @shape | |
result.size = result.shape.foldl(a * b, 1) | |
result.data = newSeq[T](result.size) | |
result.strides = parseStrides(result.shape) | |
echo result | |
proc toTensor[T](x: openArray[T], shape: varargs[int]): Tensor[T] = | |
## x has to be 1D | |
result = initTensor[T](shape) | |
for i, el in x: | |
result.data[i] = el | |
proc computeIdx[T](t: Tensor[T], axis: int, i, j: int): int = | |
if axis == 0: | |
result = i * t.shape[axis] + j | |
else: | |
result = i + t.strides[axis] * j | |
proc stridedPrinting[T](t: Tensor[T], axis: int) = | |
## iterate each element along `axis` and print | |
let axLen = t.shape[axis] | |
let numToStride = t.size div axLen | |
var idx = 0 | |
var res = "" | |
for i in 0 ..< numToStride: | |
for j in 0 ..< axLen: | |
let idx = computeIdx(t, axis, i, j) | |
res.add($t.data[idx] & ", ") | |
res.add "\n" | |
echo res | |
proc pretty[T](t: Tensor[T]): string = | |
let axLen = t.shape[0] | |
let numToStride = t.size div axLen | |
var idx = 0 | |
var res = "" | |
for i in 0 ..< numToStride: | |
for j in 0 ..< axLen: | |
let idx = computeIdx(t, 0, i, j) | |
var toModBy = "" | |
var anyDim = false | |
for l in 2 ..< t.shape.len: | |
let toMod = t.shape[0 ..< l].foldl(a * b, 1) | |
if idx mod toMod == 0: | |
anyDim = true | |
if toModBy.len == 0: | |
toModBy = "\n" | |
toModBy.add("(dim: " & $l & ", idx: " & $((idx div toMod) mod t.shape[l]) & ")") | |
if anyDim and idx mod toMod != 0: | |
toModBy.add("(dim: " & $l & ", idx: " & $((idx div toMod) mod t.shape[l]) & ")") | |
if toModBy.len > 0: | |
toModBy.add "\n" | |
res.add toModBy | |
res.add($t.data[idx] & ", ") | |
res.add "\n" | |
result = res | |
let t = toSeq(1 .. 9).toTensor(3, 3) | |
stridedPrinting(t, axis = 0) | |
stridedPrinting(t, axis = 1) | |
let t2 = toSeq(1 .. 18).toTensor(3, 3, 2) | |
stridedPrinting(t2, axis = 0) | |
stridedPrinting(t2, axis = 1) | |
stridedPrinting(t2, axis = 2) | |
let t3 = toSeq(1 .. 36).toTensor(3, 3, 2, 2) | |
stridedPrinting(t3, axis = 0) | |
stridedPrinting(t3, axis = 1) | |
stridedPrinting(t3, axis = 2) | |
stridedPrinting(t3, axis = 3) | |
let t4 = toSeq(1 .. 144).toTensor(3, 3, 2, 2, 2, 2) | |
echo t4.pretty() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Output of
stridedPrinting
fort
:Output for strided printing for
t2
:Output for strided printing for t3:
"Pretty" printing of
t4
: