Created
April 25, 2024 13:50
-
-
Save stellarpower/54c54a85b1fc57c3bccaaa70361d179f to your computer and use it in GitHub Desktop.
Brief and Messy Script for Comparing KEras Soft-DTW against Torch pysdtw
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
. | |
├── Common.py | |
├── pysdtw | |
│ ├── ... Rest of checkout here | |
│ ├── Test Arrays -> /Some/Common/Path | |
│ ├── test.py | |
├── pytorch-softdtw | |
│ ├── ... Rest of checkout here | |
│ ├── Test Arrays -> /Some/Common/Path | |
│ └── test.py # Test script is very similar but not identical; these two are already tested as identical so not necessarily needed | |
├── Soft-DTW | |
│ ├── Test Arrays -> /Some/Common/Path | |
│ ├── TestImplementation.py | |
│ ├── ... | |
TestArrays are all symlinked ot the same folder. | |
pytorch-softdtw probably doesn't need ot be tested, as pysdtw is a rework of this. But it could be useful. | |
_Testing is easiest if you comment out the numba.jit decorator on softdtw, to run plain python code on the CPU that can be inspected easily in a debugger_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pysdtw | |
import os, sys, pickle, torch | |
import numpy as np | |
import torch.nn as nn | |
# Get the path of the current script | |
ScriptDirectory = os.path.dirname(os.path.realpath(__file__)) | |
CacheDirectory = f"{ ScriptDirectory }/Test Arrays" | |
# For common file in directory above the pysdtw repo. | |
sys.path.insert(0, f"{ScriptDirectory}/..") | |
from Common import * | |
# Used for debuggin how exactly the gradient tape works. | |
class L1LossFunction(torch.nn.Module): #torch.autograd.Function): | |
def __init__(self): | |
super().__init__() | |
#@staticmethod | |
#def forward(self, ctx, y_true, y_pred): | |
# ctx.save_for_backward(y_true, y_pred) | |
# return y_true - y_pred | |
@staticmethod | |
#def forward(ctx, input, target): | |
def forward(input, target): | |
#ctx.save_for_backward(input, target) | |
return input - target | |
UseCUDA = False | |
Device = torch.device('cuda' if UseCUDA else 'cpu') | |
# the input data includes a batch dimension | |
#X = torch.rand((4, 8, 2), device=device, requires_grad=True) | |
#Y = torch.rand((4, 8, 2), device=device) | |
y_true = load(f"{ CacheDirectory }/y_true.pkl" ).astype('float32') | |
y_pred = load(f"{ CacheDirectory }/y_pred.pkl" ).astype('float32') | |
# These are dictionaries of {Gamma => scalar/array} | |
referenceLosses = load(f"{ CacheDirectory }/TestingLosses.pkl" ) | |
referenceGradients = load(f"{ CacheDirectory }/TestingGradients.pkl") | |
y_true = torch.from_numpy(y_true).to(Device) | |
# We only need gradients for y_pred; y_true is a constant. | |
y_pred = torch.from_numpy(y_pred).to(Device) | |
y_pred.requires_grad = True | |
# optionally choose a pairwise distance function | |
#distanceMetric = pysdtw.distance.pairwise_l2_squared | |
distanceMetric = pysdtw.distance.pairwise_l2_squared_exact | |
g = {} | |
GammaValues = [1.0, 0.1, 0.01] | |
for gamma in GammaValues: | |
# create the SoftDTW distance function | |
sdtw = pysdtw.SoftDTW(gamma = gamma, dist_func = distanceMetric, use_cuda = UseCUDA) | |
#sdtw = nn.MSELoss() | |
#sdtw = L1LossFunction() | |
# This is still "attached" to the graph, so apparently carries around all the intermediate calculations etc. | |
# We need the attached version for the gradients | |
lossesPerSequenceGraph = sdtw(y_true, y_pred) | |
# Have to call sum before running the backwards pass. | |
lossesPerSequenceGraph.sum().backward() | |
# We only want the gradient on y_pred; y_true is constant and thus doesn't affect them. | |
torchGradients = y_pred.grad | |
# Be very careful here about the order in which things are detached! | |
# The copy seems necessary - otherwise the array ends up referencing junk on the heap. | |
torchGradients = np.copy(torchGradients.detach().cpu().numpy()) | |
lossesPerSequence = lossesPerSequenceGraph.detach().numpy() | |
summedLoss = lossesPerSequence.sum() | |
# Check against the version previousl calculated _or_ created through Keras. | |
assert np.allclose(summedLoss, referenceLosses [gamma]), "Tests failed - losses differ" | |
g[gamma] = torchGradients | |
# We MUST zero out the gradients before the next iteration. | |
# Unlike TensorFlow, Torch does not handle this througha context manager. | |
y_pred.grad.zero_() | |
# We need a slightly larger tolerance here | |
assert np.allclose(torchGradients, referenceGradients[gamma], rtol = 1e-4), "Tests failed - gradients differ" | |
_break = "hold" # set breakpoint here so that results don't get swallowed up | |
allAlignmentMatrices = pysdtw.alignmentMatrices | |
for gamma in GammaValues: | |
matrices = allAlignmentMatrices[np.float32(gamma)] | |
matrices = np.array(matrices) | |
save3DCSV(f"{ CacheDirectory }/tonison - gamma is {gamma}.csv", matrices) | |
_break = "hold" # set breakpoint here so that results don't get swallowed up | |
#input() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment