Last active
May 22, 2024 05:56
-
-
Save buttercutter/b3331ca1fd9e2f5871b0eded6b758f39 to your computer and use it in GitHub Desktop.
Mamba: Linear-Time Sequence Modeling with Selective State Spaces
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import DataLoader, Dataset | |
from torch.nn import functional as F | |
from einops import rearrange, repeat | |
from tqdm import tqdm | |
import math | |
import os | |
import urllib.request | |
from zipfile import ZipFile | |
from transformers import AutoTokenizer | |
torch.autograd.set_detect_anomaly(True) | |
debugging_is_on = 0 | |
def print_tensor_info(tensor_name, tensor): | |
# Check if tensor is floating point, and convert if necessary | |
tensor_float = tensor.float() if not tensor.is_floating_point() else tensor | |
# Gather the information | |
info = { | |
"shape": tuple(tensor.shape), | |
"min/max": (tensor.min().item(), tensor.max().item()), | |
"mean": tensor_float.mean().item(), | |
"std": tensor_float.std().item() | |
} | |
# Print the default representation and the extra information | |
print(f"{tensor_name} = {tensor}") | |
for key, value in info.items(): | |
print(f"{key}: {value}") | |
USE_MAMBA = 1 | |
USE_TRANSFORMER = ~USE_MAMBA | |
DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM = 0 | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# User hyperparameters | |
d_model = 16 | |
state_size = 64 # Example state size | |
seq_len = 100 # Example sequence length | |
batch_size = 128 # Example batch size | |
class S6(nn.Module): | |
def __init__(self, seq_len, d_model, state_size, device): | |
super(S6, self).__init__() | |
self.fc1 = nn.Linear(d_model, d_model, device=device) | |
self.fc2 = nn.Linear(d_model, state_size, device=device) | |
self.fc3 = nn.Linear(d_model, state_size, device=device) | |
self.seq_len = seq_len | |
self.d_model = d_model | |
self.state_size = state_size | |
#self.A = nn.Parameter(torch.ones(d_model, state_size, device=device)) | |
#self.A = nn.Parameter(F.normalize(torch.ones(d_model, state_size, device=device), p=2, dim=-1)) | |
#nn.init.xavier_uniform_(self.A) | |
# S4D real initialization, MAMBA removed imaginary portions for S4D-Inv and S4D-Lin initialization schemes | |
# described in [On the Parameterization and Initialization of Diagonal State Space Models](https://arxiv.org/abs/2206.11893) | |
# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/modules/mamba_simple.py#L103-L108C23 | |
A = repeat( | |
torch.arange(1, state_size + 1, dtype=torch.float32, device=device), | |
"n -> d n", | |
d=d_model, | |
).contiguous() | |
A_log = torch.log(A) # For numerical stability during training process | |
self.A_log = nn.Parameter(A_log) | |
self.A_log._no_weight_decay = True | |
self.A = torch.zeros_like(self.A_log) | |
self.B = torch.zeros(batch_size, self.seq_len, self.state_size, device=device) | |
self.C = torch.zeros(batch_size, self.seq_len, self.state_size, device=device) | |
#self.delta = torch.zeros(batch_size, self.seq_len, self.d_model, device=device) | |
# Initialize delta parameter using a uniform distribution and apply the inverse softplus | |
uniform_distribution = torch.distributions.Uniform(0.001, 0.1) | |
# Sample from the uniform distribution and then apply the inverse softplus | |
self.delta = self.inverse_softplus(uniform_distribution.sample((batch_size, self.seq_len, self.d_model))) | |
self.dA = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device) | |
self.dB = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device) | |
# h should have dimensions [batch_size, seq_len, d_model, state_size] | |
self.h = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device) | |
self.y = torch.zeros(batch_size, self.seq_len, self.d_model, device=device) | |
def inverse_softplus(self, y): | |
return torch.log(torch.exp(y) - 1) | |
def discretization(self): | |
# discretization function is defined based on the MAMBA paper's description using ZOH on page 28 | |
# in Section C : Mechanics on Selective SSMs | |
# See also "Zero-order hold discretization" maths proof inside https://studywolf.wordpress.com/tag/zero-order-hold/ | |
""" | |
Here is an explanation of the mathematical rationale for the formulation of Δt used in Mamba: | |
The key idea is that Δt controls the discretization rate of the continuous SSM dynamics. By making Δt input-dependent, it introduces selectivity into the discrete transition matrices. | |
Specifically, in Mamba they parameterize Δt as: | |
Δt = τΔ(Parameter + sΔ(xt)) | |
Where: | |
- Parameter is a learned scalar parameter that controls the baseline discretization rate | |
- sΔ(xt) is a projection that makes Δt input-dependent by computing a value based on xt | |
- τΔ(x) = softplus(x) transforms the result to be positive through the softplus nonlinearity | |
The rationale for this formulation is: | |
- Parameter provides a reasonable default discretization rate | |
- sΔ(xt) injects input-dependence through the projection | |
- softplus ensures Δt is positive as required to be a valid timestep | |
- The projection sΔ allows the model to learn to modulate Δt based on the input xt | |
- This modulation creates selectivity in how rapidly or slowly the states update | |
So in summary, the learned input-dependent projection allows Δt, and thus the discrete dynamics, to become selective. The softplus and scalar parameter provide useful inductive biases on top of this flexibility. | |
The end result is discrete transition matrices that are selective on the input, enabling powerful sequence modeling capabilities. | |
Credit: Claude2 AI chatbot | |
""" | |
# For numerical stability during training process | |
self.A = -torch.exp(self.A_log.float()) # (d_model, state_size) | |
#print(f"self.A.shape = {self.A.shape}") | |
#print(f"self.B.shape = {self.B.shape}") | |
#print(f"self.delta.shape = {self.delta.shape}") | |
# inverse() only supports square matrix | |
#dB = torch.matmul(torch.inverse(A * delta), torch.matmul(dA - torch.eye(A.shape[0]), B)) | |
self.dB = torch.einsum("bld,bln->bldn", self.delta, self.B) | |
# https://github.com/state-spaces/mamba/blob/0131c1e94a46fc9f70bcfc9d57962963bb2f0b9e/mamba_ssm/modules/mamba_simple.py#L240 | |
#dA = torch.matrix_exp(A * delta) # matrix_exp() only supports square matrix | |
self.dA = torch.exp(torch.einsum("bld,dn->bldn", self.delta, self.A)) | |
#print(f"self.dA.shape = {self.dA.shape}") | |
#print(f"self.dA.requires_grad = {self.dA.requires_grad}") | |
return self.dA, self.dB | |
def forward(self, x): | |
# Refer to Algorithm 2 in the MAMBA paper | |
self.B = self.fc2(x) | |
self.C = self.fc3(x) | |
# "a large ∆ resets the state `h` and focuses on the current input `x`, | |
# while a small ∆ persists the state and ignores the current input." | |
self.delta = F.softplus(self.fc1(x)) | |
# Uses ZOH as in MAMBA, Hungry Hippo still uses bilinear transform for discretization | |
self.discretization() | |
if DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM: # this will trigger in-place runtime error if without using `h_new` | |
#print(f"self.dA = {self.dA}, self.dB = {self.dB}") | |
#print(f"self.dA.shape = {self.dA.shape}") | |
#print(f"self.dB.shape = {self.dB.shape}") | |
#print(f"x.shape = {x.shape}") | |
#print(f"self.h.shape = {self.h.shape}") | |
#print(f"self.C.shape = {self.C.shape}") | |
global current_batch_size | |
current_batch_size = x.shape[0] | |
if self.h.shape[0] != current_batch_size: | |
#print("Adjusting h_new for the different batch size of input data `x`") | |
different_batch_size = True | |
# Resize self.h to match the current batch size | |
h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h[:current_batch_size, ...]) + rearrange(x, "b l d -> b l d 1") * self.dB | |
else: | |
different_batch_size = False | |
h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h) + rearrange(x, "b l d -> b l d 1") * self.dB | |
# y needs to have a shape of [batch_size, seq_len, d_model] | |
self.y = torch.einsum('bln,bldn->bld', self.C, h_new) | |
# Update self.h with the detached state of h_new | |
# Only do this if retaining gradients for self.h is not necessary for backprop | |
# Otherwise, store h_new in a temporary list and update self.h after the loop | |
global temp_buffer | |
temp_buffer = h_new.detach().clone() if not self.h.requires_grad else h_new.clone() | |
#print(f"temp_buffer.shape = {temp_buffer.shape}") | |
#print(f"self.y = {self.y}") | |
#print(f"self.dA.requires_grad = {self.dA.requires_grad}") | |
#print(f"self.dB.requires_grad = {self.dB.requires_grad}") | |
#print(f"self.C.requires_grad = {self.C.requires_grad}") | |
#print(f"self.h.requires_grad = {self.h.requires_grad}") | |
#print(f"self.y.requires_grad = {self.y.requires_grad}") | |
return self.y | |
else: # this will not trigger in-place runtime error | |
# h should have dimensions [batch_size, seq_len, d_model, state_size] | |
h = torch.zeros(x.size(0), self.seq_len, self.d_model, self.state_size, device=x.device) | |
y = torch.zeros_like(x) | |
h = torch.einsum('bldn,bldn->bldn', self.dA, h) + rearrange(x, "b l d -> b l d 1") * self.dB | |
# y needs to have a shape of [batch_size, seq_len, d_model] | |
y = torch.einsum('bln,bldn->bld', self.C, h) | |
return y | |
class MambaBlock(nn.Module): | |
def __init__(self, seq_len, d_model, state_size, device): | |
super(MambaBlock, self).__init__() | |
self.inp_proj = nn.Linear(d_model, 2*d_model, device=device) | |
self.out_proj = nn.Linear(2*d_model, d_model, device=device) | |
# For residual skip connection | |
self.D = nn.Linear(d_model, 2*d_model, device=device) | |
# Set _no_weight_decay attribute on bias | |
self.out_proj.bias._no_weight_decay = True | |
# Initialize bias to a small constant value | |
nn.init.constant_(self.out_proj.bias, 1.0) | |
self.S6 = S6(seq_len, 2*d_model, state_size, device) | |
# Add 1D convolution with kernel size 3 | |
self.conv = nn.Conv1d(seq_len, seq_len, kernel_size=3, padding=1, device=device) | |
# rmsnorm | |
self.norm = RMSNorm(d_model, device=device) | |
def forward(self, x, attention_mask=None): | |
if attention_mask is not None: | |
# Apply the attention mask | |
x = x * attention_mask.unsqueeze(-1) | |
""" | |
x_proj.shape = torch.Size([batch_size, seq_len, 2*d_model]) | |
x_conv.shape = torch.Size([batch_size, seq_len, 2*d_model]) | |
x_conv_act.shape = torch.Size([batch_size, seq_len, 2*d_model]) | |
""" | |
# Refer to Figure 3 in the MAMBA paper | |
x = self.norm(x) | |
x_proj = self.inp_proj(x) | |
#print(f"x_proj.shape = {x_proj.shape}") | |
# Add 1D convolution with kernel size 3 | |
x_conv = self.conv(x_proj) | |
# Create a triangular mask of the same shape as the input sequence | |
mask = torch.tril(torch.ones(seq_len, 2*d_model, device=device)) | |
# Add batch dimension with unsqueeze(0) -> (1, seq_len, seq_len) | |
# Repeat batch dim to match x_conv batches with .repeat() | |
current_batch_size = x.shape[0] | |
mask = mask.repeat(current_batch_size, 1, 1) | |
# Apply causal mask to zero out the masked regions | |
x_conv = x_conv * mask | |
#print(f"x_conv.shape = {x_conv.shape}") | |
x_conv_act = F.silu(x_conv) # Swish activation can be implemented as x * sigmoid(x) | |
#print(f"x_conv_act.shape = {x_conv_act.shape}") | |
x_ssm = self.S6(x_conv_act) | |
#print(f"x_ssm.shape = {x_ssm.shape}") | |
# residual skip connection with nonlinearity introduced by multiplication | |
x_residual = F.silu(self.D(x)) | |
#print(f"x_residual.shape = {x_residual.shape}") | |
x_combined = x_ssm * x_residual | |
#print(f"x_combined.shape = {x_combined.shape}") | |
x_out = self.out_proj(x_combined) | |
#print(f"x_out.shape = {x_out.shape}") | |
return x_out | |
class Mamba(nn.Module): | |
def __init__(self, seq_len, d_model, state_size, vocab_size, device): | |
super(Mamba, self).__init__() | |
if vocab_size is None: | |
vocab_size = d_model | |
self.mamba_block1 = MambaBlock(seq_len, d_model, state_size, device) | |
self.mamba_block2 = MambaBlock(seq_len, d_model, state_size, device) | |
self.mamba_block3 = MambaBlock(seq_len, d_model, state_size, device) | |
self.final_proj = nn.Linear(d_model, vocab_size, device=device) | |
def forward(self, x, attention_mask=None): | |
x = self.mamba_block1(x, attention_mask) | |
x = self.mamba_block2(x, attention_mask) | |
x = self.mamba_block3(x, attention_mask) | |
x = self.final_proj(x) | |
return x | |
class RMSNorm(nn.Module): | |
def __init__(self, | |
d_model: int, | |
eps: float = 1e-5, | |
device: str ='cuda'): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(d_model, device=device)) | |
def forward(self, x): | |
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight | |
return output | |
# Example usage: | |
# Create a random input tensor | |
if USE_MAMBA: | |
x = torch.rand(batch_size, seq_len, d_model, device=device) | |
# Create the Mamba model | |
mamba = Mamba(seq_len, d_model, state_size, None, device) | |
# rmsnorm | |
norm = RMSNorm(d_model) | |
x = norm(x) | |
# Forward pass | |
test_output = mamba(x) | |
print(f"test_output.shape = {test_output.shape}") # Should be [batch_size, seq_len, d_model] | |
class Enwiki8Dataset(Dataset): | |
def __init__(self, data): | |
self.data = data | |
def __len__(self): | |
return len(self.data['encoded_inputs']) | |
def __getitem__(self, idx): | |
item = {key: val[idx].clone().detach() for key, val in self.data.items()} | |
return item | |
# Define a function for padding | |
def pad_sequences_3d(sequences, max_len=None, pad_value=0): | |
if sequences.ndim == 3: | |
# Assuming sequences is a tensor of shape (batch_size, seq_len, feature_size) | |
batch_size, seq_len, feature_size = sequences.shape | |
else: | |
# Assuming sequences is a tensor of shape (batch_size, seq_len) | |
batch_size, seq_len = sequences.shape | |
if max_len is None: | |
max_len = seq_len + 1 | |
if sequences.ndim == 3: | |
# Initialize padded_sequences with the pad_value | |
padded_sequences = torch.full((batch_size, max_len, feature_size), fill_value=pad_value, dtype=sequences.dtype, device=sequences.device) | |
# Pad each sequence to the max_len | |
padded_sequences[:, :seq_len, :] = sequences | |
else: | |
# Initialize padded_sequences with the pad_value | |
padded_sequences = torch.full((batch_size, max_len), fill_value=pad_value, dtype=sequences.dtype, device=sequences.device) | |
# Pad each sequence to the max_len | |
padded_sequences[:, :seq_len] = sequences | |
return padded_sequences | |
def train(model, tokenizer, data_loader, optimizer, criterion, device, max_grad_norm=1.0, DEBUGGING_IS_ON=False): | |
model.train() | |
total_loss = 0 | |
for batch in data_loader: | |
optimizer.zero_grad() | |
original_data = batch['input_ids'].clone().to(device) # data without downsized dimension | |
input_data = batch['encoded_inputs'].clone().to(device) # data with downsized dimension for Mamba model | |
attention_mask = batch['attention_mask'].clone().to(device) | |
# In most sequence modeling tasks, like language modeling, the target should be the next token | |
# in the sequence rather than the input token itself. | |
# This is because the model's goal is to predict the next word given the previous words. | |
# Shift the input data by one position to get the target, so that each target token | |
# is the next token following the input token. | |
target = original_data[:, 1:] | |
input_data = input_data[:, :-1] | |
#print("Before padding: ") | |
#print(f"target.shape = {target.shape}") | |
#print(f"input_data.shape = {input_data.shape}") | |
# Pad all the sequences in the batch: | |
input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id) | |
target = pad_sequences_3d(target, max_len=original_data.size(1), pad_value=tokenizer.pad_token_id) | |
#print("After padding: ") | |
#print(f"target.shape = {target.shape}") | |
#print(f"input_data.shape = {input_data.shape}") | |
# For Mamba model, it can only accept downsized `input_data` due to RAM memory restriction | |
# and already have a final_proj layer to upsize the `output` dimension to be the same as `target` | |
output = model(input_data, attention_mask) | |
#print(f"Output shape: {output.shape}") | |
#print(f"Target shape: {target.shape}") | |
loss = criterion(output.view(-1, vocab_size), target.view(-1)) | |
loss.backward(retain_graph=True) | |
# Clip gradients: gradients are modified in place | |
#torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) | |
for name, param in model.named_parameters(): | |
if 'out_proj.bias' not in name: | |
# clip weights but not bias for out_proj | |
torch.nn.utils.clip_grad_norm_(param, max_norm=max_grad_norm) | |
if DEBUGGING_IS_ON: | |
print("DEBUGGING IS ON !!!") | |
print_tensor_info("output", output) | |
print_tensor_info("target", target) | |
for name, parameter in model.named_parameters(): | |
if parameter.grad is not None: | |
print(f"{name} gradient: {parameter.grad.data.norm(2)}") | |
else: | |
print(f"{name} has no gradient") | |
if USE_MAMBA and DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM: | |
# update self.h from temp_buffer | |
#print(f"temp_buffer = {temp_buffer}") | |
#print(f"temp_buffer.shape = {temp_buffer.shape}") | |
#print(f"current_batch_size = {current_batch_size}") | |
model.S6.h[:current_batch_size, ...].copy_(temp_buffer) | |
optimizer.step() | |
total_loss += loss.item() | |
return total_loss / len(data_loader) | |
def evaluate(model, data_loader, criterion, device, DEBUGGING_IS_ON=False): | |
model.eval() | |
total_loss = 0 | |
with torch.no_grad(): | |
for batch in data_loader: | |
original_data = batch['input_ids'].clone().to(device) # data without downsized dimension | |
input_data = batch['encoded_inputs'].clone().detach().to(device) # data with downsized dimension for Mamba model | |
attention_mask = batch['attention_mask'].clone().detach().to(device) | |
# In most sequence modeling tasks, like language modeling, the target should be the next token | |
# in the sequence rather than the input token itself. | |
# This is because the model's goal is to predict the next word given the previous words. | |
# Shift the input data by one position to get the target, so that each target token | |
# is the next token following the input token. | |
target = original_data[:, 1:] | |
input_data = input_data[:, :-1] | |
#print("Before padding: ") | |
#print(f"target.shape = {target.shape}") | |
#print(f"input_data.shape = {input_data.shape}") | |
# Pad all the sequences in the batch: | |
input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id) | |
target = pad_sequences_3d(target, max_len=original_data.size(1), pad_value=tokenizer.pad_token_id) | |
#print("After padding: ") | |
#print(f"target.shape = {target.shape}") | |
#print(f"input_data.shape = {input_data.shape}") | |
# For Mamba model, it can only accept downsized `input_data` due to RAM memory restriction | |
# and already have a final_proj layer to upsize the `output` dimension to be the same as `target` | |
output = model(input_data, attention_mask) | |
#print(f"Output shape: {output.shape}") | |
#print(f"Target shape: {target.shape}") | |
loss = criterion(output.view(-1, vocab_size), target.view(-1)) | |
total_loss += loss.item() | |
if DEBUGGING_IS_ON: | |
print("DEBUGGING IS ON !!!") | |
print_tensor_info("output", output) | |
print_tensor_info("target", target) | |
return total_loss / len(data_loader) | |
def calculate_perplexity(loss): | |
return math.exp(loss) | |
def load_enwiki8_dataset(): | |
print(f"Download and extract enwiki8 data") | |
url = "http://mattmahoney.net/dc/enwik8.zip" | |
urllib.request.urlretrieve(url, "enwik8.zip") | |
with ZipFile("enwik8.zip") as f: | |
data = f.read("enwik8").decode("utf-8") | |
return data | |
# Tokenize and encode the dataset | |
def encode_dataset(tokenizer, text_data): | |
def batch_encode(tokenizer, text_data, batch_size=1000): | |
# Tokenize in batches | |
batched_input_ids = [] | |
for i in range(0, len(text_data), batch_size): | |
batch = text_data[i:i+batch_size] | |
inputs = tokenizer(batch, add_special_tokens=True, truncation=True, | |
padding='max_length', max_length=seq_len, | |
return_tensors='pt') | |
batched_input_ids.append(inputs['input_ids']) | |
return torch.cat(batched_input_ids) | |
# Assuming enwiki8_data is a list of sentences | |
input_ids = batch_encode(tokenizer, enwiki8_data) | |
# vocab_size is the number of unique tokens in the tokenizer's vocabulary | |
global vocab_size | |
vocab_size = len(tokenizer.vocab) # Note that for some tokenizers, we might access the vocab directly | |
print(f"vocab_size = {vocab_size}") | |
# Create an embedding layer | |
# embedding_dim is the size of the embedding vectors (MAMBA model's D) | |
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model) | |
# Pass `input_ids` through the embedding layer | |
# This will change `input_ids` from shape [B, L] to [B, L, D] | |
#encoded_inputs = embedding_layer(input_ids) ## this eats memory, so use batched_embedding_calls instead | |
def batch_embedding_calls(input_ids, embedding_layer, batch_size=256): | |
# Check if input_ids is already a tensor, if not convert it | |
if not isinstance(input_ids, torch.Tensor): | |
input_ids = torch.tensor(input_ids, dtype=torch.long) | |
# Calculate the number of batches needed | |
num_batches = math.ceil(input_ids.size(0) / batch_size) | |
# List to hold the output embeddings | |
output_embeddings = [] | |
# Process each batch | |
for i in range(num_batches): | |
# Calculate start and end indices for the current batch | |
start_idx = i * batch_size | |
end_idx = start_idx + batch_size | |
# Get the batch | |
input_id_batch = input_ids[start_idx:end_idx] | |
# Call the embedding layer | |
with torch.no_grad(): # No need gradients for this operation | |
batch_embeddings = embedding_layer(input_id_batch) | |
# Append the result to the list | |
output_embeddings.append(batch_embeddings) | |
# Concatenate the embeddings from each batch into a single tensor | |
all_embeddings = torch.cat(output_embeddings, dim=0) | |
return all_embeddings | |
# `input_ids` is a list or tensor of the input IDs and `embedding_layer` is model's embedding layer | |
if USE_MAMBA: | |
# Set `batch_size` to a value that works for memory constraints | |
# batch_embedding_calls() is very slow, not suitable to implement directly during forward pass | |
encoded_inputs = batch_embedding_calls(input_ids, embedding_layer, batch_size=1).float() | |
elif USE_TRANSFORMER: | |
encoded_inputs = input_ids.long() # Cast input_ids to long if necessary | |
attention_mask = (input_ids != tokenizer.pad_token_id).type(input_ids.dtype) | |
#print(f"attention_mask.shape = {attention_mask.shape}") | |
#print(f"encoded_inputs.shape = {encoded_inputs.shape}") | |
return encoded_inputs, attention_mask, input_ids | |
# Load a pretrained tokenizer | |
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
#tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') | |
# Use an existing special token as the padding token. | |
#tokenizer.pad_token = tokenizer.eos_token | |
# Assuming encoded_inputs is a preprocessed tensor of shape [num_samples, seq_len, d_model] | |
if USE_MAMBA: | |
encoded_inputs_file = 'encoded_inputs_mamba.pt' | |
elif USE_TRANSFORMER: | |
encoded_inputs_file = 'encoded_inputs_transformer.pt' | |
if os.path.exists(encoded_inputs_file): | |
print("Loading pre-tokenized data...") | |
encoded_inputs = torch.load(encoded_inputs_file) | |
else: | |
print("Tokenizing raw data...") | |
enwiki8_data = load_enwiki8_dataset() | |
encoded_inputs, attention_mask, input_ids = encode_dataset(tokenizer, enwiki8_data) | |
torch.save(encoded_inputs, encoded_inputs_file) | |
print(f"finished tokenizing data") | |
# Combine into a single dictionary | |
data = { | |
'input_ids': input_ids, | |
'encoded_inputs': encoded_inputs, | |
'attention_mask': attention_mask | |
} | |
# Split the data into train and validation sets | |
total_size = len(data['encoded_inputs']) | |
train_size = int(total_size * 0.8) | |
train_data = {key: val[:train_size] for key, val in data.items()} | |
val_data = {key: val[train_size:] for key, val in data.items()} | |
train_dataset = Enwiki8Dataset(train_data) | |
val_dataset = Enwiki8Dataset(val_data) | |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) | |
# Initialize the model | |
if USE_MAMBA: | |
model = Mamba(seq_len, d_model, state_size, vocab_size, device).to(device) | |
elif USE_TRANSFORMER: | |
from transformers import AutoModel | |
# Create TinyBert model instance | |
bert_model = AutoModel.from_pretrained("prajjwal1/bert-tiny").to(device) | |
print(f"bert_model.config.hidden_size = {bert_model.config.hidden_size}") | |
class NextTokenPredictor(nn.Module): | |
def __init__(self, bert_model, vocab_size): | |
super(NextTokenPredictor, self).__init__() | |
self.bert = bert_model | |
self.predictor = nn.Linear(bert_model.config.hidden_size, vocab_size) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.bert(input_ids, attention_mask=attention_mask) | |
sequence_output = outputs.last_hidden_state | |
prediction_scores = self.predictor(sequence_output) | |
return prediction_scores | |
model = NextTokenPredictor(bert_model, vocab_size).to(device) | |
# Define the loss function and optimizer | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.AdamW(model.parameters(), lr=5e-6) | |
# Training loop | |
num_epochs = 25 # Number of epochs to train for | |
for epoch in tqdm(range(num_epochs)): # loop over the dataset multiple times | |
train_loss = train(model, tokenizer, train_loader, optimizer, criterion, device, max_grad_norm=10.0, DEBUGGING_IS_ON=debugging_is_on) | |
val_loss = evaluate(model, val_loader, criterion, device, DEBUGGING_IS_ON=debugging_is_on) | |
val_perplexity = calculate_perplexity(val_loss) | |
print(f'Epoch: {epoch+1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation Perplexity: {val_perplexity:.4f}') | |
if train_loss < 0 or val_loss < 0: | |
debugging_is_on = 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
test_output.shape = torch.Size([256, 100, 8]) | |
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: | |
The secret `HF_TOKEN` does not exist in your Colab secrets. | |
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session. | |
You will be able to reuse this secret in all of your notebooks. | |
Please note that authentication is recommended but still optional to access public models or datasets. | |
warnings.warn( | |
tokenizer_config.json: 100% | |
28.0/28.0 [00:00<00:00, 474B/s] | |
config.json: 100% | |
570/570 [00:00<00:00, 15.4kB/s] | |
vocab.txt: 100% | |
232k/232k [00:00<00:00, 2.79MB/s] | |
tokenizer.json: 100% | |
466k/466k [00:00<00:00, 4.03MB/s] | |
Tokenizing raw data... | |
Download and extract enwiki8 data | |
vocab_size = 30522 | |
finished tokenizing data | |
4%|▍ | 1/25 [01:19<31:51, 79.63s/it] | |
Streaming output truncated to the last 5000 lines. | |
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9953], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949], | |
..., | |
[1.0123, 1.0031, 0.9993, ..., 1.0050, 1.0097, 0.9956], | |
[1.0247, 1.0138, 0.9932, ..., 1.0348, 1.0349, 0.9925], | |
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9953], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949], | |
..., | |
[1.0123, 1.0032, 0.9993, ..., 1.0050, 1.0097, 0.9956], | |
[1.0247, 1.0138, 0.9932, ..., 1.0348, 1.0349, 0.9925], | |
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]], | |
..., | |
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9953], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949], | |
..., | |
[1.0123, 1.0031, 0.9993, ..., 1.0050, 1.0097, 0.9956], | |
[1.0247, 1.0138, 0.9932, ..., 1.0348, 1.0349, 0.9925], | |
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9953], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949], | |
..., | |
[1.0123, 1.0032, 0.9993, ..., 1.0050, 1.0097, 0.9956], | |
[1.0247, 1.0138, 0.9932, ..., 1.0348, 1.0349, 0.9925], | |
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0064, 0.9967, ..., 1.0035, 1.0033, 0.9953], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949], | |
..., | |
[1.0123, 1.0032, 0.9992, ..., 1.0050, 1.0097, 0.9956], | |
[1.0247, 1.0138, 0.9932, ..., 1.0348, 1.0350, 0.9925], | |
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.963894784450531, 1.03498375415802) | |
mean: 1.0009433031082153 | |
std: 0.007903832010924816 | |
target = tensor([[[ 0.3699, 0.2790, 0.5842, ..., 1.5279, 0.0930, -1.2546], | |
[ 3.0019, -0.4078, -0.7056, ..., 1.3097, 0.8935, 1.5305], | |
[-0.0055, 1.4312, -0.2068, ..., 0.2403, 0.8108, -0.4160], | |
..., | |
[ 0.2078, 0.4916, -0.6117, ..., -0.0424, -0.4392, -1.6947], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 1.4162, 0.2253, -0.0672, ..., -0.9196, 0.7513, 0.9457], | |
[-0.7880, 0.3277, -0.4625, ..., 1.0912, 0.8847, 0.0261], | |
[-1.0726, -0.8486, 0.7417, ..., 0.2901, 0.5678, 0.3142], | |
..., | |
[ 0.1903, 0.7261, -1.3328, ..., -1.6171, 0.1211, -0.1400], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.4741, -0.1311, 2.4631, ..., 1.1667, 0.0434, 2.2398], | |
[-0.0729, -1.5737, 0.1047, ..., -1.7538, 0.6804, -0.4289], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
..., | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505], | |
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
[ 2.8159, 0.9576, -0.3607, ..., 2.3174, -0.3391, -0.0629], | |
..., | |
[-1.0850, -0.6438, 0.2743, ..., -1.1642, -0.8742, -0.2776], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.9667, -1.2356, -0.0224, ..., -0.2647, -0.8683, 1.7923], | |
[-2.1527, -0.0821, -0.2856, ..., 0.0990, 1.7970, 0.9253], | |
[ 2.0079, 0.2899, 0.1462, ..., 0.0117, 0.1548, 0.9144], | |
..., | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.8051, -1.0119, 0.1091, ..., 1.4222, 0.1646, 0.0119], | |
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505], | |
[ 1.6541, 0.4031, -0.3804, ..., 0.1305, 0.6855, -0.8260], | |
..., | |
[ 0.9954, 0.6389, 0.7271, ..., -0.3038, 0.5158, -1.5865], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.062912464141846, 3.9631638526916504) | |
mean: -0.039700813591480255 | |
std: 0.99717116355896 | |
mamba_block1.inp_proj.weight gradient: 6.656166988250334e-06 | |
mamba_block1.inp_proj.bias gradient: 9.812881216930691e-06 | |
mamba_block1.out_proj.weight gradient: 8.017977961571887e-05 | |
mamba_block1.out_proj.bias gradient: 0.0019019206520169973 | |
mamba_block1.D.weight gradient: 2.1973764887661673e-05 | |
mamba_block1.D.bias gradient: 2.58996915363241e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 2.925547278209706e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.201373030809918e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.4111941709415987e-05 | |
mamba_block1.S6.fc2.bias gradient: 2.2652086045127362e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.3704809134651441e-05 | |
mamba_block1.S6.fc3.bias gradient: 2.2110638383310288e-05 | |
mamba_block1.conv.weight gradient: 4.779999653692357e-05 | |
mamba_block1.conv.bias gradient: 8.194298061425798e-06 | |
mamba_block1.conv_linear.weight gradient: 1.5478439308935776e-05 | |
mamba_block1.conv_linear.bias gradient: 5.2600811613956466e-05 | |
mamba_block1.norm.weight gradient: 4.226156761433231e-06 | |
mamba_block2.inp_proj.weight gradient: 0.007911593653261662 | |
mamba_block2.inp_proj.bias gradient: 0.002791037317365408 | |
mamba_block2.out_proj.weight gradient: 0.008094603195786476 | |
mamba_block2.out_proj.bias gradient: 0.020823726430535316 | |
mamba_block2.D.weight gradient: 0.004664940293878317 | |
mamba_block2.D.bias gradient: 0.0016456706216558814 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0010565894190222025 | |
mamba_block2.S6.fc1.bias gradient: 0.0013118531787768006 | |
mamba_block2.S6.fc2.weight gradient: 0.0029101786203682423 | |
mamba_block2.S6.fc2.bias gradient: 0.003688430180773139 | |
mamba_block2.S6.fc3.weight gradient: 0.0027016273234039545 | |
mamba_block2.S6.fc3.bias gradient: 0.0033678270410746336 | |
mamba_block2.conv.weight gradient: 0.011834848672151566 | |
mamba_block2.conv.bias gradient: 0.0008222330361604691 | |
mamba_block2.conv_linear.weight gradient: 0.008866168558597565 | |
mamba_block2.conv_linear.bias gradient: 0.005379044450819492 | |
mamba_block2.norm.weight gradient: 0.0031838142313063145 | |
mamba_block3.inp_proj.weight gradient: 0.08878087252378464 | |
mamba_block3.inp_proj.bias gradient: 0.03128563240170479 | |
mamba_block3.out_proj.weight gradient: 0.04247088357806206 | |
mamba_block3.out_proj.bias gradient: 9.189469096781977e-08 | |
mamba_block3.D.weight gradient: 0.04408084601163864 | |
mamba_block3.D.bias gradient: 0.015569724142551422 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004259995650500059 | |
mamba_block3.S6.fc1.bias gradient: 0.004275763873010874 | |
mamba_block3.S6.fc2.weight gradient: 0.01696809008717537 | |
mamba_block3.S6.fc2.bias gradient: 0.01751020923256874 | |
mamba_block3.S6.fc3.weight gradient: 0.017805999144911766 | |
mamba_block3.S6.fc3.bias gradient: 0.018047377467155457 | |
mamba_block3.conv.weight gradient: 0.1599825918674469 | |
mamba_block3.conv.bias gradient: 0.015678398311138153 | |
mamba_block3.conv_linear.weight gradient: 0.0732274278998375 | |
mamba_block3.conv_linear.bias gradient: 0.04735743626952171 | |
mamba_block3.norm.weight gradient: 0.02320939488708973 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9953], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949], | |
..., | |
[1.0123, 1.0032, 0.9993, ..., 1.0050, 1.0097, 0.9956], | |
[1.0248, 1.0139, 0.9932, ..., 1.0349, 1.0351, 0.9925], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9952], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949], | |
..., | |
[1.0123, 1.0032, 0.9993, ..., 1.0050, 1.0097, 0.9956], | |
[1.0248, 1.0138, 0.9932, ..., 1.0349, 1.0351, 0.9925], | |
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9952], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949], | |
..., | |
[1.0123, 1.0031, 0.9993, ..., 1.0050, 1.0097, 0.9956], | |
[1.0248, 1.0138, 0.9932, ..., 1.0349, 1.0351, 0.9925], | |
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]], | |
..., | |
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9953], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949], | |
..., | |
[1.0123, 1.0031, 0.9993, ..., 1.0050, 1.0097, 0.9956], | |
[1.0248, 1.0139, 0.9932, ..., 1.0349, 1.0351, 0.9925], | |
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9953], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0123, 1.0032, 0.9993, ..., 1.0050, 1.0097, 0.9956], | |
[1.0248, 1.0138, 0.9932, ..., 1.0349, 1.0351, 0.9925], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9952], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949], | |
..., | |
[1.0123, 1.0032, 0.9993, ..., 1.0050, 1.0097, 0.9956], | |
[1.0248, 1.0138, 0.9932, ..., 1.0349, 1.0350, 0.9925], | |
[1.0150, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9638505578041077, 1.0351523160934448) | |
mean: 1.000943899154663 | |
std: 0.007913809269666672 | |
target = tensor([[[-8.3893e-01, -8.6427e-01, 4.2425e-01, ..., 5.8477e-01, | |
1.5457e+00, -4.3527e-01], | |
[-2.3506e+00, -9.3173e-01, -1.7008e-01, ..., -1.3117e+00, | |
1.3262e+00, -2.5985e-02], | |
[ 2.0079e+00, 2.8988e-01, 1.4619e-01, ..., 1.1677e-02, | |
1.5477e-01, 9.1439e-01], | |
..., | |
[-9.1471e-02, 1.2755e-01, 7.2934e-01, ..., 1.1558e+00, | |
-3.6694e-01, -2.0441e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 4.9737e-01, 7.9059e-01, -1.0028e+00, ..., 1.2410e+00, | |
-1.4670e+00, -1.0270e+00], | |
[ 1.6537e-02, 5.9942e-01, -1.0490e+00, ..., -1.0667e+00, | |
-1.8011e-01, -2.0437e-01], | |
[ 2.0079e+00, 2.8988e-01, 1.4619e-01, ..., 1.1677e-02, | |
1.5477e-01, 9.1439e-01], | |
..., | |
[ 6.8187e-01, 6.5033e-01, -7.8437e-02, ..., 1.2032e+00, | |
4.3723e-01, 5.0549e-02], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 9.2321e-01, 2.1170e-01, -5.1829e-01, ..., 7.5898e-01, | |
1.7760e+00, 9.7635e-01], | |
[-2.8159e-02, 8.7647e-01, 3.6170e-01, ..., -8.5379e-01, | |
5.3774e-01, -1.6134e+00], | |
[ 6.4561e-01, -1.7245e+00, -5.6855e-01, ..., -4.0166e-01, | |
-1.8768e+00, -1.1828e+00], | |
..., | |
[ 1.3604e+00, 8.3413e-01, 9.7125e-01, ..., -9.8477e-02, | |
-2.4212e-01, 6.4055e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
..., | |
[[ 6.4588e-01, -9.7711e-01, 1.4713e-01, ..., -1.7452e+00, | |
2.4286e-02, 1.4304e-02], | |
[ 4.1251e-01, 8.6066e-01, -2.1138e-01, ..., -5.0017e-03, | |
-4.6324e-02, -1.4117e+00], | |
[ 2.1325e+00, 6.8348e-02, 1.1581e+00, ..., 1.2571e+00, | |
4.6634e-01, -7.2127e-01], | |
..., | |
[ 2.1325e+00, 6.8348e-02, 1.1581e+00, ..., 1.2571e+00, | |
4.6634e-01, -7.2127e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01, | |
-3.2478e-01, -6.1987e-01], | |
[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01, | |
-3.2478e-01, -6.1987e-01], | |
[-1.5360e+00, -4.3045e-01, -2.2538e-02, ..., -8.2286e-01, | |
2.1251e-01, 1.5091e-01], | |
..., | |
[-3.8774e-01, 3.0244e-01, -1.1404e+00, ..., 2.0661e+00, | |
-6.1905e-01, -9.3546e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 8.4960e-01, -4.3197e-01, 5.5274e-01, ..., -2.7416e-01, | |
2.0447e+00, -5.1754e-01], | |
[ 6.3896e-02, 3.2472e-04, -1.2828e+00, ..., -1.0525e+00, | |
-1.3741e+00, -1.5745e+00], | |
[ 6.8187e-01, 6.5033e-01, -7.8437e-02, ..., 1.2032e+00, | |
4.3723e-01, 5.0549e-02], | |
..., | |
[ 1.8655e-01, -3.5074e-01, 6.4411e-02, ..., 9.5573e-01, | |
1.1114e+00, -1.9372e+00], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]]], device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.062912464141846, 4.1066203117370605) | |
mean: -0.044864702969789505 | |
std: 0.9986171722412109 | |
mamba_block1.inp_proj.weight gradient: 6.592638783331495e-06 | |
mamba_block1.inp_proj.bias gradient: 1.3115262845531106e-05 | |
mamba_block1.out_proj.weight gradient: 8.0315483501181e-05 | |
mamba_block1.out_proj.bias gradient: 0.001918964902870357 | |
mamba_block1.D.weight gradient: 2.1023452063673176e-05 | |
mamba_block1.D.bias gradient: 2.4453845981042832e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 2.9185762286942918e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.370540409581736e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.885054552985821e-05 | |
mamba_block1.S6.fc2.bias gradient: 2.9134898795746267e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.7872369426186197e-05 | |
mamba_block1.S6.fc3.bias gradient: 2.769949242065195e-05 | |
mamba_block1.conv.weight gradient: 4.985975465388037e-05 | |
mamba_block1.conv.bias gradient: 8.862216418492608e-06 | |
mamba_block1.conv_linear.weight gradient: 1.7178435882669874e-05 | |
mamba_block1.conv_linear.bias gradient: 5.5030737712513655e-05 | |
mamba_block1.norm.weight gradient: 6.020811724738451e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008193454705178738 | |
mamba_block2.inp_proj.bias gradient: 0.002890463685616851 | |
mamba_block2.out_proj.weight gradient: 0.00791140180081129 | |
mamba_block2.out_proj.bias gradient: 0.019019681960344315 | |
mamba_block2.D.weight gradient: 0.0049378108233213425 | |
mamba_block2.D.bias gradient: 0.0017419286305084825 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0010571378516033292 | |
mamba_block2.S6.fc1.bias gradient: 0.0013405927456915379 | |
mamba_block2.S6.fc2.weight gradient: 0.0029663534369319677 | |
mamba_block2.S6.fc2.bias gradient: 0.003960576374083757 | |
mamba_block2.S6.fc3.weight gradient: 0.002755584428086877 | |
mamba_block2.S6.fc3.bias gradient: 0.00362204248085618 | |
mamba_block2.conv.weight gradient: 0.011379457078874111 | |
mamba_block2.conv.bias gradient: 0.0008499903487972915 | |
mamba_block2.conv_linear.weight gradient: 0.009108995087444782 | |
mamba_block2.conv_linear.bias gradient: 0.005850357934832573 | |
mamba_block2.norm.weight gradient: 0.0032401597127318382 | |
mamba_block3.inp_proj.weight gradient: 0.08919057995080948 | |
mamba_block3.inp_proj.bias gradient: 0.03143063187599182 | |
mamba_block3.out_proj.weight gradient: 0.04368537664413452 | |
mamba_block3.out_proj.bias gradient: 8.888643066029545e-08 | |
mamba_block3.D.weight gradient: 0.04513612762093544 | |
mamba_block3.D.bias gradient: 0.015944337472319603 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004324209876358509 | |
mamba_block3.S6.fc1.bias gradient: 0.004632828291505575 | |
mamba_block3.S6.fc2.weight gradient: 0.017361413687467575 | |
mamba_block3.S6.fc2.bias gradient: 0.016848554834723473 | |
mamba_block3.S6.fc3.weight gradient: 0.018281958997249603 | |
mamba_block3.S6.fc3.bias gradient: 0.017590906471014023 | |
mamba_block3.conv.weight gradient: 0.16209356486797333 | |
mamba_block3.conv.bias gradient: 0.015879524871706963 | |
mamba_block3.conv_linear.weight gradient: 0.0732278898358345 | |
mamba_block3.conv_linear.bias gradient: 0.04939649626612663 | |
mamba_block3.norm.weight gradient: 0.022089680656790733 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949], | |
..., | |
[1.0123, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0249, 1.0138, 0.9933, ..., 1.0351, 1.0353, 0.9925], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9949], | |
..., | |
[1.0123, 1.0031, 0.9993, ..., 1.0050, 1.0097, 0.9956], | |
[1.0249, 1.0138, 0.9932, ..., 1.0351, 1.0353, 0.9925], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9953], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0123, 1.0031, 0.9993, ..., 1.0050, 1.0097, 0.9956], | |
[1.0249, 1.0139, 0.9932, ..., 1.0350, 1.0352, 0.9925], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]], | |
..., | |
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0064, 0.9966, ..., 1.0035, 1.0033, 0.9952], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0123, 1.0032, 0.9992, ..., 1.0050, 1.0098, 0.9956], | |
[1.0249, 1.0138, 0.9932, ..., 1.0351, 1.0353, 0.9925], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0033, 0.9952], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0249, 1.0139, 0.9932, ..., 1.0350, 1.0353, 0.9925], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0033, 0.9952], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0249, 1.0138, 0.9932, ..., 1.0351, 1.0353, 0.9925], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9965]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9638057351112366, 1.035346508026123) | |
mean: 1.0009446144104004 | |
std: 0.00792401097714901 | |
target = tensor([[[-1.1622e-01, 1.4332e+00, 8.4441e-01, ..., 6.9435e-04, | |
-6.6773e-02, 1.0834e-01], | |
[ 5.7104e-01, -4.6999e-01, 1.1255e+00, ..., -8.9141e-01, | |
1.4730e+00, -9.9213e-02], | |
[ 1.0452e+00, 7.1647e-01, 6.4485e-02, ..., 1.4146e-01, | |
1.8992e-01, -1.2258e+00], | |
..., | |
[-4.3337e-01, -1.1911e-01, 1.6830e+00, ..., 1.7715e+00, | |
2.0065e-01, -1.6473e+00], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 5.0676e-01, -9.3567e-01, 5.5966e-01, ..., -2.6422e-01, | |
1.4782e+00, 2.3104e+00], | |
[ 1.4765e+00, -7.6907e-01, 4.0878e-01, ..., 9.7170e-01, | |
7.2011e-01, 5.5136e-01], | |
[ 6.6372e-01, 1.7408e-01, -3.5581e-01, ..., -1.4354e+00, | |
1.1672e+00, 4.4820e-02], | |
..., | |
[ 6.6372e-01, 1.7408e-01, -3.5581e-01, ..., -1.4354e+00, | |
1.1672e+00, 4.4820e-02], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[-3.9287e-01, 1.2244e+00, 1.4819e+00, ..., -8.4328e-01, | |
-1.3749e+00, -6.7026e-01], | |
[ 4.9737e-01, 7.9059e-01, -1.0028e+00, ..., 1.2410e+00, | |
-1.4670e+00, -1.0270e+00], | |
[ 1.1287e+00, -4.2922e-01, -6.2596e-01, ..., 4.3149e-03, | |
-1.7797e+00, -1.4768e+00], | |
..., | |
[ 6.8187e-01, 6.5033e-01, -7.8437e-02, ..., 1.2032e+00, | |
4.3723e-01, 5.0549e-02], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
..., | |
[[-5.5460e-01, -1.8933e-02, 8.0374e-01, ..., 9.7693e-01, | |
4.5635e-01, -1.4246e+00], | |
[ 1.5424e+00, -8.6155e-01, -1.6940e+00, ..., -1.3017e+00, | |
-4.4700e-01, -1.3483e+00], | |
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02, | |
-6.2687e-01, -4.9886e-01], | |
..., | |
[-4.3280e-01, 4.2621e-01, -7.8516e-01, ..., 3.9015e-01, | |
8.2322e-01, -1.1738e+00], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[-6.0796e-03, 7.3052e-02, 1.9578e-01, ..., -5.9691e-01, | |
-9.9734e-01, -2.2435e+00], | |
[-3.7331e-01, 1.8360e+00, -1.2402e+00, ..., 1.2983e+00, | |
-6.1130e-01, -2.7833e-01], | |
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02, | |
-6.2687e-01, -4.9886e-01], | |
..., | |
[-3.0289e-01, -7.8487e-01, 6.5365e-01, ..., 2.1631e-02, | |
-5.1024e-02, 1.3417e+00], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02, | |
-6.2687e-01, -4.9886e-01], | |
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02, | |
-6.2687e-01, -4.9886e-01], | |
[ 1.0189e-01, -3.6331e-01, 1.4970e+00, ..., -1.1070e+00, | |
1.8470e+00, 7.8808e-01], | |
..., | |
[-8.3236e-01, -4.5072e-01, 2.3980e-01, ..., 7.7698e-01, | |
-1.6973e+00, -1.6883e+00], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]]], device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-3.7171173095703125, 4.436890602111816) | |
mean: -0.04337029531598091 | |
std: 0.996184229850769 | |
mamba_block1.inp_proj.weight gradient: 8.927928320190404e-06 | |
mamba_block1.inp_proj.bias gradient: 1.3252603821456432e-05 | |
mamba_block1.out_proj.weight gradient: 6.981792103033513e-05 | |
mamba_block1.out_proj.bias gradient: 0.0018748645670711994 | |
mamba_block1.D.weight gradient: 1.9445253201411106e-05 | |
mamba_block1.D.bias gradient: 2.4798053345875815e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 2.8577871944435174e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.1480857362330426e-06 | |
mamba_block1.S6.fc2.weight gradient: 2.404704355285503e-05 | |
mamba_block1.S6.fc2.bias gradient: 3.7825808249181136e-05 | |
mamba_block1.S6.fc3.weight gradient: 2.3223990865517408e-05 | |
mamba_block1.S6.fc3.bias gradient: 3.674719846458174e-05 | |
mamba_block1.conv.weight gradient: 5.1158986025257036e-05 | |
mamba_block1.conv.bias gradient: 1.1530260053405073e-05 | |
mamba_block1.conv_linear.weight gradient: 1.925104697875213e-05 | |
mamba_block1.conv_linear.bias gradient: 6.093188130762428e-05 | |
mamba_block1.norm.weight gradient: 5.640730250888737e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008388090878725052 | |
mamba_block2.inp_proj.bias gradient: 0.00295907910913229 | |
mamba_block2.out_proj.weight gradient: 0.008359096944332123 | |
mamba_block2.out_proj.bias gradient: 0.018886201083660126 | |
mamba_block2.D.weight gradient: 0.005083959549665451 | |
mamba_block2.D.bias gradient: 0.0017934782663360238 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0011655916459858418 | |
mamba_block2.S6.fc1.bias gradient: 0.0014712277334183455 | |
mamba_block2.S6.fc2.weight gradient: 0.0031938902102410793 | |
mamba_block2.S6.fc2.bias gradient: 0.004056679084897041 | |
mamba_block2.S6.fc3.weight gradient: 0.0029626914765685797 | |
mamba_block2.S6.fc3.bias gradient: 0.0037037297151982784 | |
mamba_block2.conv.weight gradient: 0.01159916166216135 | |
mamba_block2.conv.bias gradient: 0.00081581249833107 | |
mamba_block2.conv_linear.weight gradient: 0.009459671564400196 | |
mamba_block2.conv_linear.bias gradient: 0.0063721355982124805 | |
mamba_block2.norm.weight gradient: 0.0032281361054629087 | |
mamba_block3.inp_proj.weight gradient: 0.09379231184720993 | |
mamba_block3.inp_proj.bias gradient: 0.03305402398109436 | |
mamba_block3.out_proj.weight gradient: 0.0430489145219326 | |
mamba_block3.out_proj.bias gradient: 5.539216729744112e-08 | |
mamba_block3.D.weight gradient: 0.03742552548646927 | |
mamba_block3.D.bias gradient: 0.013216960243880749 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.00466371001675725 | |
mamba_block3.S6.fc1.bias gradient: 0.005200730636715889 | |
mamba_block3.S6.fc2.weight gradient: 0.01829967088997364 | |
mamba_block3.S6.fc2.bias gradient: 0.018736790865659714 | |
mamba_block3.S6.fc3.weight gradient: 0.01904943399131298 | |
mamba_block3.S6.fc3.bias gradient: 0.019463086500763893 | |
mamba_block3.conv.weight gradient: 0.16161637008190155 | |
mamba_block3.conv.bias gradient: 0.01583053544163704 | |
mamba_block3.conv_linear.weight gradient: 0.07311218231916428 | |
mamba_block3.conv_linear.bias gradient: 0.043246448040008545 | |
mamba_block3.norm.weight gradient: 0.02208569645881653 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0250, 1.0138, 0.9933, ..., 1.0352, 1.0355, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9992, ..., 1.0050, 1.0098, 0.9956], | |
[1.0250, 1.0139, 0.9933, ..., 1.0352, 1.0354, 0.9926], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0064, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0097, 0.9956], | |
[1.0250, 1.0139, 0.9933, ..., 1.0351, 1.0354, 0.9925], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]], | |
..., | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0250, 1.0139, 0.9932, ..., 1.0352, 1.0354, 0.9925], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0031, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0250, 1.0139, 0.9933, ..., 1.0352, 1.0354, 0.9926], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0031, 0.9993, ..., 1.0050, 1.0097, 0.9956], | |
[1.0250, 1.0139, 0.9933, ..., 1.0351, 1.0354, 0.9925], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9637707471847534, 1.0355018377304077) | |
mean: 1.0009452104568481 | |
std: 0.00793476589024067 | |
target = tensor([[[ 1.2582e+00, 6.2747e-01, -1.9484e+00, ..., -7.7599e-01, | |
1.0496e+00, 5.3618e-01], | |
[-1.3360e+00, 1.4383e-01, 1.7031e+00, ..., -1.1077e+00, | |
8.4779e-01, -3.4812e-01], | |
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02, | |
-6.2687e-01, -4.9886e-01], | |
..., | |
[-1.0850e+00, -6.4378e-01, 2.7434e-01, ..., -1.1642e+00, | |
-8.7424e-01, -2.7755e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 1.1027e+00, 7.7860e-01, 1.2513e+00, ..., -2.4502e-01, | |
3.2866e-01, -1.6867e+00], | |
[ 2.5271e+00, 3.8280e-01, 4.4642e-01, ..., 1.7231e-01, | |
-5.7369e-01, 2.5980e+00], | |
[ 1.2286e+00, -1.6981e+00, 1.1041e-01, ..., -1.3688e+00, | |
-4.3559e-01, 2.2583e-01], | |
..., | |
[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01, | |
-3.2478e-01, -6.1987e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 5.0676e-01, -9.3567e-01, 5.5966e-01, ..., -2.6422e-01, | |
1.4782e+00, 2.3104e+00], | |
[-5.8164e-01, 7.9681e-03, 1.8231e+00, ..., -1.1851e+00, | |
4.1620e-01, -2.9570e-03], | |
[ 9.3353e-01, 1.8774e-01, -2.0042e+00, ..., -1.1503e+00, | |
-1.7980e+00, -5.6396e-01], | |
..., | |
[ 6.3467e-01, -6.0116e-01, 3.4803e-01, ..., 1.5082e+00, | |
-9.4524e-01, 2.0558e+00], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
..., | |
[[-4.1120e-01, 2.8092e-02, 5.1873e-01, ..., 6.3312e-01, | |
3.6938e-01, 1.5776e-01], | |
[ 2.9367e-01, 2.7959e+00, -1.3492e+00, ..., -1.4478e+00, | |
-5.1723e-01, 8.9243e-01], | |
[ 2.6099e-02, -1.0027e-01, -1.5132e+00, ..., -3.9709e-02, | |
1.1886e-01, 1.3587e+00], | |
..., | |
[-1.0028e+00, 1.3193e+00, 1.1326e+00, ..., 1.1135e+00, | |
-2.1063e+00, -1.4438e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 4.9737e-01, 7.9059e-01, -1.0028e+00, ..., 1.2410e+00, | |
-1.4670e+00, -1.0270e+00], | |
[-1.0945e-01, 1.1213e+00, 2.1538e-01, ..., -6.1082e-01, | |
1.7132e-01, -1.0861e+00], | |
[-1.6287e+00, -9.6626e-01, -2.6581e-03, ..., 1.2085e+00, | |
-9.0730e-01, 4.8345e-02], | |
..., | |
[-3.8774e-01, 3.0244e-01, -1.1404e+00, ..., 2.0661e+00, | |
-6.1905e-01, -9.3546e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 9.2342e-01, -9.4206e-01, 1.0451e+00, ..., -4.7812e-01, | |
4.2050e-02, 2.9336e-01], | |
[-4.0796e-01, 1.0457e+00, 1.2001e-02, ..., -1.2754e-01, | |
2.3795e+00, 1.2947e-01], | |
[-8.9576e-01, -1.3298e+00, 4.7374e-01, ..., -2.1709e-01, | |
-5.5530e-02, -4.8000e-01], | |
..., | |
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02, | |
-6.2687e-01, -4.9886e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]]], device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-3.8195555210113525, 4.436890602111816) | |
mean: -0.03474504500627518 | |
std: 0.9982456564903259 | |
mamba_block1.inp_proj.weight gradient: 6.786232461308828e-06 | |
mamba_block1.inp_proj.bias gradient: 1.2180601515865419e-05 | |
mamba_block1.out_proj.weight gradient: 6.749451131327078e-05 | |
mamba_block1.out_proj.bias gradient: 0.0017811341676861048 | |
mamba_block1.D.weight gradient: 1.8621594790602103e-05 | |
mamba_block1.D.bias gradient: 2.298756589880213e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 2.4946195935626747e-06 | |
mamba_block1.S6.fc1.bias gradient: 3.2899588404688984e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.3449737707560416e-05 | |
mamba_block1.S6.fc2.bias gradient: 2.101602149195969e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.3065160601399839e-05 | |
mamba_block1.S6.fc3.bias gradient: 2.0496405340963975e-05 | |
mamba_block1.conv.weight gradient: 4.914818055112846e-05 | |
mamba_block1.conv.bias gradient: 8.672555850353092e-06 | |
mamba_block1.conv_linear.weight gradient: 1.5568120943498798e-05 | |
mamba_block1.conv_linear.bias gradient: 4.648379763239063e-05 | |
mamba_block1.norm.weight gradient: 3.9241208469320554e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008072205819189548 | |
mamba_block2.inp_proj.bias gradient: 0.002847641473636031 | |
mamba_block2.out_proj.weight gradient: 0.008008691482245922 | |
mamba_block2.out_proj.bias gradient: 0.020124541595578194 | |
mamba_block2.D.weight gradient: 0.0048057264648377895 | |
mamba_block2.D.bias gradient: 0.0016953115118667483 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0010728761553764343 | |
mamba_block2.S6.fc1.bias gradient: 0.0013542993692681193 | |
mamba_block2.S6.fc2.weight gradient: 0.003033427754417062 | |
mamba_block2.S6.fc2.bias gradient: 0.004091166891157627 | |
mamba_block2.S6.fc3.weight gradient: 0.002817986300215125 | |
mamba_block2.S6.fc3.bias gradient: 0.003747538896277547 | |
mamba_block2.conv.weight gradient: 0.011370806023478508 | |
mamba_block2.conv.bias gradient: 0.0008210184169001877 | |
mamba_block2.conv_linear.weight gradient: 0.008997836150228977 | |
mamba_block2.conv_linear.bias gradient: 0.006177668925374746 | |
mamba_block2.norm.weight gradient: 0.003070503007620573 | |
mamba_block3.inp_proj.weight gradient: 0.09553761035203934 | |
mamba_block3.inp_proj.bias gradient: 0.033675868064165115 | |
mamba_block3.out_proj.weight gradient: 0.043427761644124985 | |
mamba_block3.out_proj.bias gradient: 6.981348121826159e-08 | |
mamba_block3.D.weight gradient: 0.044211406260728836 | |
mamba_block3.D.bias gradient: 0.01561672892421484 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.00466552609577775 | |
mamba_block3.S6.fc1.bias gradient: 0.005186178721487522 | |
mamba_block3.S6.fc2.weight gradient: 0.01796438917517662 | |
mamba_block3.S6.fc2.bias gradient: 0.020860623568296432 | |
mamba_block3.S6.fc3.weight gradient: 0.018770582973957062 | |
mamba_block3.S6.fc3.bias gradient: 0.02151082083582878 | |
mamba_block3.conv.weight gradient: 0.16399376094341278 | |
mamba_block3.conv.bias gradient: 0.015978528186678886 | |
mamba_block3.conv_linear.weight gradient: 0.07053054869174957 | |
mamba_block3.conv_linear.bias gradient: 0.04882459342479706 | |
mamba_block3.norm.weight gradient: 0.022343263030052185 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0251, 1.0139, 0.9933, ..., 1.0353, 1.0356, 0.9925], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0251, 1.0139, 0.9933, ..., 1.0353, 1.0356, 0.9926], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0251, 1.0139, 0.9933, ..., 1.0353, 1.0356, 0.9926], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]], | |
..., | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0251, 1.0139, 0.9933, ..., 1.0353, 1.0356, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9992, ..., 1.0050, 1.0098, 0.9956], | |
[1.0251, 1.0139, 0.9933, ..., 1.0353, 1.0356, 0.9926], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0031, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0251, 1.0139, 0.9933, ..., 1.0353, 1.0356, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9637242555618286, 1.0356793403625488) | |
mean: 1.0009459257125854 | |
std: 0.007945802994072437 | |
target = tensor([[[-1.0850e+00, -6.4378e-01, 2.7434e-01, ..., -1.1642e+00, | |
-8.7424e-01, -2.7755e-01], | |
[-3.2981e-01, -6.2568e-01, 7.4563e-01, ..., -2.8829e+00, | |
-2.6204e+00, 1.0786e+00], | |
[ 8.9414e-01, -2.4687e+00, 5.5291e-01, ..., 1.8136e-02, | |
2.4835e-01, 5.5237e-02], | |
..., | |
[ 4.7191e-01, 3.6167e-01, -3.5786e-01, ..., -3.8691e-01, | |
1.6128e+00, 2.4838e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[-1.6287e+00, -9.6626e-01, -2.6581e-03, ..., 1.2085e+00, | |
-9.0730e-01, 4.8345e-02], | |
[ 1.3781e-01, -3.8981e-01, 4.6194e-01, ..., 1.9883e-01, | |
-3.7158e-01, 3.5527e-01], | |
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02, | |
-6.2687e-01, -4.9886e-01], | |
..., | |
[-7.6430e-01, -1.8293e+00, 3.4729e-01, ..., -1.8000e-02, | |
-4.8519e-01, -4.4253e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[-2.2150e-02, -1.3493e+00, -6.6053e-01, ..., 9.5004e-01, | |
-2.8410e-01, 1.1236e-01], | |
[ 7.7266e-01, -1.2528e-01, -5.1251e-01, ..., -9.5071e-01, | |
1.0857e+00, 6.4368e-01], | |
[-2.7640e-01, 1.4894e+00, 1.4303e-01, ..., -2.2086e-01, | |
2.4025e+00, 8.1037e-01], | |
..., | |
[ 1.0189e-01, -3.6331e-01, 1.4970e+00, ..., -1.1070e+00, | |
1.8470e+00, 7.8808e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
..., | |
[[ 2.9367e-01, 2.7959e+00, -1.3492e+00, ..., -1.4478e+00, | |
-5.1723e-01, 8.9243e-01], | |
[-5.3807e-01, 1.8558e+00, -1.3125e+00, ..., -2.1141e+00, | |
-5.7919e-01, -8.3718e-02], | |
[-1.0806e-01, 8.7904e-01, 6.7809e-01, ..., -5.8664e-01, | |
-1.6239e-01, 4.4618e-01], | |
..., | |
[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01, | |
-3.2478e-01, -6.1987e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[-9.1818e-01, 9.8258e-02, -4.3746e-01, ..., 1.4176e-01, | |
-7.2111e-01, -1.5051e+00], | |
[-2.6851e-01, -6.7224e-01, -6.0742e-01, ..., 3.7681e-01, | |
-1.0639e+00, -1.6735e+00], | |
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02, | |
-6.2687e-01, -4.9886e-01], | |
..., | |
[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01, | |
-3.2478e-01, -6.1987e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 8.2675e-01, -7.5806e-01, -1.7703e+00, ..., -1.0994e+00, | |
5.3112e-02, -9.7968e-01], | |
[ 1.1250e+00, 1.1728e+00, 5.4517e-01, ..., -1.0478e+00, | |
4.3682e-01, 1.5019e+00], | |
[-1.8971e-01, 2.4852e-01, 1.0079e+00, ..., -1.3113e-01, | |
-7.4732e-01, 1.3381e+00], | |
..., | |
[ 2.4386e-01, -1.2999e-01, -1.2611e+00, ..., -5.8786e-01, | |
-1.3674e-02, -1.0314e+00], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]]], device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-3.7811129093170166, 3.9631638526916504) | |
mean: -0.037468329071998596 | |
std: 0.9972341060638428 | |
mamba_block1.inp_proj.weight gradient: 7.708286830165889e-06 | |
mamba_block1.inp_proj.bias gradient: 1.3073607078695204e-05 | |
mamba_block1.out_proj.weight gradient: 8.224073826568201e-05 | |
mamba_block1.out_proj.bias gradient: 0.0018965421477332711 | |
mamba_block1.D.weight gradient: 2.0551829948090017e-05 | |
mamba_block1.D.bias gradient: 2.6780233383760788e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.2340644793293905e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.6128470785333775e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.912422703753691e-05 | |
mamba_block1.S6.fc2.bias gradient: 3.044418008357752e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.824636638048105e-05 | |
mamba_block1.S6.fc3.bias gradient: 2.918856989708729e-05 | |
mamba_block1.conv.weight gradient: 4.986266867490485e-05 | |
mamba_block1.conv.bias gradient: 8.942976819525938e-06 | |
mamba_block1.conv_linear.weight gradient: 1.743154825817328e-05 | |
mamba_block1.conv_linear.bias gradient: 6.066836431273259e-05 | |
mamba_block1.norm.weight gradient: 6.70700228511123e-06 | |
mamba_block2.inp_proj.weight gradient: 0.007989523932337761 | |
mamba_block2.inp_proj.bias gradient: 0.0028184521943330765 | |
mamba_block2.out_proj.weight gradient: 0.008060919120907784 | |
mamba_block2.out_proj.bias gradient: 0.018983684480190277 | |
mamba_block2.D.weight gradient: 0.004776414018124342 | |
mamba_block2.D.bias gradient: 0.0016849525272846222 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0010103124659508467 | |
mamba_block2.S6.fc1.bias gradient: 0.0012939375592395663 | |
mamba_block2.S6.fc2.weight gradient: 0.002913458738476038 | |
mamba_block2.S6.fc2.bias gradient: 0.003868406405672431 | |
mamba_block2.S6.fc3.weight gradient: 0.0027066958136856556 | |
mamba_block2.S6.fc3.bias gradient: 0.00352851883508265 | |
mamba_block2.conv.weight gradient: 0.011430115438997746 | |
mamba_block2.conv.bias gradient: 0.0008375911857001483 | |
mamba_block2.conv_linear.weight gradient: 0.008923036977648735 | |
mamba_block2.conv_linear.bias gradient: 0.005675981752574444 | |
mamba_block2.norm.weight gradient: 0.003188737900927663 | |
mamba_block3.inp_proj.weight gradient: 0.09013670682907104 | |
mamba_block3.inp_proj.bias gradient: 0.03176071494817734 | |
mamba_block3.out_proj.weight gradient: 0.04090064764022827 | |
mamba_block3.out_proj.bias gradient: 8.47963974592858e-08 | |
mamba_block3.D.weight gradient: 0.04176799952983856 | |
mamba_block3.D.bias gradient: 0.014754664152860641 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.00422002375125885 | |
mamba_block3.S6.fc1.bias gradient: 0.0044675562530756 | |
mamba_block3.S6.fc2.weight gradient: 0.0172658059746027 | |
mamba_block3.S6.fc2.bias gradient: 0.019037457183003426 | |
mamba_block3.S6.fc3.weight gradient: 0.01806436851620674 | |
mamba_block3.S6.fc3.bias gradient: 0.019686652347445488 | |
mamba_block3.conv.weight gradient: 0.160682812333107 | |
mamba_block3.conv.bias gradient: 0.01587325893342495 | |
mamba_block3.conv_linear.weight gradient: 0.06943379342556 | |
mamba_block3.conv_linear.bias gradient: 0.05003580451011658 | |
mamba_block3.norm.weight gradient: 0.02130027487874031 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9992, ..., 1.0050, 1.0098, 0.9956], | |
[1.0252, 1.0139, 0.9933, ..., 1.0354, 1.0358, 0.9926], | |
[1.0151, 1.0008, 1.0022, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0252, 1.0138, 0.9933, ..., 1.0354, 1.0358, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0031, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0252, 1.0139, 0.9933, ..., 1.0354, 1.0358, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
..., | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0252, 1.0138, 0.9933, ..., 1.0354, 1.0358, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0029, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0252, 1.0139, 0.9933, ..., 1.0354, 1.0358, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0252, 1.0139, 0.9933, ..., 1.0355, 1.0358, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9636823534965515, 1.0358315706253052) | |
mean: 1.0009465217590332 | |
std: 0.007956922054290771 | |
target = tensor([[[-0.4988, -0.6918, -0.5971, ..., -1.3568, -0.3844, 0.6915], | |
[ 0.3156, -1.5684, -0.7855, ..., -0.0484, -0.9211, -0.2853], | |
[-0.3282, -0.4495, 0.3974, ..., 0.9546, -0.4394, -0.2031], | |
..., | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-1.0574, 0.2834, -0.0056, ..., 1.8479, 0.3408, -0.3568], | |
[ 0.2411, 1.3111, 0.5789, ..., -0.3322, 1.1244, -1.1123], | |
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
..., | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.0383, 0.7346, 1.1566, ..., -1.2615, -1.3133, 0.8579], | |
[-0.7210, 1.5826, 0.4122, ..., 0.3692, 1.2578, 0.0504], | |
[-0.0915, 0.1275, 0.7293, ..., 1.1558, -0.3669, -0.2044], | |
..., | |
[-2.0846, -0.2999, 0.0431, ..., -0.1785, 1.3174, 1.8029], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[ 0.1019, -0.3633, 1.4970, ..., -1.1070, 1.8470, 0.7881], | |
[-0.7269, -0.3677, -0.0543, ..., 0.4443, 0.2045, 0.0918], | |
[-0.9685, 0.8548, -0.1369, ..., 0.6784, 0.1392, -0.7722], | |
..., | |
[ 0.3352, 1.3790, -1.4903, ..., 0.1442, 0.8230, -0.7261], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 1.3890, 1.7612, -1.5885, ..., 0.6402, 1.1234, 0.6314], | |
[ 0.5255, -0.6071, -0.6983, ..., -0.1975, 0.2420, 0.5352], | |
[-0.3851, -1.0689, 0.9486, ..., -0.4575, 0.1463, -0.2335], | |
..., | |
[ 1.8217, 0.4270, 0.9168, ..., -0.5362, 0.0306, -0.0278], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.6884, 0.6784, -1.0951, ..., -1.3010, 0.1661, -0.5777], | |
[-0.6884, 0.6784, -1.0951, ..., -1.3010, 0.1661, -0.5777], | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
..., | |
[ 1.1470, 1.1976, -0.3732, ..., -0.1076, 2.3560, -0.8394], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.062912464141846, 3.9631638526916504) | |
mean: -0.036761801689863205 | |
std: 0.9940363168716431 | |
mamba_block1.inp_proj.weight gradient: 7.884400474722497e-06 | |
mamba_block1.inp_proj.bias gradient: 1.4247218132368289e-05 | |
mamba_block1.out_proj.weight gradient: 7.718842971371487e-05 | |
mamba_block1.out_proj.bias gradient: 0.00190842489246279 | |
mamba_block1.D.weight gradient: 2.390151894360315e-05 | |
mamba_block1.D.bias gradient: 2.5844841729849577e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.1179840789263835e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.432289642863907e-06 | |
mamba_block1.S6.fc2.weight gradient: 2.273116297146771e-05 | |
mamba_block1.S6.fc2.bias gradient: 3.4517765016062185e-05 | |
mamba_block1.S6.fc3.weight gradient: 2.1624193323077634e-05 | |
mamba_block1.S6.fc3.bias gradient: 3.2941166864475235e-05 | |
mamba_block1.conv.weight gradient: 4.8649115342414007e-05 | |
mamba_block1.conv.bias gradient: 1.0308258424629457e-05 | |
mamba_block1.conv_linear.weight gradient: 1.780458478606306e-05 | |
mamba_block1.conv_linear.bias gradient: 5.851773312315345e-05 | |
mamba_block1.norm.weight gradient: 6.214230779733043e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008193948306143284 | |
mamba_block2.inp_proj.bias gradient: 0.002890530275180936 | |
mamba_block2.out_proj.weight gradient: 0.008393185213208199 | |
mamba_block2.out_proj.bias gradient: 0.019226713106036186 | |
mamba_block2.D.weight gradient: 0.0050740428268909454 | |
mamba_block2.D.bias gradient: 0.00178993318695575 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.001113535021431744 | |
mamba_block2.S6.fc1.bias gradient: 0.0014299751492217183 | |
mamba_block2.S6.fc2.weight gradient: 0.003115260973572731 | |
mamba_block2.S6.fc2.bias gradient: 0.004271494224667549 | |
mamba_block2.S6.fc3.weight gradient: 0.002891583601012826 | |
mamba_block2.S6.fc3.bias gradient: 0.003908565733581781 | |
mamba_block2.conv.weight gradient: 0.011489537544548512 | |
mamba_block2.conv.bias gradient: 0.0008398296195082366 | |
mamba_block2.conv_linear.weight gradient: 0.009319150820374489 | |
mamba_block2.conv_linear.bias gradient: 0.006470021326094866 | |
mamba_block2.norm.weight gradient: 0.003238578559830785 | |
mamba_block3.inp_proj.weight gradient: 0.09595136344432831 | |
mamba_block3.inp_proj.bias gradient: 0.033816881477832794 | |
mamba_block3.out_proj.weight gradient: 0.044870056211948395 | |
mamba_block3.out_proj.bias gradient: 6.642654426514127e-08 | |
mamba_block3.D.weight gradient: 0.04112648218870163 | |
mamba_block3.D.bias gradient: 0.014526760205626488 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004705202765762806 | |
mamba_block3.S6.fc1.bias gradient: 0.0051134019158780575 | |
mamba_block3.S6.fc2.weight gradient: 0.01586112007498741 | |
mamba_block3.S6.fc2.bias gradient: 0.013769936747848988 | |
mamba_block3.S6.fc3.weight gradient: 0.016735132783651352 | |
mamba_block3.S6.fc3.bias gradient: 0.014533820562064648 | |
mamba_block3.conv.weight gradient: 0.16413094103336334 | |
mamba_block3.conv.bias gradient: 0.016221188008785248 | |
mamba_block3.conv_linear.weight gradient: 0.07475219666957855 | |
mamba_block3.conv_linear.bias gradient: 0.04689887911081314 | |
mamba_block3.norm.weight gradient: 0.0209010262042284 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0253, 1.0139, 0.9934, ..., 1.0355, 1.0359, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9956], | |
[1.0253, 1.0139, 0.9933, ..., 1.0356, 1.0360, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0253, 1.0139, 0.9934, ..., 1.0356, 1.0360, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
..., | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9956], | |
[1.0253, 1.0139, 0.9934, ..., 1.0356, 1.0359, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0052, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9956], | |
[1.0253, 1.0139, 0.9934, ..., 1.0355, 1.0359, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0031, 0.9993, ..., 1.0051, 1.0098, 0.9956], | |
[1.0253, 1.0139, 0.9934, ..., 1.0355, 1.0359, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9636375904083252, 1.0359997749328613) | |
mean: 1.00094735622406 | |
std: 0.007967358455061913 | |
target = tensor([[[-0.4072, -0.5642, 0.8817, ..., 0.7706, -0.4521, -0.3770], | |
[-0.6884, 0.6784, -1.0951, ..., -1.3010, 0.1661, -0.5777], | |
[-0.8324, -0.4507, 0.2398, ..., 0.7770, -1.6973, -1.6883], | |
..., | |
[-0.6713, -0.7524, -0.7726, ..., -0.4873, 0.0152, 1.0856], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.0967, -0.0082, -0.8923, ..., -1.5443, -0.6645, -0.7764], | |
[ 0.1676, 0.6545, -0.4603, ..., 0.9874, 1.3225, 0.1617], | |
[ 1.1250, 1.1728, 0.5452, ..., -1.0478, 0.4368, 1.5019], | |
..., | |
[ 0.3847, -0.3358, -0.6223, ..., -0.8391, 0.2528, 0.4785], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-1.1551, 0.3942, 1.8174, ..., -0.4964, -0.1678, 0.8586], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
..., | |
[ 0.0378, -0.2052, -0.5975, ..., -0.1757, 0.5491, -1.5124], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[-0.8834, -1.0507, -0.6479, ..., -0.5122, 0.4084, -0.2526], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
..., | |
[-0.9680, -1.2858, -1.1414, ..., -0.5307, -0.5660, -0.2579], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-1.7408, 2.1198, 0.1103, ..., 0.3524, 1.0912, -0.2684], | |
[ 1.0274, 0.0333, 0.1309, ..., -0.9998, -0.7694, 0.0330], | |
[ 1.9510, -1.3107, 1.1983, ..., 1.4472, -0.9458, 0.6607], | |
..., | |
[ 0.9335, 0.1877, -2.0042, ..., -1.1503, -1.7980, -0.5640], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-1.9022, -1.0460, -0.8212, ..., -0.6692, -0.9739, -0.3634], | |
[ 0.6606, 0.6995, -1.1284, ..., 0.8394, -0.4208, -0.3543], | |
[-1.4874, 0.0543, 1.0052, ..., 1.6346, -0.3576, 0.3655], | |
..., | |
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.0426435470581055, 3.9631638526916504) | |
mean: -0.041232164949178696 | |
std: 0.9984654784202576 | |
mamba_block1.inp_proj.weight gradient: 7.257808647409547e-06 | |
mamba_block1.inp_proj.bias gradient: 1.2536821486719418e-05 | |
mamba_block1.out_proj.weight gradient: 7.684003503527492e-05 | |
mamba_block1.out_proj.bias gradient: 0.0019513132283464074 | |
mamba_block1.D.weight gradient: 2.0352830688352697e-05 | |
mamba_block1.D.bias gradient: 2.5774028472369537e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 2.904744633269729e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.214916316414019e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.7458878573961556e-05 | |
mamba_block1.S6.fc2.bias gradient: 2.8126347388024442e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.7111855413531885e-05 | |
mamba_block1.S6.fc3.bias gradient: 2.7683992811944336e-05 | |
mamba_block1.conv.weight gradient: 4.96259490319062e-05 | |
mamba_block1.conv.bias gradient: 7.881315468694083e-06 | |
mamba_block1.conv_linear.weight gradient: 1.8037038898910396e-05 | |
mamba_block1.conv_linear.bias gradient: 5.335741661838256e-05 | |
mamba_block1.norm.weight gradient: 5.022363893658621e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008337417617440224 | |
mamba_block2.inp_proj.bias gradient: 0.0029411145951598883 | |
mamba_block2.out_proj.weight gradient: 0.00849019642919302 | |
mamba_block2.out_proj.bias gradient: 0.018002500757575035 | |
mamba_block2.D.weight gradient: 0.005273017566651106 | |
mamba_block2.D.bias gradient: 0.001860112533904612 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0011189623037353158 | |
mamba_block2.S6.fc1.bias gradient: 0.001455896534025669 | |
mamba_block2.S6.fc2.weight gradient: 0.0032008709385991096 | |
mamba_block2.S6.fc2.bias gradient: 0.004261344205588102 | |
mamba_block2.S6.fc3.weight gradient: 0.002974285278469324 | |
mamba_block2.S6.fc3.bias gradient: 0.003896206384524703 | |
mamba_block2.conv.weight gradient: 0.011668318882584572 | |
mamba_block2.conv.bias gradient: 0.0008345782989636064 | |
mamba_block2.conv_linear.weight gradient: 0.009310065768659115 | |
mamba_block2.conv_linear.bias gradient: 0.0065779616124928 | |
mamba_block2.norm.weight gradient: 0.003354104468598962 | |
mamba_block3.inp_proj.weight gradient: 0.09024792164564133 | |
mamba_block3.inp_proj.bias gradient: 0.03179505467414856 | |
mamba_block3.out_proj.weight gradient: 0.043840229511260986 | |
mamba_block3.out_proj.bias gradient: 7.639116006430413e-08 | |
mamba_block3.D.weight gradient: 0.041635192930698395 | |
mamba_block3.D.bias gradient: 0.014703379012644291 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004581212531775236 | |
mamba_block3.S6.fc1.bias gradient: 0.004985094536095858 | |
mamba_block3.S6.fc2.weight gradient: 0.016831787303090096 | |
mamba_block3.S6.fc2.bias gradient: 0.015875842422246933 | |
mamba_block3.S6.fc3.weight gradient: 0.017684169113636017 | |
mamba_block3.S6.fc3.bias gradient: 0.016509560868144035 | |
mamba_block3.conv.weight gradient: 0.16424451768398285 | |
mamba_block3.conv.bias gradient: 0.016224239021539688 | |
mamba_block3.conv_linear.weight gradient: 0.07360465824604034 | |
mamba_block3.conv_linear.bias gradient: 0.04532546550035477 | |
mamba_block3.norm.weight gradient: 0.02086501196026802 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9956], | |
[1.0254, 1.0139, 0.9934, ..., 1.0357, 1.0362, 0.9926], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9956], | |
[1.0254, 1.0139, 0.9934, ..., 1.0357, 1.0361, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0254, 1.0138, 0.9934, ..., 1.0357, 1.0362, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
..., | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0254, 1.0139, 0.9934, ..., 1.0357, 1.0361, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0050, 1.0098, 0.9956], | |
[1.0254, 1.0139, 0.9934, ..., 1.0357, 1.0361, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]], | |
[[1.0053, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9951], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9956], | |
[1.0254, 1.0139, 0.9934, ..., 1.0357, 1.0361, 0.9926], | |
[1.0151, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9964]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9636093378067017, 1.036202073097229) | |
mean: 1.0009483098983765 | |
std: 0.007977578788995743 | |
target = tensor([[[-1.1048, -0.1699, 0.3172, ..., 0.6925, 0.7191, 1.4389], | |
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505], | |
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
..., | |
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.0580, -0.5538, -0.9253, ..., 0.6467, 1.4621, -0.3138], | |
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583], | |
[-0.7571, -1.6050, -0.0124, ..., -0.9880, -0.9499, 0.8033], | |
..., | |
[ 1.7151, 1.0070, 0.6890, ..., -2.3825, -0.5136, 0.5498], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.6873, 0.3348, 0.9381, ..., -0.6962, 0.4933, -0.2609], | |
[-1.0015, -0.2747, -0.3922, ..., 0.6551, 0.1457, 1.8843], | |
[-0.2476, -0.3332, -0.2145, ..., -0.8714, 0.4179, -0.0367], | |
..., | |
[ 0.6761, -0.8634, -0.7832, ..., 0.2734, -0.3206, -0.2002], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[ 0.0383, 0.7346, 1.1566, ..., -1.2615, -1.3133, 0.8579], | |
[ 1.0334, -0.8128, -0.3230, ..., 0.2623, 2.1819, 0.4262], | |
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417], | |
..., | |
[-0.8324, -0.4507, 0.2398, ..., 0.7770, -1.6973, -1.6883], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.5255, -0.6071, -0.6983, ..., -0.1975, 0.2420, 0.5352], | |
[ 0.2451, -0.2897, 0.8116, ..., -0.1863, 0.8451, -1.3344], | |
[ 1.2114, 0.3140, -1.4007, ..., 1.1863, 0.0090, 0.1881], | |
..., | |
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.7789, 0.5709, 1.3162, ..., 0.9926, 0.0632, -1.1557], | |
[ 1.2286, -1.6981, 0.1104, ..., -1.3688, -0.4356, 0.2258], | |
[-0.8187, 1.7120, 1.2602, ..., 0.9032, -1.0293, -0.3666], | |
..., | |
[-1.2926, -0.1770, 0.0189, ..., 0.3937, -0.4130, 1.5345], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-3.895033836364746, 4.436890602111816) | |
mean: -0.04351688176393509 | |
std: 0.9954912662506104 | |
mamba_block1.inp_proj.weight gradient: 9.100136594497599e-06 | |
mamba_block1.inp_proj.bias gradient: 1.3978528841107618e-05 | |
mamba_block1.out_proj.weight gradient: 8.495710790157318e-05 | |
mamba_block1.out_proj.bias gradient: 0.0020170381758362055 | |
mamba_block1.D.weight gradient: 2.5585110051906668e-05 | |
mamba_block1.D.bias gradient: 2.804783252940979e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.4274560221092543e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.966110736859264e-06 | |
mamba_block1.S6.fc2.weight gradient: 2.0056253561051562e-05 | |
mamba_block1.S6.fc2.bias gradient: 3.251164889661595e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.9456374502624385e-05 | |
mamba_block1.S6.fc3.bias gradient: 3.162693974445574e-05 | |
mamba_block1.conv.weight gradient: 5.1763261581072584e-05 | |
mamba_block1.conv.bias gradient: 8.400884325965308e-06 | |
mamba_block1.conv_linear.weight gradient: 1.8744514818536118e-05 | |
mamba_block1.conv_linear.bias gradient: 6.390304770320654e-05 | |
mamba_block1.norm.weight gradient: 7.944043318275362e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008640581741929054 | |
mamba_block2.inp_proj.bias gradient: 0.00304804602637887 | |
mamba_block2.out_proj.weight gradient: 0.008689355105161667 | |
mamba_block2.out_proj.bias gradient: 0.020420508459210396 | |
mamba_block2.D.weight gradient: 0.005131551064550877 | |
mamba_block2.D.bias gradient: 0.0018101985333487391 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0011199985165148973 | |
mamba_block2.S6.fc1.bias gradient: 0.0014004047261551023 | |
mamba_block2.S6.fc2.weight gradient: 0.0031374646350741386 | |
mamba_block2.S6.fc2.bias gradient: 0.004075405187904835 | |
mamba_block2.S6.fc3.weight gradient: 0.002909448929131031 | |
mamba_block2.S6.fc3.bias gradient: 0.0037189736030995846 | |
mamba_block2.conv.weight gradient: 0.012041107751429081 | |
mamba_block2.conv.bias gradient: 0.0009172795107588172 | |
mamba_block2.conv_linear.weight gradient: 0.009620931930840015 | |
mamba_block2.conv_linear.bias gradient: 0.00613889517262578 | |
mamba_block2.norm.weight gradient: 0.0034102771896868944 | |
mamba_block3.inp_proj.weight gradient: 0.0934479609131813 | |
mamba_block3.inp_proj.bias gradient: 0.03292842581868172 | |
mamba_block3.out_proj.weight gradient: 0.044350285083055496 | |
mamba_block3.out_proj.bias gradient: 1.3198520321111573e-07 | |
mamba_block3.D.weight gradient: 0.04143233224749565 | |
mamba_block3.D.bias gradient: 0.01463618129491806 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004638255573809147 | |
mamba_block3.S6.fc1.bias gradient: 0.004756370093673468 | |
mamba_block3.S6.fc2.weight gradient: 0.01848497800529003 | |
mamba_block3.S6.fc2.bias gradient: 0.017832543700933456 | |
mamba_block3.S6.fc3.weight gradient: 0.01909303106367588 | |
mamba_block3.S6.fc3.bias gradient: 0.018696237355470657 | |
mamba_block3.conv.weight gradient: 0.1660355180501938 | |
mamba_block3.conv.bias gradient: 0.01661163568496704 | |
mamba_block3.conv_linear.weight gradient: 0.07300027459859848 | |
mamba_block3.conv_linear.bias gradient: 0.04526618868112564 | |
mamba_block3.norm.weight gradient: 0.023428840562701225 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0255, 1.0139, 0.9934, ..., 1.0358, 1.0363, 0.9926], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9964]], | |
[[1.0054, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0255, 1.0139, 0.9934, ..., 1.0358, 1.0363, 0.9926], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9964]], | |
[[1.0054, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0255, 1.0138, 0.9935, ..., 1.0358, 1.0363, 0.9926], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9964]], | |
..., | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0031, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0255, 1.0139, 0.9934, ..., 1.0358, 1.0363, 0.9926], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0255, 1.0139, 0.9934, ..., 1.0358, 1.0363, 0.9926], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0127, 0.9963]], | |
[[1.0054, 1.0070, 0.9965, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9956], | |
[1.0255, 1.0138, 0.9934, ..., 1.0358, 1.0363, 0.9926], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9964]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9635644555091858, 1.0363487005233765) | |
mean: 1.0009490251541138 | |
std: 0.007988139055669308 | |
target = tensor([[[ 0.1019, -0.3633, 1.4970, ..., -1.1070, 1.8470, 0.7881], | |
[ 0.4941, 0.0856, 0.3690, ..., 1.3915, 2.5161, 0.3218], | |
[ 1.7297, 0.7489, 0.7269, ..., -0.3836, 0.6932, 0.7111], | |
..., | |
[-1.6287, -0.9663, -0.0027, ..., 1.2085, -0.9073, 0.0483], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.5515, 0.7887, -1.5313, ..., 1.2504, 0.1500, -1.8818], | |
[ 0.8201, 1.6476, 0.4960, ..., -0.2201, 0.8857, 0.0669], | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
..., | |
[-0.8324, -0.4507, 0.2398, ..., 0.7770, -1.6973, -1.6883], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.3433, -0.6291, 0.8468, ..., -1.2711, -1.2323, -0.1769], | |
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505], | |
[-1.6287, -0.9663, -0.0027, ..., 1.2085, -0.9073, 0.0483], | |
..., | |
[ 0.5068, -0.9357, 0.5597, ..., -0.2642, 1.4782, 2.3104], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[ 0.0519, -1.6652, 0.4465, ..., 0.6680, 0.7076, -0.9326], | |
[ 1.4714, 0.0124, -0.2384, ..., -0.2375, 1.1155, 0.2285], | |
[ 0.1385, -0.2701, 0.1457, ..., 0.4512, -1.1078, -0.2718], | |
..., | |
[ 0.0368, 0.7456, -1.4815, ..., 0.9900, 1.4748, -0.2182], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.8941, -2.4687, 0.5529, ..., 0.0181, 0.2483, 0.0552], | |
[-1.1859, 0.6207, 1.1728, ..., 0.3623, 0.6124, 0.1387], | |
[-0.5582, -1.8473, -0.1892, ..., -1.3669, 1.0029, -0.2609], | |
..., | |
[-0.2494, 1.6643, 1.0550, ..., 0.0960, -0.4710, 0.4718], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.4882, 0.1704, -1.3309, ..., -1.1073, 0.2595, -0.9865], | |
[ 1.0589, -0.6273, -1.0979, ..., -1.3877, -0.8624, 0.4007], | |
[ 2.1325, 0.0683, 1.1581, ..., 1.2571, 0.4663, -0.7213], | |
..., | |
[ 0.5853, 0.2439, -0.6474, ..., 0.7711, -0.5776, 0.1493], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.203547954559326, 3.9631638526916504) | |
mean: -0.03762376680970192 | |
std: 0.998776376247406 | |
mamba_block1.inp_proj.weight gradient: 8.837168934405781e-06 | |
mamba_block1.inp_proj.bias gradient: 1.3773204045719467e-05 | |
mamba_block1.out_proj.weight gradient: 8.079419058049098e-05 | |
mamba_block1.out_proj.bias gradient: 0.001937799621373415 | |
mamba_block1.D.weight gradient: 2.2958665795158595e-05 | |
mamba_block1.D.bias gradient: 2.63482506852597e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.0723977033630945e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.301163698983146e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.544903170724865e-05 | |
mamba_block1.S6.fc2.bias gradient: 2.3952608898980543e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.4936680599930696e-05 | |
mamba_block1.S6.fc3.bias gradient: 2.318737460882403e-05 | |
mamba_block1.conv.weight gradient: 5.149357093614526e-05 | |
mamba_block1.conv.bias gradient: 8.877683285390958e-06 | |
mamba_block1.conv_linear.weight gradient: 1.7505970390629955e-05 | |
mamba_block1.conv_linear.bias gradient: 5.392322418629192e-05 | |
mamba_block1.norm.weight gradient: 4.662304036173737e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008749652653932571 | |
mamba_block2.inp_proj.bias gradient: 0.003086492419242859 | |
mamba_block2.out_proj.weight gradient: 0.008550377562642097 | |
mamba_block2.out_proj.bias gradient: 0.0215672105550766 | |
mamba_block2.D.weight gradient: 0.0049406080506742 | |
mamba_block2.D.bias gradient: 0.001742832944728434 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0011518856044858694 | |
mamba_block2.S6.fc1.bias gradient: 0.0014069563476368785 | |
mamba_block2.S6.fc2.weight gradient: 0.0031139287166297436 | |
mamba_block2.S6.fc2.bias gradient: 0.003985110204666853 | |
mamba_block2.S6.fc3.weight gradient: 0.002889266237616539 | |
mamba_block2.S6.fc3.bias gradient: 0.003641492687165737 | |
mamba_block2.conv.weight gradient: 0.011927427724003792 | |
mamba_block2.conv.bias gradient: 0.0008801336516626179 | |
mamba_block2.conv_linear.weight gradient: 0.009689133614301682 | |
mamba_block2.conv_linear.bias gradient: 0.005984927993267775 | |
mamba_block2.norm.weight gradient: 0.0032904285471886396 | |
mamba_block3.inp_proj.weight gradient: 0.09736455976963043 | |
mamba_block3.inp_proj.bias gradient: 0.034317947924137115 | |
mamba_block3.out_proj.weight gradient: 0.04285677894949913 | |
mamba_block3.out_proj.bias gradient: 4.898083716398105e-08 | |
mamba_block3.D.weight gradient: 0.0413573682308197 | |
mamba_block3.D.bias gradient: 0.014606538228690624 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004640404600650072 | |
mamba_block3.S6.fc1.bias gradient: 0.005210277624428272 | |
mamba_block3.S6.fc2.weight gradient: 0.017779408022761345 | |
mamba_block3.S6.fc2.bias gradient: 0.019206494092941284 | |
mamba_block3.S6.fc3.weight gradient: 0.018545418977737427 | |
mamba_block3.S6.fc3.bias gradient: 0.019921310245990753 | |
mamba_block3.conv.weight gradient: 0.16301649808883667 | |
mamba_block3.conv.bias gradient: 0.01633545197546482 | |
mamba_block3.conv_linear.weight gradient: 0.07117603719234467 | |
mamba_block3.conv_linear.bias gradient: 0.04827188700437546 | |
mamba_block3.norm.weight gradient: 0.023814095184206963 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0256, 1.0139, 0.9935, ..., 1.0359, 1.0365, 0.9926], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0256, 1.0140, 0.9934, ..., 1.0359, 1.0364, 0.9926], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0256, 1.0139, 0.9935, ..., 1.0359, 1.0364, 0.9926], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]], | |
..., | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0256, 1.0139, 0.9935, ..., 1.0359, 1.0365, 0.9926], | |
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0256, 1.0139, 0.9935, ..., 1.0359, 1.0365, 0.9926], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0031, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0256, 1.0139, 0.9935, ..., 1.0359, 1.0364, 0.9926], | |
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9635269641876221, 1.0365265607833862) | |
mean: 1.0009498596191406 | |
std: 0.007998868823051453 | |
target = tensor([[[-1.2912e-01, 2.7761e-01, 6.5007e-01, ..., 5.3681e-01, | |
1.4878e+00, -6.7947e-01], | |
[ 6.0857e-01, 1.4627e+00, 2.5454e-01, ..., 1.6538e+00, | |
-1.0191e+00, 8.4912e-01], | |
[ 4.9737e-01, 7.9059e-01, -1.0028e+00, ..., 1.2410e+00, | |
-1.4670e+00, -1.0270e+00], | |
..., | |
[-7.9787e-01, 6.7782e-01, 1.4350e-01, ..., 3.0334e-01, | |
6.2231e-01, -9.4687e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 1.5454e+00, 9.2277e-01, 3.0021e-01, ..., -8.3794e-01, | |
7.2716e-01, -1.8499e+00], | |
[ 1.3767e+00, 1.0128e-01, 3.5444e-01, ..., 7.6632e-02, | |
-1.6822e+00, -1.4354e+00], | |
[ 1.0189e-01, -3.6331e-01, 1.4970e+00, ..., -1.1070e+00, | |
1.8470e+00, 7.8808e-01], | |
..., | |
[-9.7426e-01, -8.7755e-01, 1.9398e-01, ..., -3.6643e-01, | |
1.9255e-03, 2.0825e+00], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 4.1016e-01, -3.8302e-03, -1.2295e-01, ..., -2.5800e-01, | |
1.4403e+00, -2.4625e-01], | |
[ 6.1667e-02, -2.4054e-02, 1.9664e+00, ..., -1.5273e+00, | |
2.2778e-01, -4.6371e-01], | |
[ 2.0079e+00, 2.8988e-01, 1.4619e-01, ..., 1.1677e-02, | |
1.5477e-01, 9.1439e-01], | |
..., | |
[ 5.0676e-01, -9.3567e-01, 5.5966e-01, ..., -2.6422e-01, | |
1.4782e+00, 2.3104e+00], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
..., | |
[[-9.1950e-02, -6.5555e-01, 1.6096e+00, ..., -1.5558e+00, | |
6.1454e-01, 1.4055e+00], | |
[-1.3817e+00, 1.9420e-01, -1.2593e+00, ..., 5.9164e-02, | |
4.6726e-01, 3.5826e-01], | |
[-1.4729e+00, -2.1761e+00, 9.2319e-01, ..., 4.0713e-01, | |
-1.6731e+00, 1.1180e+00], | |
..., | |
[-6.8272e-01, -2.8986e-01, -8.1461e-02, ..., 3.9673e-01, | |
2.5136e-01, 6.9517e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 1.1350e+00, -1.6268e+00, 1.5461e+00, ..., -1.8916e+00, | |
-1.9114e+00, 1.1798e+00], | |
[-5.3807e-01, 1.8558e+00, -1.3125e+00, ..., -2.1141e+00, | |
-5.7919e-01, -8.3718e-02], | |
[ 5.2551e-01, -6.0715e-01, -6.9834e-01, ..., -1.9748e-01, | |
2.4198e-01, 5.3519e-01], | |
..., | |
[-1.5331e-01, 5.9299e-01, 1.3224e-01, ..., -1.6126e+00, | |
-6.2350e-01, -1.3132e+00], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 3.8803e-01, -1.8462e-01, 3.3831e-01, ..., -6.0497e-01, | |
1.2007e-01, -1.0940e+00], | |
[-2.9351e-01, 9.6275e-01, -1.5990e+00, ..., -1.4156e+00, | |
-1.0206e+00, -1.0802e+00], | |
[ 6.3896e-02, 3.2472e-04, -1.2828e+00, ..., -1.0525e+00, | |
-1.3741e+00, -1.5745e+00], | |
..., | |
[ 1.2286e+00, -1.6981e+00, 1.1041e-01, ..., -1.3688e+00, | |
-4.3559e-01, 2.2583e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]]], device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.062912464141846, 3.9631638526916504) | |
mean: -0.03548859432339668 | |
std: 0.9919624924659729 | |
mamba_block1.inp_proj.weight gradient: 7.336026101256721e-06 | |
mamba_block1.inp_proj.bias gradient: 1.0591419595584739e-05 | |
mamba_block1.out_proj.weight gradient: 7.542931416537613e-05 | |
mamba_block1.out_proj.bias gradient: 0.00191353855188936 | |
mamba_block1.D.weight gradient: 1.9951721696997993e-05 | |
mamba_block1.D.bias gradient: 2.5469640604569577e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 2.71867838819162e-06 | |
mamba_block1.S6.fc1.bias gradient: 3.859155185637064e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.6416681319242343e-05 | |
mamba_block1.S6.fc2.bias gradient: 2.4024193407967687e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.577912007633131e-05 | |
mamba_block1.S6.fc3.bias gradient: 2.3290005628950894e-05 | |
mamba_block1.conv.weight gradient: 4.8139336286112666e-05 | |
mamba_block1.conv.bias gradient: 7.358869879681151e-06 | |
mamba_block1.conv_linear.weight gradient: 1.663083821767941e-05 | |
mamba_block1.conv_linear.bias gradient: 5.175057958695106e-05 | |
mamba_block1.norm.weight gradient: 5.831683665746823e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008543197065591812 | |
mamba_block2.inp_proj.bias gradient: 0.003013620851561427 | |
mamba_block2.out_proj.weight gradient: 0.008505250327289104 | |
mamba_block2.out_proj.bias gradient: 0.020350750535726547 | |
mamba_block2.D.weight gradient: 0.004964698106050491 | |
mamba_block2.D.bias gradient: 0.0017513002967461944 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.001128288102336228 | |
mamba_block2.S6.fc1.bias gradient: 0.0014180107973515987 | |
mamba_block2.S6.fc2.weight gradient: 0.0031073689460754395 | |
mamba_block2.S6.fc2.bias gradient: 0.004030835349112749 | |
mamba_block2.S6.fc3.weight gradient: 0.0028815357945859432 | |
mamba_block2.S6.fc3.bias gradient: 0.0036812785547226667 | |
mamba_block2.conv.weight gradient: 0.011601514182984829 | |
mamba_block2.conv.bias gradient: 0.0008414683397859335 | |
mamba_block2.conv_linear.weight gradient: 0.00955396518111229 | |
mamba_block2.conv_linear.bias gradient: 0.006414241157472134 | |
mamba_block2.norm.weight gradient: 0.0032158917747437954 | |
mamba_block3.inp_proj.weight gradient: 0.08929342031478882 | |
mamba_block3.inp_proj.bias gradient: 0.03145138546824455 | |
mamba_block3.out_proj.weight gradient: 0.04311970993876457 | |
mamba_block3.out_proj.bias gradient: 4.8321194157097125e-08 | |
mamba_block3.D.weight gradient: 0.04716562479734421 | |
mamba_block3.D.bias gradient: 0.01665896736085415 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004517058376222849 | |
mamba_block3.S6.fc1.bias gradient: 0.004935705102980137 | |
mamba_block3.S6.fc2.weight gradient: 0.019176244735717773 | |
mamba_block3.S6.fc2.bias gradient: 0.018631301820278168 | |
mamba_block3.S6.fc3.weight gradient: 0.019713152199983597 | |
mamba_block3.S6.fc3.bias gradient: 0.019606487825512886 | |
mamba_block3.conv.weight gradient: 0.1666284203529358 | |
mamba_block3.conv.bias gradient: 0.016464872285723686 | |
mamba_block3.conv_linear.weight gradient: 0.07397904992103577 | |
mamba_block3.conv_linear.bias gradient: 0.04956180974841118 | |
mamba_block3.norm.weight gradient: 0.0232034083455801 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0257, 1.0139, 0.9935, ..., 1.0360, 1.0366, 0.9927], | |
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0257, 1.0139, 0.9935, ..., 1.0361, 1.0366, 0.9927], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0257, 1.0139, 0.9935, ..., 1.0361, 1.0367, 0.9927], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]], | |
..., | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9956], | |
[1.0257, 1.0139, 0.9935, ..., 1.0361, 1.0366, 0.9927], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0257, 1.0139, 0.9935, ..., 1.0361, 1.0366, 0.9927], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0257, 1.0139, 0.9935, ..., 1.0361, 1.0367, 0.9927], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9634785056114197, 1.0367209911346436) | |
mean: 1.000950574874878 | |
std: 0.008009559474885464 | |
target = tensor([[[-1.6873, 1.0246, -0.1206, ..., 1.1435, -0.6533, -0.3542], | |
[ 1.1250, 1.1728, 0.5452, ..., -1.0478, 0.4368, 1.5019], | |
[-0.0696, -0.1048, 0.0242, ..., 0.4035, -0.3938, -1.4395], | |
..., | |
[-0.1848, -1.6540, -1.1411, ..., -0.4016, -0.8012, 2.9020], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.1664, 0.1732, -0.6635, ..., -0.2567, -0.0699, -0.5926], | |
[-0.8324, -0.4507, 0.2398, ..., 0.7770, -1.6973, -1.6883], | |
[-1.9692, 0.7451, 1.1040, ..., 1.9508, 0.4760, 0.4680], | |
..., | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 1.1350, -1.6268, 1.5461, ..., -1.8916, -1.9114, 1.1798], | |
[-0.3877, 0.3024, -1.1404, ..., 2.0661, -0.6191, -0.9355], | |
[-0.0229, 1.0938, -0.0923, ..., -0.2372, -0.9342, -0.0119], | |
..., | |
[ 0.1846, -0.2507, 0.2757, ..., -0.6224, -2.2640, 0.0596], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[ 1.2858, 1.6970, -0.7869, ..., -0.6545, -0.6808, -1.0335], | |
[-0.0877, -0.1239, -1.0846, ..., -0.3003, -0.8652, 2.1638], | |
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
..., | |
[ 0.6566, -2.1274, -0.8276, ..., 0.4022, -0.3414, 0.9617], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.5679, 1.1681, -0.2152, ..., 0.4324, -0.3278, 0.3071], | |
[-2.0951, -1.0155, 0.1259, ..., 0.1168, 0.8176, -1.6148], | |
[-0.6088, 0.1729, 0.2571, ..., 1.8095, 0.2413, 1.2040], | |
..., | |
[ 1.6005, 0.5397, -0.2096, ..., 0.3994, 1.0095, 0.0273], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.3876, 0.3656, 0.3301, ..., 0.5791, -0.6306, 0.7447], | |
[-1.5113, 1.0783, 0.5030, ..., -0.0460, -0.4079, 0.7232], | |
[ 2.0079, 0.2899, 0.1462, ..., 0.0117, 0.1548, 0.9144], | |
..., | |
[-0.4137, -0.8622, -1.2035, ..., -1.6588, 1.6848, 1.1585], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.0426435470581055, 3.9631638526916504) | |
mean: -0.045278389006853104 | |
std: 1.0026400089263916 | |
mamba_block1.inp_proj.weight gradient: 7.799080776749179e-06 | |
mamba_block1.inp_proj.bias gradient: 1.4282920346886385e-05 | |
mamba_block1.out_proj.weight gradient: 7.392667612293735e-05 | |
mamba_block1.out_proj.bias gradient: 0.001936393091455102 | |
mamba_block1.D.weight gradient: 2.073771793220658e-05 | |
mamba_block1.D.bias gradient: 2.6325606086174957e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 2.736860324148438e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.0264562812808435e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.7693038898869418e-05 | |
mamba_block1.S6.fc2.bias gradient: 2.8277174351387657e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.71951214724686e-05 | |
mamba_block1.S6.fc3.bias gradient: 2.7498386771185324e-05 | |
mamba_block1.conv.weight gradient: 5.321425123838708e-05 | |
mamba_block1.conv.bias gradient: 1.0530870895308908e-05 | |
mamba_block1.conv_linear.weight gradient: 1.8596581867313944e-05 | |
mamba_block1.conv_linear.bias gradient: 5.6252407375723124e-05 | |
mamba_block1.norm.weight gradient: 3.8478110582218505e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008295131847262383 | |
mamba_block2.inp_proj.bias gradient: 0.0029260965529829264 | |
mamba_block2.out_proj.weight gradient: 0.008724554441869259 | |
mamba_block2.out_proj.bias gradient: 0.020028317347168922 | |
mamba_block2.D.weight gradient: 0.00506933219730854 | |
mamba_block2.D.bias gradient: 0.001788186258636415 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0011186042102053761 | |
mamba_block2.S6.fc1.bias gradient: 0.0013990309089422226 | |
mamba_block2.S6.fc2.weight gradient: 0.0030988024082034826 | |
mamba_block2.S6.fc2.bias gradient: 0.003995432984083891 | |
mamba_block2.S6.fc3.weight gradient: 0.0028782477602362633 | |
mamba_block2.S6.fc3.bias gradient: 0.003649807535111904 | |
mamba_block2.conv.weight gradient: 0.011995462700724602 | |
mamba_block2.conv.bias gradient: 0.0008671989198774099 | |
mamba_block2.conv_linear.weight gradient: 0.009438985958695412 | |
mamba_block2.conv_linear.bias gradient: 0.006164009682834148 | |
mamba_block2.norm.weight gradient: 0.003288773586973548 | |
mamba_block3.inp_proj.weight gradient: 0.09197264909744263 | |
mamba_block3.inp_proj.bias gradient: 0.032406117767095566 | |
mamba_block3.out_proj.weight gradient: 0.04633298143744469 | |
mamba_block3.out_proj.bias gradient: 7.247580668945375e-08 | |
mamba_block3.D.weight gradient: 0.04627760872244835 | |
mamba_block3.D.bias gradient: 0.01634623482823372 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004637931939214468 | |
mamba_block3.S6.fc1.bias gradient: 0.004829443525522947 | |
mamba_block3.S6.fc2.weight gradient: 0.018521852791309357 | |
mamba_block3.S6.fc2.bias gradient: 0.017666202038526535 | |
mamba_block3.S6.fc3.weight gradient: 0.019322099164128304 | |
mamba_block3.S6.fc3.bias gradient: 0.01848628558218479 | |
mamba_block3.conv.weight gradient: 0.1661185473203659 | |
mamba_block3.conv.bias gradient: 0.016495849937200546 | |
mamba_block3.conv_linear.weight gradient: 0.07894043624401093 | |
mamba_block3.conv_linear.bias gradient: 0.055515777319669724 | |
mamba_block3.norm.weight gradient: 0.02291417308151722 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0258, 1.0139, 0.9935, ..., 1.0362, 1.0368, 0.9926], | |
[1.0152, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0258, 1.0139, 0.9935, ..., 1.0362, 1.0368, 0.9927], | |
[1.0152, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0258, 1.0139, 0.9935, ..., 1.0362, 1.0368, 0.9927], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]], | |
..., | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0258, 1.0139, 0.9935, ..., 1.0362, 1.0368, 0.9927], | |
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0258, 1.0139, 0.9935, ..., 1.0362, 1.0368, 0.9927], | |
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0258, 1.0139, 0.9936, ..., 1.0362, 1.0368, 0.9927], | |
[1.0152, 1.0008, 1.0023, ..., 1.0067, 1.0128, 0.9963]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9634485244750977, 1.0368773937225342) | |
mean: 1.0009514093399048 | |
std: 0.008020208217203617 | |
target = tensor([[[-7.2930e-02, -1.5876e+00, -1.7188e-01, ..., 1.2421e+00, | |
7.0656e-01, 4.5039e-01], | |
[-3.9398e-01, -3.2687e-01, -2.5453e+00, ..., 2.5721e-01, | |
-2.0661e-01, 1.2287e-01], | |
[-2.1060e-01, 8.8564e-01, -2.6071e-01, ..., -3.1098e-01, | |
-9.9241e-01, 2.3331e-01], | |
..., | |
[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01, | |
-3.2478e-01, -6.1987e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[-4.0608e-02, 8.3209e-01, -5.6969e-01, ..., 1.5675e-01, | |
-2.0986e+00, 1.0620e+00], | |
[-6.5043e-01, -1.3069e+00, -1.5379e-01, ..., 7.5101e-01, | |
-1.4239e+00, -2.3928e-01], | |
[-1.0840e+00, 6.8404e-01, 1.2655e-03, ..., -2.0117e-01, | |
7.5984e-01, -4.4518e-01], | |
..., | |
[ 5.0676e-01, -9.3567e-01, 5.5966e-01, ..., -2.6422e-01, | |
1.4782e+00, 2.3104e+00], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02, | |
-6.2687e-01, -4.9886e-01], | |
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02, | |
-6.2687e-01, -4.9886e-01], | |
[-1.3817e+00, 1.9420e-01, -1.2593e+00, ..., 5.9164e-02, | |
4.6726e-01, 3.5826e-01], | |
..., | |
[-1.3817e+00, 1.9420e-01, -1.2593e+00, ..., 5.9164e-02, | |
4.6726e-01, 3.5826e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
..., | |
[[ 6.8187e-01, 6.5033e-01, -7.8437e-02, ..., 1.2032e+00, | |
4.3723e-01, 5.0549e-02], | |
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02, | |
-6.2687e-01, -4.9886e-01], | |
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02, | |
-6.2687e-01, -4.9886e-01], | |
..., | |
[ 1.6843e+00, -9.7226e-01, -1.0947e+00, ..., 2.9206e-01, | |
8.7524e-01, -1.8599e+00], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[-8.9061e-01, -1.8222e-01, -8.0707e-01, ..., 9.1797e-01, | |
5.8479e-01, 8.0782e-01], | |
[-1.6287e+00, -9.6626e-01, -2.6581e-03, ..., 1.2085e+00, | |
-9.0730e-01, 4.8345e-02], | |
[-5.9222e-01, 1.0597e+00, -8.2489e-01, ..., 3.3105e-01, | |
5.1061e-01, -1.4595e-01], | |
..., | |
[ 2.5486e+00, 1.0393e-01, 1.4986e+00, ..., -1.2783e+00, | |
-7.6003e-01, -8.4845e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[-9.1471e-02, 1.2755e-01, 7.2934e-01, ..., 1.1558e+00, | |
-3.6694e-01, -2.0441e-01], | |
[ 4.9737e-01, 7.9059e-01, -1.0028e+00, ..., 1.2410e+00, | |
-1.4670e+00, -1.0270e+00], | |
[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01, | |
-3.2478e-01, -6.1987e-01], | |
..., | |
[ 5.2773e-01, -1.5574e+00, 5.2337e-02, ..., -4.6493e-01, | |
-4.2155e-02, 2.6858e+00], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]]], device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.335270404815674, 4.436890602111816) | |
mean: -0.04498155415058136 | |
std: 1.001204490661621 | |
mamba_block1.inp_proj.weight gradient: 7.761200322420336e-06 | |
mamba_block1.inp_proj.bias gradient: 1.3567365385824814e-05 | |
mamba_block1.out_proj.weight gradient: 8.328901458298787e-05 | |
mamba_block1.out_proj.bias gradient: 0.001959998393431306 | |
mamba_block1.D.weight gradient: 2.5373388780280948e-05 | |
mamba_block1.D.bias gradient: 2.6117280867765658e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 2.946576842077775e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.046666617796291e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.7208309145644307e-05 | |
mamba_block1.S6.fc2.bias gradient: 2.567360206739977e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.6707210306776688e-05 | |
mamba_block1.S6.fc3.bias gradient: 2.5138038836303167e-05 | |
mamba_block1.conv.weight gradient: 5.01254471600987e-05 | |
mamba_block1.conv.bias gradient: 8.776757567829918e-06 | |
mamba_block1.conv_linear.weight gradient: 1.8366961739957333e-05 | |
mamba_block1.conv_linear.bias gradient: 5.426110510597937e-05 | |
mamba_block1.norm.weight gradient: 5.980205969535746e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008871041238307953 | |
mamba_block2.inp_proj.bias gradient: 0.003129237564280629 | |
mamba_block2.out_proj.weight gradient: 0.008679620921611786 | |
mamba_block2.out_proj.bias gradient: 0.021518703550100327 | |
mamba_block2.D.weight gradient: 0.005031412001699209 | |
mamba_block2.D.bias gradient: 0.0017748060636222363 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.00114446971565485 | |
mamba_block2.S6.fc1.bias gradient: 0.0014250640524551272 | |
mamba_block2.S6.fc2.weight gradient: 0.0031288242898881435 | |
mamba_block2.S6.fc2.bias gradient: 0.0038880067877471447 | |
mamba_block2.S6.fc3.weight gradient: 0.00290935137309134 | |
mamba_block2.S6.fc3.bias gradient: 0.0035403394140303135 | |
mamba_block2.conv.weight gradient: 0.011712766252458096 | |
mamba_block2.conv.bias gradient: 0.0008701256010681391 | |
mamba_block2.conv_linear.weight gradient: 0.009876110590994358 | |
mamba_block2.conv_linear.bias gradient: 0.0061364308930933475 | |
mamba_block2.norm.weight gradient: 0.0033124934416264296 | |
mamba_block3.inp_proj.weight gradient: 0.09452962875366211 | |
mamba_block3.inp_proj.bias gradient: 0.033306971192359924 | |
mamba_block3.out_proj.weight gradient: 0.04634234309196472 | |
mamba_block3.out_proj.bias gradient: 6.166452948264123e-08 | |
mamba_block3.D.weight gradient: 0.05371489003300667 | |
mamba_block3.D.bias gradient: 0.018968436866998672 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.0050024171359837055 | |
mamba_block3.S6.fc1.bias gradient: 0.005564017221331596 | |
mamba_block3.S6.fc2.weight gradient: 0.01854153349995613 | |
mamba_block3.S6.fc2.bias gradient: 0.01928705908358097 | |
mamba_block3.S6.fc3.weight gradient: 0.019480036571621895 | |
mamba_block3.S6.fc3.bias gradient: 0.020197639241814613 | |
mamba_block3.conv.weight gradient: 0.16645151376724243 | |
mamba_block3.conv.bias gradient: 0.01688520796597004 | |
mamba_block3.conv_linear.weight gradient: 0.07389499992132187 | |
mamba_block3.conv_linear.bias gradient: 0.04942871257662773 | |
mamba_block3.norm.weight gradient: 0.025095799937844276 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0259, 1.0139, 0.9936, ..., 1.0363, 1.0370, 0.9927], | |
[1.0152, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0124, 1.0032, 0.9993, ..., 1.0051, 1.0098, 0.9955], | |
[1.0259, 1.0139, 0.9936, ..., 1.0363, 1.0370, 0.9927], | |
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0259, 1.0139, 0.9936, ..., 1.0363, 1.0370, 0.9927], | |
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]], | |
..., | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0259, 1.0139, 0.9936, ..., 1.0363, 1.0370, 0.9927], | |
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0259, 1.0139, 0.9936, ..., 1.0363, 1.0370, 0.9927], | |
[1.0152, 1.0008, 1.0023, ..., 1.0066, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0049, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0259, 1.0138, 0.9936, ..., 1.0363, 1.0370, 0.9927], | |
[1.0152, 1.0008, 1.0024, ..., 1.0067, 1.0128, 0.9963]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9633966684341431, 1.0370630025863647) | |
mean: 1.0009517669677734 | |
std: 0.008031157776713371 | |
target = tensor([[[ 0.0443, 0.8938, -0.9553, ..., -1.4890, -0.4225, -0.0190], | |
[ 0.6640, -0.1311, -0.5633, ..., -0.9016, 0.9427, -2.0075], | |
[ 0.0461, -0.2495, 0.1279, ..., 0.0830, 0.3020, -0.2346], | |
..., | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 1.1027, 0.7786, 1.2513, ..., -0.2450, 0.3287, -1.6867], | |
[ 0.4430, 1.5878, 1.1566, ..., 0.9944, -1.4975, -0.3028], | |
[ 0.1019, -0.3633, 1.4970, ..., -1.1070, 1.8470, 0.7881], | |
..., | |
[-1.6522, -1.8433, -0.7053, ..., -0.3111, 1.1393, -0.4392], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-1.4081, -1.4504, -0.2482, ..., 1.1472, -0.2720, 1.4572], | |
[-0.3877, 0.3024, -1.1404, ..., 2.0661, -0.6191, -0.9355], | |
[ 1.7248, -1.1906, -0.3196, ..., 0.4742, 2.1333, 0.7659], | |
..., | |
[-0.6860, -0.8323, 1.5771, ..., 0.9945, -0.1853, 0.8436], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[-0.3277, -0.7792, 0.8676, ..., 0.2718, -1.9822, 0.4135], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
..., | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.9245, 1.2509, 0.0326, ..., 0.1467, 1.1677, -0.9162], | |
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583], | |
[ 1.0915, -1.1170, -0.3247, ..., -0.9190, 1.1993, 0.6716], | |
..., | |
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417], | |
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417], | |
[ 0.0290, -0.8582, -0.3564, ..., 1.0203, 0.7273, 0.1357], | |
..., | |
[-0.5816, 0.0080, 1.8231, ..., -1.1851, 0.4162, -0.0030], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.0426435470581055, 3.9631638526916504) | |
mean: -0.04280959814786911 | |
std: 0.9980056881904602 | |
mamba_block1.inp_proj.weight gradient: 9.450679499423131e-06 | |
mamba_block1.inp_proj.bias gradient: 1.5937728676362894e-05 | |
mamba_block1.out_proj.weight gradient: 7.879294571466744e-05 | |
mamba_block1.out_proj.bias gradient: 0.001937422202900052 | |
mamba_block1.D.weight gradient: 2.1698680939152837e-05 | |
mamba_block1.D.bias gradient: 2.7504667741595767e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.608612132666167e-06 | |
mamba_block1.S6.fc1.bias gradient: 5.172527835384244e-06 | |
mamba_block1.S6.fc2.weight gradient: 2.2741836801287718e-05 | |
mamba_block1.S6.fc2.bias gradient: 3.544551145751029e-05 | |
mamba_block1.S6.fc3.weight gradient: 2.2189064111444168e-05 | |
mamba_block1.S6.fc3.bias gradient: 3.465290501480922e-05 | |
mamba_block1.conv.weight gradient: 5.530563066713512e-05 | |
mamba_block1.conv.bias gradient: 8.797837836027611e-06 | |
mamba_block1.conv_linear.weight gradient: 2.0959831090294756e-05 | |
mamba_block1.conv_linear.bias gradient: 6.4122024923563e-05 | |
mamba_block1.norm.weight gradient: 6.733871941833058e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008612259291112423 | |
mamba_block2.inp_proj.bias gradient: 0.0030379192903637886 | |
mamba_block2.out_proj.weight gradient: 0.008520158939063549 | |
mamba_block2.out_proj.bias gradient: 0.021844662725925446 | |
mamba_block2.D.weight gradient: 0.0048363665118813515 | |
mamba_block2.D.bias gradient: 0.0017059911042451859 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0011253735283389688 | |
mamba_block2.S6.fc1.bias gradient: 0.0013490341370925307 | |
mamba_block2.S6.fc2.weight gradient: 0.0030624449718743563 | |
mamba_block2.S6.fc2.bias gradient: 0.0040188198909163475 | |
mamba_block2.S6.fc3.weight gradient: 0.0028361633885651827 | |
mamba_block2.S6.fc3.bias gradient: 0.003677345346659422 | |
mamba_block2.conv.weight gradient: 0.01234238687902689 | |
mamba_block2.conv.bias gradient: 0.000914952193852514 | |
mamba_block2.conv_linear.weight gradient: 0.009588218294084072 | |
mamba_block2.conv_linear.bias gradient: 0.0058879912830889225 | |
mamba_block2.norm.weight gradient: 0.0032451574224978685 | |
mamba_block3.inp_proj.weight gradient: 0.09518136829137802 | |
mamba_block3.inp_proj.bias gradient: 0.03354042023420334 | |
mamba_block3.out_proj.weight gradient: 0.04517391696572304 | |
mamba_block3.out_proj.bias gradient: 1.0457387844553523e-07 | |
mamba_block3.D.weight gradient: 0.047621291130781174 | |
mamba_block3.D.bias gradient: 0.016819199547171593 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.0044674440287053585 | |
mamba_block3.S6.fc1.bias gradient: 0.0047797891311347485 | |
mamba_block3.S6.fc2.weight gradient: 0.01866152696311474 | |
mamba_block3.S6.fc2.bias gradient: 0.019014324992895126 | |
mamba_block3.S6.fc3.weight gradient: 0.01939181052148342 | |
mamba_block3.S6.fc3.bias gradient: 0.020019695162773132 | |
mamba_block3.conv.weight gradient: 0.16745711863040924 | |
mamba_block3.conv.bias gradient: 0.01679954305291176 | |
mamba_block3.conv_linear.weight gradient: 0.07562368363142014 | |
mamba_block3.conv_linear.bias gradient: 0.054438941180706024 | |
mamba_block3.norm.weight gradient: 0.024364715442061424 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0260, 1.0139, 0.9936, ..., 1.0364, 1.0372, 0.9927], | |
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0260, 1.0139, 0.9936, ..., 1.0365, 1.0372, 0.9927], | |
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0260, 1.0139, 0.9936, ..., 1.0364, 1.0372, 0.9927], | |
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9962]], | |
..., | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0260, 1.0139, 0.9936, ..., 1.0365, 1.0372, 0.9927], | |
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0260, 1.0138, 0.9936, ..., 1.0365, 1.0372, 0.9927], | |
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0034, 0.9950], | |
[1.0053, 1.0065, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0065, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0260, 1.0139, 0.9936, ..., 1.0365, 1.0372, 0.9927], | |
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9963]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.963355541229248, 1.0372397899627686) | |
mean: 1.0009523630142212 | |
std: 0.008042216300964355 | |
target = tensor([[[ 6.4588e-01, -9.7711e-01, 1.4713e-01, ..., -1.7452e+00, | |
2.4286e-02, 1.4304e-02], | |
[ 1.2286e+00, -1.6981e+00, 1.1041e-01, ..., -1.3688e+00, | |
-4.3559e-01, 2.2583e-01], | |
[ 5.7543e-01, -1.2123e+00, 1.6030e+00, ..., -8.3098e-01, | |
-2.7845e+00, -2.1074e-02], | |
..., | |
[-1.7780e+00, -3.5378e-01, 5.7481e-01, ..., -4.0297e-01, | |
-3.2478e-01, -6.1987e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[-1.7117e+00, 7.8182e-01, 1.7233e-01, ..., 1.2136e+00, | |
3.9576e-01, -4.3173e-01], | |
[ 4.4899e-01, -2.4530e+00, -1.6500e-01, ..., -1.3791e-01, | |
-4.4953e-02, 3.7787e-01], | |
[-9.2340e-01, 1.3582e+00, 1.4513e+00, ..., -3.5925e-01, | |
-1.2063e+00, -1.5141e-01], | |
..., | |
[ 1.2389e+00, 1.4293e-01, -9.1111e-01, ..., -5.7567e-02, | |
9.1207e-01, 5.4976e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[-1.6287e+00, -9.6626e-01, -2.6581e-03, ..., 1.2085e+00, | |
-9.0730e-01, 4.8345e-02], | |
[ 1.0462e+00, 9.3372e-01, 9.1681e-01, ..., 4.8498e-01, | |
5.8902e-01, -9.3716e-02], | |
[ 5.0676e-01, -9.3567e-01, 5.5966e-01, ..., -2.6422e-01, | |
1.4782e+00, 2.3104e+00], | |
..., | |
[ 1.0110e+00, 5.1812e-01, -9.6063e-01, ..., 6.9258e-01, | |
2.0773e-01, -8.0699e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
..., | |
[[-4.0720e-01, -5.6419e-01, 8.8167e-01, ..., 7.7059e-01, | |
-4.5208e-01, -3.7696e-01], | |
[-2.8159e-02, 8.7647e-01, 3.6170e-01, ..., -8.5379e-01, | |
5.3774e-01, -1.6134e+00], | |
[ 1.0189e-01, -3.6331e-01, 1.4970e+00, ..., -1.1070e+00, | |
1.8470e+00, 7.8808e-01], | |
..., | |
[-3.2770e-01, -7.7916e-01, 8.6764e-01, ..., 2.7178e-01, | |
-1.9822e+00, 4.1346e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[-1.5331e-01, 5.9299e-01, 1.3224e-01, ..., -1.6126e+00, | |
-6.2350e-01, -1.3132e+00], | |
[-7.7284e-01, 7.0358e-01, -2.8840e-02, ..., -1.3084e+00, | |
-3.0288e-01, -8.2964e-01], | |
[ 1.9111e-01, -1.2776e+00, -1.7906e-01, ..., -1.6976e-01, | |
-3.4747e-01, 1.2224e+00], | |
..., | |
[ 1.0189e-01, -3.6331e-01, 1.4970e+00, ..., -1.1070e+00, | |
1.8470e+00, 7.8808e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 1.0189e-01, -3.6331e-01, 1.4970e+00, ..., -1.1070e+00, | |
1.8470e+00, 7.8808e-01], | |
[-1.7242e+00, 4.8190e-01, 1.8281e+00, ..., 4.0987e-01, | |
-2.7694e-01, -1.8146e-01], | |
[-7.2930e-02, -1.5876e+00, -1.7188e-01, ..., 1.2421e+00, | |
7.0656e-01, 4.5039e-01], | |
..., | |
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02, | |
-6.2687e-01, -4.9886e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]]], device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.062912464141846, 4.436890602111816) | |
mean: -0.045185498893260956 | |
std: 1.000666618347168 | |
mamba_block1.inp_proj.weight gradient: 7.601716788485646e-06 | |
mamba_block1.inp_proj.bias gradient: 1.3704358934774064e-05 | |
mamba_block1.out_proj.weight gradient: 8.511666965205222e-05 | |
mamba_block1.out_proj.bias gradient: 0.0020013893954455853 | |
mamba_block1.D.weight gradient: 2.2376692868419923e-05 | |
mamba_block1.D.bias gradient: 2.6539159080130048e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.140715534755145e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.521216851571808e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.8029242710326798e-05 | |
mamba_block1.S6.fc2.bias gradient: 2.9236758564366028e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.7625890905037522e-05 | |
mamba_block1.S6.fc3.bias gradient: 2.8570417271112092e-05 | |
mamba_block1.conv.weight gradient: 5.065084042144008e-05 | |
mamba_block1.conv.bias gradient: 8.323479960381519e-06 | |
mamba_block1.conv_linear.weight gradient: 1.7244665286852978e-05 | |
mamba_block1.conv_linear.bias gradient: 5.684297138941474e-05 | |
mamba_block1.norm.weight gradient: 5.099446298117982e-06 | |
mamba_block2.inp_proj.weight gradient: 0.00868897046893835 | |
mamba_block2.inp_proj.bias gradient: 0.0030649492982774973 | |
mamba_block2.out_proj.weight gradient: 0.008559616282582283 | |
mamba_block2.out_proj.bias gradient: 0.020069818943738937 | |
mamba_block2.D.weight gradient: 0.005039901006966829 | |
mamba_block2.D.bias gradient: 0.0017777711618691683 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0011198658030480146 | |
mamba_block2.S6.fc1.bias gradient: 0.0014025933342054486 | |
mamba_block2.S6.fc2.weight gradient: 0.003106378484517336 | |
mamba_block2.S6.fc2.bias gradient: 0.004144839011132717 | |
mamba_block2.S6.fc3.weight gradient: 0.002879718318581581 | |
mamba_block2.S6.fc3.bias gradient: 0.003787883324548602 | |
mamba_block2.conv.weight gradient: 0.011892883107066154 | |
mamba_block2.conv.bias gradient: 0.0009201067732647061 | |
mamba_block2.conv_linear.weight gradient: 0.0095598753541708 | |
mamba_block2.conv_linear.bias gradient: 0.006426714826375246 | |
mamba_block2.norm.weight gradient: 0.003324520541355014 | |
mamba_block3.inp_proj.weight gradient: 0.09168072789907455 | |
mamba_block3.inp_proj.bias gradient: 0.03229833021759987 | |
mamba_block3.out_proj.weight gradient: 0.04626988619565964 | |
mamba_block3.out_proj.bias gradient: 7.785879319044398e-08 | |
mamba_block3.D.weight gradient: 0.04269688203930855 | |
mamba_block3.D.bias gradient: 0.015081201680004597 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004488199949264526 | |
mamba_block3.S6.fc1.bias gradient: 0.004849399905651808 | |
mamba_block3.S6.fc2.weight gradient: 0.015625562518835068 | |
mamba_block3.S6.fc2.bias gradient: 0.013996592722833157 | |
mamba_block3.S6.fc3.weight gradient: 0.016473641619086266 | |
mamba_block3.S6.fc3.bias gradient: 0.014759422279894352 | |
mamba_block3.conv.weight gradient: 0.16740712523460388 | |
mamba_block3.conv.bias gradient: 0.01699674315750599 | |
mamba_block3.conv_linear.weight gradient: 0.07384508848190308 | |
mamba_block3.conv_linear.bias gradient: 0.049471717327833176 | |
mamba_block3.norm.weight gradient: 0.02215823344886303 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0262, 1.0139, 0.9936, ..., 1.0366, 1.0374, 0.9927], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0128, 0.9962]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0261, 1.0139, 0.9937, ..., 1.0366, 1.0374, 0.9927], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0128, 0.9962]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0261, 1.0139, 0.9936, ..., 1.0366, 1.0373, 0.9927], | |
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9962]], | |
..., | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0261, 1.0139, 0.9936, ..., 1.0366, 1.0374, 0.9927], | |
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9962]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0261, 1.0139, 0.9936, ..., 1.0366, 1.0373, 0.9927], | |
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9962]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0261, 1.0139, 0.9936, ..., 1.0366, 1.0373, 0.9927], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0128, 0.9962]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.963310956954956, 1.037413239479065) | |
mean: 1.0009527206420898 | |
std: 0.008052636869251728 | |
target = tensor([[[-0.0729, -1.5876, -0.1719, ..., 1.2421, 0.7066, 0.4504], | |
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
..., | |
[ 1.0504, 1.2879, 1.0797, ..., -0.6484, -1.7104, 0.3437], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.9335, 0.1877, -2.0042, ..., -1.1503, -1.7980, -0.5640], | |
[ 0.9420, -0.5706, -0.0884, ..., -0.8730, 1.7595, -1.3956], | |
[-1.6287, -0.9663, -0.0027, ..., 1.2085, -0.9073, 0.0483], | |
..., | |
[ 0.7351, -1.4100, 0.0052, ..., 0.4583, 1.4485, -0.0438], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-1.0589, 0.7468, 0.2544, ..., 0.0787, 0.4220, 0.6745], | |
[-1.0592, 0.1051, -0.3675, ..., -0.1518, -0.8563, -1.1461], | |
[ 1.0504, 1.2879, 1.0797, ..., -0.6484, -1.7104, 0.3437], | |
..., | |
[-0.1842, -1.0219, 1.0257, ..., -0.1165, -0.2031, -0.5445], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[-0.1081, 0.8790, 0.6781, ..., -0.5866, -0.1624, 0.4462], | |
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417], | |
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417], | |
..., | |
[-1.1794, -1.5216, 0.0929, ..., 0.0363, 1.0894, 2.1755], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.3699, 0.2790, 0.5842, ..., 1.5279, 0.0930, -1.2546], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
..., | |
[-0.0282, 0.8765, 0.3617, ..., -0.8538, 0.5377, -1.6134], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.1884, 0.9860, -0.6278, ..., 0.4238, -1.8099, -0.7295], | |
[-0.2976, -2.1353, -0.2941, ..., -0.8635, 0.5327, 0.7513], | |
[-0.7571, -1.6050, -0.0124, ..., -0.9880, -0.9499, 0.8033], | |
..., | |
[-0.7124, -0.1175, 0.4958, ..., 0.8150, 0.6772, 1.4007], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.062912464141846, 4.436890602111816) | |
mean: -0.04132794588804245 | |
std: 0.9972973465919495 | |
mamba_block1.inp_proj.weight gradient: 8.05390391178662e-06 | |
mamba_block1.inp_proj.bias gradient: 1.3198721717344597e-05 | |
mamba_block1.out_proj.weight gradient: 8.135655662044883e-05 | |
mamba_block1.out_proj.bias gradient: 0.0019478988833725452 | |
mamba_block1.D.weight gradient: 2.6859646823140793e-05 | |
mamba_block1.D.bias gradient: 2.88037299469579e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.693293592732516e-06 | |
mamba_block1.S6.fc1.bias gradient: 5.4035172070143744e-06 | |
mamba_block1.S6.fc2.weight gradient: 2.4369408492930233e-05 | |
mamba_block1.S6.fc2.bias gradient: 3.828832268482074e-05 | |
mamba_block1.S6.fc3.weight gradient: 2.356448931095656e-05 | |
mamba_block1.S6.fc3.bias gradient: 3.708050280692987e-05 | |
mamba_block1.conv.weight gradient: 5.2881990995956585e-05 | |
mamba_block1.conv.bias gradient: 8.851877282722853e-06 | |
mamba_block1.conv_linear.weight gradient: 1.8447050024406053e-05 | |
mamba_block1.conv_linear.bias gradient: 6.762157136108726e-05 | |
mamba_block1.norm.weight gradient: 6.673816642432939e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008638471364974976 | |
mamba_block2.inp_proj.bias gradient: 0.0030471261125057936 | |
mamba_block2.out_proj.weight gradient: 0.00904333870857954 | |
mamba_block2.out_proj.bias gradient: 0.022471558302640915 | |
mamba_block2.D.weight gradient: 0.005017615854740143 | |
mamba_block2.D.bias gradient: 0.0017699210438877344 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0011227426584810019 | |
mamba_block2.S6.fc1.bias gradient: 0.0013828028459101915 | |
mamba_block2.S6.fc2.weight gradient: 0.0031488672830164433 | |
mamba_block2.S6.fc2.bias gradient: 0.004178797360509634 | |
mamba_block2.S6.fc3.weight gradient: 0.002923778258264065 | |
mamba_block2.S6.fc3.bias gradient: 0.0038226121105253696 | |
mamba_block2.conv.weight gradient: 0.012409028597176075 | |
mamba_block2.conv.bias gradient: 0.0009441959555260837 | |
mamba_block2.conv_linear.weight gradient: 0.00955372303724289 | |
mamba_block2.conv_linear.bias gradient: 0.006237642839550972 | |
mamba_block2.norm.weight gradient: 0.003364289877936244 | |
mamba_block3.inp_proj.weight gradient: 0.1019379273056984 | |
mamba_block3.inp_proj.bias gradient: 0.03592952340841293 | |
mamba_block3.out_proj.weight gradient: 0.048958100378513336 | |
mamba_block3.out_proj.bias gradient: 4.5786606506226235e-08 | |
mamba_block3.D.weight gradient: 0.05141298845410347 | |
mamba_block3.D.bias gradient: 0.018162697553634644 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004572192206978798 | |
mamba_block3.S6.fc1.bias gradient: 0.005084225907921791 | |
mamba_block3.S6.fc2.weight gradient: 0.016733428463339806 | |
mamba_block3.S6.fc2.bias gradient: 0.018023964017629623 | |
mamba_block3.S6.fc3.weight gradient: 0.017616454511880875 | |
mamba_block3.S6.fc3.bias gradient: 0.018776189535856247 | |
mamba_block3.conv.weight gradient: 0.16808335483074188 | |
mamba_block3.conv.bias gradient: 0.017074599862098694 | |
mamba_block3.conv_linear.weight gradient: 0.08281093835830688 | |
mamba_block3.conv_linear.bias gradient: 0.060874998569488525 | |
mamba_block3.norm.weight gradient: 0.024700812995433807 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0262, 1.0139, 0.9937, ..., 1.0367, 1.0375, 0.9927], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0128, 0.9962]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0262, 1.0140, 0.9936, ..., 1.0367, 1.0375, 0.9927], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0262, 1.0139, 0.9937, ..., 1.0367, 1.0376, 0.9928], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0128, 0.9962]], | |
..., | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0262, 1.0139, 0.9937, ..., 1.0367, 1.0375, 0.9927], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0013, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0262, 1.0139, 0.9937, ..., 1.0367, 1.0375, 0.9927], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0128, 0.9962]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0262, 1.0139, 0.9937, ..., 1.0367, 1.0375, 0.9927], | |
[1.0153, 1.0009, 1.0023, ..., 1.0066, 1.0128, 0.9962]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9632824659347534, 1.0375888347625732) | |
mean: 1.000953197479248 | |
std: 0.008062897250056267 | |
target = tensor([[[ 0.2981, -0.4210, -1.5597, ..., -2.1300, -0.6522, 1.3287], | |
[ 0.1265, 0.3060, -1.2604, ..., 1.1243, -0.3889, -0.2856], | |
[-1.0045, -0.8447, 0.0927, ..., -0.8352, -1.6738, 0.2916], | |
..., | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417], | |
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417], | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
..., | |
[-0.5858, 2.4877, 0.2696, ..., -0.1860, 0.7473, 0.5435], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.8496, -0.4320, 0.5527, ..., -0.2742, 2.0447, -0.5175], | |
[-0.2976, -2.1353, -0.2941, ..., -0.8635, 0.5327, 0.7513], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
..., | |
[ 1.8766, 0.0779, -2.8239, ..., -0.9667, -0.3084, 1.0684], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[-0.1949, 1.6447, -0.3521, ..., -1.4622, 0.0887, 0.7248], | |
[ 0.2176, -0.9511, 0.3012, ..., -0.8989, -0.1549, 0.8165], | |
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417], | |
..., | |
[-0.4027, 1.6658, -0.0122, ..., -0.5772, 2.0100, -0.6190], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-1.6138, 0.4046, 0.1097, ..., 0.0617, -0.5610, 0.3161], | |
[-1.9648, 2.5084, 1.4522, ..., -1.1336, -1.4860, 1.4592], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
..., | |
[ 1.2397, -0.0055, 0.5789, ..., -1.6370, -0.4645, -1.3456], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.1453, 2.1408, 0.6240, ..., 0.6000, 0.3859, 1.1016], | |
[ 0.0083, -1.4979, -0.0571, ..., -0.1176, 1.0814, 0.6415], | |
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505], | |
..., | |
[-0.9827, -0.1144, 2.1513, ..., 0.4412, -1.5209, -0.6943], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.062912464141846, 3.9631638526916504) | |
mean: -0.04515944421291351 | |
std: 0.9984781742095947 | |
mamba_block1.inp_proj.weight gradient: 8.875967978383414e-06 | |
mamba_block1.inp_proj.bias gradient: 1.4417765669350047e-05 | |
mamba_block1.out_proj.weight gradient: 8.700188482180238e-05 | |
mamba_block1.out_proj.bias gradient: 0.0019073453731834888 | |
mamba_block1.D.weight gradient: 2.5550840291543864e-05 | |
mamba_block1.D.bias gradient: 2.8005353669868782e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.4992071960004978e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.9961317927227356e-06 | |
mamba_block1.S6.fc2.weight gradient: 2.349440001125913e-05 | |
mamba_block1.S6.fc2.bias gradient: 3.7002042517997324e-05 | |
mamba_block1.S6.fc3.weight gradient: 2.271077573823277e-05 | |
mamba_block1.S6.fc3.bias gradient: 3.5931770980823785e-05 | |
mamba_block1.conv.weight gradient: 5.5383403378073126e-05 | |
mamba_block1.conv.bias gradient: 1.014738609228516e-05 | |
mamba_block1.conv_linear.weight gradient: 1.983829861273989e-05 | |
mamba_block1.conv_linear.bias gradient: 6.198877235874534e-05 | |
mamba_block1.norm.weight gradient: 7.389627626253059e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008638354018330574 | |
mamba_block2.inp_proj.bias gradient: 0.0030470658093690872 | |
mamba_block2.out_proj.weight gradient: 0.008586056530475616 | |
mamba_block2.out_proj.bias gradient: 0.021609429270029068 | |
mamba_block2.D.weight gradient: 0.005029033869504929 | |
mamba_block2.D.bias gradient: 0.0017739012837409973 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0011420527007430792 | |
mamba_block2.S6.fc1.bias gradient: 0.0013999291695654392 | |
mamba_block2.S6.fc2.weight gradient: 0.003110809251666069 | |
mamba_block2.S6.fc2.bias gradient: 0.004045133478939533 | |
mamba_block2.S6.fc3.weight gradient: 0.0028845700435340405 | |
mamba_block2.S6.fc3.bias gradient: 0.003693824866786599 | |
mamba_block2.conv.weight gradient: 0.012067034840583801 | |
mamba_block2.conv.bias gradient: 0.0008926084847189486 | |
mamba_block2.conv_linear.weight gradient: 0.009714269079267979 | |
mamba_block2.conv_linear.bias gradient: 0.00610923208296299 | |
mamba_block2.norm.weight gradient: 0.003266611136496067 | |
mamba_block3.inp_proj.weight gradient: 0.09592823684215546 | |
mamba_block3.inp_proj.bias gradient: 0.033799685537815094 | |
mamba_block3.out_proj.weight gradient: 0.04734755679965019 | |
mamba_block3.out_proj.bias gradient: 1.0258085580971965e-07 | |
mamba_block3.D.weight gradient: 0.051352519541978836 | |
mamba_block3.D.bias gradient: 0.018139084801077843 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004822483751922846 | |
mamba_block3.S6.fc1.bias gradient: 0.005282685160636902 | |
mamba_block3.S6.fc2.weight gradient: 0.017948182299733162 | |
mamba_block3.S6.fc2.bias gradient: 0.017465978860855103 | |
mamba_block3.S6.fc3.weight gradient: 0.01869887299835682 | |
mamba_block3.S6.fc3.bias gradient: 0.018225835636258125 | |
mamba_block3.conv.weight gradient: 0.16913020610809326 | |
mamba_block3.conv.bias gradient: 0.01695968210697174 | |
mamba_block3.conv_linear.weight gradient: 0.07513276487588882 | |
mamba_block3.conv_linear.bias gradient: 0.04860156029462814 | |
mamba_block3.norm.weight gradient: 0.024453913792967796 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9967, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0263, 1.0139, 0.9937, ..., 1.0369, 1.0377, 0.9928], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0263, 1.0139, 0.9937, ..., 1.0369, 1.0377, 0.9928], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0263, 1.0139, 0.9937, ..., 1.0369, 1.0377, 0.9928], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
..., | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0263, 1.0139, 0.9937, ..., 1.0369, 1.0377, 0.9927], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0263, 1.0139, 0.9937, ..., 1.0369, 1.0377, 0.9928], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0263, 1.0139, 0.9937, ..., 1.0368, 1.0377, 0.9928], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9632494449615479, 1.037784218788147) | |
mean: 1.0009536743164062 | |
std: 0.008073143661022186 | |
target = tensor([[[ 1.4736, 0.5671, 0.4209, ..., -0.5206, -0.6041, 1.2744], | |
[ 0.4655, -0.2961, -0.1109, ..., 0.1105, 0.1356, -0.1565], | |
[ 1.1350, -1.6268, 1.5461, ..., -1.8916, -1.9114, 1.1798], | |
..., | |
[ 0.0083, -1.4979, -0.0571, ..., -0.1176, 1.0814, 0.6415], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 2.4458, -0.4609, -0.6933, ..., 0.2152, 0.6763, -1.1608], | |
[ 0.6606, 0.6995, -1.1284, ..., 0.8394, -0.4208, -0.3543], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
..., | |
[ 0.9285, 0.5199, 1.0481, ..., 2.5334, 0.8552, -1.4535], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 2.1325, 0.0683, 1.1581, ..., 1.2571, 0.4663, -0.7213], | |
[-0.1533, 0.5930, 0.1322, ..., -1.6126, -0.6235, -1.3132], | |
[ 1.4308, 0.4979, -3.0519, ..., -0.6231, 0.7584, -0.9699], | |
..., | |
[ 1.1250, 1.1728, 0.5452, ..., -1.0478, 0.4368, 1.5019], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[ 0.8267, -0.7581, -1.7703, ..., -1.0994, 0.0531, -0.9797], | |
[ 0.1753, -0.7625, 0.1469, ..., 0.3183, 0.1690, -0.1184], | |
[ 1.1250, 1.1728, 0.5452, ..., -1.0478, 0.4368, 1.5019], | |
..., | |
[-0.3877, 0.3024, -1.1404, ..., 2.0661, -0.6191, -0.9355], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-1.1214, 0.1502, -1.9141, ..., -0.7317, 0.2875, 0.2514], | |
[-0.7564, -0.0087, -1.0106, ..., 0.9032, -1.1468, -0.9196], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
..., | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.2905, -0.4199, -0.1905, ..., -1.0879, 0.5756, 1.0819], | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
..., | |
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.062912464141846, 3.9631638526916504) | |
mean: -0.03772750869393349 | |
std: 0.9975149035453796 | |
mamba_block1.inp_proj.weight gradient: 7.80315986048663e-06 | |
mamba_block1.inp_proj.bias gradient: 1.3307480912772007e-05 | |
mamba_block1.out_proj.weight gradient: 8.432415779680014e-05 | |
mamba_block1.out_proj.bias gradient: 0.0019840870518237352 | |
mamba_block1.D.weight gradient: 2.4237377147073857e-05 | |
mamba_block1.D.bias gradient: 2.91150627163006e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.429665866860887e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.9839036364573985e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.875735142675694e-05 | |
mamba_block1.S6.fc2.bias gradient: 2.9167822503950447e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.8030044884653762e-05 | |
mamba_block1.S6.fc3.bias gradient: 2.8111589926993474e-05 | |
mamba_block1.conv.weight gradient: 5.296375456964597e-05 | |
mamba_block1.conv.bias gradient: 9.648112609283999e-06 | |
mamba_block1.conv_linear.weight gradient: 1.8288330466020852e-05 | |
mamba_block1.conv_linear.bias gradient: 6.101214967202395e-05 | |
mamba_block1.norm.weight gradient: 6.388846941263182e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008736726827919483 | |
mamba_block2.inp_proj.bias gradient: 0.003081721253693104 | |
mamba_block2.out_proj.weight gradient: 0.008913129568099976 | |
mamba_block2.out_proj.bias gradient: 0.021359387785196304 | |
mamba_block2.D.weight gradient: 0.005379557143896818 | |
mamba_block2.D.bias gradient: 0.0018975320272147655 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0012010665377601981 | |
mamba_block2.S6.fc1.bias gradient: 0.0015100709861144423 | |
mamba_block2.S6.fc2.weight gradient: 0.0034038915764540434 | |
mamba_block2.S6.fc2.bias gradient: 0.0046056401915848255 | |
mamba_block2.S6.fc3.weight gradient: 0.0031575202010571957 | |
mamba_block2.S6.fc3.bias gradient: 0.0042189848609268665 | |
mamba_block2.conv.weight gradient: 0.012530266307294369 | |
mamba_block2.conv.bias gradient: 0.0009156710002571344 | |
mamba_block2.conv_linear.weight gradient: 0.009922015480697155 | |
mamba_block2.conv_linear.bias gradient: 0.0066932328045368195 | |
mamba_block2.norm.weight gradient: 0.003500598017126322 | |
mamba_block3.inp_proj.weight gradient: 0.10002566874027252 | |
mamba_block3.inp_proj.bias gradient: 0.03524862602353096 | |
mamba_block3.out_proj.weight gradient: 0.046485137194395065 | |
mamba_block3.out_proj.bias gradient: 5.290740290320173e-08 | |
mamba_block3.D.weight gradient: 0.0468018613755703 | |
mamba_block3.D.bias gradient: 0.016533298417925835 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.005081711336970329 | |
mamba_block3.S6.fc1.bias gradient: 0.005625939462333918 | |
mamba_block3.S6.fc2.weight gradient: 0.018596788868308067 | |
mamba_block3.S6.fc2.bias gradient: 0.019362082704901695 | |
mamba_block3.S6.fc3.weight gradient: 0.019523371011018753 | |
mamba_block3.S6.fc3.bias gradient: 0.020279493182897568 | |
mamba_block3.conv.weight gradient: 0.16787144541740417 | |
mamba_block3.conv.bias gradient: 0.017073169350624084 | |
mamba_block3.conv_linear.weight gradient: 0.07526125758886337 | |
mamba_block3.conv_linear.bias gradient: 0.04669470712542534 | |
mamba_block3.norm.weight gradient: 0.024019045755267143 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0264, 1.0139, 0.9937, ..., 1.0370, 1.0379, 0.9928], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0264, 1.0139, 0.9937, ..., 1.0370, 1.0379, 0.9928], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0264, 1.0139, 0.9937, ..., 1.0370, 1.0379, 0.9928], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
..., | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0264, 1.0139, 0.9938, ..., 1.0370, 1.0379, 0.9928], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0264, 1.0139, 0.9937, ..., 1.0370, 1.0379, 0.9928], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
[[1.0054, 1.0070, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0264, 1.0139, 0.9937, ..., 1.0370, 1.0379, 0.9928], | |
[1.0153, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9632093906402588, 1.0379583835601807) | |
mean: 1.000954270362854 | |
std: 0.008083177730441093 | |
target = tensor([[[-0.3148, -2.4389, -0.7981, ..., 1.4565, 0.6902, -2.8516], | |
[-0.8489, -1.0928, -0.8596, ..., -0.1898, -0.6665, -1.0761], | |
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583], | |
..., | |
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-1.6416, -0.2457, -2.1230, ..., -0.0060, -1.1015, 1.9065], | |
[ 0.6251, -0.3997, -0.4391, ..., 0.7783, -1.3073, -0.5255], | |
[-0.4625, 0.4049, -0.4079, ..., 0.6291, 1.8454, 0.2429], | |
..., | |
[-0.6884, 0.6784, -1.0951, ..., -1.3010, 0.1661, -0.5777], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.1575, -1.4267, 1.2486, ..., -0.2827, 0.5434, -0.3321], | |
[ 0.9335, 0.1877, -2.0042, ..., -1.1503, -1.7980, -0.5640], | |
[-0.3470, 1.6160, -1.1352, ..., 1.0317, 1.0726, 0.2802], | |
..., | |
[-1.0850, -0.6438, 0.2743, ..., -1.1642, -0.8742, -0.2776], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[ 0.9105, 0.0598, -0.7111, ..., 0.9642, -0.3206, 0.5715], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
..., | |
[-2.0443, 0.7522, -0.2560, ..., 0.3880, 0.9740, 0.8830], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.3006, -1.3258, 0.1337, ..., 0.5020, -1.0170, -1.4881], | |
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583], | |
[-1.3397, -0.5167, 0.8265, ..., 0.2521, -0.3263, 0.4133], | |
..., | |
[ 2.2013, -0.1434, -0.3354, ..., 0.7899, -1.2002, 0.6800], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.8830, -1.9559, 0.9161, ..., -0.2516, -1.0361, -0.5355], | |
[ 0.9335, 0.1877, -2.0042, ..., -1.1503, -1.7980, -0.5640], | |
[-1.7728, -0.2004, -0.4214, ..., -0.8403, 0.5624, 1.3858], | |
..., | |
[ 0.1098, 1.6962, 1.1069, ..., 0.4857, 0.8313, 2.2824], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.293727397918701, 4.436890602111816) | |
mean: -0.03683827444911003 | |
std: 0.9964061379432678 | |
mamba_block1.inp_proj.weight gradient: 9.050143489730544e-06 | |
mamba_block1.inp_proj.bias gradient: 1.3294705240696203e-05 | |
mamba_block1.out_proj.weight gradient: 8.745616651140153e-05 | |
mamba_block1.out_proj.bias gradient: 0.002077836310490966 | |
mamba_block1.D.weight gradient: 2.5103221560129896e-05 | |
mamba_block1.D.bias gradient: 2.983676859003026e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.631815843618824e-06 | |
mamba_block1.S6.fc1.bias gradient: 5.111124210088747e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.814823190215975e-05 | |
mamba_block1.S6.fc2.bias gradient: 2.881790351239033e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.7504753486718982e-05 | |
mamba_block1.S6.fc3.bias gradient: 2.7879737899638712e-05 | |
mamba_block1.conv.weight gradient: 5.321612843545154e-05 | |
mamba_block1.conv.bias gradient: 8.180650183930993e-06 | |
mamba_block1.conv_linear.weight gradient: 1.8778771845973097e-05 | |
mamba_block1.conv_linear.bias gradient: 6.445188046200201e-05 | |
mamba_block1.norm.weight gradient: 6.032561486790655e-06 | |
mamba_block2.inp_proj.weight gradient: 0.009084882214665413 | |
mamba_block2.inp_proj.bias gradient: 0.003204511944204569 | |
mamba_block2.out_proj.weight gradient: 0.009371430613100529 | |
mamba_block2.out_proj.bias gradient: 0.02175474539399147 | |
mamba_block2.D.weight gradient: 0.005289588589221239 | |
mamba_block2.D.bias gradient: 0.0018657789332792163 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0011827623238787055 | |
mamba_block2.S6.fc1.bias gradient: 0.00147047801874578 | |
mamba_block2.S6.fc2.weight gradient: 0.00330118159763515 | |
mamba_block2.S6.fc2.bias gradient: 0.004464610014110804 | |
mamba_block2.S6.fc3.weight gradient: 0.0030671891290694475 | |
mamba_block2.S6.fc3.bias gradient: 0.00408173305913806 | |
mamba_block2.conv.weight gradient: 0.012712618336081505 | |
mamba_block2.conv.bias gradient: 0.0009551901021040976 | |
mamba_block2.conv_linear.weight gradient: 0.010156966745853424 | |
mamba_block2.conv_linear.bias gradient: 0.0067415423691272736 | |
mamba_block2.norm.weight gradient: 0.0035081824753433466 | |
mamba_block3.inp_proj.weight gradient: 0.09881533682346344 | |
mamba_block3.inp_proj.bias gradient: 0.03481940180063248 | |
mamba_block3.out_proj.weight gradient: 0.04826899245381355 | |
mamba_block3.out_proj.bias gradient: 8.839717935416047e-08 | |
mamba_block3.D.weight gradient: 0.04611104354262352 | |
mamba_block3.D.bias gradient: 0.0162859745323658 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004664601292461157 | |
mamba_block3.S6.fc1.bias gradient: 0.004930058494210243 | |
mamba_block3.S6.fc2.weight gradient: 0.017418786883354187 | |
mamba_block3.S6.fc2.bias gradient: 0.01641864702105522 | |
mamba_block3.S6.fc3.weight gradient: 0.018391642719507217 | |
mamba_block3.S6.fc3.bias gradient: 0.01718025654554367 | |
mamba_block3.conv.weight gradient: 0.1693553924560547 | |
mamba_block3.conv.bias gradient: 0.017287174239754677 | |
mamba_block3.conv_linear.weight gradient: 0.08135779947042465 | |
mamba_block3.conv_linear.bias gradient: 0.05725223198533058 | |
mamba_block3.norm.weight gradient: 0.024166064336895943 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0266, 1.0139, 0.9938, ..., 1.0371, 1.0381, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0266, 1.0140, 0.9937, ..., 1.0371, 1.0381, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0266, 1.0140, 0.9938, ..., 1.0371, 1.0380, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
..., | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0265, 1.0139, 0.9938, ..., 1.0371, 1.0381, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0265, 1.0139, 0.9938, ..., 1.0371, 1.0381, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0265, 1.0139, 0.9938, ..., 1.0372, 1.0381, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9962]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9631598591804504, 1.0381572246551514) | |
mean: 1.0009552240371704 | |
std: 0.008093072101473808 | |
target = tensor([[[-1.6287, -0.9663, -0.0027, ..., 1.2085, -0.9073, 0.0483], | |
[-0.2070, -0.1024, -0.5238, ..., 0.6950, -0.0898, 0.7767], | |
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583], | |
..., | |
[ 1.0039, 1.2015, 1.3542, ..., 0.8332, 0.5095, 0.6952], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[ 1.2464, 0.3930, 0.7058, ..., -0.5867, 0.7455, -0.8427], | |
[ 1.2286, -1.6981, 0.1104, ..., -1.3688, -0.4356, 0.2258], | |
..., | |
[-0.1288, 0.0194, 0.3021, ..., -0.5487, -0.8879, 0.6104], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.1324, -1.4891, -1.6448, ..., -0.2209, -0.6961, 0.3296], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
..., | |
[-1.2696, 1.7538, -0.7169, ..., -0.5047, 0.6277, 1.0967], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[-1.4777, -0.0545, 0.2544, ..., 0.3233, 0.7367, 0.1191], | |
[-0.7127, 0.5620, -2.2520, ..., 0.6136, -1.2390, -0.4233], | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
..., | |
[-0.6169, -0.1545, -0.1991, ..., 1.8318, 0.8822, -0.0214], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[-0.0338, 1.4944, -0.6408, ..., -0.5996, 1.3481, -0.3070], | |
..., | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[-0.1816, -0.4553, -1.1590, ..., -0.4902, -0.3588, 1.5264], | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
..., | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.062912464141846, 4.436890602111816) | |
mean: -0.03783148527145386 | |
std: 0.999710202217102 | |
mamba_block1.inp_proj.weight gradient: 9.031292393046897e-06 | |
mamba_block1.inp_proj.bias gradient: 1.4736494449607562e-05 | |
mamba_block1.out_proj.weight gradient: 8.168919157469645e-05 | |
mamba_block1.out_proj.bias gradient: 0.0020221523009240627 | |
mamba_block1.D.weight gradient: 2.2245831132750027e-05 | |
mamba_block1.D.bias gradient: 2.855477032426279e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.2786444990051677e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.6202048906707205e-06 | |
mamba_block1.S6.fc2.weight gradient: 2.992124791489914e-05 | |
mamba_block1.S6.fc2.bias gradient: 4.635922232409939e-05 | |
mamba_block1.S6.fc3.weight gradient: 2.8813263270421885e-05 | |
mamba_block1.S6.fc3.bias gradient: 4.471436113817617e-05 | |
mamba_block1.conv.weight gradient: 5.269802568363957e-05 | |
mamba_block1.conv.bias gradient: 9.273887371819e-06 | |
mamba_block1.conv_linear.weight gradient: 2.0372070139274e-05 | |
mamba_block1.conv_linear.bias gradient: 6.995351577643305e-05 | |
mamba_block1.norm.weight gradient: 6.127910182840424e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008746661245822906 | |
mamba_block2.inp_proj.bias gradient: 0.003085183212533593 | |
mamba_block2.out_proj.weight gradient: 0.008963636122643948 | |
mamba_block2.out_proj.bias gradient: 0.021550053730607033 | |
mamba_block2.D.weight gradient: 0.005240651313215494 | |
mamba_block2.D.bias gradient: 0.0018485068576410413 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.001189987058751285 | |
mamba_block2.S6.fc1.bias gradient: 0.0014902404509484768 | |
mamba_block2.S6.fc2.weight gradient: 0.003241603495553136 | |
mamba_block2.S6.fc2.bias gradient: 0.0039949361234903336 | |
mamba_block2.S6.fc3.weight gradient: 0.0030070182401686907 | |
mamba_block2.S6.fc3.bias gradient: 0.003639199770987034 | |
mamba_block2.conv.weight gradient: 0.012088305316865444 | |
mamba_block2.conv.bias gradient: 0.0008949778275564313 | |
mamba_block2.conv_linear.weight gradient: 0.010031497105956078 | |
mamba_block2.conv_linear.bias gradient: 0.0062803542241454124 | |
mamba_block2.norm.weight gradient: 0.0034341691061854362 | |
mamba_block3.inp_proj.weight gradient: 0.09730257093906403 | |
mamba_block3.inp_proj.bias gradient: 0.03428181633353233 | |
mamba_block3.out_proj.weight gradient: 0.047032639384269714 | |
mamba_block3.out_proj.bias gradient: 6.363734428305179e-08 | |
mamba_block3.D.weight gradient: 0.0502166673541069 | |
mamba_block3.D.bias gradient: 0.017735188826918602 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004838158842176199 | |
mamba_block3.S6.fc1.bias gradient: 0.0052208430133759975 | |
mamba_block3.S6.fc2.weight gradient: 0.019451884552836418 | |
mamba_block3.S6.fc2.bias gradient: 0.018285999074578285 | |
mamba_block3.S6.fc3.weight gradient: 0.02031179703772068 | |
mamba_block3.S6.fc3.bias gradient: 0.01904214359819889 | |
mamba_block3.conv.weight gradient: 0.16995052993297577 | |
mamba_block3.conv.bias gradient: 0.01724228635430336 | |
mamba_block3.conv_linear.weight gradient: 0.07708907872438431 | |
mamba_block3.conv_linear.bias gradient: 0.05492333322763443 | |
mamba_block3.norm.weight gradient: 0.02481512911617756 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0266, 1.0139, 0.9938, ..., 1.0372, 1.0382, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0267, 1.0139, 0.9938, ..., 1.0372, 1.0382, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0266, 1.0139, 0.9938, ..., 1.0373, 1.0383, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
..., | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0099, 0.9955], | |
[1.0267, 1.0139, 0.9938, ..., 1.0372, 1.0383, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0013, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0266, 1.0139, 0.9938, ..., 1.0373, 1.0383, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0013, 0.9948], | |
..., | |
[1.0125, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0266, 1.0139, 0.9938, ..., 1.0372, 1.0383, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9631233811378479, 1.0383379459381104) | |
mean: 1.0009559392929077 | |
std: 0.008103198371827602 | |
target = tensor([[[-0.3860, -0.5861, 0.8306, ..., 0.6317, -1.0193, -0.9245], | |
[-1.0569, -0.2186, -1.6387, ..., -1.4346, 0.8052, 0.0375], | |
[-0.2675, 0.8185, -0.1607, ..., -0.9674, -0.5626, 1.1895], | |
..., | |
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.5150, -0.6664, 2.2355, ..., 1.6450, -1.1488, -1.9170], | |
[-0.0915, 0.1275, 0.7293, ..., 1.1558, -0.3669, -0.2044], | |
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
..., | |
[-0.1322, -1.6870, -0.4999, ..., 0.5103, 0.1246, -0.4105], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 1.1350, -1.6268, 1.5461, ..., -1.8916, -1.9114, 1.1798], | |
[ 1.9510, -1.3107, 1.1983, ..., 1.4472, -0.9458, 0.6607], | |
[-0.0915, 0.1275, 0.7293, ..., 1.1558, -0.3669, -0.2044], | |
..., | |
[ 0.2615, 0.2679, -0.0044, ..., -0.9232, 0.1088, -2.0702], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[ 0.9234, -0.9421, 1.0451, ..., -0.4781, 0.0420, 0.2934], | |
[ 0.8607, 0.1229, -0.0035, ..., 0.5764, -2.2611, 1.0230], | |
[-0.0915, 0.1275, 0.7293, ..., 1.1558, -0.3669, -0.2044], | |
..., | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 1.2071, 0.1010, 1.8911, ..., 1.6783, 0.7741, 0.0761], | |
[-0.5735, -0.7252, 0.3188, ..., 0.7167, 0.8917, 1.2515], | |
[ 0.8172, 2.1389, 1.0939, ..., 0.7351, -1.6642, 1.7776], | |
..., | |
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.5068, -0.9357, 0.5597, ..., -0.2642, 1.4782, 2.3104], | |
[ 0.6347, -0.6012, 0.3480, ..., 1.5082, -0.9452, 2.0558], | |
[ 0.9335, 0.1877, -2.0042, ..., -1.1503, -1.7980, -0.5640], | |
..., | |
[ 0.5068, -0.9357, 0.5597, ..., -0.2642, 1.4782, 2.3104], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-3.895033836364746, 3.9631638526916504) | |
mean: -0.03932111710309982 | |
std: 0.9978443384170532 | |
mamba_block1.inp_proj.weight gradient: 9.64433729677694e-06 | |
mamba_block1.inp_proj.bias gradient: 1.7220121662830934e-05 | |
mamba_block1.out_proj.weight gradient: 8.797573536867276e-05 | |
mamba_block1.out_proj.bias gradient: 0.0020824200473725796 | |
mamba_block1.D.weight gradient: 2.7073385354015045e-05 | |
mamba_block1.D.bias gradient: 2.999591379193589e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.646446202765219e-06 | |
mamba_block1.S6.fc1.bias gradient: 5.3650101108360104e-06 | |
mamba_block1.S6.fc2.weight gradient: 2.464946919644717e-05 | |
mamba_block1.S6.fc2.bias gradient: 3.9769645809428766e-05 | |
mamba_block1.S6.fc3.weight gradient: 2.3926391804707237e-05 | |
mamba_block1.S6.fc3.bias gradient: 3.8730660889996216e-05 | |
mamba_block1.conv.weight gradient: 5.6322707678191364e-05 | |
mamba_block1.conv.bias gradient: 1.1165049727424048e-05 | |
mamba_block1.conv_linear.weight gradient: 2.2284859369392507e-05 | |
mamba_block1.conv_linear.bias gradient: 7.244812877615914e-05 | |
mamba_block1.norm.weight gradient: 6.758013569196919e-06 | |
mamba_block2.inp_proj.weight gradient: 0.008704984560608864 | |
mamba_block2.inp_proj.bias gradient: 0.003070454578846693 | |
mamba_block2.out_proj.weight gradient: 0.00883528869599104 | |
mamba_block2.out_proj.bias gradient: 0.021135663613677025 | |
mamba_block2.D.weight gradient: 0.005147555842995644 | |
mamba_block2.D.bias gradient: 0.0018156523583456874 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.001187782734632492 | |
mamba_block2.S6.fc1.bias gradient: 0.0014818564523011446 | |
mamba_block2.S6.fc2.weight gradient: 0.003225495107471943 | |
mamba_block2.S6.fc2.bias gradient: 0.004239422734826803 | |
mamba_block2.S6.fc3.weight gradient: 0.002991206245496869 | |
mamba_block2.S6.fc3.bias gradient: 0.0038775706198066473 | |
mamba_block2.conv.weight gradient: 0.012525566853582859 | |
mamba_block2.conv.bias gradient: 0.0008841158705763519 | |
mamba_block2.conv_linear.weight gradient: 0.009744850918650627 | |
mamba_block2.conv_linear.bias gradient: 0.006499310024082661 | |
mamba_block2.norm.weight gradient: 0.0034381034784018993 | |
mamba_block3.inp_proj.weight gradient: 0.09548863023519516 | |
mamba_block3.inp_proj.bias gradient: 0.03364047408103943 | |
mamba_block3.out_proj.weight gradient: 0.04658520221710205 | |
mamba_block3.out_proj.bias gradient: 7.18145756195554e-08 | |
mamba_block3.D.weight gradient: 0.044484060257673264 | |
mamba_block3.D.bias gradient: 0.015709370374679565 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004670510068535805 | |
mamba_block3.S6.fc1.bias gradient: 0.0048215556889772415 | |
mamba_block3.S6.fc2.weight gradient: 0.016718773171305656 | |
mamba_block3.S6.fc2.bias gradient: 0.014307697303593159 | |
mamba_block3.S6.fc3.weight gradient: 0.017529338598251343 | |
mamba_block3.S6.fc3.bias gradient: 0.014922278933227062 | |
mamba_block3.conv.weight gradient: 0.17111019790172577 | |
mamba_block3.conv.bias gradient: 0.017431458458304405 | |
mamba_block3.conv_linear.weight gradient: 0.07662271708250046 | |
mamba_block3.conv_linear.bias gradient: 0.04832540079951286 | |
mamba_block3.norm.weight gradient: 0.023405499756336212 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0268, 1.0139, 0.9938, ..., 1.0374, 1.0385, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0268, 1.0139, 0.9938, ..., 1.0374, 1.0385, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0268, 1.0140, 0.9938, ..., 1.0374, 1.0385, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
..., | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0267, 1.0139, 0.9939, ..., 1.0374, 1.0385, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0268, 1.0140, 0.9938, ..., 1.0374, 1.0384, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9952], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0268, 1.0139, 0.9939, ..., 1.0374, 1.0385, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9630730152130127, 1.0385202169418335) | |
mean: 1.0009565353393555 | |
std: 0.008112873882055283 | |
target = tensor([[[ 0.9143, 0.0766, -1.7929, ..., -0.3747, -0.3347, 1.5366], | |
[ 0.5731, -1.5050, -1.4184, ..., 1.9338, -1.1914, -0.8985], | |
[-1.1330, 1.9570, 0.5161, ..., 0.3537, 0.1684, 0.5828], | |
..., | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.6279, 2.0758, -0.0780, ..., 1.0208, -0.5319, -0.6121], | |
[ 0.6215, -0.0394, 0.0192, ..., -0.9825, 0.1665, -1.2019], | |
[-0.5596, -0.3416, 0.7026, ..., 0.5689, -0.4135, -0.8946], | |
..., | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 2.1325, 0.0683, 1.1581, ..., 1.2571, 0.4663, -0.7213], | |
[-0.1533, 0.5930, 0.1322, ..., -1.6126, -0.6235, -1.3132], | |
[ 0.7243, -0.4449, -0.2085, ..., -0.3937, 0.7526, -0.2379], | |
..., | |
[-0.8324, -0.4507, 0.2398, ..., 0.7770, -1.6973, -1.6883], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[-1.4081, -1.4504, -0.2482, ..., 1.1472, -0.2720, 1.4572], | |
[-0.3877, 0.3024, -1.1404, ..., 2.0661, -0.6191, -0.9355], | |
[ 1.5206, 0.1764, 1.2191, ..., 0.9333, -0.5523, -0.3989], | |
..., | |
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.1533, 0.5930, 0.1322, ..., -1.6126, -0.6235, -1.3132], | |
[-0.6782, -1.3975, 0.4929, ..., 0.3559, -0.3167, -1.3198], | |
[ 0.4125, 0.8607, -0.2114, ..., -0.0050, -0.0463, -1.4117], | |
..., | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.0061, 0.0731, 0.1958, ..., -0.5969, -0.9973, -2.2435], | |
[-0.0282, 0.8765, 0.3617, ..., -0.8538, 0.5377, -1.6134], | |
[ 1.2286, -1.6981, 0.1104, ..., -1.3688, -0.4356, 0.2258], | |
..., | |
[ 1.9510, -1.3107, 1.1983, ..., 1.4472, -0.9458, 0.6607], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.062912464141846, 3.9631638526916504) | |
mean: -0.03662058338522911 | |
std: 0.9983726143836975 | |
mamba_block1.inp_proj.weight gradient: 7.730282050033566e-06 | |
mamba_block1.inp_proj.bias gradient: 1.4676887076348066e-05 | |
mamba_block1.out_proj.weight gradient: 8.539092959836125e-05 | |
mamba_block1.out_proj.bias gradient: 0.002003992907702923 | |
mamba_block1.D.weight gradient: 2.3007403797237203e-05 | |
mamba_block1.D.bias gradient: 2.977289659611415e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.472130174486665e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.907219135930063e-06 | |
mamba_block1.S6.fc2.weight gradient: 2.2677684683003463e-05 | |
mamba_block1.S6.fc2.bias gradient: 3.604851008276455e-05 | |
mamba_block1.S6.fc3.weight gradient: 2.215725544374436e-05 | |
mamba_block1.S6.fc3.bias gradient: 3.528478919179179e-05 | |
mamba_block1.conv.weight gradient: 5.533275179914199e-05 | |
mamba_block1.conv.bias gradient: 9.216477337758988e-06 | |
mamba_block1.conv_linear.weight gradient: 1.9760547729674727e-05 | |
mamba_block1.conv_linear.bias gradient: 6.318444502539933e-05 | |
mamba_block1.norm.weight gradient: 5.745764610765036e-06 | |
mamba_block2.inp_proj.weight gradient: 0.009204614907503128 | |
mamba_block2.inp_proj.bias gradient: 0.0032466622069478035 | |
mamba_block2.out_proj.weight gradient: 0.009073683060705662 | |
mamba_block2.out_proj.bias gradient: 0.021660299971699715 | |
mamba_block2.D.weight gradient: 0.00532105565071106 | |
mamba_block2.D.bias gradient: 0.0018768273293972015 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0012535410933196545 | |
mamba_block2.S6.fc1.bias gradient: 0.0015219313791021705 | |
mamba_block2.S6.fc2.weight gradient: 0.00339127192273736 | |
mamba_block2.S6.fc2.bias gradient: 0.004481762647628784 | |
mamba_block2.S6.fc3.weight gradient: 0.0031439471058547497 | |
mamba_block2.S6.fc3.bias gradient: 0.004101179540157318 | |
mamba_block2.conv.weight gradient: 0.01293809525668621 | |
mamba_block2.conv.bias gradient: 0.0009332753252238035 | |
mamba_block2.conv_linear.weight gradient: 0.010446547530591488 | |
mamba_block2.conv_linear.bias gradient: 0.006634898949414492 | |
mamba_block2.norm.weight gradient: 0.0034382217563688755 | |
mamba_block3.inp_proj.weight gradient: 0.10265932977199554 | |
mamba_block3.inp_proj.bias gradient: 0.03616949915885925 | |
mamba_block3.out_proj.weight gradient: 0.04556775465607643 | |
mamba_block3.out_proj.bias gradient: 6.957547071806403e-08 | |
mamba_block3.D.weight gradient: 0.0451095886528492 | |
mamba_block3.D.bias gradient: 0.015929805114865303 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004865700379014015 | |
mamba_block3.S6.fc1.bias gradient: 0.005076521076261997 | |
mamba_block3.S6.fc2.weight gradient: 0.01791907660663128 | |
mamba_block3.S6.fc2.bias gradient: 0.015313433483242989 | |
mamba_block3.S6.fc3.weight gradient: 0.01880076713860035 | |
mamba_block3.S6.fc3.bias gradient: 0.016120247542858124 | |
mamba_block3.conv.weight gradient: 0.1731170117855072 | |
mamba_block3.conv.bias gradient: 0.017587197944521904 | |
mamba_block3.conv_linear.weight gradient: 0.08034130185842514 | |
mamba_block3.conv_linear.bias gradient: 0.048483993858098984 | |
mamba_block3.norm.weight gradient: 0.023936070501804352 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0268, 1.0140, 0.9939, ..., 1.0375, 1.0386, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0269, 1.0140, 0.9939, ..., 1.0375, 1.0386, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0269, 1.0139, 0.9939, ..., 1.0375, 1.0387, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
..., | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0269, 1.0139, 0.9939, ..., 1.0375, 1.0386, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0269, 1.0139, 0.9939, ..., 1.0375, 1.0386, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0269, 1.0139, 0.9939, ..., 1.0375, 1.0386, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.963013768196106, 1.0387181043624878) | |
mean: 1.0009573698043823 | |
std: 0.0081221554428339 | |
target = tensor([[[-0.1301, -0.9992, 0.0145, ..., -0.3068, 0.9142, -2.1975], | |
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505], | |
[ 0.5060, 2.4795, 0.5965, ..., -0.9060, -0.1549, 0.4353], | |
..., | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.3053, -0.2988, -1.8848, ..., 1.0646, -0.5250, -0.6723], | |
[ 0.1819, -1.5667, -1.6287, ..., -1.8686, 0.2220, 0.7203], | |
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417], | |
..., | |
[-0.0729, -1.5876, -0.1719, ..., 1.2421, 0.7066, 0.4504], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 2.1325, 0.0683, 1.1581, ..., 1.2571, 0.4663, -0.7213], | |
[-0.6782, -1.3975, 0.4929, ..., 0.3559, -0.3167, -1.3198], | |
[ 0.4125, 0.8607, -0.2114, ..., -0.0050, -0.0463, -1.4117], | |
..., | |
[ 2.1325, 0.0683, 1.1581, ..., 1.2571, 0.4663, -0.7213], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[-0.1533, 0.5930, 0.1322, ..., -1.6126, -0.6235, -1.3132], | |
[-0.7129, -0.4076, -0.2963, ..., 1.9239, -1.4047, 0.4096], | |
[ 0.4125, 0.8607, -0.2114, ..., -0.0050, -0.0463, -1.4117], | |
..., | |
[-0.1533, 0.5930, 0.1322, ..., -1.6126, -0.6235, -1.3132], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.8332, -0.0094, 0.1628, ..., -0.1734, -0.2614, -1.1598], | |
[-1.0951, 0.4374, 1.1074, ..., 0.7239, 0.9897, 0.3390], | |
[-0.8718, 0.4088, -0.5637, ..., 0.8139, 0.3387, -0.3325], | |
..., | |
[-0.0729, -1.5876, -0.1719, ..., 1.2421, 0.7066, 0.4504], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.4021, 1.1737, -0.5641, ..., -0.0357, 0.1684, 0.1134], | |
[ 2.2047, -2.3811, 1.1213, ..., 0.0741, 0.3054, 0.0921], | |
[ 0.3262, -0.7823, -0.1636, ..., 0.3129, -0.0835, 0.3686], | |
..., | |
[ 0.3753, -0.1169, 0.4159, ..., 0.8816, -0.7008, 1.1613], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.0426435470581055, 3.970458745956421) | |
mean: -0.03180283308029175 | |
std: 0.9980447292327881 | |
mamba_block1.inp_proj.weight gradient: 8.23084428702714e-06 | |
mamba_block1.inp_proj.bias gradient: 1.3557030797528569e-05 | |
mamba_block1.out_proj.weight gradient: 8.309278200613335e-05 | |
mamba_block1.out_proj.bias gradient: 0.0020105468574911356 | |
mamba_block1.D.weight gradient: 2.385676998528652e-05 | |
mamba_block1.D.bias gradient: 2.8827647838625126e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.5073132949037245e-06 | |
mamba_block1.S6.fc1.bias gradient: 5.155464805284282e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.3048370419710409e-05 | |
mamba_block1.S6.fc2.bias gradient: 2.1127018044353463e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.2743626029987354e-05 | |
mamba_block1.S6.fc3.bias gradient: 2.0642875824705698e-05 | |
mamba_block1.conv.weight gradient: 5.38126150786411e-05 | |
mamba_block1.conv.bias gradient: 9.490980119153392e-06 | |
mamba_block1.conv_linear.weight gradient: 1.9674673239933327e-05 | |
mamba_block1.conv_linear.bias gradient: 5.9852962294826284e-05 | |
mamba_block1.norm.weight gradient: 5.154767222848022e-06 | |
mamba_block2.inp_proj.weight gradient: 0.009204426780343056 | |
mamba_block2.inp_proj.bias gradient: 0.0032465667463839054 | |
mamba_block2.out_proj.weight gradient: 0.00896873977035284 | |
mamba_block2.out_proj.bias gradient: 0.021650521084666252 | |
mamba_block2.D.weight gradient: 0.005361825693398714 | |
mamba_block2.D.bias gradient: 0.0018912054365500808 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.001220333855599165 | |
mamba_block2.S6.fc1.bias gradient: 0.0015079170698300004 | |
mamba_block2.S6.fc2.weight gradient: 0.003332579042762518 | |
mamba_block2.S6.fc2.bias gradient: 0.004285544157028198 | |
mamba_block2.S6.fc3.weight gradient: 0.0030965039040893316 | |
mamba_block2.S6.fc3.bias gradient: 0.003915925044566393 | |
mamba_block2.conv.weight gradient: 0.012820214033126831 | |
mamba_block2.conv.bias gradient: 0.0009201719076372683 | |
mamba_block2.conv_linear.weight gradient: 0.01025470346212387 | |
mamba_block2.conv_linear.bias gradient: 0.00649257143959403 | |
mamba_block2.norm.weight gradient: 0.0035417950712144375 | |
mamba_block3.inp_proj.weight gradient: 0.09837187826633453 | |
mamba_block3.inp_proj.bias gradient: 0.034655362367630005 | |
mamba_block3.out_proj.weight gradient: 0.04687173664569855 | |
mamba_block3.out_proj.bias gradient: 8.309151411367566e-08 | |
mamba_block3.D.weight gradient: 0.053607795387506485 | |
mamba_block3.D.bias gradient: 0.01893215999007225 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.0047625149600207806 | |
mamba_block3.S6.fc1.bias gradient: 0.005159389693289995 | |
mamba_block3.S6.fc2.weight gradient: 0.0174313522875309 | |
mamba_block3.S6.fc2.bias gradient: 0.01905071921646595 | |
mamba_block3.S6.fc3.weight gradient: 0.0183222908526659 | |
mamba_block3.S6.fc3.bias gradient: 0.019951820373535156 | |
mamba_block3.conv.weight gradient: 0.17245595157146454 | |
mamba_block3.conv.bias gradient: 0.017512831836938858 | |
mamba_block3.conv_linear.weight gradient: 0.07903292030096054 | |
mamba_block3.conv_linear.bias gradient: 0.05545603483915329 | |
mamba_block3.norm.weight gradient: 0.02469041757285595 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9965, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0270, 1.0139, 0.9939, ..., 1.0377, 1.0388, 0.9929], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9965, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0270, 1.0139, 0.9939, ..., 1.0377, 1.0388, 0.9929], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0270, 1.0140, 0.9939, ..., 1.0376, 1.0388, 0.9928], | |
[1.0154, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9961]], | |
..., | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0270, 1.0139, 0.9939, ..., 1.0377, 1.0388, 0.9929], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0270, 1.0139, 0.9939, ..., 1.0377, 1.0389, 0.9929], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0270, 1.0139, 0.9939, ..., 1.0376, 1.0388, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9629841446876526, 1.0388917922973633) | |
mean: 1.0009583234786987 | |
std: 0.008131702430546284 | |
target = tensor([[[-2.2013, 1.4751, 0.8977, ..., -1.7997, -1.3911, -0.1680], | |
[-1.6287, -0.9663, -0.0027, ..., 1.2085, -0.9073, 0.0483], | |
[ 1.0801, -0.3887, -0.2138, ..., 0.7030, -1.7206, 0.2015], | |
..., | |
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-1.0850, -0.6438, 0.2743, ..., -1.1642, -0.8742, -0.2776], | |
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505], | |
[-0.4621, 0.4395, 1.3246, ..., -0.5279, 0.6105, 2.4551], | |
..., | |
[ 0.6628, 2.2315, 0.2679, ..., 0.4018, 1.3974, 0.2715], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.3699, 0.2790, 0.5842, ..., 1.5279, 0.0930, -1.2546], | |
[-0.0915, 0.1275, 0.7293, ..., 1.1558, -0.3669, -0.2044], | |
[-0.1491, -0.8928, -1.1765, ..., -0.9342, 2.1916, 0.8451], | |
..., | |
[-0.3843, 0.2086, -1.3855, ..., 0.5185, 0.1296, 0.5115], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[ 0.4125, 0.8607, -0.2114, ..., -0.0050, -0.0463, -1.4117], | |
[-0.8696, -0.4503, 0.3101, ..., 1.0256, -0.7886, -1.2446], | |
[-0.2637, 0.3004, 0.0593, ..., -1.0608, 0.3555, 0.8021], | |
..., | |
[ 1.4968, 0.2999, 0.0651, ..., -0.6530, -1.8364, 0.4741], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.0221, -1.3493, -0.6605, ..., 0.9500, -0.2841, 0.1124], | |
[-1.9098, -0.9699, -1.8455, ..., 0.6946, 1.9096, 0.4540], | |
[ 0.4838, -0.5874, -1.1409, ..., -0.1160, -0.5902, 0.5632], | |
..., | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.4428, 0.8034, -0.7362, ..., -0.6413, 2.3065, -0.3966], | |
[ 0.5679, 1.1681, -0.2152, ..., 0.4324, -0.3278, 0.3071], | |
[-0.6504, -1.3069, -0.1538, ..., 0.7510, -1.4239, -0.2393], | |
..., | |
[ 0.1019, -0.3633, 1.4970, ..., -1.1070, 1.8470, 0.7881], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.203547954559326, 4.436890602111816) | |
mean: -0.03840525448322296 | |
std: 1.0000741481781006 | |
mamba_block1.inp_proj.weight gradient: 8.991964023152832e-06 | |
mamba_block1.inp_proj.bias gradient: 1.7170854334835894e-05 | |
mamba_block1.out_proj.weight gradient: 9.231397416442633e-05 | |
mamba_block1.out_proj.bias gradient: 0.00212532258592546 | |
mamba_block1.D.weight gradient: 2.5002524125739e-05 | |
mamba_block1.D.bias gradient: 3.0523628083756194e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.370713557160343e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.826616986974841e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.8368005839874968e-05 | |
mamba_block1.S6.fc2.bias gradient: 2.9517530492739752e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.795031494111754e-05 | |
mamba_block1.S6.fc3.bias gradient: 2.8881711841677316e-05 | |
mamba_block1.conv.weight gradient: 5.597508788923733e-05 | |
mamba_block1.conv.bias gradient: 9.704062904347666e-06 | |
mamba_block1.conv_linear.weight gradient: 2.1619767721858807e-05 | |
mamba_block1.conv_linear.bias gradient: 6.322540866676718e-05 | |
mamba_block1.norm.weight gradient: 5.060106559540145e-06 | |
mamba_block2.inp_proj.weight gradient: 0.009407997131347656 | |
mamba_block2.inp_proj.bias gradient: 0.003318359376862645 | |
mamba_block2.out_proj.weight gradient: 0.00935858953744173 | |
mamba_block2.out_proj.bias gradient: 0.022997798398137093 | |
mamba_block2.D.weight gradient: 0.00555233983322978 | |
mamba_block2.D.bias gradient: 0.001958364387974143 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0012591216946020722 | |
mamba_block2.S6.fc1.bias gradient: 0.0015584889333695173 | |
mamba_block2.S6.fc2.weight gradient: 0.003501052735373378 | |
mamba_block2.S6.fc2.bias gradient: 0.0046097771264612675 | |
mamba_block2.S6.fc3.weight gradient: 0.0032506112474948168 | |
mamba_block2.S6.fc3.bias gradient: 0.004216910805553198 | |
mamba_block2.conv.weight gradient: 0.012959428131580353 | |
mamba_block2.conv.bias gradient: 0.0009631924913264811 | |
mamba_block2.conv_linear.weight gradient: 0.01051102951169014 | |
mamba_block2.conv_linear.bias gradient: 0.006955367047339678 | |
mamba_block2.norm.weight gradient: 0.0037049155216664076 | |
mamba_block3.inp_proj.weight gradient: 0.10352003574371338 | |
mamba_block3.inp_proj.bias gradient: 0.03647966310381889 | |
mamba_block3.out_proj.weight gradient: 0.048568934202194214 | |
mamba_block3.out_proj.bias gradient: 5.696314886449727e-08 | |
mamba_block3.D.weight gradient: 0.048673514276742935 | |
mamba_block3.D.bias gradient: 0.017192283645272255 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004940532147884369 | |
mamba_block3.S6.fc1.bias gradient: 0.00535177206620574 | |
mamba_block3.S6.fc2.weight gradient: 0.017920760437846184 | |
mamba_block3.S6.fc2.bias gradient: 0.01845845952630043 | |
mamba_block3.S6.fc3.weight gradient: 0.01888885162770748 | |
mamba_block3.S6.fc3.bias gradient: 0.019463254138827324 | |
mamba_block3.conv.weight gradient: 0.17237348854541779 | |
mamba_block3.conv.bias gradient: 0.01762760430574417 | |
mamba_block3.conv_linear.weight gradient: 0.08468104898929596 | |
mamba_block3.conv_linear.bias gradient: 0.054214172065258026 | |
mamba_block3.norm.weight gradient: 0.025246795266866684 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0271, 1.0139, 0.9939, ..., 1.0378, 1.0390, 0.9929], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0271, 1.0139, 0.9940, ..., 1.0378, 1.0390, 0.9928], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0271, 1.0140, 0.9939, ..., 1.0378, 1.0390, 0.9929], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
..., | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0053, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0271, 1.0140, 0.9940, ..., 1.0378, 1.0390, 0.9929], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0130, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0066, 0.9966, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0271, 1.0140, 0.9940, ..., 1.0378, 1.0390, 0.9929], | |
[1.0154, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0066, 0.9966, ..., 1.0036, 1.0034, 0.9951], | |
[1.0050, 1.0066, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0271, 1.0139, 0.9940, ..., 1.0378, 1.0390, 0.9929], | |
[1.0154, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9629332423210144, 1.0391098260879517) | |
mean: 1.0009595155715942 | |
std: 0.008141160011291504 | |
target = tensor([[[-1.8136, -0.6341, -0.1093, ..., 0.4384, -0.3662, -0.9972], | |
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
[ 0.7458, 1.6577, 0.0364, ..., 1.2313, -0.1711, 0.0749], | |
..., | |
[ 1.4597, 1.8877, -0.5288, ..., -0.6553, -0.6894, -0.6879], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 2.1325, 0.0683, 1.1581, ..., 1.2571, 0.4663, -0.7213], | |
[-0.6782, -1.3975, 0.4929, ..., 0.3559, -0.3167, -1.3198], | |
[ 0.4125, 0.8607, -0.2114, ..., -0.0050, -0.0463, -1.4117], | |
..., | |
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.8389, -0.8643, 0.4242, ..., 0.5848, 1.5457, -0.4353], | |
[-0.8738, -0.8262, 0.1785, ..., -0.7729, 0.3997, 1.0206], | |
[-0.0377, -0.0993, -0.3354, ..., -0.4587, 2.1620, -1.0884], | |
..., | |
[-0.8324, -0.4507, 0.2398, ..., 0.7770, -1.6973, -1.6883], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[ 0.9234, -0.9421, 1.0451, ..., -0.4781, 0.0420, 0.2934], | |
[-1.6287, -0.9663, -0.0027, ..., 1.2085, -0.9073, 0.0483], | |
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
..., | |
[ 0.8195, 1.3160, 0.7905, ..., 0.3638, -0.4126, 0.3174], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.5068, -0.9357, 0.5597, ..., -0.2642, 1.4782, 2.3104], | |
[ 0.6347, -0.6012, 0.3480, ..., 1.5082, -0.9452, 2.0558], | |
[ 0.9335, 0.1877, -2.0042, ..., -1.1503, -1.7980, -0.5640], | |
..., | |
[-0.3877, 0.3024, -1.1404, ..., 2.0661, -0.6191, -0.9355], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.8417, 0.6482, 2.4525, ..., -0.0736, -1.3844, -1.5417], | |
[-1.6539, -0.2380, 1.2548, ..., -0.2029, -0.3846, -1.4885], | |
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505], | |
..., | |
[ 0.5459, -1.0019, 1.6465, ..., -0.7943, 1.1101, 1.4487], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.062912464141846, 3.9631638526916504) | |
mean: -0.043465666472911835 | |
std: 0.9981724619865417 | |
mamba_block1.inp_proj.weight gradient: 9.746233445184771e-06 | |
mamba_block1.inp_proj.bias gradient: 1.613192944205366e-05 | |
mamba_block1.out_proj.weight gradient: 9.361501724924892e-05 | |
mamba_block1.out_proj.bias gradient: 0.002090686932206154 | |
mamba_block1.D.weight gradient: 2.397046955593396e-05 | |
mamba_block1.D.bias gradient: 2.898801540140994e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.917689809895819e-06 | |
mamba_block1.S6.fc1.bias gradient: 5.5823461480031256e-06 | |
mamba_block1.S6.fc2.weight gradient: 2.6315385184716433e-05 | |
mamba_block1.S6.fc2.bias gradient: 4.1125862480839714e-05 | |
mamba_block1.S6.fc3.weight gradient: 2.5236464352929033e-05 | |
mamba_block1.S6.fc3.bias gradient: 3.954580824938603e-05 | |
mamba_block1.conv.weight gradient: 5.663771298713982e-05 | |
mamba_block1.conv.bias gradient: 1.0728360393841285e-05 | |
mamba_block1.conv_linear.weight gradient: 2.2806825654697604e-05 | |
mamba_block1.conv_linear.bias gradient: 7.061885844450444e-05 | |
mamba_block1.norm.weight gradient: 6.072532869438874e-06 | |
mamba_block2.inp_proj.weight gradient: 0.009219370782375336 | |
mamba_block2.inp_proj.bias gradient: 0.003251793095842004 | |
mamba_block2.out_proj.weight gradient: 0.009203227236866951 | |
mamba_block2.out_proj.bias gradient: 0.022098751738667488 | |
mamba_block2.D.weight gradient: 0.00542529858648777 | |
mamba_block2.D.bias gradient: 0.0019135409966111183 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0012004896998405457 | |
mamba_block2.S6.fc1.bias gradient: 0.0015050115762278438 | |
mamba_block2.S6.fc2.weight gradient: 0.0033298747148364782 | |
mamba_block2.S6.fc2.bias gradient: 0.0042640469036996365 | |
mamba_block2.S6.fc3.weight gradient: 0.0030927204061299562 | |
mamba_block2.S6.fc3.bias gradient: 0.0038947444409132004 | |
mamba_block2.conv.weight gradient: 0.012392286211252213 | |
mamba_block2.conv.bias gradient: 0.0009247513953596354 | |
mamba_block2.conv_linear.weight gradient: 0.010260899551212788 | |
mamba_block2.conv_linear.bias gradient: 0.006446880754083395 | |
mamba_block2.norm.weight gradient: 0.0036036702804267406 | |
mamba_block3.inp_proj.weight gradient: 0.09711047261953354 | |
mamba_block3.inp_proj.bias gradient: 0.034207314252853394 | |
mamba_block3.out_proj.weight gradient: 0.04665215313434601 | |
mamba_block3.out_proj.bias gradient: 8.800382289564368e-08 | |
mamba_block3.D.weight gradient: 0.04661861062049866 | |
mamba_block3.D.bias gradient: 0.016462555155158043 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004771151579916477 | |
mamba_block3.S6.fc1.bias gradient: 0.004979342687875032 | |
mamba_block3.S6.fc2.weight gradient: 0.017470696941018105 | |
mamba_block3.S6.fc2.bias gradient: 0.018941262736916542 | |
mamba_block3.S6.fc3.weight gradient: 0.018347127363085747 | |
mamba_block3.S6.fc3.bias gradient: 0.01973014324903488 | |
mamba_block3.conv.weight gradient: 0.176979660987854 | |
mamba_block3.conv.bias gradient: 0.017966141924262047 | |
mamba_block3.conv_linear.weight gradient: 0.0760805532336235 | |
mamba_block3.conv_linear.bias gradient: 0.0522383488714695 | |
mamba_block3.norm.weight gradient: 0.024708587676286697 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0066, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0272, 1.0140, 0.9940, ..., 1.0379, 1.0392, 0.9929], | |
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0066, 0.9966, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0272, 1.0139, 0.9940, ..., 1.0379, 1.0392, 0.9929], | |
[1.0155, 1.0009, 1.0024, ..., 1.0066, 1.0130, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0066, 0.9966, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0271, 1.0139, 0.9940, ..., 1.0379, 1.0391, 0.9929], | |
[1.0155, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
..., | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0066, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0013, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0051, 1.0100, 0.9955], | |
[1.0272, 1.0139, 0.9940, ..., 1.0379, 1.0392, 0.9929], | |
[1.0155, 1.0009, 1.0024, ..., 1.0066, 1.0129, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0066, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0272, 1.0139, 0.9940, ..., 1.0379, 1.0392, 0.9929], | |
[1.0155, 1.0009, 1.0024, ..., 1.0066, 1.0130, 0.9961]], | |
[[1.0054, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0066, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0272, 1.0140, 0.9940, ..., 1.0379, 1.0392, 0.9929], | |
[1.0155, 1.0009, 1.0024, ..., 1.0066, 1.0130, 0.9961]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9628950357437134, 1.0392659902572632) | |
mean: 1.0009607076644897 | |
std: 0.008150782436132431 | |
target = tensor([[[-1.6287e+00, -9.6626e-01, -2.6581e-03, ..., 1.2085e+00, | |
-9.0730e-01, 4.8345e-02], | |
[ 1.3781e-01, -3.8981e-01, 4.6194e-01, ..., 1.9883e-01, | |
-3.7158e-01, 3.5527e-01], | |
[-1.8160e-01, -4.5527e-01, -1.1590e+00, ..., -4.9020e-01, | |
-3.5882e-01, 1.5264e+00], | |
..., | |
[-1.3817e+00, 1.9420e-01, -1.2593e+00, ..., 5.9164e-02, | |
4.6726e-01, 3.5826e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[-1.8078e-01, 2.4252e-01, -6.3844e-01, ..., -7.8935e-02, | |
6.0249e-01, -5.8976e-01], | |
[-1.5419e+00, 1.8212e+00, 1.8157e+00, ..., -9.8702e-03, | |
-7.1342e-01, -2.5576e-01], | |
[ 6.8187e-01, 6.5033e-01, -7.8437e-02, ..., 1.2032e+00, | |
4.3723e-01, 5.0549e-02], | |
..., | |
[ 2.0079e+00, 2.8988e-01, 1.4619e-01, ..., 1.1677e-02, | |
1.5477e-01, 9.1439e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 1.1616e+00, -1.3193e+00, -6.1403e-01, ..., 1.9209e+00, | |
1.3262e+00, 5.9860e-02], | |
[ 3.5665e-01, -3.1154e-01, -1.2586e+00, ..., -9.5706e-01, | |
-2.0711e+00, -1.0293e+00], | |
[-6.5829e-01, -2.2524e-01, 2.0800e+00, ..., 7.8087e-01, | |
7.4104e-01, -1.9717e+00], | |
..., | |
[-5.4846e-01, -1.0681e+00, -1.4576e+00, ..., -9.8265e-01, | |
2.3030e+00, 1.3365e+00], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
..., | |
[[-1.2654e-01, -6.2001e-01, -2.4523e+00, ..., -1.2879e+00, | |
-2.9765e-01, 1.6772e+00], | |
[ 2.0127e-01, 3.0885e-01, 1.0572e+00, ..., 2.7429e-01, | |
-7.5508e-01, 3.9383e-01], | |
[ 6.6372e-01, 1.7408e-01, -3.5581e-01, ..., -1.4354e+00, | |
1.1672e+00, 4.4820e-02], | |
..., | |
[ 3.2237e-01, 5.8461e-01, -3.4121e-03, ..., -2.5780e-01, | |
-1.3302e+00, -5.8217e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[ 8.9742e-01, -8.8120e-01, 7.8961e-01, ..., -8.5716e-01, | |
-1.6618e+00, -3.6012e-01], | |
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02, | |
-6.2687e-01, -4.9886e-01], | |
[ 2.5504e+00, -5.1241e-01, -2.8401e-01, ..., -9.7184e-02, | |
-6.2687e-01, -4.9886e-01], | |
..., | |
[ 1.7525e-01, -7.6253e-01, 1.4695e-01, ..., 3.1829e-01, | |
1.6904e-01, -1.1843e-01], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]], | |
[[-1.5811e-01, 1.7191e-01, 1.5170e+00, ..., -4.7812e-01, | |
4.5400e-02, 1.0040e+00], | |
[-7.9920e-01, 1.3202e+00, -1.7626e-01, ..., 1.0812e-01, | |
-7.9432e-02, -3.9323e-01], | |
[-3.4205e-01, 1.4121e+00, 2.6875e+00, ..., 2.3489e-01, | |
1.2428e-01, -6.1681e-02], | |
..., | |
[-1.6287e+00, -9.6626e-01, -2.6581e-03, ..., 1.2085e+00, | |
-9.0730e-01, 4.8345e-02], | |
[ 9.1893e-01, 5.5449e-01, 1.3656e+00, ..., 7.6945e-01, | |
1.3140e+00, 2.3112e-01], | |
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, | |
0.0000e+00, 0.0000e+00]]], device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.203547954559326, 4.436890602111816) | |
mean: -0.04427793249487877 | |
std: 0.9962211847305298 | |
mamba_block1.inp_proj.weight gradient: 9.651295840740204e-06 | |
mamba_block1.inp_proj.bias gradient: 1.6675130609655753e-05 | |
mamba_block1.out_proj.weight gradient: 9.673438034951687e-05 | |
mamba_block1.out_proj.bias gradient: 0.002231176942586899 | |
mamba_block1.D.weight gradient: 2.4628068786114454e-05 | |
mamba_block1.D.bias gradient: 3.227106572012417e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.960161848226562e-06 | |
mamba_block1.S6.fc1.bias gradient: 5.747951036028098e-06 | |
mamba_block1.S6.fc2.weight gradient: 2.7824979042634368e-05 | |
mamba_block1.S6.fc2.bias gradient: 4.349627124611288e-05 | |
mamba_block1.S6.fc3.weight gradient: 2.6914129193755798e-05 | |
mamba_block1.S6.fc3.bias gradient: 4.2169736843788996e-05 | |
mamba_block1.conv.weight gradient: 5.887298539164476e-05 | |
mamba_block1.conv.bias gradient: 1.2363690075289924e-05 | |
mamba_block1.conv_linear.weight gradient: 2.250147554150317e-05 | |
mamba_block1.conv_linear.bias gradient: 7.393556734314188e-05 | |
mamba_block1.norm.weight gradient: 7.986875061760657e-06 | |
mamba_block2.inp_proj.weight gradient: 0.00992958340793848 | |
mamba_block2.inp_proj.bias gradient: 0.003502229694277048 | |
mamba_block2.out_proj.weight gradient: 0.00934018474072218 | |
mamba_block2.out_proj.bias gradient: 0.021850064396858215 | |
mamba_block2.D.weight gradient: 0.005657918751239777 | |
mamba_block2.D.bias gradient: 0.0019955879542976618 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0012419703416526318 | |
mamba_block2.S6.fc1.bias gradient: 0.0015505834016948938 | |
mamba_block2.S6.fc2.weight gradient: 0.0034699649550020695 | |
mamba_block2.S6.fc2.bias gradient: 0.0044254641979932785 | |
mamba_block2.S6.fc3.weight gradient: 0.0032270450610667467 | |
mamba_block2.S6.fc3.bias gradient: 0.0040438235737383366 | |
mamba_block2.conv.weight gradient: 0.01314946822822094 | |
mamba_block2.conv.bias gradient: 0.000992308254353702 | |
mamba_block2.conv_linear.weight gradient: 0.01082244236022234 | |
mamba_block2.conv_linear.bias gradient: 0.006649897433817387 | |
mamba_block2.norm.weight gradient: 0.003846563631668687 | |
mamba_block3.inp_proj.weight gradient: 0.09649848937988281 | |
mamba_block3.inp_proj.bias gradient: 0.033982012420892715 | |
mamba_block3.out_proj.weight gradient: 0.04800761863589287 | |
mamba_block3.out_proj.bias gradient: 5.8161077731710975e-08 | |
mamba_block3.D.weight gradient: 0.04944797605276108 | |
mamba_block3.D.bias gradient: 0.017467014491558075 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004728706553578377 | |
mamba_block3.S6.fc1.bias gradient: 0.004932350944727659 | |
mamba_block3.S6.fc2.weight gradient: 0.017793580889701843 | |
mamba_block3.S6.fc2.bias gradient: 0.018198303878307343 | |
mamba_block3.S6.fc3.weight gradient: 0.018742457032203674 | |
mamba_block3.S6.fc3.bias gradient: 0.01912083476781845 | |
mamba_block3.conv.weight gradient: 0.17541822791099548 | |
mamba_block3.conv.bias gradient: 0.01787564344704151 | |
mamba_block3.conv_linear.weight gradient: 0.07878080010414124 | |
mamba_block3.conv_linear.bias gradient: 0.056361325085163116 | |
mamba_block3.norm.weight gradient: 0.024798491969704628 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0273, 1.0139, 0.9940, ..., 1.0381, 1.0394, 0.9929], | |
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0066, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0273, 1.0140, 0.9940, ..., 1.0380, 1.0393, 0.9929], | |
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0066, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0273, 1.0140, 0.9940, ..., 1.0381, 1.0394, 0.9929], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
..., | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9948], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0273, 1.0140, 0.9940, ..., 1.0381, 1.0394, 0.9929], | |
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0066, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0273, 1.0140, 0.9940, ..., 1.0381, 1.0394, 0.9929], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0273, 1.0140, 0.9940, ..., 1.0380, 1.0394, 0.9929], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9628592133522034, 1.0394682884216309) | |
mean: 1.0009617805480957 | |
std: 0.008160555735230446 | |
target = tensor([[[ 0.1249, 0.1479, 0.4132, ..., -0.2172, -0.6020, 0.3062], | |
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583], | |
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
..., | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[ 2.3047, 0.6387, -1.2971, ..., -0.9008, -0.7687, -0.1274], | |
..., | |
[-2.1382, -0.4375, 0.1092, ..., -0.3279, 0.5643, 0.2475], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.0803, -0.3617, -0.1729, ..., -0.1031, 1.7060, -0.8089], | |
[ 0.4838, -0.5874, -1.1409, ..., -0.1160, -0.5902, 0.5632], | |
[-0.0287, -1.1162, -0.5596, ..., 1.2069, -1.2071, 0.0189], | |
..., | |
[-0.5417, 0.3940, 0.3822, ..., -0.3933, -1.1325, -0.0510], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[ 1.1773, 0.2168, 0.4060, ..., -0.1085, 0.1342, 0.5152], | |
[ 0.1019, -0.3633, 1.4970, ..., -1.1070, 1.8470, 0.7881], | |
[-0.4072, 0.9312, 1.0190, ..., -0.9175, 0.1262, -0.9890], | |
..., | |
[ 0.8941, -2.4687, 0.5529, ..., 0.0181, 0.2483, 0.0552], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 1.0778, -0.1816, -0.6237, ..., 0.5324, -0.4506, -0.2228], | |
[-0.0915, 0.1275, 0.7293, ..., 1.1558, -0.3669, -0.2044], | |
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
..., | |
[ 0.9335, 0.1877, -2.0042, ..., -1.1503, -1.7980, -0.5640], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417], | |
[-0.3029, -0.7849, 0.6537, ..., 0.0216, -0.0510, 1.3417], | |
[-0.3334, 1.1392, 0.6457, ..., -0.8796, -1.0417, -0.8816], | |
..., | |
[ 0.0445, -1.4234, -0.5175, ..., 1.3668, -0.1756, 0.9730], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.0426435470581055, 3.9631638526916504) | |
mean: -0.03936820104718208 | |
std: 0.9987026453018188 | |
mamba_block1.inp_proj.weight gradient: 7.3608516686363146e-06 | |
mamba_block1.inp_proj.bias gradient: 1.3926341125625186e-05 | |
mamba_block1.out_proj.weight gradient: 9.16551289265044e-05 | |
mamba_block1.out_proj.bias gradient: 0.0020849842112511396 | |
mamba_block1.D.weight gradient: 2.415824019408319e-05 | |
mamba_block1.D.bias gradient: 2.9187509426265024e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.2261905289487913e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.610864834830863e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.9801318558165804e-05 | |
mamba_block1.S6.fc2.bias gradient: 3.303638368379325e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.9410605091252364e-05 | |
mamba_block1.S6.fc3.bias gradient: 3.262169411755167e-05 | |
mamba_block1.conv.weight gradient: 5.535763193620369e-05 | |
mamba_block1.conv.bias gradient: 1.002511453407351e-05 | |
mamba_block1.conv_linear.weight gradient: 2.053168100246694e-05 | |
mamba_block1.conv_linear.bias gradient: 6.233315070858225e-05 | |
mamba_block1.norm.weight gradient: 6.316646249615587e-06 | |
mamba_block2.inp_proj.weight gradient: 0.009400204755365849 | |
mamba_block2.inp_proj.bias gradient: 0.003315509995445609 | |
mamba_block2.out_proj.weight gradient: 0.009219340980052948 | |
mamba_block2.out_proj.bias gradient: 0.021291621029376984 | |
mamba_block2.D.weight gradient: 0.005432978272438049 | |
mamba_block2.D.bias gradient: 0.0019162470707669854 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0012191222049295902 | |
mamba_block2.S6.fc1.bias gradient: 0.00152978312689811 | |
mamba_block2.S6.fc2.weight gradient: 0.003414183622226119 | |
mamba_block2.S6.fc2.bias gradient: 0.004413718823343515 | |
mamba_block2.S6.fc3.weight gradient: 0.003169463714584708 | |
mamba_block2.S6.fc3.bias gradient: 0.004032541066408157 | |
mamba_block2.conv.weight gradient: 0.012530333362519741 | |
mamba_block2.conv.bias gradient: 0.0009309174492955208 | |
mamba_block2.conv_linear.weight gradient: 0.010265973396599293 | |
mamba_block2.conv_linear.bias gradient: 0.0068965875543653965 | |
mamba_block2.norm.weight gradient: 0.0036086307372897863 | |
mamba_block3.inp_proj.weight gradient: 0.09822450578212738 | |
mamba_block3.inp_proj.bias gradient: 0.034595269709825516 | |
mamba_block3.out_proj.weight gradient: 0.04629985988140106 | |
mamba_block3.out_proj.bias gradient: 7.177833083460428e-08 | |
mamba_block3.D.weight gradient: 0.04594907537102699 | |
mamba_block3.D.bias gradient: 0.01622462458908558 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.004941641353070736 | |
mamba_block3.S6.fc1.bias gradient: 0.005250551737844944 | |
mamba_block3.S6.fc2.weight gradient: 0.01758144423365593 | |
mamba_block3.S6.fc2.bias gradient: 0.015057405456900597 | |
mamba_block3.S6.fc3.weight gradient: 0.018423432484269142 | |
mamba_block3.S6.fc3.bias gradient: 0.015630599111318588 | |
mamba_block3.conv.weight gradient: 0.17396844923496246 | |
mamba_block3.conv.bias gradient: 0.018070943653583527 | |
mamba_block3.conv_linear.weight gradient: 0.07747268676757812 | |
mamba_block3.conv_linear.bias gradient: 0.04878745600581169 | |
mamba_block3.norm.weight gradient: 0.023946603760123253 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0274, 1.0140, 0.9941, ..., 1.0382, 1.0396, 0.9929], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0274, 1.0140, 0.9941, ..., 1.0382, 1.0396, 0.9930], | |
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0274, 1.0140, 0.9941, ..., 1.0382, 1.0395, 0.9929], | |
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
..., | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0274, 1.0139, 0.9941, ..., 1.0382, 1.0396, 0.9929], | |
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0274, 1.0139, 0.9941, ..., 1.0382, 1.0396, 0.9929], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0274, 1.0140, 0.9941, ..., 1.0382, 1.0396, 0.9929], | |
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9628267288208008, 1.0396476984024048) | |
mean: 1.000962495803833 | |
std: 0.008170326240360737 | |
target = tensor([[[-1.5137, 0.0657, -0.9680, ..., 1.6269, -0.2294, 0.1420], | |
[-0.0282, 0.8765, 0.3617, ..., -0.8538, 0.5377, -1.6134], | |
[-0.1533, 0.5930, 0.1322, ..., -1.6126, -0.6235, -1.3132], | |
..., | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
[-0.3501, 0.5389, -0.7310, ..., -0.0815, 0.4691, 0.4229], | |
[ 1.8674, -0.6901, -1.5037, ..., 0.8689, 1.6506, 0.1824], | |
..., | |
[ 1.7151, 1.0070, 0.6890, ..., -2.3825, -0.5136, 0.5498], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 1.1616, -1.3193, -0.6140, ..., 1.9209, 1.3262, 0.0599], | |
[ 0.3868, -1.0279, -0.3675, ..., -0.6507, 0.5047, -0.5453], | |
[-0.0137, -1.7578, 0.7203, ..., -0.7771, 1.8718, -0.1505], | |
..., | |
[-0.8427, 1.6156, 1.2061, ..., 0.4317, 1.9322, 0.3907], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[ 0.8524, -0.2206, 0.9268, ..., -0.7495, -0.6237, 0.3975], | |
[-1.7090, -1.0052, -1.0034, ..., -0.9609, -1.5528, -0.8253], | |
[ 1.1688, 0.4326, 0.6992, ..., -0.6485, -0.1625, 1.0952], | |
..., | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.3843, -1.3493, -1.1372, ..., -1.0553, 0.6164, 1.1378], | |
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583], | |
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
..., | |
[-1.8815, 0.8164, 0.5484, ..., -0.3336, 0.1990, 1.8763], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-2.3747, -0.7223, 0.2557, ..., 0.5687, -0.0835, -2.1125], | |
[-1.0612, -0.9659, 0.0180, ..., 2.1914, -2.7829, 1.4622], | |
[-1.3817, 0.1942, -1.2593, ..., 0.0592, 0.4673, 0.3583], | |
..., | |
[-1.6361, 0.3128, -1.7070, ..., 1.2532, -1.0657, 0.4411], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-3.7949037551879883, 3.9631638526916504) | |
mean: -0.04261881858110428 | |
std: 0.9969062209129333 | |
mamba_block1.inp_proj.weight gradient: 1.1255859135417268e-05 | |
mamba_block1.inp_proj.bias gradient: 1.5913874449324794e-05 | |
mamba_block1.out_proj.weight gradient: 8.910077303880826e-05 | |
mamba_block1.out_proj.bias gradient: 0.0021355818025767803 | |
mamba_block1.D.weight gradient: 2.525844865886029e-05 | |
mamba_block1.D.bias gradient: 3.0218390747904778e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.48461890098406e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.768642611452378e-06 | |
mamba_block1.S6.fc2.weight gradient: 1.4115617887000553e-05 | |
mamba_block1.S6.fc2.bias gradient: 2.272063647978939e-05 | |
mamba_block1.S6.fc3.weight gradient: 1.4028752957528923e-05 | |
mamba_block1.S6.fc3.bias gradient: 2.2550859284820035e-05 | |
mamba_block1.conv.weight gradient: 5.71208875044249e-05 | |
mamba_block1.conv.bias gradient: 1.0420963008073159e-05 | |
mamba_block1.conv_linear.weight gradient: 2.3161373974289745e-05 | |
mamba_block1.conv_linear.bias gradient: 6.114652205724269e-05 | |
mamba_block1.norm.weight gradient: 3.6366443509905366e-06 | |
mamba_block2.inp_proj.weight gradient: 0.009616881608963013 | |
mamba_block2.inp_proj.bias gradient: 0.0033918858971446753 | |
mamba_block2.out_proj.weight gradient: 0.009563383646309376 | |
mamba_block2.out_proj.bias gradient: 0.02152535691857338 | |
mamba_block2.D.weight gradient: 0.005726755131036043 | |
mamba_block2.D.bias gradient: 0.0020198312122374773 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0013312151422724128 | |
mamba_block2.S6.fc1.bias gradient: 0.0016568585997447371 | |
mamba_block2.S6.fc2.weight gradient: 0.0035273043904453516 | |
mamba_block2.S6.fc2.bias gradient: 0.004576574545353651 | |
mamba_block2.S6.fc3.weight gradient: 0.0032715476118028164 | |
mamba_block2.S6.fc3.bias gradient: 0.0041853212751448154 | |
mamba_block2.conv.weight gradient: 0.012951940298080444 | |
mamba_block2.conv.bias gradient: 0.0009483825415372849 | |
mamba_block2.conv_linear.weight gradient: 0.010837584733963013 | |
mamba_block2.conv_linear.bias gradient: 0.007166210561990738 | |
mamba_block2.norm.weight gradient: 0.0036460179835557938 | |
mamba_block3.inp_proj.weight gradient: 0.09701595455408096 | |
mamba_block3.inp_proj.bias gradient: 0.034169118851423264 | |
mamba_block3.out_proj.weight gradient: 0.047714706510305405 | |
mamba_block3.out_proj.bias gradient: 3.2355920609461464e-08 | |
mamba_block3.D.weight gradient: 0.05303538590669632 | |
mamba_block3.D.bias gradient: 0.018727842718362808 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.005017312243580818 | |
mamba_block3.S6.fc1.bias gradient: 0.0054448735900223255 | |
mamba_block3.S6.fc2.weight gradient: 0.016375340521335602 | |
mamba_block3.S6.fc2.bias gradient: 0.01577238366007805 | |
mamba_block3.S6.fc3.weight gradient: 0.017361776903271675 | |
mamba_block3.S6.fc3.bias gradient: 0.016659947112202644 | |
mamba_block3.conv.weight gradient: 0.17826542258262634 | |
mamba_block3.conv.bias gradient: 0.018036192283034325 | |
mamba_block3.conv_linear.weight gradient: 0.0783582255244255 | |
mamba_block3.conv_linear.bias gradient: 0.04873950034379959 | |
mamba_block3.norm.weight gradient: 0.023995602503418922 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0276, 1.0140, 0.9941, ..., 1.0383, 1.0398, 0.9930], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0276, 1.0140, 0.9941, ..., 1.0383, 1.0398, 0.9930], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0275, 1.0140, 0.9941, ..., 1.0383, 1.0398, 0.9929], | |
[1.0155, 1.0009, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
..., | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0276, 1.0140, 0.9941, ..., 1.0383, 1.0398, 0.9930], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0275, 1.0140, 0.9941, ..., 1.0383, 1.0398, 0.9930], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0035, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0275, 1.0139, 0.9941, ..., 1.0383, 1.0398, 0.9929], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9627799391746521, 1.0398532152175903) | |
mean: 1.0009632110595703 | |
std: 0.008179647848010063 | |
target = tensor([[[ 0.1461, -1.5373, 1.8414, ..., 2.1830, 2.1411, -0.5229], | |
[-0.8324, -0.4507, 0.2398, ..., 0.7770, -1.6973, -1.6883], | |
[-0.1043, 1.6985, 0.2488, ..., 0.7312, 1.5784, 2.1510], | |
..., | |
[ 0.6566, -1.6146, 0.8445, ..., -0.1293, -0.1222, -0.1538], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.4719, 0.3617, -0.3579, ..., -0.3869, 1.6128, 0.2484], | |
[ 0.4135, 0.7320, -0.6458, ..., 1.4670, 0.7813, -1.1558], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
..., | |
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.8389, -0.8643, 0.4242, ..., 0.5848, 1.5457, -0.4353], | |
[-0.1063, -1.2993, 1.8509, ..., -1.6608, 1.7566, -1.0788], | |
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505], | |
..., | |
[ 0.8352, 0.9417, -0.3653, ..., -0.0158, -0.0074, 0.4276], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[ 0.4102, -0.0038, -0.1229, ..., -0.2580, 1.4403, -0.2463], | |
[-0.9544, -0.8690, 0.8258, ..., -1.0243, 1.2432, 0.9506], | |
[ 2.0079, 0.2899, 0.1462, ..., 0.0117, 0.1548, 0.9144], | |
..., | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 0.6819, 0.6503, -0.0784, ..., 1.2032, 0.4372, 0.0505], | |
[-0.6884, 0.6784, -1.0951, ..., -1.3010, 0.1661, -0.5777], | |
..., | |
[ 1.6601, -0.1799, 0.7201, ..., 0.6700, 0.4782, -0.1476], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.8417, 0.6482, 2.4525, ..., -0.0736, -1.3844, -1.5417], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
[ 2.5504, -0.5124, -0.2840, ..., -0.0972, -0.6269, -0.4989], | |
..., | |
[ 0.0948, -1.3712, -1.2927, ..., 0.6679, 0.6076, 0.3466], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-4.062912464141846, 3.9631638526916504) | |
mean: -0.0383220836520195 | |
std: 0.9961226582527161 | |
mamba_block1.inp_proj.weight gradient: 1.0743613529484719e-05 | |
mamba_block1.inp_proj.bias gradient: 1.3487114301824477e-05 | |
mamba_block1.out_proj.weight gradient: 9.076731657842174e-05 | |
mamba_block1.out_proj.bias gradient: 0.0021569449454545975 | |
mamba_block1.D.weight gradient: 2.2917021851753816e-05 | |
mamba_block1.D.bias gradient: 2.967903492390178e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.3216970223293174e-06 | |
mamba_block1.S6.fc1.bias gradient: 4.622655069397297e-06 | |
mamba_block1.S6.fc2.weight gradient: 2.1657657271134667e-05 | |
mamba_block1.S6.fc2.bias gradient: 3.360168557264842e-05 | |
mamba_block1.S6.fc3.weight gradient: 2.1000263586756773e-05 | |
mamba_block1.S6.fc3.bias gradient: 3.272424510214478e-05 | |
mamba_block1.conv.weight gradient: 5.6427117669954896e-05 | |
mamba_block1.conv.bias gradient: 9.617741852707695e-06 | |
mamba_block1.conv_linear.weight gradient: 2.0858946299995296e-05 | |
mamba_block1.conv_linear.bias gradient: 6.406073953257874e-05 | |
mamba_block1.norm.weight gradient: 7.72745261201635e-06 | |
mamba_block2.inp_proj.weight gradient: 0.009284038096666336 | |
mamba_block2.inp_proj.bias gradient: 0.0032744971103966236 | |
mamba_block2.out_proj.weight gradient: 0.009324286133050919 | |
mamba_block2.out_proj.bias gradient: 0.022122588008642197 | |
mamba_block2.D.weight gradient: 0.00526660680770874 | |
mamba_block2.D.bias gradient: 0.0018575439462438226 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0011835441691800952 | |
mamba_block2.S6.fc1.bias gradient: 0.0014753035502508283 | |
mamba_block2.S6.fc2.weight gradient: 0.0032423110678792 | |
mamba_block2.S6.fc2.bias gradient: 0.004185546655207872 | |
mamba_block2.S6.fc3.weight gradient: 0.0030104692559689283 | |
mamba_block2.S6.fc3.bias gradient: 0.0038207483012229204 | |
mamba_block2.conv.weight gradient: 0.012847086414694786 | |
mamba_block2.conv.bias gradient: 0.0009625973762013018 | |
mamba_block2.conv_linear.weight gradient: 0.010192912071943283 | |
mamba_block2.conv_linear.bias gradient: 0.006669995374977589 | |
mamba_block2.norm.weight gradient: 0.003584688063710928 | |
mamba_block3.inp_proj.weight gradient: 0.10175695270299911 | |
mamba_block3.inp_proj.bias gradient: 0.03584553673863411 | |
mamba_block3.out_proj.weight gradient: 0.04716596007347107 | |
mamba_block3.out_proj.bias gradient: 7.46511830129748e-08 | |
mamba_block3.D.weight gradient: 0.04783427715301514 | |
mamba_block3.D.bias gradient: 0.016893167048692703 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.0046433014795184135 | |
mamba_block3.S6.fc1.bias gradient: 0.004779241979122162 | |
mamba_block3.S6.fc2.weight gradient: 0.016801055520772934 | |
mamba_block3.S6.fc2.bias gradient: 0.016340354457497597 | |
mamba_block3.S6.fc3.weight gradient: 0.017698541283607483 | |
mamba_block3.S6.fc3.bias gradient: 0.017167871817946434 | |
mamba_block3.conv.weight gradient: 0.17702537775039673 | |
mamba_block3.conv.bias gradient: 0.01801113784313202 | |
mamba_block3.conv_linear.weight gradient: 0.08223015069961548 | |
mamba_block3.conv_linear.bias gradient: 0.05006503686308861 | |
mamba_block3.norm.weight gradient: 0.02444547973573208 | |
DEBUGGING IS ON !!! | |
output = tensor([[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0035, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0033, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0277, 1.0139, 0.9942, ..., 1.0385, 1.0400, 0.9930], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0035, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0033, 0.9993, ..., 1.0052, 1.0101, 0.9955], | |
[1.0276, 1.0140, 0.9941, ..., 1.0385, 1.0399, 0.9930], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0035, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0101, 0.9955], | |
[1.0276, 1.0139, 0.9941, ..., 1.0385, 1.0400, 0.9930], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
..., | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0035, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0101, 0.9955], | |
[1.0276, 1.0140, 0.9941, ..., 1.0385, 1.0400, 0.9930], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0034, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0100, 0.9955], | |
[1.0276, 1.0140, 0.9941, ..., 1.0385, 1.0399, 0.9930], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]], | |
[[1.0055, 1.0071, 0.9964, ..., 1.0038, 1.0035, 0.9950], | |
[1.0054, 1.0067, 0.9965, ..., 1.0037, 1.0035, 0.9951], | |
[1.0050, 1.0067, 0.9966, ..., 1.0030, 1.0012, 0.9947], | |
..., | |
[1.0126, 1.0032, 0.9993, ..., 1.0052, 1.0101, 0.9955], | |
[1.0276, 1.0140, 0.9941, ..., 1.0385, 1.0400, 0.9930], | |
[1.0155, 1.0010, 1.0025, ..., 1.0066, 1.0130, 0.9960]]], | |
device='cuda:0', grad_fn=<ViewBackward0>) | |
shape: (256, 100, 8) | |
min/max: (0.9627537727355957, 1.0400527715682983) | |
mean: 1.0009636878967285 | |
std: 0.008188863284885883 | |
target = tensor([[[ 0.7047, -0.1636, -1.4103, ..., 0.0981, 0.1269, 0.2884], | |
[-0.0752, -0.2943, -0.5152, ..., -1.0968, 0.3245, -0.6512], | |
[ 2.5271, 0.3828, 0.4464, ..., 0.1723, -0.5737, 2.5980], | |
..., | |
[-0.3420, -0.8499, 1.2070, ..., -0.0510, -0.3249, -0.1682], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.0061, 0.0731, 0.1958, ..., -0.5969, -0.9973, -2.2435], | |
[ 1.1250, 1.1728, 0.5452, ..., -1.0478, 0.4368, 1.5019], | |
[-1.0631, -0.8864, 0.3796, ..., -0.0263, -1.2731, -1.4496], | |
..., | |
[-0.1816, -0.4553, -1.1590, ..., -0.4902, -0.3588, 1.5264], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.9234, -0.9421, 1.0451, ..., -0.4781, 0.0420, 0.2934], | |
[-1.6287, -0.9663, -0.0027, ..., 1.2085, -0.9073, 0.0483], | |
[ 0.2567, 1.9508, 2.2958, ..., 0.1782, 0.4551, -1.1158], | |
..., | |
[ 0.8881, -0.0437, -1.5893, ..., -0.5971, -0.4100, 1.8774], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
..., | |
[[-0.6347, 1.8580, -0.0971, ..., 1.7939, 0.2032, -0.1249], | |
[-0.2951, -0.1044, -1.3054, ..., -0.6431, -0.4934, 0.8809], | |
[ 0.4335, 0.5802, 1.1978, ..., 1.2420, 0.1373, -0.8186], | |
..., | |
[-0.6884, 0.6784, -1.0951, ..., -1.3010, 0.1661, -0.5777], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[ 0.3804, -0.0668, -0.8042, ..., 0.1360, 0.0233, -1.5833], | |
[ 1.5597, -0.4917, -0.4323, ..., -1.5830, -0.4509, -0.0552], | |
[-0.6078, 0.9015, 0.9592, ..., -0.3502, -0.7853, 1.1148], | |
..., | |
[ 0.6637, 0.1741, -0.3558, ..., -1.4354, 1.1672, 0.0448], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], | |
[[-0.0729, -1.5876, -0.1719, ..., 1.2421, 0.7066, 0.4504], | |
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
[-1.7780, -0.3538, 0.5748, ..., -0.4030, -0.3248, -0.6199], | |
..., | |
[ 0.4974, 0.7906, -1.0028, ..., 1.2410, -1.4670, -1.0270], | |
[ 0.9189, 0.5545, 1.3656, ..., 0.7694, 1.3140, 0.2311], | |
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], | |
device='cuda:0') | |
shape: (256, 100, 8) | |
min/max: (-3.77933669090271, 3.9631638526916504) | |
mean: -0.04762903228402138 | |
std: 1.0014238357543945 | |
mamba_block1.inp_proj.weight gradient: 8.953470569394995e-06 | |
mamba_block1.inp_proj.bias gradient: 1.7052547264029272e-05 | |
mamba_block1.out_proj.weight gradient: 9.578504250384867e-05 | |
mamba_block1.out_proj.bias gradient: 0.0022079788614064455 | |
mamba_block1.D.weight gradient: 2.5630950403865427e-05 | |
mamba_block1.D.bias gradient: 3.2377221941715106e-05 | |
mamba_block1.S6.A_log gradient: 0.0 | |
mamba_block1.S6.fc1.weight gradient: 3.6499291127256583e-06 | |
mamba_block1.S6.fc1.bias gradient: 5.3718267736257985e-06 | |
mamba_block1.S6.fc2.weight gradient: 2.8647153158090077e-05 | |
mamba_block1.S6.fc2.bias gradient: 4.695907045970671e-05 | |
mamba_block1.S6.fc3.weight gradient: 2.7555784981814213e-05 | |
mamba_block1.S6.fc3.bias gradient: 4.5302771468413994e-05 | |
mamba_block1.conv.weight gradient: 5.673680789186619e-05 | |
mamba_block1.conv.bias gradient: 1.0635959370119963e-05 | |
mamba_block1.conv_linear.weight gradient: 2.1541560272453353e-05 | |
mamba_block1.conv_linear.bias gradient: 7.382209150819108e-05 | |
mamba_block1.norm.weight gradient: 6.72381838739966e-06 | |
mamba_block2.inp_proj.weight gradient: 0.01039169728755951 | |
mamba_block2.inp_proj.bias gradient: 0.0036651124246418476 | |
mamba_block2.out_proj.weight gradient: 0.009969900362193584 | |
mamba_block2.out_proj.bias gradient: 0.022011322900652885 | |
mamba_block2.D.weight gradient: 0.005989375524222851 | |
mamba_block2.D.bias gradient: 0.0021124156191945076 | |
mamba_block2.S6.A_log gradient: 0.0 | |
mamba_block2.S6.fc1.weight gradient: 0.0013773165410384536 | |
mamba_block2.S6.fc1.bias gradient: 0.001692585414275527 | |
mamba_block2.S6.fc2.weight gradient: 0.0037395297549664974 | |
mamba_block2.S6.fc2.bias gradient: 0.004679436329752207 | |
mamba_block2.S6.fc3.weight gradient: 0.0034694750793278217 | |
mamba_block2.S6.fc3.bias gradient: 0.004271483514457941 | |
mamba_block2.conv.weight gradient: 0.013511927798390388 | |
mamba_block2.conv.bias gradient: 0.0010048352414742112 | |
mamba_block2.conv_linear.weight gradient: 0.011564495973289013 | |
mamba_block2.conv_linear.bias gradient: 0.007294516544789076 | |
mamba_block2.norm.weight gradient: 0.003915737848728895 | |
mamba_block3.inp_proj.weight gradient: 0.10254763066768646 | |
mamba_block3.inp_proj.bias gradient: 0.03612254932522774 | |
mamba_block3.out_proj.weight gradient: 0.048949189484119415 | |
mamba_block3.out_proj.bias gradient: 7.861819284471494e-08 | |
mamba_block3.D.weight gradient: 0.049613695591688156 | |
mamba_block3.D.bias gradient: 0.017519617453217506 | |
mamba_block3.S6.A_log gradient: 0.0 | |
mamba_block3.S6.fc1.weight gradient: 0.005079771392047405 | |
mamba_block3.S6.fc1.bias gradient: 0.005594966001808643 | |
mamba_block3.S6.fc2.weight gradient: 0.01893662102520466 | |
mamba_block3.S6.fc2.bias gradient: 0.0198878962546587 | |
mamba_block3.S6.fc3.weight gradient: 0.019843893125653267 | |
mamba_block3.S6.fc3.bias gradient: 0.02052515186369419 | |
mamba_block3.conv.weight gradient: 0.17641305923461914 | |
mamba_block3.conv.bias gradient: 0.01822531968355179 | |
mamba_block3.conv_linear.weight gradient: 0.08087118715047836 | |
mamba_block3.conv_linear.bias gradient: 0.057623837143182755 | |
mamba_block3.norm.weight gradient: 0.024983780458569527 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment