This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from torch.utils.data.sampler import Sampler | |
| import itertools | |
| class SequentialRandomSampler(Sampler): | |
| """Samples elements sequentially, starting from a random location. | |
| For when you want to sequentially sampled a random subset | |
| Usage: | |
| loader = torch.utils.data.DataLoader( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def isfinite(x): | |
| """ | |
| Quick pytorch test that there are no nan's or infs. | |
| note: torch now has torch.isnan | |
| url: https://gist.github.com/wassname/df8bc03e60f81ff081e1895aabe1f519 | |
| """ | |
| not_inf = ((x + 1) != x) | |
| not_nan = (x == x) | |
| return not_inf & not_nan |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # code for question on reddit https://www.reddit.com/r/MachineLearning/comments/8poc3z/r_blog_post_on_world_models_for_sonic/e0cwb5v/ | |
| # from this | |
| def forward(self, x): | |
| self.lstm.flatten_parameters() | |
| x = F.relu(self.fc1(x)) | |
| z, self.hidden = self.lstm(x, self.hidden) | |
| sequence = x.size()[1] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Pytorch sampler that samples ordered indices from unordered sequences. | |
| Good for use with dask and RNN's, because | |
| 1. Dask will slow down if sampling between chunks, so we must do one chunk at a time | |
| 2. RNN's need sequences so we must have seqences e.g. 1,2,3 | |
| 3. But RNN's train better with batches that are uncorrelated so we want each batch to be sequence from a different part of a chunk. | |
| For example, given each chunk is `range(12)`. Our seq_len is 3. We might end up with these indices: | |
| - [[1,2,3],[9,10,11],[4,5,6]] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch.utils.data | |
| class NumpyDataset(torch.utils.data.Dataset): | |
| """Dataset wrapping arrays. | |
| Each sample will be retrieved by indexing array along the first dimension. | |
| Arguments: | |
| *arrays (numpy.array): arrays that have the same size of the first dimension. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dask.callbacks import Callback | |
| from tqdm.auto import tqdm | |
| class TQDMDaskProgressBar(Callback, object): | |
| """ | |
| A tqdm progress bar for dask. | |
| Usage: | |
| ``` | |
| with TQDMDaskProgressBar(): |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| In jupyter notebook simple logging to console | |
| """ | |
| import logging | |
| import sys | |
| logging.basicConfig(stream=sys.stdout, level=logging.INFO) | |
| # Test | |
| logger = logging.getLogger('LOGGER_NAME') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class AdamStepLR(torch.optim.Adam): | |
| """Combine Adam and lr_scheduler.StepLR so we can use it as a normal optimiser""" | |
| def __init__(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, step_size=50000, gamma=0.5): | |
| super().__init__(params, lr, betas, eps, weight_decay) | |
| self.scheduler = torch.optim.lr_scheduler.StepLR(self, step_size, gamma) | |
| def step(self): | |
| self.scheduler.step() | |
| return super().step() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def window_stack(x, window=4, pad=True): | |
| """ | |
| Stack along a moving window of a pytorch timeseries | |
| Inputs: | |
| tensor of dims (batches/time, channels) | |
| pad: if true the left side will be padded to let the output match | |
| Outputs: | |
| if pad=True: a tensor of size (batches, channels, window) | |
| else: tensor of size (batches-window, channels, window) |