This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Simplified Implementation of the Linear Recurrent Unit | |
------------------------------------------------------ | |
We present here a simplified JAX implementation of the Linear Recurrent Unit (LRU). | |
The state of the LRU is driven by the input $(u_k)_{k=1}^L$ of sequence length $L$ | |
according to the following formula (and efficiently parallelized using an associative scan): | |
$x_{k} = \Lambda x_{k-1} +\exp(\gamma^{\log})\odot (B u_{k})$, | |
and the output is computed at each timestamp $k$ as follows: $y_k = C x_k + D u_k$. | |
In our code, $B,C$ follow Glorot initialization, with $B$ scaled additionally by a factor 2 | |
to account for halving the state variance by taking the real part of the output projection. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using SimpleChains | |
function f(x) | |
N = Base.isqrt(length(x)) | |
A = reshape(view(x, 1:N*N), (N,N)) | |
expA = exp(A) | |
vec(expA) | |
end | |
T = Float32; | |
D = 2 # 2x2 matrices |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# install tinycudann via | |
# pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch | |
import torch | |
import tinycudann as tcnn | |
import time | |
class TCNNMatrixExponentEstimator1(torch.nn.Module): | |
def __init__(self, hidden=16) -> None: | |
super().__init__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using PyCall | |
using CUDA | |
using DLPack | |
using Test | |
#using Zygote | |
#using ChainRulesCore | |
@show DLPack.PYCALL_NOOP_DELETER | |
jax = pyimport("jax") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using PyCall | |
using DLPack | |
using Test | |
using Zygote | |
using ChainRulesCore | |
torch = pyimport("torch") | |
functorch = pyimport("functorch") | |
dlpack = pyimport("torch.utils.dlpack") | |
py""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Flux | |
qdim = 2 | |
nn = Chain(Dense(qdim, 32, tanh), Dense(32, 2)); | |
q = rand(2, 5); | |
function jac(x) | |
o = nn(x) | |
return reduce(hcat, [o[:, i] for i in 1:size(x)[end]]) | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# | |
# File: test_dircol.py | |
# | |
import numpy as np | |
import torch | |
from optimalcontrol.dircolproblem import DIRCOLProblem | |
from mechamodlearn import utils |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using MuJoCo | |
modelfile = "test/humanoid.xml" | |
pm = mj_loadXML(modelfile) # Raw C pointer to mjModel | |
pd = mj_makeData(pm) # Raw C pointer to mjData | |
m, d = mj.mapmujoco(pm, pd) # wrap with our jlModel, jlData types | |
# we can manipulate data in the raw C structs now | |
nq = mj.get(m, :nq) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
\newcommand{\ba}{\mathbf{a}} | |
\newcommand{\bb}{\mathbf{b}} | |
\newcommand{\bc}{\mathbf{c}} | |
\newcommand{\bd}{\mathbf{d}} | |
\newcommand{\be}{\mathbf{e}} | |
\newcommand{\bg}{\mathbf{g}} | |
\newcommand{\bh}{\mathbf{h}} | |
\newcommand{\bi}{\mathbf{i}} | |
\newcommand{\bj}{\mathbf{j}} | |
\newcommand{\bk}{\mathbf{k}} |
NewerOlder