https://wandb.ai/rom1504/dalle2_train_decoder/runs/mic5buox/files/decoder_config.json
get dalle2
get the config file
get these 2 .sh
run sbatch start_big.sh
#include <stdio.h> | |
// Check tensor core's warp register layout | |
// nvcc -arch=sm_75 tensorcore_mapping.cu -o mapping | |
// ./mapping | |
// Define some error checking macros. | |
#define cudaErrCheck(stat) { cudaErrCheck_((stat), __FILE__, __LINE__); } | |
void cudaErrCheck_(cudaError_t stat, const char *file, int line) { | |
if (stat != cudaSuccess) { |
https://wandb.ai/rom1504/dalle2_train_decoder/runs/mic5buox/files/decoder_config.json
get dalle2
get the config file
get these 2 .sh
run sbatch start_big.sh
Issue title: (working implementation) Fused multi-head attention for arbitrary sequence lengths.
TL;DR you can run multi-head attention (fwd+bwd) faster and with no extra memory – with any sequence length and head dim. We’d love to make it available via apex. We need your advice on how best to do that.
Why should I care? Here's how it compares against the standard multihead attention (blue) for one multi-head attention layer of GPT-J on an RTX 3080Ti.
time, with backward (ms) | peak vram allocated (mb) |
---|---|
from typing import Tuple | |
import torch | |
import torch.nn.functional as F | |
import itertools | |
from timeit import default_timer as timer | |
class SoftmaxWeightedMean(torch.autograd.Function): | |
@staticmethod |
import torch | |
import torch.utils.dlpack | |
import jax | |
import jax.dlpack | |
# A generic mechanism for turning a JAX function into a PyTorch function. | |
def j2t(x_jax): | |
x_torch = torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x_jax)) | |
return x_torch |
"""Complex momentum SGD and Adam. See https://arxiv.org/abs/2102.08431.""" | |
import math | |
import torch | |
from torch import optim | |
class ComplexSGD(optim.Optimizer): | |
def __init__(self, params, lr=1e-2, momentum=0.9, angle=math.pi / 8, weight_decay=0.): |
from functools import partial | |
import torch | |
def _const(example, val): | |
return torch.tensor(val, dtype=example.dtype) | |
def pad(x, axis, side): | |
shape = list(x.size()) | |
if axis == -1: | |
axis = len(shape) - 1 |