Skip to content

Instantly share code, notes, and snippets.

@wweic
Last active March 30, 2019 04:44
Show Gist options
  • Save wweic/42f121822466868baf7fbc7eddcd4871 to your computer and use it in GitHub Desktop.
Save wweic/42f121822466868baf7fbc7eddcd4871 to your computer and use it in GitHub Desktop.
TensorArray in Relay

Relay Tensor/TensorArray definition

// Support tensors up to rank 6, i'm using int to represent a dimention just to 
// make the code compile in OCaml while I experiment
type dynamic_tensor =
    Tensor0 of int
  | Tensor1 of int * int
  | Tensor2 of int * int * int
  | Tensor3 of int * int * int * int
  | Tensor4 of int * int * int * int * int
  | Tensor5 of int * int * int * int * int * int
  | Tensor6 of int * int * int * int * int * int * int
;;

// This is real type definition
type dynamic_tensor_real =
    Tensor0 of TensorType(shape=())
  | Tensor1 of TensorType(shape=(Any))
  | Tensor2 of TensorType(shape=(Any, Any))
  | Tensor3 of TensorType(shape=(Any, Any, Any))
  | Tensor4 of TensorType(shape=(Any, Any, Any, Any))
  | Tensor5 of TensorType(shape=(Any, Any, Any, Any, Any))
  | Tensor6 of TensorType(shape=(Any, Any, Any, Any, Any, Any))
;;

type tensor_array_t = dynamic_tensor list
;;

TensorArray

How to create a tensor array:

// For oprator TensorArrayV3
val tensor_array : int -> tensor_array_t
let rec tensor_array size =
    match size with
    0 -> []
    | n -> List.cons (Tensor0 0) (tensor_array (n-1));;
   
// Example
tensor_array 10;;
- : dynamic_tensor list =
[Tensor0 0; Tensor0 0; Tensor0 0; Tensor0 0; Tensor0 0; Tensor0 0; Tensor0 0;
 Tensor0 0; Tensor0 0; Tensor0 0]    

TensorArrayRead

Read from a tensor array.

// For operator TensorArrayRead
let tensor_array_read ta n = List.nth ta n;;

// Example
tensor_array_read (tensor_array 10) 5;;
- : dynamic_tensor = Tensor0 0  

TensorArrayWrite

Write to a tensor array.

// For operator TensorArrayWrite
val tensor_array_write : tensor_array_t -> int -> dynamic_tensor -> tensor_array_t
let rec tensor_array_write ta n v =
    match n with
    | 0 -> List.cons v (List.tl ta)
    | n -> List.cons (List.hd ta) (tensor_array_write (List.tl ta) (n-1) v);;

// Example
tensor_array_write (tensor_array 2) 1 (Tensor1 (0,0));;
- : dynamic_tensor list = [Tensor0 0; Tensor1 (0, 0)]

TensorArrayScatter

// For operator TensorArrayScatter
val tensor_array_scatter : tensor_array_t -> int list -> dynamic_tensor list -> tensor_array_t
let rec tensor_array_scatter ta indices values =
match indices, values with
| [], [] -> ta
| i :: tl, v :: vtl -> tensor_array_scatter (tensor_array_write ta i v) tl vtl
| _ -> ta;;

// Example
tensor_array_scatter (tensor_array 3) [1;2] [Tensor1 (1,1); Tensor1 (2,2)];;
- : dynamic_tensor list = [Tensor0 0; Tensor1 (1, 1); Tensor1 (2, 2)]

TensorArrayGather

// For operator TensorArrayGather
val tensor_array_gather_helper : tensor_array_t -> int list -> tensor_array_t
let rec tensor_array_gather_helper ta indices =
    match indices with
    | [] -> []
    | h :: tl ->
        List.cons (tensor_array_read ta h) (tensor_array_gather_helper ta tl);;

let tensor_array_gather ta indices =
    tensor_array_stack (tensor_array_gather_helper ta indices);;
    
// Example
tensor_array_gather_helper (tensor_array 3) [1];;
- : dynamic_tensor list = [Tensor0 0]

Difficult to implement

These operators are infeasible to implement as we need to convert a tensor array into a tensor. But current tvm operator doesn't recognize relay VMObject.

  • TensorArrayConcat
  • TensorArrayStack
  • TensorArrayUnstack
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment