Skip to content

Instantly share code, notes, and snippets.

@stellarpower
Created April 25, 2024 13:50
Show Gist options
  • Save stellarpower/54c54a85b1fc57c3bccaaa70361d179f to your computer and use it in GitHub Desktop.
Save stellarpower/54c54a85b1fc57c3bccaaa70361d179f to your computer and use it in GitHub Desktop.
Brief and Messy Script for Comparing KEras Soft-DTW against Torch pysdtw
.
├── Common.py
├── pysdtw
│   ├── ... Rest of checkout here
│   ├── Test Arrays -> /Some/Common/Path
│   ├── test.py
├── pytorch-softdtw
│   ├── ... Rest of checkout here
│   ├── Test Arrays -> /Some/Common/Path
│   └── test.py # Test script is very similar but not identical; these two are already tested as identical so not necessarily needed
├── Soft-DTW
│   ├── Test Arrays -> /Some/Common/Path
│   ├── TestImplementation.py
│   ├── ...
TestArrays are all symlinked ot the same folder.
pytorch-softdtw probably doesn't need ot be tested, as pysdtw is a rework of this. But it could be useful.
_Testing is easiest if you comment out the numba.jit decorator on softdtw, to run plain python code on the CPU that can be inspected easily in a debugger_
import pysdtw
import os, sys, pickle, torch
import numpy as np
import torch.nn as nn
# Get the path of the current script
ScriptDirectory = os.path.dirname(os.path.realpath(__file__))
CacheDirectory = f"{ ScriptDirectory }/Test Arrays"
# For common file in directory above the pysdtw repo.
sys.path.insert(0, f"{ScriptDirectory}/..")
from Common import *
# Used for debuggin how exactly the gradient tape works.
class L1LossFunction(torch.nn.Module): #torch.autograd.Function):
def __init__(self):
super().__init__()
#@staticmethod
#def forward(self, ctx, y_true, y_pred):
# ctx.save_for_backward(y_true, y_pred)
# return y_true - y_pred
@staticmethod
#def forward(ctx, input, target):
def forward(input, target):
#ctx.save_for_backward(input, target)
return input - target
UseCUDA = False
Device = torch.device('cuda' if UseCUDA else 'cpu')
# the input data includes a batch dimension
#X = torch.rand((4, 8, 2), device=device, requires_grad=True)
#Y = torch.rand((4, 8, 2), device=device)
y_true = load(f"{ CacheDirectory }/y_true.pkl" ).astype('float32')
y_pred = load(f"{ CacheDirectory }/y_pred.pkl" ).astype('float32')
# These are dictionaries of {Gamma => scalar/array}
referenceLosses = load(f"{ CacheDirectory }/TestingLosses.pkl" )
referenceGradients = load(f"{ CacheDirectory }/TestingGradients.pkl")
y_true = torch.from_numpy(y_true).to(Device)
# We only need gradients for y_pred; y_true is a constant.
y_pred = torch.from_numpy(y_pred).to(Device)
y_pred.requires_grad = True
# optionally choose a pairwise distance function
#distanceMetric = pysdtw.distance.pairwise_l2_squared
distanceMetric = pysdtw.distance.pairwise_l2_squared_exact
g = {}
GammaValues = [1.0, 0.1, 0.01]
for gamma in GammaValues:
# create the SoftDTW distance function
sdtw = pysdtw.SoftDTW(gamma = gamma, dist_func = distanceMetric, use_cuda = UseCUDA)
#sdtw = nn.MSELoss()
#sdtw = L1LossFunction()
# This is still "attached" to the graph, so apparently carries around all the intermediate calculations etc.
# We need the attached version for the gradients
lossesPerSequenceGraph = sdtw(y_true, y_pred)
# Have to call sum before running the backwards pass.
lossesPerSequenceGraph.sum().backward()
# We only want the gradient on y_pred; y_true is constant and thus doesn't affect them.
torchGradients = y_pred.grad
# Be very careful here about the order in which things are detached!
# The copy seems necessary - otherwise the array ends up referencing junk on the heap.
torchGradients = np.copy(torchGradients.detach().cpu().numpy())
lossesPerSequence = lossesPerSequenceGraph.detach().numpy()
summedLoss = lossesPerSequence.sum()
# Check against the version previousl calculated _or_ created through Keras.
assert np.allclose(summedLoss, referenceLosses [gamma]), "Tests failed - losses differ"
g[gamma] = torchGradients
# We MUST zero out the gradients before the next iteration.
# Unlike TensorFlow, Torch does not handle this througha context manager.
y_pred.grad.zero_()
# We need a slightly larger tolerance here
assert np.allclose(torchGradients, referenceGradients[gamma], rtol = 1e-4), "Tests failed - gradients differ"
_break = "hold" # set breakpoint here so that results don't get swallowed up
allAlignmentMatrices = pysdtw.alignmentMatrices
for gamma in GammaValues:
matrices = allAlignmentMatrices[np.float32(gamma)]
matrices = np.array(matrices)
save3DCSV(f"{ CacheDirectory }/tonison - gamma is {gamma}.csv", matrices)
_break = "hold" # set breakpoint here so that results don't get swallowed up
#input()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment