Skip to content

Instantly share code, notes, and snippets.

@sebinsua
Last active July 21, 2023 22:47
Show Gist options
  • Save sebinsua/e756554faaa2d346aff0341374ee792e to your computer and use it in GitHub Desktop.
Save sebinsua/e756554faaa2d346aff0341374ee792e to your computer and use it in GitHub Desktop.
type _Tuple<
T,
N extends number,
R extends readonly T[] = []
> = R["length"] extends N ? R : _Tuple<T, N, readonly [T, ...R]>;
type Tuple<T, N extends number> = _Tuple<T, N> & {
readonly length: N;
[I: number]: T;
[Symbol.iterator]: () => IterableIterator<T>;
};
function zip<
A extends readonly any[],
Length extends A["length"],
B extends Tuple<any, Length>
>(a: A, b: B): Tuple<readonly [A[number], B[number]], Length> {
if (a.length !== b.length) {
throw new Error(`zip cannot operate on different length arrays; ${a.length} !== ${b.length}`);
}
return a.map((v, index) => [v, b[index]]) as Tuple<
readonly [A[number], B[number]],
Length
>;
}
const a = [1, 2, 3] as const;
const b1 = [1, 2, 6, 2, 4] as const;
const b2 = [1, 2, 6] as const;
// @ts-expect-error Source has 5 element(s) but target allows only 3.
const c1 = zip(a, b1);
const c2 = zip(a, b2);
console.log(c2);
// ^?
@sebinsua
Copy link
Author

sebinsua commented May 2, 2023

// Utilities
type UnionToIntersection<U> = (U extends any ? (k: U) => void : never) extends (
  k: infer I
) => void
  ? I
  : never;
type IsUnion<T> = [T] extends [UnionToIntersection<T>] ? false : true;

type And<A, B> = A extends true ? (B extends true ? true : false) : false;
type Or<A, B> = A extends true ? true : B extends true ? true : false;
type Not<A> = A extends true ? false : true;

// Defines a branded dimension with a runtime-determined size and an associated label.
export type Var<Label extends string> = number & { label: Label };
export const Var = <Label extends string>(d: number, label: Label) => {
  return d as Var<Label>;
};

// We check whether `T` is a numeric literal by checking that `number` does not
// extend from it but that it does extend from `number`.
type IsNumericLiteral<T> = number extends T
  ? false
  : T extends number
  ? true
  : false;
type IsVar<T> = T extends Var<string> ? true : false;

// For type-checking of tensors to work they can only be created using
// numeric literals (e.g. `5`)  or `Var<string>` and not types like
// `number` or `1 | 2 | 3`.
type IsNumericLiteralOrVar<T extends number | (string & number) | readonly number[]> = And<
  // We disallow `T` to be a union of types.
  Not<IsUnion<T>>,
  Or<
    // We allow `T` to be a numeric literal but not a number.
    IsNumericLiteral<T>,
    // We allow `T` to be a `Var`.
    IsVar<T>
  >
>;
type IsShapeContainingOnlyNumericLiteralsOrVarDimensions<
  T extends ReadonlyArray<number | Var<string>>
> = T extends ReadonlyArray<unknown>
  ? { [K in keyof T]: IsNumericLiteralOrVar<T[K]> } extends {
      [K in keyof T]: true;
    }
    ? T
    : never
  : never;

export type Dimension = number | Var<string>;
export type Tensor<Shape extends readonly Dimension[]> = {
  data: Float32Array;
  shape: Shape;
};
export type InvalidTensor<Shape extends readonly Dimension[]> = [
  never,
  "Invalid tensor: please provide an array of only numeric literals or `Var`s.",
  Shape
];
export function tensor<const Shape extends readonly Dimension[]>(
  shape: Shape,
  init?: number[]
): Shape extends IsShapeContainingOnlyNumericLiteralsOrVarDimensions<Shape>
  ? Tensor<Shape>
  : InvalidTensor<Shape> {
  return {
    data: new Float32Array(init || (shape.reduce((a, b) => a * b, 1) as any)),
    shape,
  } as any;
}

export type Matrix<Rows extends Dimension, Columns extends Dimension> = Tensor<[Rows, Columns]>;
export type InvalidMatrix<Shape extends readonly [Dimension, Dimension]> = [
  never,
  "Invalid matrix: please provide an array of only numeric literals or `Var`s.",
  Shape
];
export function matrix<const Shape extends readonly [Dimension, Dimension]>(
  shape: Shape,
  init?: number[]
): Shape extends IsShapeContainingOnlyNumericLiteralsOrVarDimensions<Shape>
  ? Matrix<Shape[0], Shape[1]>
  : InvalidMatrix<Shape> {
  return tensor(shape, init) as any;
}

export type Vector<Size extends number> = Tensor<[1, Size]>;
export type InvalidVector<Size extends Dimension | readonly number[]> = [
  never,
  "Invalid vector: please provide either a numeric literal or a `Var`.",
  Size
];
export function vector<const Size extends number>(
  size: Size,
  init?: number[]
): true extends IsNumericLiteralOrVar<Size>
  ? Vector<Size>
  : InvalidVector<Size>;
export function vector<const Size extends Dimension>(
  size: Size,
  init?: number[]
): true extends IsNumericLiteralOrVar<Size>
  ? Vector<Size>
  : InvalidVector<Size>;
export function vector<const Array extends readonly [number, ...number[]]>(
  init: Array
): Vector<Array["length"]>;
export function vector<const Size extends number | Var<string> | readonly number[]>(
  size: Size,
  init?: number[]
): Vector<any> {
  let shape: Dimension[];
  if (typeof size === "number") {
    shape = [1, size];
  } else if (Array.isArray(size)) {
    shape = [1, size.length];
    init = size;
  } else {
    throw new Error("Invalid input type for vector.");
  }

  return tensor(shape, init) as any;
}

