This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Union, List | |
from pathlib import Path | |
import torch | |
from torch.utils.data import Dataset | |
from scipy.io import loadmat | |
import itertools | |
import os | |
def znormalize(train_x, train_y, test_x, test_y): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from typing import List | |
import torch.nn.functional as F | |
def receptive_field(kernel_size: int, dilation: int): | |
return 1 + (kernel_size - 1) * dilation | |
class Seq2SeqConv1d(torch.nn.Module): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
def tree_stack(trees): | |
"""Takes a list of trees and stacks every corresponding leaf. | |
For example, given two trees ((a, b), c) and ((a', b'), c'), returns | |
((stack(a, a'), stack(b, b')), stack(c, c')). | |
Useful for turning a list of objects into something you can feed to a | |
vmapped function. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.data import Dataset | |
class TimeSeriesDataset(Dataset): | |
def __init__( | |
self, | |
ts: torch.Tensor, | |
x_ts: torch.Tensor, | |
normalize=True, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import OrderedDict | |
import torch | |
from torch import Tensor, Size | |
from torch.nn import Linear | |
class MLP(torch.nn.Sequential): | |
"""Multi-layered perception, i.e. fully-connected neural network | |
Args: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pickle | |
from pathlib import Path | |
from wandb.errors import CommError | |
import wandb | |
def get_history(user="", project="", query={}, **kwargs): | |
api = wandb.Api() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import namedtuple | |
class Data(namedtuple("Data", ("x", "y"))): | |
def to(self, device, non_blocking=False): | |
x = self.x.to(device, non_blocking=non_blocking) | |
y = self.y.to(device, non_blocking=non_blocking) | |
return Data(x, y) | |
def contiguous(self): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
import jax.numpy as jnp | |
import jax | |
import equinox as eqx | |
from typing import Union, Any | |
from abc import ABC, abstractmethod | |
MaybeParameterizedArray = Union[jax.Array, "ParameterizedArray"] | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from typing import Optional, List | |
def array_to_dataframe(array, axis_names: Optional[List[str]]=None): | |
"""Based on https://stackoverflow.com/questions/35525028/how-to-transform-a-3d-arrays-into-a-dataframe-in-python""" | |
if axis_names is None: | |
axis_names = list(range(array.ndim)) | |
index = pd.MultiIndex.from_product([range(s) for s in array.shape], names=names) | |
df = pd.DataFrame({"array": array.flatten()}, index=index)["array"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
tir_palette = { | |
"very_low": "#A61D2A", | |
"low": "#EE1D23", | |
"in_range": "#26B257", | |
"high": "#FAAB1A", | |
"very_high": "#F47D21" | |
} | |
def color_bg(bgs): | |
bins = [0, 54, 70, 180, 250, 1000] |
OlderNewer