This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| torch::jit::IValue myIvalue; | |
| torch::jit::IValue myIvalue2; | |
| torch::jit::script::Module module; | |
| std::unordered_map<std::string, torch::jit::IValue> umap = {{"x", myIvalue}, {"opt", myIvalue2}}; | |
| auto result = module.get_method("forward")({}, umap); | |
| // shows all potential arguments to model forward | |
| std::cout << module.get_method("forward").function().getSchema().arguments() << std::endl; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| df = ### assumes you have a df | |
| category_column = 'main_category' | |
| number = 1000 | |
| dfz = [] | |
| for cat in df[category_column].unique(): | |
| dfz.append(df[df[category_column]==cat].sample(n=number)) | |
| final_df = pd.concat(dfz) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #uncomment to run in a notebook | |
| #%load_ext Cython | |
| %%cython | |
| cdef int round(int n, int d): | |
| return (n + d // 2) // d | |
| def search(it:int, target) -> bool: | |
| cdef int first = 0 | |
| cdef int last = len(target) | |
| cdef int middle = round(first+last, 2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def n_hot(y, num_classes, scatter_dim): | |
| # we assume the masking-value is always -1 | |
| # add extra class and shift y's | |
| nc = num_classes+1 | |
| y+= 1 | |
| y_tensor = y.view(*y.size()[:scatter_dim], -1) | |
| zeros = torch.zeros(*y.size()[:scatter_dim], nc, dtype=y.dtype, device=y.device) | |
| res = zeros.scatter(scatter_dim, y_tensor, 1) | |
| return res.index_select(scatter_dim, torch.arange(1, nc).long()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from torch.utils.data import Dataset, DataLoader | |
| import numpy as np | |
| class BatchSampler(Sampler): | |
| def __init__(self, num_samples, batch_size, shuffle=True): | |
| ''' | |
| Samples a 1d sequence as batches of indices | |
| :param num_samples: total number of datapoints (1d data sequence) to be sampled from. | |
| ''' | |
| self.num_samples = num_samples |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def f(x, y, z): | |
| x = "hello" | |
| y, z = "swag", "master" | |
| d = {'x': "bleh", 'y': "bleh", 'z': "bleh"} | |
| f(**d) | |
| print(d) | |
| # {'x': 'bleh', 'y': 'bleh', 'z': 'bleh'} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import matplotlib.pyplot as plt | |
| %matplotlib inline | |
| def find_lr(net, criterion, optimizer, trn_loader, init_value = 1e-8, final_value=10., beta = 0.98): | |
| num = len(trn_loader)-1 | |
| mult = (final_value / init_value) ** (1/num) | |
| lr = init_value | |
| optimizer.param_groups[0]['lr'] = lr | |
| avg_loss = 0. | |
| best_loss = 0. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Circular LR as implemented in fast.ai, however this is not dependent on all the interals of it | |
| class CircularLR: | |
| def __init__(self, optimizer, nb, div=10, pct=10, momentums=None): | |
| self.nb,self.div,self.pct = nb,div,pct | |
| self.cycle_nb = int(nb * (1-pct/100) / 2) | |
| self.opt = optimizer | |
| self.init_lr = self.opt.param_groups[0]['lr'] | |
| if momentums is not None: | |
| self.moms = momentums |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def num_tries_gt_zero(scores, batch_size, max_trials, max_num, device): | |
| ''' | |
| returns: [1 x batch_size] the lowest indice per row where scores were first greater than 0. plus 1 | |
| ''' | |
| tmp = scores.gt(0).nonzero().t() | |
| # We offset these values by 1 to look for unset values (zeros) later | |
| values = tmp[1] + 1 | |
| # TODO just allocate normal zero-tensor and fill it? | |
| # Sparse tensors can't be moved with .to() or .cuda() if you want to send in cuda variables first | |
| if device.type == 'cuda': |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def _to_one_hot(y, n_dims, dtype=torch.cuda.FloatTensor): | |
| scatter_dim = len(y.size()) | |
| y_tensor = y.type(torch.cuda.LongTensor).view(*y.size(), -1) | |
| zeros = torch.zeros(*y.size(), n_dims).type(dtype) | |
| return zeros.scatter(scatter_dim, y_tensor, 1) | |
| class LSEP2(Function): |