function zip<SameVector extends Vector<number>>(a: SameVector, b: SameVector): Matrix<SameVector["shape"][1], 2> {
  if (a.shape[1] !== b.shape[1]) {
    throw new Error(
      `zip cannot operate on different length vectors; ${a.shape[1]} !== ${b.shape[1]}`
    );
  }

  const length = a.shape[1];
  const resultData: number[] = [];
  for (let i = 0; i < length; i++) {
    resultData.push(a.data[i], b.data[i]);
  }

  return matrix([length, 2], resultData) as any;
}

const t1 = tensor([5, Var(3, "three"), 10]);

const m1 = matrix([5, Var(3, "three")]);

const v1 = vector(2);
const v2 = vector(Var(3, "three"));
const v3 = vector([1, 2, 3]);
const v4 = vector([1, 2, 3]);
const v5 = vector([4, 5, 6]);
const v6 = vector([7, 8, 9, 10]);
const v7 = vector(Var(3, "three"), [1, 2, 3]);
const v8 = vector(Var(3, "three"), [5, 10, 15]);
const v9 = vector(Var(4, "four"), [10, 11, 12, 13]);

const zipped = zip(v4, v5); // Works fine
const zippedError = zip(v4, v6); // Now it's a compile-time error
const zipped2 = zip(v7, v8); // Works fine
const zippedError2 = zip(v7, v9); // Now it's a compile-time error

@sebinsua
Copy link
Author

I improved upon this and wrote it up as a blog post here: https://twitter.com/sebinsua/status/1656297294008819712

@sebinsua
Copy link
Author

sebinsua commented Jul 21, 2023

type Zipped<A extends readonly (readonly any[])[]> = {
  [RowIndex in keyof A[0]]: RowIndex extends "length" ? number : {
    [ArrayIndex in keyof A]: A[ArrayIndex] extends readonly any[] ? A[ArrayIndex][RowIndex] : never;
  };
};

type NumericRange<TotalLength extends number, TempRange extends any[] = []> =
  TempRange['length'] extends TotalLength ? TempRange : NumericRange<TotalLength, [...TempRange, TempRange['length']]>;

type TupleFromKeys<T, K extends readonly (keyof T)[]> = {
  [I in keyof K]: K[I] extends keyof T ? T[K[I]] : never;
};

type Tuple<T extends readonly (readonly any[])[]> = TupleFromKeys<Zipped<T>, NumericRange<T[0]['length']>>;

function zip<const T extends readonly (readonly any[])[]>(...arrays: T): Tuple<T> {
  const minLength = Math.min(...arrays.map(arr => arr.length));
  const result: any[][] = [];

  for (let i = 0; i < minLength; i++) {
    const zippedItem = arrays.map(arr => arr[i]);
    result.push(zippedItem);
  }

  return result as Tuple<T>;
}

const grid = [
  ['a', 1, true],
  ['b', 2, false],
  ['c', 3, true]
] as const;

const [col1, col2, col3] = zip(...grid);
console.log(col1);
console.log(col2);
console.log(col3);

^ This approach would be useful if you needed to convert a grid of rows into a list of columns.

See here: https://www.typescriptlang.org/play?#code/C4TwDgpgBAWglmSATAPAQShAHsCA7JAZygCcIBDJAezwBsQoAKMym+qcvEAbQF0BKPgD4oAXigBvAFBQo3AEpUA7gEkC2KHDxQA1hBBUAZlDTcADL14AuKItXqsmHPiJQARLXwBzYAAs3UAD8UHgArgC2AEYQJFA20rKy3GgkJOQgakgaWrr6RibWJsmp6ZnYvE64BMQs1HQMnDwVwaYpaRkOvArKZVgVNngQAG4xANwyUAC+49NSUqCQUAByETFwAMbynF4QKAAqVMDktAAy3n6VLsRhUTEANFB7EOFgW3g7l9UcXHxicrwiUQTJ4vN47bgAck87z8EIq2CqrgOR1O518QUez1e22gNhW4TWmxx+0OxzOMN8D24ADpaSDse8IA96WCIJDoT5fHCAeN5uBoHtQmBPAAxEhUcIAaX0hH2D0ln1ctTYDEYegMxj2ggBfwSchUmm06vyksKku4Knhzi+xs1GL23HNlv6IWGYyksz5i0Fwt2e0VNQodXYzCDKu+TW1gMeQtF4qlMpQ8EQEFQeyED3xhNZ+3MXShaLhQiEvMMoTw62AcBoUAAXggUOsaIRgI8A6Qw-UmMqu40+FHGLTqeQSiBCDYtRPY36RHqm3gW1BwlpyZy-gBZch+anLvCD2kj9qEHfkMCMQ9iESH6kcvz8fjjWTzxdkQihWjAGx9roVcR8XmyIYVCxIwnitnAfxmKMmhQCgS4rmi0FwAA1Mh-CSBMT7Nq29YpkgKi4OEfyHukx7hKe56pJeHCpNwcACI+iSvu+wDUmAoSEL4jC4cgBHPA+EyTHMshkMAoQkNozEfhwxA+p4+wlh6czPq2XgkHASB-NwEyQuQEIPAAjA8wAkKEEC8HcOkQpE+lQAATA8hjHIQ5mWUkELrLZADMxmmeZUgVOQxAqbyKlyE2tBGVAEUOdFVC0F5v51gg+7UmpGkCc+8UQDeVBeIwEUGZlzbZbl+UxcVC6lbQeUFfFXkPkAA

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment