This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.optim import Optimizer | |
import math | |
import torch | |
from torch import Tensor | |
from typing import List, Optional, Callable | |
def adaprox(params: List[Tensor], | |
proxes: List[Callable[[Tensor, float], Tensor]], | |
grads: List[Tensor], | |
exp_avgs: List[Tensor], |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from IPython.display import Code | |
def get_source_code(obj, display=False): | |
if inspect.isclass(obj): | |
this_class = obj | |
else: | |
# get class from instance | |
this_class = type(obj) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cProfile, os, tempfile | |
from IPython.display import Image | |
def cprof2png(command): | |
# create tmp file to store cprof | |
temp = tempfile.NamedTemporaryFile() | |
# run profiler | |
cProfile.run(command, temp.name) | |
# parse to dot and make png figure from temp | |
os.system(f"gprof2dot -f pstats {temp.name} | dot -Tpng -o {temp.name}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.optim.sgd import SGD | |
from torch.optim.optimizer import required | |
class PGM(SGD): | |
def __init__(self, params, proxs, lr=required, momentum=0, dampening=0, | |
nesterov=False): | |
kwargs = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=0, nesterov=nesterov) | |
super().__init__(params, **kwargs) | |
if len(proxs) != len(self.param_groups): | |
raise ValueError("Invalid length of argument proxs: {} instead of {}".format(len(proxs), len(self.param_groups))) |