Skip to content

Instantly share code, notes, and snippets.

@rejuvyesh
rejuvyesh / lru.py
Created April 13, 2023 02:45 — forked from Ryu1845/lru.py
Linear Recurrent Unit (LRU) from the paper ["Resurrecting Recurrent Neural Networks for Long Sequences"](https://arxiv.org/abs/2303.06349)
"""
Simplified Implementation of the Linear Recurrent Unit
------------------------------------------------------
We present here a simplified JAX implementation of the Linear Recurrent Unit (LRU).
The state of the LRU is driven by the input $(u_k)_{k=1}^L$ of sequence length $L$
according to the following formula (and efficiently parallelized using an associative scan):
$x_{k} = \Lambda x_{k-1} +\exp(\gamma^{\log})\odot (B u_{k})$,
and the output is computed at each timestamp $k$ as follows: $y_k = C x_k + D u_k$.
In our code, $B,C$ follow Glorot initialization, with $B$ scaled additionally by a factor 2
to account for halving the state variance by taking the real part of the output projection.
using SimpleChains
function f(x)
N = Base.isqrt(length(x))
A = reshape(view(x, 1:N*N), (N,N))
expA = exp(A)
vec(expA)
end
T = Float32;
D = 2 # 2x2 matrices
# install tinycudann via
# pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
import torch
import tinycudann as tcnn
import time
class TCNNMatrixExponentEstimator1(torch.nn.Module):
def __init__(self, hidden=16) -> None:
super().__init__()
@rejuvyesh
rejuvyesh / stresstest_jax_dlpack.jl
Last active February 9, 2022 17:32
DLPACk segfault reproduce on CUDA+Jax
using PyCall
using CUDA
using DLPack
using Test
#using Zygote
#using ChainRulesCore
@show DLPack.PYCALL_NOOP_DELETER
jax = pyimport("jax")
@rejuvyesh
rejuvyesh / stresstest_dlpack.jl
Created February 7, 2022 22:12
DLPack reproduce segfault
using PyCall
using DLPack
using Test
using Zygote
using ChainRulesCore
torch = pyimport("torch")
functorch = pyimport("functorch")
dlpack = pyimport("torch.utils.dlpack")
py"""
using Flux
qdim = 2
nn = Chain(Dense(qdim, 32, tanh), Dense(32, 2));
q = rand(2, 5);
function jac(x)
o = nn(x)
return reduce(hcat, [o[:, i] for i in 1:size(x)[end]])
end
#!/usr/bin/env python3
#
# File: test_dircol.py
#
import numpy as np
import torch
from optimalcontrol.dircolproblem import DIRCOLProblem
from mechamodlearn import utils
using MuJoCo
modelfile = "test/humanoid.xml"
pm = mj_loadXML(modelfile) # Raw C pointer to mjModel
pd = mj_makeData(pm) # Raw C pointer to mjData
m, d = mj.mapmujoco(pm, pd) # wrap with our jlModel, jlData types
# we can manipulate data in the raw C structs now
nq = mj.get(m, :nq)
@rejuvyesh
rejuvyesh / notebook.ipynb
Created July 19, 2017 00:52 — forked from eamartin/notebook.ipynb
Understanding & Visualizing Self-Normalizing Neural Networks
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
\newcommand{\ba}{\mathbf{a}}
\newcommand{\bb}{\mathbf{b}}
\newcommand{\bc}{\mathbf{c}}
\newcommand{\bd}{\mathbf{d}}
\newcommand{\be}{\mathbf{e}}
\newcommand{\bg}{\mathbf{g}}
\newcommand{\bh}{\mathbf{h}}
\newcommand{\bi}{\mathbf{i}}
\newcommand{\bj}{\mathbf{j}}
\newcommand{\bk}{\mathbf{k}}