This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using REPL | |
using REPL.LineEdit | |
macro keyboard() | |
quote | |
debugprompt(@__MODULE__, Base.@locals) | |
println() | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# This uses the Nabla.jl-style interception mechanism whereby | |
# we wrap things that are to be differentiated w.r.t. in a | |
# thin wrapper. There are lots of thing that you can't | |
# propoagate derivative information through with this kind of | |
# approach without quite a lot of extra machinery, but the | |
# examples at the bottom do work. | |
# | |
using ChainRules, Cassette |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
struct CachedConv | |
conv::Conv | |
cache::Ref{Tuple} | |
end | |
CachedConv(c::Conv) = CachedConv(c, ()) | |
Flux.@treelike CachedConv | |
function (m::CachedConv)(x::AbstractArray) | |
# Has the user changed batch size on us? If so, clear our cache and re-up! | |
if !isempty(m.cache[]) && size(m.cache[][2], 4) != size(x, 4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Zygote, Statistics, Flux | |
# We modify (the implementation of) batchnorm to be more ammenable to CPUs pretending to be TPUs. | |
struct ZygoteBatchNorm{F,V,W} | |
λ::F # activation function | |
β::V # bias | |
γ::V # scale | |
μ::W # moving mean | |
σ::W # moving std | |
ϵ::Float32 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env bash | |
function red_echo() | |
{ | |
tput setaf 1 | |
echo "$*" | |
tput sgr0 | |
} | |
if [[ "$*" == *--help* ]]; then |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- The Datafun runtime. | |
module Runtime where | |
import qualified Data.Set as Set | |
import Data.Set (Set) | |
class Eq a => Preord a where | |
(<:) :: a -> a -> Bool | |
class Preord a => Semilat a where |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import numpy.random as npr | |
import jax.numpy as np | |
from jax import lax | |
from jax import grad, pjit, papply | |
### set up some synthetic data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const BRAILLE = split("⠀⠁⠂⠄⡀⠈⠐⠠⢀", "") .|> s -> Int(s[1]) | |
function show_any_nonzero(S::SparseMatrixCSC; maxw = displaysize(stdout)[2], maxh = displaysize(stdout)[1]-3) | |
h,w = size(S) | |
h > 4maxh && (w = max(1, (w*4maxh+h÷2)÷h); h = 4maxh) | |
w > 2maxw && (h = max(1, (h*2maxw+w÷2)÷w); w = 2maxw) | |
P = fill(BRAILLE[1], (w+3)÷2, (h+3)÷4) | |
P[end, :].=10 | |
@inbounds for c = 0:w-1, r = 0:h-1 | |
_anynz(S, r*S.m÷h+1, c*S.n÷w+1, (r+1)*S.m÷h, (c+1)*S.n÷w) && |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module TraceCalls | |
using Cassette | |
mutable struct Trace | |
level::Int | |
cutoff::Int | |
end | |
Cassette.@context TraceCtx |
NewerOlder