A ZSH theme optimized for people who use:
- Solarized
- Git
- Unicode-compatible fonts and terminals (I use iTerm2 + Menlo)
For Mac users, I highly recommend iTerm 2 + Solarized Dark
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | |
Args: | |
logits: logits distribution shape (vocabulary size) | |
top_k >0: keep only top k tokens with highest probability (top-k filtering). | |
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | |
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | |
""" | |
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear | |
top_k = min(top_k, logits.size(-1)) # Safety check |
import sys | |
from collections import OrderedDict | |
PY2 = sys.version_info[0] == 2 | |
_internal_attrs = {'_backend', '_parameters', '_buffers', '_backward_hooks', '_forward_hooks', '_forward_pre_hooks', '_modules'} | |
class Scope(object): | |
def __init__(self): | |
self._modules = OrderedDict() |
def train_fn(model, optimizer, criterion, batch): | |
x, y, lengths = batch | |
x = Variable(x.cuda()) | |
y = Variable(y.cuda(), requires_grad=False) | |
mask = Variable(torch.ByteTensor(x.size()).fill_(1).cuda(), | |
requires_grad=False) | |
for k, l in enumerate(lengths): | |
mask[:l, k, :] = 0 |