Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import pickle | |
import math | |
print ("git revision:", torch.__version__) | |
with open('/tmp/1.0.0a0-058c128_timings.pkl', "rb") as f: | |
reference_timings_dict = pickle.load(f) | |
print (""" | |
input shape = (bs, channels) + features | |
mode = (training/eval)-(forward+backward)/forward |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
csrc = """ | |
#include <torch/extension.h> | |
#include <THC/THCDeviceUtils.cuh> | |
#include <THC/THCGeneral.h> | |
#include "ATen/ATen.h" | |
#include "ATen/AccumulateType.h" | |
#include "ATen/cuda/CUDAContext.h" | |
using namespace at; | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload | |
import builtins | |
import math | |
import pickle | |
class dtype: ... | |
_dtype = dtype |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from PIL import Image | |
from torch.utils.data import DataLoader | |
import torchvision | |
from torchvision import transforms, datasets | |
from torch.autograd import Variable | |
import torch.nn as nn | |
import torch.optim | |
import torch.backends.cudnn as cudnn; cudnn.benchmark = True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
indices = torch.LongTensor([[1,1,1], | |
[2,1,1]]) # must be two dimensional with one row per dimension | |
values = torch.arange(1,4) | |
size = torch.Size((3,3)) | |
a = torch.sparse.FloatTensor(indices, values, size) | |
b = torch.eye(3) | |
b += a | |
print (b) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.autograd import Variable | |
def linear_with_sumsq(inp, weight, bias=None): | |
def provide_sumsq(inp,w,b): | |
def _h(i): | |
if not hasattr(w, 'grad_sumsq'): | |
w.grad_sumsq = 0 | |
w.grad_sumsq += ((i**2).t().matmul(inp**2))*i.size(0) | |
if b is not None: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from timeit import default_timer as time | |
import numpy as np | |
from numba import cuda | |
import os | |
os.environ['NUMBAPRO_LIBDEVICE']='/usr/lib/nvidia-cuda-toolkit/libdevice/' | |
os.environ['NUMBAPRO_NVVM']='/usr/lib/x86_64-linux-gnu/libnvvm.so.3.1.0' | |
import numpy | |
import torch | |
import ctypes |
NewerOlder