Skip to content

Instantly share code, notes, and snippets.

View albanD's full-sized avatar
🌄
Recharging until end of Nov

albanD

🌄
Recharging until end of Nov
View GitHub Profile
@albanD
albanD / mem_tracker.py
Created July 7, 2023 22:50
Tracking time and stack traces of when Tensors are created, used and die
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakTensorKeyDictionary
import time
import warnings
import weakref
import traceback
@albanD
albanD / pytreeify.py
Created January 24, 2023 19:34
Make PyTorch custom Function unpack input and output using pytree.
import torch
from torch.autograd import Function
import torch.utils._pytree as pytree
# Basically wraps things in and out before passing it to the real function that the user defined.
def pytreeify(cls):
assert issubclass(cls, Function)
orig_fw = cls.forward
orig_bw = cls.backward
# Implements Alban's idea of making available the forward traceback
# corresponding to the execution of the current backwared node as a global
# Updated of https://gist.github.com/soulitzer/28140cc4cd7d26828ff7f07b1235d9f5
# to add inter op tracking
import torch
from torch import autograd
from torch.utils._python_dispatch import TorchDispatchMode
current_metadata = None
@albanD
albanD / opt_as_hook.py
Last active August 8, 2023 07:49
PyTorch optimizer as hook
import torch
from torch import nn
from torch.optim.sgd import sgd
import gc
import objgraph
import weakref
def all():
# Only a subset of the args you could have
def set_sgd_hook(mod, p, lr, weight_decay, momentum):
@albanD
albanD / common_dtype.md
Last active May 18, 2020 19:21
Python function common dtype

Ops to test on python side

If nothing is specified, all argument combination should be considered

CPU and GPU

  • copy_ no_sparse && no_quantize && self!=source && not_copy_transpose
  • gather
  • gather(out=)
  • scatter_(Tensor)
  • scatter(Tensor)
  • scatter_(value)
from patch_convolution import *
import torch
import torch.nn as nn
import time
# ---------------
# Parameters
# ---------------
# Number of profile iterations to run
itt = 30
import torch
from torch import nn
from torch.nn import functional as F
class EasyDataParallel(nn.Module):
def __init__(self, gpus):
super().__init__()
# Handle cpu / 1 gpu case better
assert isinstance(gpus, list)
@albanD
albanD / linear_jit_debug.md
Last active October 16, 2019 19:27
Autodiff linear debugging

Debugging code

std::cout << "Forwarding into jit module" << std::endl;
std::cout << "Forward code:" << std::endl;
std::cout << *grad.f.get() << std::endl;
std::cout << "Backward code:" << std::endl;
std::cout << *grad.df.get() << std::endl;
std::cout << "End print !" << std::endl;
@albanD
albanD / hessian.py
Created September 25, 2019 21:03
Compute full Hessian of a network
import torch
from torch import nn
from torchviz import make_dot
from torch.autograd.gradcheck import gradcheck
torch.set_default_tensor_type(torch.DoubleTensor)
my_mod = nn.Sequential(nn.Linear(2, 2, bias=False), nn.Sigmoid(), nn.Linear(2, 2, bias=False), nn.Sigmoid(), nn.Linear(2, 1, bias=False))
params = list(my_mod.parameters())
local threads = require "threads"
threads.Threads.serialization('threads.sharedserialize')
n_task = 3
local pools = {}
for task=1,n_task do
pools[task] = threads.Threads(5,
function()
-- Needed only for serialized elements