This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, functools | |
def private(member): | |
@functools.wraps(member) | |
def wrapper(*function_args): | |
myself = member.__name__ | |
caller = sys._getframe(1).f_code.co_name | |
if (not caller in dir(function_args[0]) and not caller is myself): | |
raise PermissionError(f"Cannot call private function {caller}.{myself}") | |
return member(*function_args) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
def conv1d_out_shape(input_size, kernel_size, stride=1, padding=0, dilation=1): | |
output_size = ((input_size + (2 * padding[0]) - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0]) + 1 | |
return math.floor(output_size) | |
def transpose_conv1d_out_shape(input_size, kernel_size, stride=1, padding=0, dilation=1): | |
output_size = ((input_size - 1) * stride) - (2 * padding) + (dilation * (kernel_size - 1)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import Subset | |
from sklearn.model_selection import train_test_split | |
class SubsetProxy(Subset): | |
def __init__(self, dataset, indices): | |
super().__init__(dataset, indices) | |
def __getattr__(self, name): | |
try: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def profiled(func): | |
def decorated(*args, **kwargs): | |
pr = cProfile.Profile() | |
pr.enable() | |
result = func(*args, **kwargs) | |
pr.disable() | |
pr.print_stats(sort="tottime") | |
return result |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import hashlib | |
funcs = [ | |
hashlib.blake2b, | |
hashlib.blake2s, | |
hashlib.md5, | |
hashlib.sha1, | |
hashlib.sha224, | |
hashlib.sha256, | |
hashlib.sha256, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def is_supported(): | |
CUDA_VERSION = torch._C._cuda_getCompiledVersion() | |
supported = True | |
for d in range(torch.cuda.device_count()): | |
capability = torch.cuda.get_device_capability(d) | |
major = capability[0] | |
minor = capability[1] | |
supported &= major > 3 # too old |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
An implementation of the pytorch Subset that returns an instance of the original dataset with a reduced number of items. | |
This has two benefits: | |
- It allows to stil access the attributes of the Dataset class, such as methods, or properties. | |
- You can use the usual python index notation with slices to chunk the dataset, rather than creating a list of indices | |
""" | |
class Dataset(object): | |
def __init__(self, iterable): | |
self.items = iterable |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This gist is an extract from: | |
https://github.com/epignatelli/reinforcement-learning-an-introduction | |
""" | |
import matplotlib | |
matplotlib.use("Agg") | |
import numpy as np | |
import matplotlib.pyplot as plt |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This gist is an extract from: | |
https://github.com/epignatelli/reinforcement-learning-an-introduction | |
""" | |
def bellman_expectation(self, state, probs, discount): | |
""" | |
Makes a one step lookahead and applies the bellman expectation equation to the state self.state_value[state] | |
Args: | |
state (Tuple[int, int]): the x, y indices that define the address on the value table |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This gist is an extract from: | |
https://github.com/epignatelli/reinforcement-learning-an-introduction | |
""" | |
def policy_evaluation(env, policy=None, steps=1, discount=1., in_place=False): | |
""" | |
Args: | |
policy (numpy.array): a numpy 3-D numpy array, where the first two dimensions identify a state and the third dimension identifies the actions. | |
The array stores the probability of taking each action. |
OlderNewer