This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Save/load pytrees to disk.""" | |
import collections | |
import h5py | |
import jax | |
import numpy as np | |
def save(filepath, tree): | |
"""Saves a pytree to an hdf5 file. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Dark mode | |
color0 #282828 | |
color1 #cc241d | |
color2 #98971a | |
color3 #d79921 | |
color4 #458588 | |
color5 #b16286 | |
color6 #689d6a | |
color7 #a89984 | |
color8 #928374 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def tf_graph_wrapper(func): | |
"""Wraps a class method with a tf.Graph context manager""" | |
@wraps(func) | |
def wrapper(self, *args, **kwargs): | |
with self._graph.as_default(): | |
return func(self, *args, **kwargs) | |
return wrapper | |
def tf_init(func): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class AdaptiveIAF(tf.nn.rnn_cell.RNNCell): | |
def __init__(self, num_units, dt, reuse=False): | |
self._dt = tf.constant(dt, dtype=tf.float32) | |
self._num_units = num_units | |
self._reuse = reuse | |
@property | |
def state_size(self): | |
return (self._num_units, self._num_units) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
# initialize variable x=10 | |
x = tf.Variable(10.0, dtype=tf.float32) | |
# objective is x ** 3 which has a saddle point at x=0 | |
f = x ** 3 | |
# create optimizer and compute gradients |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from moviepy.editor import ImageSequenceClip | |
def gif(filename, array, fps=10, scale=1.0): | |
"""Creates a gif given a stack of images using moviepy | |
Notes | |
----- | |
works with current Github version of moviepy (not the pip version) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
from time import perf_counter | |
from IPython.core.magics.execution import _format_time as fmt | |
class Timer(object): | |
""" | |
Timer is a simple class to keep track of elapsed time. | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numpy.random import rand, randn, shuffle, choice, sample | |
from numpy import arange, zeros, ones, eye, linspace, pi, inf, nan, cov, array | |
from scipy.linalg import * | |
from matplotlib.pyplot import * |
NewerOlder