Skip to content

Instantly share code, notes, and snippets.

@Advaitgaur004
Last active August 1, 2025 07:51
Show Gist options
  • Save Advaitgaur004/12d5a8a9c6833c65beac3e2cb9a60af6 to your computer and use it in GitHub Desktop.
Save Advaitgaur004/12d5a8a9c6833c65beac3e2cb9a60af6 to your computer and use it in GitHub Desktop.
Ctensor softmax testing in src/main.c
#include "cten.h"
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <time.h>
void* _cten_malloc(size_t size);
enum MemoryPoolIds {
PoolId_Default = 0,
PoolId_Model = 1,
PoolId_Optimizer = 2,
};
typedef struct Model {
Tensor weight_1, weight_2;
Tensor bias_1, bias_2;
} Model;
Tensor Model_forward(Model* model, Tensor x) {
x = nn_linear(x, model->weight_1, model->bias_1);
x = nn_relu(x);
x = nn_linear(x, model->weight_2, model->bias_2);
return x;
}
void assert_equal(Tensor a, Tensor b) {
float epsilon = 1e-3;
assert(a.data->numel == b.data->numel);
cten_assert_shape("Shape not equal", a.shape, b.shape);
for (int i = 0; i < a.data->numel; i++) {
assert((a.data->flex[i] - b.data->flex[i] < epsilon) &&
(a.data->flex[i] - b.data->flex[i] > -epsilon));
}
}
Tensor create_tensor(TensorShape shape, float* data, bool requires_grad) {
Tensor res = Tensor_new(shape, requires_grad);
int numel = res.data->numel;
for (int i = 0; i < numel; i++) {
res.data->flex[i] = data[i];
}
return res;
}
int main() {
cten_initilize();
cten_begin_malloc(PoolId_Default);
//dim1 simple softmax
float a_data[] = { 1,2,3,4 };
Tensor a = create_tensor((TensorShape) { 4 }, a_data, true);
int last_dim_a = TensorShape_dim(a.shape) - 1;
Tensor b = nn_softmax(a, last_dim_a);
Tensor c = Tensor_sum(b);
Tensor_backward(c, (Tensor) {NULL});
float b_data[] = { 0.0321, 0.0871, 0.2369, 0.6439 };
Tensor b_answer = create_tensor((TensorShape) { 4 }, b_data, false);
assert_equal(b, b_answer);
Tensor a_grad_answer = Tensor_zeros((TensorShape) { 4 }, false);
assert_equal(a.node->grad, a_grad_answer);
//dim1 hard softmax
Tensor a1 = create_tensor((TensorShape) { 4 }, a_data, true);
int last_dim_a1 = TensorShape_dim(a1.shape) - 1;
Tensor b1 = nn_softmax(a1, last_dim_a1);
Tensor z1 = Tensor_mul(a1, b1);
Tensor c1 = Tensor_sum(z1);
Tensor_backward(c1, (Tensor) { NULL });
float b_grad_data[] = { 1,2,3,4 };
Tensor b1_grad_answer = create_tensor((TensorShape) { 4 }, b_grad_data, false);
assert_equal(b1.node->grad, b1_grad_answer);
float a_grad_data[] = { -0.0479, -0.0429, 0.1202, 0.9706 };
Tensor a1_grad_answer = create_tensor((TensorShape) { 4 }, a_grad_data, false);
assert_equal(a1.node->grad, a1_grad_answer);
//dim2 simple softmax
float dim2_a_data[] = { 1,2,3,4,5,6 };
Tensor a2 = create_tensor((TensorShape) { 2,3 }, dim2_a_data, true);
int last_dim_a2 = TensorShape_dim(a2.shape) - 1;
Tensor b2 = nn_softmax(a2, last_dim_a2);
Tensor c2 = Tensor_sum(b2);
Tensor_backward(c2, (Tensor) { NULL });
float dim2_b_data[] = { 0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652 };
Tensor b2_answer = create_tensor((TensorShape) { 2,3 }, dim2_b_data, false);
assert_equal(b2, b2_answer);
Tensor a2_grad_answer = Tensor_zeros((TensorShape) { 2,3 }, false);
assert_equal(a2.node->grad, a2_grad_answer);
//dim2 hard softmax
Tensor a3 = create_tensor((TensorShape) { 2,3 }, dim2_a_data, true);
int last_dim_a3 = TensorShape_dim(a3.shape) - 1;
Tensor b3 = nn_softmax(a3, last_dim_a3);
Tensor z3 = Tensor_mul(a3, b3);
Tensor c3 = Tensor_sum(z3);
Tensor_backward(c3, (Tensor) { NULL });
float dim2_b_grad_data[] = { 1,2,3,4,5,6 };
Tensor b3_grad_answer = create_tensor((TensorShape) { 2,3 }, dim2_b_grad_data, false);
assert_equal(b3.node->grad, b3_grad_answer);
float dim2_a_grad_data[] = { -0.0518, 0.1040, 0.9478, -0.0518, 0.1040, 0.9478 };
Tensor a3_grad_answer = create_tensor((TensorShape) { 2,3 }, dim2_a_grad_data, false);
assert_equal(a3.node->grad, a3_grad_answer);
//dim3 simple softmax
float dim3_a_data[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 };
Tensor a4 = create_tensor((TensorShape) { 2,3,4 }, dim3_a_data, true);
int last_dim_a4 = TensorShape_dim(a4.shape) - 1;
Tensor b4 = nn_softmax(a4, last_dim_a4);
Tensor c4 = Tensor_sum(b4);
Tensor_backward(c4, (Tensor) { NULL });
float dim3_b_data[] = { 0.0321,0.0871,0.2369,0.6439,0.0321,0.0871,0.2369,0.6439,
0.0321,0.0871,0.2369,0.6439,0.0321,0.0871,0.2369,0.6439,
0.0321,0.0871,0.2369,0.6439,0.0321,0.0871,0.2369,0.6439 };
Tensor b4_answer = create_tensor((TensorShape) { 2,3,4 }, dim3_b_data, false);
assert_equal(b4, b4_answer);
Tensor a4_grad_answer = Tensor_zeros((TensorShape) { 2,3,4 }, false);
assert_equal(a4.node->grad, a4_grad_answer);
//dim3 hard softmax
Tensor a5 = create_tensor((TensorShape) { 2,3,4 }, dim3_a_data, true);
int last_dim_a5 = TensorShape_dim(a5.shape) - 1;
Tensor b5 = nn_softmax(a5, last_dim_a5);
Tensor z5 = Tensor_mul(a5, b5);
Tensor c5 = Tensor_sum(z5);
Tensor_backward(c5, (Tensor) { NULL });
float dim3_b_grad_data[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 };
Tensor b5_grad_answer = create_tensor((TensorShape) { 2,3,4 }, dim3_b_grad_data, false);
assert_equal(b5.node->grad, b5_grad_answer);
float dim3_a_grad_data[] = { -0.0479,-0.0429,0.1202,0.9706,-0.0479,-0.0429,0.1202,0.9706,
-0.0479,-0.0429,0.1202,0.9706,-0.0479,-0.0429,0.1202,0.9706,
-0.0479,-0.0429,0.1202,0.9706,-0.0479,-0.0429,0.1202,0.9706 };
Tensor a5_grad_answer = create_tensor((TensorShape) { 2,3,4 }, dim3_a_grad_data, false);
assert_equal(a5.node->grad, a5_grad_answer);
// (FORWARD PASS)
// --- Test Case 1: Shape=[6], Dim=0 ---
TensorShape shape_1 = { 6, 0, 0, 0 };
int dim_1 = 0;
float input_data_1[] = {
-1.211608f, -0.344907f, 1.322972f, -0.427439f, -1.496399f, 0.685451f
};
float expected_output_1[] = {
0.039064f, 0.092935f, 0.492638f, 0.085572f, 0.029383f, 0.260409f
};
Tensor t_input_1 = create_tensor(shape_1, input_data_1, false);
Tensor t_output_1 = nn_softmax(t_input_1, dim_1);
Tensor t_expected_1 = create_tensor(shape_1, expected_output_1, false);
assert_equal(t_output_1, t_expected_1);
// --- Test Case 2: Shape=[4, 5], Dim=0 ---
TensorShape shape_2 = { 4, 5, 0, 0 };
int dim_2 = 0;
float input_data_2[] = {
-0.973596f, -0.593090f, 0.240839f, 0.778621f, -0.619067f, 1.254894f, -0.395984f, -1.496162f,
0.154189f, 0.167212f, 0.130392f, -0.652786f, 0.904415f, -0.958059f, -1.114413f, -0.863093f,
-0.971413f, 0.044263f, -0.548822f, 0.153366f
};
float expected_output_2[] = {
0.069354f, 0.260082f, 0.253852f, 0.505862f, 0.167515f, 0.644007f, 0.316748f, 0.044690f,
0.270922f, 0.367732f, 0.209183f, 0.245011f, 0.492909f, 0.089084f, 0.102077f, 0.077457f,
0.178159f, 0.208549f, 0.134131f, 0.362676f
};
Tensor t_input_2 = create_tensor(shape_2, input_data_2, false);
Tensor t_output_2 = nn_softmax(t_input_2, dim_2);
Tensor t_expected_2 = create_tensor(shape_2, expected_output_2, false);
assert_equal(t_output_2, t_expected_2);
// --- Test Case 3: Shape=[4, 5], Dim=1 ---
TensorShape shape_3 = { 4, 5, 0, 0 };
int dim_3 = 1;
float input_data_3[] = {
-0.635548f, 1.540740f, -0.242792f, 0.434214f, -0.923128f, 0.454926f, 0.762836f, 0.451787f,
-0.751669f, 2.009526f, 0.965157f, -0.943850f, 0.146616f, 0.078624f, -0.412194f, 0.908799f,
-0.616340f, 0.549718f, 0.507248f, 0.581007f
};
float expected_output_3[] = {
0.066848f, 0.589165f, 0.099006f, 0.194840f, 0.050141f, 0.119192f, 0.162170f, 0.118818f,
0.035664f, 0.564157f, 0.443729f, 0.065773f, 0.195717f, 0.182852f, 0.111929f, 0.302508f,
0.065823f, 0.211246f, 0.202463f, 0.217961f
};
Tensor t_input_3 = create_tensor(shape_3, input_data_3, false);
Tensor t_output_3 = nn_softmax(t_input_3, dim_3);
Tensor t_expected_3 = create_tensor(shape_3, expected_output_3, false);
assert_equal(t_output_3, t_expected_3);
// --- Test Case 5: Shape=[3, 4, 5], Dim=1 ---
TensorShape shape_5 = { 3, 4, 5, 0 };
int dim_5 = 1;
float input_data_5[] = {
0.329450f, -0.516812f, -0.198207f, 0.948730f, -1.939420f, -0.845908f, -1.289681f, 0.117924f,
-1.209618f, -0.385438f, -1.246859f, -0.602803f, 0.499607f, 1.421608f, -0.082139f, -0.354326f,
1.332360f, -1.326013f, -1.678908f, 0.702517f, -0.563351f, -0.613228f, 1.338767f, -1.208467f,
0.090329f, 0.869903f, -0.215624f, 0.413125f, 0.446294f, -0.154964f, 0.473783f, -0.962919f,
-0.257307f, -0.566939f, 0.528619f, -0.467904f, -0.606036f, 1.747805f, 0.583291f, -0.186993f,
-1.803139f, 2.192201f, 1.364973f, 0.921052f, -0.799728f, 1.227652f, 1.203330f, 0.735748f,
-0.198611f, 0.797526f, -1.293098f, -0.937443f, 1.848753f, -0.077266f, 0.413816f, -2.428300f,
0.277431f, -0.883364f, -1.196831f, 1.152895f
};
float expected_output_5[] = {
0.495012f, 0.114497f, 0.212544f, 0.358119f, 0.038202f, 0.152814f, 0.052862f, 0.291570f,
0.041368f, 0.180704f, 0.102337f, 0.105063f, 0.427076f, 0.574639f, 0.244730f, 0.249836f,
0.727578f, 0.068810f, 0.025874f, 0.536364f, 0.109726f, 0.238074f, 0.322130f, 0.070765f,
0.244478f, 0.460007f, 0.354314f, 0.127653f, 0.370233f, 0.191298f, 0.309551f, 0.167820f,
0.065293f, 0.134410f, 0.378955f, 0.120715f, 0.239792f, 0.484925f, 0.424592f, 0.185268f,
0.041816f, 0.639752f, 0.306676f, 0.550911f, 0.061155f, 0.866166f, 0.237985f, 0.163460f,
0.179812f, 0.302070f, 0.069639f, 0.027979f, 0.497488f, 0.203010f, 0.205809f, 0.022379f,
0.094284f, 0.032377f, 0.066267f, 0.430966f
};
Tensor t_input_5 = create_tensor(shape_5, input_data_5, false);
Tensor t_output_5 = nn_softmax(t_input_5, dim_5);
Tensor t_expected_5 = create_tensor(shape_5, expected_output_5, false);
assert_equal(t_output_5, t_expected_5);
// --- Test Case 6: Shape=[3, 4, 5], Dim=2 ---
TensorShape shape_6 = { 3, 4, 5, 0 };
int dim_6 = 2;
float input_data_6[] = {
-0.756514f, -1.056839f, 1.597832f, -2.225006f, -0.249725f, 0.087490f, 1.345537f, -0.307392f,
0.428569f, 0.085686f, 0.741215f, 0.588628f, -0.053020f, -0.708825f, -0.952247f, 1.000654f,
1.107366f, 0.641223f, 0.423863f, -1.611730f, -1.041236f, 0.094752f, -2.314693f, -0.517261f,
0.995701f, 1.425699f, 1.413595f, 0.511550f, -0.910978f, 2.286082f, -0.168117f, 0.493978f,
-1.621782f, 1.217807f, -0.100802f, 0.590658f, 0.337501f, 1.765542f, 0.486405f, 1.376526f,
-1.236085f, 0.626381f, -1.855668f, -0.734225f, 0.681713f, 0.901065f, -2.439517f, -0.263396f,
-1.291715f, 0.858222f, 0.146793f, 0.921066f, -0.744419f, -1.882994f, 0.447125f, 1.785266f,
-0.253973f, -0.175359f, -1.915487f, -0.002335f
};
float expected_output_6[] = {
0.070611f, 0.052293f, 0.743625f, 0.016260f, 0.117211f, 0.131632f, 0.463151f, 0.088688f,
0.185135f, 0.131394f, 0.366454f, 0.314595f, 0.165611f, 0.085956f, 0.067384f, 0.290213f,
0.322895f, 0.202590f, 0.163012f, 0.021290f, 0.072728f, 0.226492f, 0.020354f, 0.122817f,
0.557609f, 0.206205f, 0.203724f, 0.082659f, 0.019929f, 0.487482f, 0.121349f, 0.235277f,
0.028361f, 0.485215f, 0.129798f, 0.123313f, 0.095734f, 0.399260f, 0.111105f, 0.270588f,
0.060844f, 0.391810f, 0.032744f, 0.100501f, 0.414100f, 0.413706f, 0.014652f, 0.129114f,
0.046172f, 0.396356f, 0.197595f, 0.428588f, 0.081045f, 0.025957f, 0.266815f, 0.683544f,
0.088948f, 0.096223f, 0.016887f, 0.114399f
};
Tensor t_input_6 = create_tensor(shape_6, input_data_6, false);
Tensor t_output_6 = nn_softmax(t_input_6, dim_6);
Tensor t_expected_6 = create_tensor(shape_6, expected_output_6, false);
assert_equal(t_output_6, t_expected_6);
// --- Test Case 7: Shape=[2, 3, 4, 5], Dim=0 ---
TensorShape shape_7 = { 2, 3, 4, 5 };
int dim_7 = 0;
float input_data_7[] = {
-0.511257f, -0.486208f, 1.811264f, -1.887858f, 0.125982f, -1.844349f, -0.887010f, 0.964408f,
-0.424253f, -0.920038f, 0.508277f, 0.178656f, -0.036295f, 1.510589f, 0.413179f, 1.068661f,
-2.145023f, 0.378915f, 1.998491f, -2.765419f, 1.537244f, -0.811624f, -1.216795f, 1.327383f,
-0.726409f, 0.241636f, 1.059124f, 1.986931f, -0.199572f, -1.392530f, -1.172699f, 0.018050f,
-0.371864f, 1.119579f, -0.214250f, -0.692109f, 1.547399f, 0.853238f, 0.812299f, -0.275152f,
-1.434234f, 1.395362f, -0.498836f, -0.007882f, -1.192756f, 1.306992f, -0.337672f, 2.039225f,
1.120167f, 0.087211f, 1.003423f, -0.288729f, -0.824022f, -0.524947f, -0.489327f, 0.323384f,
3.372605f, -0.329193f, 0.681299f, -1.262865f, -0.836464f, 1.061926f, -0.135600f, -0.231079f,
-1.152552f, 0.053548f, -0.834973f, -1.011236f, 0.081576f, -0.580494f, -0.799639f, -0.602497f,
2.375473f, 0.757348f, 0.427160f, 0.556822f, 0.972614f, 0.664390f, 0.395206f, -1.041777f,
1.490402f, -0.157482f, 2.521987f, 0.666604f, 0.563484f, 2.161451f, 0.480465f, -0.581542f,
0.612467f, 2.106347f, -0.821566f, 0.057422f, -0.882044f, -0.337474f, 0.275213f, 0.519093f,
0.629499f, -0.058555f, -1.126014f, -1.360813f, -2.175084f, -2.276683f, -1.839429f, -0.702685f,
-0.197460f, 0.301871f, 2.437877f, -1.052195f, 2.052355f, 0.165766f, -0.066026f, 0.735303f,
0.877751f, -0.804196f, -1.014561f, -0.524483f, 0.728983f, -0.361918f, 0.156797f, -1.103459f
};
float expected_output_7[] = {
0.580593f, 0.175356f, 0.875104f, 0.160195f, 0.782200f, 0.130347f, 0.486994f, 0.878216f,
0.376172f, 0.415920f, 0.787164f, 0.685929f, 0.082280f, 0.679884f, 0.496505f, 0.625237f,
0.042386f, 0.429112f, 0.832477f, 0.151403f, 0.511708f, 0.342057f, 0.023231f, 0.659435f,
0.215871f, 0.127882f, 0.640759f, 0.928805f, 0.307456f, 0.029344f, 0.413108f, 0.490158f,
0.624849f, 0.811082f, 0.380020f, 0.229488f, 0.714614f, 0.713367f, 0.874167f, 0.747564f,
0.677182f, 0.975206f, 0.792587f, 0.667035f, 0.269867f, 0.732064f, 0.058660f, 0.956537f,
0.282481f, 0.480372f, 0.744492f, 0.264243f, 0.154234f, 0.569362f, 0.628371f, 0.700119f,
0.933617f, 0.508181f, 0.628200f, 0.460233f, 0.419407f, 0.824644f, 0.124896f, 0.839805f,
0.217800f, 0.869653f, 0.513006f, 0.121784f, 0.623828f, 0.584080f, 0.212836f, 0.314071f,
0.917720f, 0.320115f, 0.503495f, 0.374763f, 0.957614f, 0.570888f, 0.167523f, 0.848597f,
0.488292f, 0.657943f, 0.976769f, 0.340565f, 0.784129f, 0.872118f, 0.359241f, 0.071195f,
0.692544f, 0.970656f, 0.586892f, 0.509842f, 0.375151f, 0.188918f, 0.619980f, 0.770512f,
0.285386f, 0.286633f, 0.125833f, 0.252436f, 0.322818f, 0.024794f, 0.207413f, 0.332965f,
0.730133f, 0.267936f, 0.941340f, 0.043463f, 0.717519f, 0.519628f, 0.255508f, 0.735757f,
0.845766f, 0.430638f, 0.371629f, 0.299881f, 0.066383f, 0.491819f, 0.371800f, 0.539767f
};
Tensor t_input_7 = create_tensor(shape_7, input_data_7, false);
Tensor t_output_7 = nn_softmax(t_input_7, dim_7);
Tensor t_expected_7 = create_tensor(shape_7, expected_output_7, false);
assert_equal(t_output_7, t_expected_7);
// --- Test Case 8: Shape=[2, 3, 4, 5], Dim=1 ---
TensorShape shape_8 = { 2, 3, 4, 5 };
int dim_8 = 1;
float input_data_8[] = {
-1.257495f, -0.301707f, -1.078206f, -0.869783f, -0.487602f, -0.677509f, 0.162030f, -0.905193f,
-0.224179f, 0.913948f, 0.087469f, 0.984896f, -0.373979f, 0.421240f, 1.695008f, -0.253288f,
-2.464733f, 0.410508f, 0.693939f, 0.704088f, -0.715867f, -0.383032f, -0.120659f, 0.230918f,
-0.320296f, 0.395113f, 0.191242f, -0.567990f, 0.270543f, -0.721094f, 0.464662f, -0.195229f,
1.703504f, 0.929006f, -0.398823f, -1.703876f, -0.537868f, 0.680445f, 0.548786f, 0.262192f,
0.891745f, 1.893878f, -1.034683f, 0.883484f, -0.858102f, 0.530082f, -0.637206f, -0.330277f,
0.298938f, 0.188639f, 0.477610f, 0.458072f, 1.039523f, 0.251319f, -0.154824f, 0.237238f,
0.253144f, 1.160356f, 0.298290f, -0.897417f, -2.845256f, 0.800772f, -0.719587f, 0.205090f,
0.121262f, -0.199399f, -0.823564f, 0.250895f, -0.711147f, 0.262902f, 0.009090f, 1.922074f,
1.151204f, -0.220231f, 0.648669f, 0.130951f, -1.111660f, -1.108665f, 0.972070f, 1.388218f,
1.991043f, 0.932965f, -0.278309f, 1.057666f, 0.982868f, -0.582407f, 0.422493f, -1.071779f,
-1.034922f, 0.807841f, -0.172142f, -0.949420f, -0.186462f, 1.273863f, -0.234617f, 0.575369f,
-0.792975f, -2.354831f, 0.391219f, 1.058017f, -1.233542f, 2.169375f, -0.878724f, -0.985356f,
0.411549f, -0.461113f, -2.248060f, -0.861590f, 0.260213f, 0.978174f, 1.910362f, 1.583116f,
0.565735f, 0.682193f, 2.205384f, -0.690529f, 0.997477f, 1.452847f, 1.229271f, -0.092104f
};
float expected_output_8[] = {
0.088518f, 0.091683f, 0.215064f, 0.102253f, 0.348128f, 0.137581f, 0.403337f, 0.239349f,
0.231090f, 0.595550f, 0.254105f, 0.526949f, 0.076367f, 0.285282f, 0.780959f, 0.348723f,
0.043448f, 0.225909f, 0.393990f, 0.542179f, 0.152145f, 0.084522f, 0.560305f, 0.307399f,
0.411528f, 0.402154f, 0.415293f, 0.335333f, 0.378997f, 0.116099f, 0.370533f, 0.161900f,
0.609740f, 0.474017f, 0.096225f, 0.081752f, 0.298398f, 0.295914f, 0.340758f, 0.348522f,
0.759337f, 0.823795f, 0.224631f, 0.590348f, 0.240344f, 0.460265f, 0.181370f, 0.425318f,
0.389913f, 0.288351f, 0.375362f, 0.311151f, 0.313893f, 0.240701f, 0.122816f, 0.569525f,
0.658154f, 0.478177f, 0.265251f, 0.109299f, 0.007575f, 0.164713f, 0.293464f, 0.273990f,
0.212596f, 0.407907f, 0.211989f, 0.626895f, 0.229098f, 0.209677f, 0.117252f, 0.565251f,
0.549661f, 0.126249f, 0.162426f, 0.333404f, 0.094193f, 0.070210f, 0.350542f, 0.513784f,
0.954464f, 0.187991f, 0.456246f, 0.642694f, 0.503205f, 0.278114f, 0.737001f, 0.167019f,
0.165732f, 0.361588f, 0.097816f, 0.032001f, 0.144263f, 0.562476f, 0.067151f, 0.519971f,
0.129546f, 0.020193f, 0.196101f, 0.369297f, 0.037961f, 0.647296f, 0.250289f, 0.083316f,
0.284200f, 0.313979f, 0.051011f, 0.206086f, 0.605170f, 0.428735f, 0.784932f, 0.402748f,
0.306076f, 0.311275f, 0.770423f, 0.146625f, 0.776261f, 0.909597f, 0.453357f, 0.116919f
};
Tensor t_input_8 = create_tensor(shape_8, input_data_8, false);
Tensor t_output_8 = nn_softmax(t_input_8, dim_8);
Tensor t_expected_8 = create_tensor(shape_8, expected_output_8, false);
assert_equal(t_output_8, t_expected_8);
// --- Test Case 9: Shape=[2, 3, 4, 5], Dim=2 ---
TensorShape shape_9 = { 2, 3, 4, 5 };
int dim_9 = 2;
float input_data_9[] = {
-1.236507f, 0.666192f, 2.592585f, 1.202181f, 0.778915f, -0.988539f, -1.170790f, -0.664889f,
-0.133236f, 0.807191f, 0.587307f, 1.008587f, -0.632058f, -0.340653f, -0.811564f, -0.874503f,
-1.419854f, -0.396616f, 0.403904f, 0.289703f, 0.134876f, -0.566567f, -1.015988f, 0.774370f,
-0.585283f, 1.023082f, -3.720424f, 0.528280f, 0.629004f, 0.682999f, 1.163303f, -0.648755f,
-0.093579f, -0.965785f, 0.615421f, -0.579691f, 0.879971f, -1.569465f, 0.572907f, 0.395161f,
0.707526f, -0.710514f, 0.541310f, -1.303612f, -0.096734f, -0.719989f, 0.291270f, 2.981184f,
-0.085921f, -1.003435f, 1.393750f, -1.186577f, -0.097973f, 0.483390f, -0.115178f, 1.874009f,
0.961650f, -0.905726f, 1.063592f, -1.107064f, -0.975180f, 0.405832f, 0.111302f, 0.681526f,
0.707183f, 0.418216f, -1.533945f, 0.452553f, -0.524506f, -0.095175f, 1.186117f, 0.202994f,
-1.421134f, -0.358176f, 0.012797f, 2.383064f, 0.235132f, -1.752441f, 0.091681f, 3.775636f,
-0.661333f, -0.440233f, -0.432518f, -2.818887f, -0.088851f, -1.265370f, 1.239664f, 0.049627f,
-0.196688f, -1.137375f, 1.724149f, -0.581866f, -0.415181f, 0.563136f, 0.445962f, -0.329975f,
-0.647561f, 1.692823f, 0.457386f, -0.711722f, -0.125222f, -0.381787f, 0.581898f, 0.475846f,
1.356747f, 1.357815f, 0.333107f, 0.230856f, -0.038589f, 0.075892f, -1.058006f, 0.743106f,
0.841856f, 1.096734f, 2.172618f, -0.493073f, 0.838112f, -0.760886f, 1.512828f, -0.775099f
};
float expected_output_9[] = {
0.100877f, 0.371500f, 0.886067f, 0.518961f, 0.351418f, 0.129266f, 0.059179f, 0.034101f,
0.136512f, 0.361496f, 0.624978f, 0.523190f, 0.035239f, 0.110941f, 0.071629f, 0.144880f,
0.046132f, 0.044594f, 0.233586f, 0.215458f, 0.148880f, 0.160974f, 0.113962f, 0.349928f,
0.094851f, 0.361892f, 0.006872f, 0.533861f, 0.302585f, 0.337171f, 0.416366f, 0.148273f,
0.286654f, 0.061410f, 0.315139f, 0.072862f, 0.683881f, 0.065522f, 0.286078f, 0.252839f,
0.155357f, 0.103434f, 0.075560f, 0.047577f, 0.363679f, 0.037271f, 0.281664f, 0.866793f,
0.160782f, 0.146873f, 0.308570f, 0.064255f, 0.039871f, 0.284109f, 0.357033f, 0.498802f,
0.550647f, 0.017777f, 0.507532f, 0.132415f, 0.023557f, 0.356732f, 0.359996f, 0.453031f,
0.042633f, 0.094901f, 0.051275f, 0.506409f, 0.135630f, 0.019111f, 0.204534f, 0.291240f,
0.077762f, 0.160173f, 0.021290f, 0.677007f, 0.300752f, 0.055832f, 0.251166f, 0.916967f,
0.072444f, 0.124290f, 0.083245f, 0.014150f, 0.278243f, 0.039598f, 0.666817f, 0.134818f,
0.194784f, 0.097511f, 0.787054f, 0.107876f, 0.084700f, 0.416428f, 0.474996f, 0.100904f,
0.101017f, 0.697236f, 0.374639f, 0.149250f, 0.154036f, 0.105145f, 0.306565f, 0.159258f,
0.273406f, 0.678730f, 0.214913f, 0.215808f, 0.095211f, 0.075952f, 0.060607f, 0.323834f,
0.397577f, 0.296313f, 0.618211f, 0.106627f, 0.356109f, 0.080050f, 0.449218f, 0.032431f
};
Tensor t_input_9 = create_tensor(shape_9, input_data_9, false);
Tensor t_output_9 = nn_softmax(t_input_9, dim_9);
Tensor t_expected_9 = create_tensor(shape_9, expected_output_9, false);
assert_equal(t_output_9, t_expected_9);
// --- Test Case 10: Shape=[2, 3, 4, 5], Dim=3 ---
TensorShape shape_10 = { 2, 3, 4, 5 };
int dim_10 = 3;
float input_data_10[] = {
-0.134424f, -0.480058f, -0.321428f, -0.705847f, 0.022239f, 0.145085f, 0.526791f, 0.674307f,
0.288172f, 0.167516f, -0.010255f, 1.812615f, -0.116002f, 1.495559f, 1.053778f, -0.978276f,
0.572919f, -0.345361f, 0.304904f, 0.770001f, 0.326273f, -0.870936f, -0.690999f, -0.568904f,
2.087554f, -0.191169f, 1.430569f, -1.059834f, 0.374943f, -0.534675f, -0.422606f, -1.609942f,
-0.922619f, 2.244947f, -0.097903f, 0.121999f, -0.149312f, -0.649967f, -1.045418f, -0.361698f,
-0.163933f, -0.846074f, -0.327313f, 0.547455f, 2.043521f, 0.119833f, -0.708708f, 1.374049f,
0.976319f, -0.034966f, 0.579193f, 0.462073f, -0.338565f, 0.910230f, 0.656344f, 1.844478f,
-0.310301f, -0.241161f, -1.556111f, -1.152509f, 1.286693f, 0.303133f, 0.007123f, 1.411441f,
2.773875f, -1.916589f, 0.379059f, -0.869618f, 0.754217f, 1.803710f, -1.470451f, 1.252759f,
-1.259440f, 1.148975f, -0.794163f, 0.289906f, -1.423498f, 1.268674f, 0.049445f, 0.448380f,
0.195606f, 0.325677f, 0.705436f, 0.019392f, -0.275582f, -0.821970f, -0.685265f, -1.903032f,
0.322047f, 0.067131f, 1.327628f, 0.337458f, 0.622310f, 0.498616f, -1.797895f, 0.285230f,
0.391524f, -1.209723f, 1.226868f, 0.358081f, 1.487907f, 0.660574f, -0.156783f, -0.198020f,
-0.855000f, -1.496671f, -0.828395f, 1.279418f, 0.645864f, -1.312350f, -0.903121f, 1.460771f,
-0.527595f, -1.623403f, -0.513170f, 0.545358f, -0.146408f, -0.963346f, -0.131965f, -0.910921f
};
float expected_output_10[] = {
0.234108f, 0.165695f, 0.194178f, 0.132206f, 0.273813f, 0.157756f, 0.231078f, 0.267808f,
0.182023f, 0.161335f, 0.064537f, 0.399455f, 0.058060f, 0.290919f, 0.187029f, 0.058987f,
0.278246f, 0.111077f, 0.212830f, 0.338860f, 0.126710f, 0.038271f, 0.045816f, 0.051766f,
0.737438f, 0.111706f, 0.565442f, 0.046862f, 0.196759f, 0.079231f, 0.056498f, 0.017234f,
0.034267f, 0.813830f, 0.078171f, 0.317244f, 0.241860f, 0.146599f, 0.098717f, 0.195581f,
0.074161f, 0.037491f, 0.062983f, 0.151053f, 0.674313f, 0.122651f, 0.053560f, 0.429902f,
0.288827f, 0.105061f, 0.210018f, 0.186806f, 0.083884f, 0.292431f, 0.226862f, 0.755602f,
0.087596f, 0.093867f, 0.025202f, 0.037733f, 0.138704f, 0.051872f, 0.038581f, 0.157132f,
0.613710f, 0.014387f, 0.142873f, 0.040988f, 0.207912f, 0.593840f, 0.030158f, 0.459283f,
0.037243f, 0.414007f, 0.059308f, 0.172431f, 0.031081f, 0.458870f, 0.135577f, 0.202041f,
0.189875f, 0.216251f, 0.316144f, 0.159199f, 0.118532f, 0.124100f, 0.142279f, 0.042099f,
0.389594f, 0.301928f, 0.426283f, 0.158370f, 0.210563f, 0.186064f, 0.018720f, 0.167332f,
0.186098f, 0.037526f, 0.429068f, 0.179977f, 0.523120f, 0.228715f, 0.101000f, 0.096920f,
0.050245f, 0.034806f, 0.067903f, 0.558854f, 0.296585f, 0.041851f, 0.066438f, 0.706390f,
0.096718f, 0.032330f, 0.098123f, 0.406015f, 0.203288f, 0.089809f, 0.206245f, 0.094643f
};
Tensor t_input_10 = create_tensor(shape_10, input_data_10, false);
Tensor t_output_10 = nn_softmax(t_input_10, dim_10);
Tensor t_expected_10 = create_tensor(shape_10, expected_output_10, false);
assert_equal(t_output_10, t_expected_10);
// Backward Pass
// --- Test Case 1: Shape=[6], Dim=0 ---
TensorShape shape_11 = { 6, 0, 0, 0 };
int dim_11 = 0;
float softmax_output_data_11[] = {
0.491293f, 0.290196f, 0.099553f, 0.041380f, 0.034337f, 0.043241f
};
float upstream_grad_data_11[] = {
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f
};
float expected_grad_11[] = {
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f
};
Tensor t_softmax_output_11 = create_tensor(shape_11, softmax_output_data_11, false);
Tensor t_upstream_grad_11 = create_tensor(shape_11, upstream_grad_data_11, false);
Tensor t_expected_grad_11 = create_tensor(shape_11, expected_grad_11, false);
Tensor t_softmax = nn_softmax(t_softmax_output_11, dim_11);
Tensor_backward(t_softmax, t_upstream_grad_11);
assert_equal(t_softmax.node->grad, t_expected_grad_11);
// --- Test Case 2: Shape=[4, 5], Dim=0 ---
TensorShape shape_22 = { 4, 5, 0, 0 };
int dim_22 = 0;
float softmax_output_data_2[] = {
0.259828f, 0.047797f, 0.277548f, 0.280166f, 0.185578f, 0.131360f, 0.137473f, 0.089268f,
0.040785f, 0.266524f, 0.194610f, 0.431466f, 0.113199f, 0.140672f, 0.308902f, 0.414202f,
0.383265f, 0.519984f, 0.538377f, 0.238995f
};
float upstream_grad_data_2[] = {
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f
};
float expected_grad_2[] = {
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f
};
Tensor t_softmax_output_22 = create_tensor(shape_22, softmax_output_data_2, false);
Tensor t_upstream_grad_22 = create_tensor(shape_22, upstream_grad_data_2, false);
Tensor t_expected_grad_22 = create_tensor(shape_22, expected_grad_2, false);
Tensor t_softmax_22 = nn_softmax(t_softmax_output_22, dim_22);
Tensor_backward(t_softmax, t_upstream_grad_22);
assert_equal(t_softmax.node->grad, t_expected_grad_22);
// --- Test Case 3: Shape=[4, 5], Dim=1 ---
TensorShape shape_33 = { 4, 5, 0, 0 };
int dim_33 = 1;
float softmax_output_data_3[] = {
0.281295f, 0.241706f, 0.252463f, 0.049890f, 0.174646f, 0.172607f, 0.073895f, 0.330703f,
0.359079f, 0.063717f, 0.308131f, 0.021912f, 0.235711f, 0.376043f, 0.058204f, 0.071114f,
0.383512f, 0.142979f, 0.187754f, 0.214642f
};
float upstream_grad_data_3[] = {
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f
};
float expected_grad_3[] = {
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f
};
Tensor t_softmax_output_33 = create_tensor(shape_33, softmax_output_data_3, false);
Tensor t_upstream_grad_33 = create_tensor(shape_33, upstream_grad_data_3, false);
Tensor t_expected_grad_33 = create_tensor(shape_33, expected_grad_3, false);
Tensor t_softmax_33 = nn_softmax(t_softmax_output_33, dim_33);
Tensor_backward(t_softmax, t_upstream_grad_33);
assert_equal(t_softmax.node->grad, t_expected_grad_33);
// --- Test Case 4: Shape=[3, 4, 5], Dim=0 ---
TensorShape shape_44 = { 3, 4, 5, 0 };
int dim_44 = 0;
float softmax_output_data_44[] = {
0.300884f, 0.460076f, 0.079808f, 0.188380f, 0.535354f, 0.090242f, 0.294035f, 0.029022f,
0.035824f, 0.898965f, 0.096504f, 0.737632f, 0.189031f, 0.509308f, 0.434179f, 0.458033f,
0.195747f, 0.209742f, 0.071491f, 0.216761f, 0.584676f, 0.350448f, 0.065936f, 0.223344f,
0.361413f, 0.644221f, 0.637309f, 0.854081f, 0.859709f, 0.015247f, 0.299948f, 0.168507f,
0.094690f, 0.174303f, 0.254352f, 0.198134f, 0.048436f, 0.084396f, 0.566486f, 0.164950f,
0.114440f, 0.189476f, 0.854256f, 0.588275f, 0.103233f, 0.265537f, 0.068656f, 0.116896f,
0.104467f, 0.085789f, 0.603549f, 0.093861f, 0.716278f, 0.316389f, 0.311469f, 0.343833f,
0.755817f, 0.705862f, 0.362023f, 0.618290f
};
float upstream_grad_data_4[] = {
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f
};
float expected_grad_4[] = {
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f
};
Tensor t_softmax_output_44 = create_tensor(shape_44, softmax_output_data_44, false);
Tensor t_upstream_grad_44 = create_tensor(shape_44, upstream_grad_data_4, false);
Tensor t_expected_grad_44 = create_tensor(shape_44, expected_grad_4, false);
Tensor t_softmax_44 = nn_softmax(t_softmax_output_44, dim_44);
Tensor_backward(t_softmax, t_upstream_grad_44);
assert_equal(t_softmax.node->grad, t_expected_grad_44);
// --- Test Case 5: Shape=[3, 4, 5], Dim=1 ---
TensorShape shape_55 = { 3, 4, 5, 0 };
int dim_55 = 1;
float softmax_output_data_55[] = {
0.213736f, 0.071841f, 0.493278f, 0.286840f, 0.112850f, 0.434872f, 0.138537f, 0.265375f,
0.182145f, 0.133135f, 0.103070f, 0.659365f, 0.101280f, 0.097179f, 0.500021f, 0.248322f,
0.130257f, 0.140068f, 0.433836f, 0.253994f, 0.491888f, 0.530057f, 0.069191f, 0.268460f,
0.311649f, 0.198072f, 0.174156f, 0.688571f, 0.316698f, 0.235733f, 0.201305f, 0.279905f,
0.094012f, 0.241565f, 0.223111f, 0.108736f, 0.015882f, 0.148226f, 0.173276f, 0.229507f,
0.111989f, 0.140138f, 0.177901f, 0.145886f, 0.395023f, 0.659009f, 0.336318f, 0.038804f,
0.182494f, 0.167875f, 0.200232f, 0.187889f, 0.192922f, 0.349600f, 0.325372f, 0.028770f,
0.335654f, 0.590373f, 0.322020f, 0.111729f
};
float upstream_grad_data_5[] = {
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f
};
float expected_grad_5[] = {
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
-0.000000f, -0.000000f, 0.000000f, 0.000000f, 0.000000f, -0.000000f, -0.000000f, 0.000000f,
0.000000f, 0.000000f, -0.000000f, -0.000000f, 0.000000f, 0.000000f, 0.000000f, -0.000000f,
-0.000000f, 0.000000f, 0.000000f, 0.000000f
};
Tensor t_softmax_output_55 = create_tensor(shape_55, softmax_output_data_55, false);
Tensor t_upstream_grad_55 = create_tensor(shape_55, upstream_grad_data_5, false);
Tensor t_expected_grad_55 = create_tensor(shape_55, expected_grad_5, false);
Tensor t_softmax_55 = nn_softmax(t_softmax_output_55, dim_55);
Tensor_backward(t_softmax_55, t_upstream_grad_55);
assert_equal(t_softmax_55.node->grad, t_expected_grad_55);
// --- Test Case 6: Shape=[3, 4, 5], Dim=2 ---
TensorShape shape_66 = { 3, 4, 5, 0 };
int dim_66 = 2;
float softmax_output_data_66[] = {
0.274572f, 0.327944f, 0.033119f, 0.192682f, 0.171683f, 0.329886f, 0.043537f, 0.133894f,
0.428686f, 0.063996f, 0.276949f, 0.062181f, 0.250740f, 0.281050f, 0.129080f, 0.316415f,
0.133936f, 0.131813f, 0.277588f, 0.140248f, 0.163021f, 0.158821f, 0.459108f, 0.069788f,
0.149263f, 0.025376f, 0.065427f, 0.082601f, 0.661098f, 0.165497f, 0.092270f, 0.146426f,
0.070573f, 0.609448f, 0.081284f, 0.203390f, 0.181041f, 0.153127f, 0.200254f, 0.262189f,
0.068847f, 0.058354f, 0.542534f, 0.274029f, 0.056236f, 0.212700f, 0.053144f, 0.120847f,
0.483704f, 0.129605f, 0.206116f, 0.459948f, 0.105415f, 0.144628f, 0.083892f, 0.379806f,
0.127531f, 0.188214f, 0.244900f, 0.059550f
};
float upstream_grad_data_6[] = {
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f
};
float expected_grad_6[] = {
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f
};
Tensor t_softmax_output_66 = create_tensor(shape_66, softmax_output_data_66, false);
Tensor t_upstream_grad_66 = create_tensor(shape_66, upstream_grad_data_6, false);
Tensor t_expected_grad_66 = create_tensor(shape_66, expected_grad_6, false);
Tensor t_softmax_66 = nn_softmax(t_softmax_output_66, dim_66);
Tensor_backward(t_softmax_66, t_upstream_grad_66);
assert_equal(t_softmax_66.node->grad, t_expected_grad_66);
// --- Test Case 7: Shape=[2, 3, 4, 5], Dim=0 ---
TensorShape shape_77 = { 2, 3, 4, 5 };
int dim_77 = 0;
float softmax_output_data_77[] = {
0.569625f, 0.125877f, 0.471346f, 0.500301f, 0.591379f, 0.292845f, 0.130029f, 0.866195f,
0.602569f, 0.597063f, 0.841322f, 0.162126f, 0.406428f, 0.866877f, 0.471768f, 0.615089f,
0.137583f, 0.670704f, 0.636926f, 0.389366f, 0.384487f, 0.651319f, 0.090905f, 0.821931f,
0.172317f, 0.785127f, 0.268635f, 0.526123f, 0.093560f, 0.311492f, 0.276829f, 0.195481f,
0.266013f, 0.279443f, 0.762119f, 0.471888f, 0.114459f, 0.624089f, 0.681086f, 0.256820f,
0.577499f, 0.522314f, 0.455161f, 0.100279f, 0.845585f, 0.194294f, 0.543203f, 0.569020f,
0.474847f, 0.273765f, 0.303435f, 0.703826f, 0.921722f, 0.701926f, 0.449861f, 0.601013f,
0.036246f, 0.620161f, 0.322930f, 0.856022f, 0.430375f, 0.874123f, 0.528654f, 0.499700f,
0.408621f, 0.707155f, 0.869971f, 0.133805f, 0.397431f, 0.402938f, 0.158678f, 0.837874f,
0.593572f, 0.133123f, 0.528232f, 0.384911f, 0.862417f, 0.329296f, 0.363074f, 0.610634f,
0.615513f, 0.348681f, 0.909096f, 0.178069f, 0.827683f, 0.214873f, 0.731365f, 0.473877f,
0.906440f, 0.688508f, 0.723171f, 0.804519f, 0.733987f, 0.720557f, 0.237881f, 0.528112f,
0.885541f, 0.375911f, 0.318914f, 0.743180f, 0.422501f, 0.477685f, 0.544839f, 0.899721f,
0.154414f, 0.805706f, 0.456797f, 0.430980f, 0.525153f, 0.726235f, 0.696565f, 0.296174f,
0.078278f, 0.298074f, 0.550139f, 0.398987f, 0.963754f, 0.379839f, 0.677070f, 0.143978f
};
float upstream_grad_data_77[] = {
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f
};
float expected_grad_77[] = {
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, -0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, -0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f
};
Tensor t_softmax_output_77 = create_tensor(shape_77, softmax_output_data_77, false);
Tensor t_upstream_grad_77 = create_tensor(shape_77, upstream_grad_data_77, false);
Tensor t_expected_grad_77 = create_tensor(shape_77, expected_grad_77, false);
Tensor t_softmax_77 = nn_softmax(t_softmax_output_77, dim_77);
Tensor_backward(t_softmax_77, t_upstream_grad_77);
assert_equal(t_softmax_77.node->grad, t_expected_grad_77);
// --- Test Case 8: Shape=[2, 3, 4, 5], Dim=1 ---
TensorShape shape_88 = { 2, 3, 4, 5 };
int dim_88 = 1;
float softmax_output_data_88[] = {
0.349329f, 0.468174f, 0.446961f, 0.105402f, 0.353047f, 0.247731f, 0.227469f, 0.284667f,
0.345296f, 0.315446f, 0.267886f, 0.386768f, 0.446913f, 0.150942f, 0.266561f, 0.462933f,
0.302654f, 0.082296f, 0.472487f, 0.324861f, 0.430109f, 0.470517f, 0.173953f, 0.380958f,
0.458720f, 0.374105f, 0.346718f, 0.266339f, 0.318630f, 0.459433f, 0.351063f, 0.358115f,
0.157742f, 0.671078f, 0.147977f, 0.020363f, 0.021255f, 0.822086f, 0.452001f, 0.456552f,
0.220562f, 0.061309f, 0.379086f, 0.513640f, 0.188233f, 0.378163f, 0.425812f, 0.448994f,
0.336075f, 0.225121f, 0.381051f, 0.255117f, 0.395346f, 0.177980f, 0.585462f, 0.516704f,
0.676090f, 0.095618f, 0.075512f, 0.218587f, 0.100579f, 0.208679f, 0.587373f, 0.037468f,
0.122700f, 0.586855f, 0.079803f, 0.348755f, 0.087391f, 0.544201f, 0.154357f, 0.195805f,
0.740805f, 0.489011f, 0.171004f, 0.098661f, 0.695403f, 0.564536f, 0.744950f, 0.145165f,
0.638400f, 0.394755f, 0.251423f, 0.197920f, 0.448250f, 0.279957f, 0.066754f, 0.514423f,
0.298992f, 0.173771f, 0.390687f, 0.100924f, 0.181796f, 0.484352f, 0.293591f, 0.673578f,
0.139546f, 0.187736f, 0.109743f, 0.504062f, 0.261021f, 0.396566f, 0.161204f, 0.764612f,
0.429050f, 0.133188f, 0.853443f, 0.136822f, 0.613616f, 0.282029f, 0.454956f, 0.703271f,
0.077399f, 0.026637f, 0.535405f, 0.227761f, 0.165051f, 0.247727f, 0.145307f, 0.350772f
};
float upstream_grad_data_88[] = {
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f
};
float expected_grad_88[] = {
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, -0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
-0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, -0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f
};
Tensor t_softmax_output_88 = create_tensor(shape_88, softmax_output_data_88, false);
Tensor t_upstream_grad_88 = create_tensor(shape_88, upstream_grad_data_88, false);
Tensor t_expected_grad_88 = create_tensor(shape_88, expected_grad_88, false);
Tensor t_softmax_88 = nn_softmax(t_softmax_output_88, dim_88);
Tensor_backward(t_softmax_88, t_upstream_grad_88);
assert_equal(t_softmax_88.node->grad, t_expected_grad_88);
// --- Test Case 9: Shape=[2, 3, 4, 5], Dim=2 ---
TensorShape shape_99 = { 2, 3, 4, 5 };
int dim_99 = 2;
float softmax_output_data_99[] = {
0.291880f, 0.125233f, 0.182044f, 0.498073f, 0.404432f, 0.169677f, 0.511194f, 0.562892f,
0.223964f, 0.301325f, 0.283817f, 0.230031f, 0.109788f, 0.121532f, 0.102276f, 0.254626f,
0.133543f, 0.145276f, 0.156432f, 0.191967f, 0.602902f, 0.167705f, 0.505784f, 0.006449f,
0.153610f, 0.213243f, 0.144857f, 0.216105f, 0.271212f, 0.187101f, 0.057155f, 0.034944f,
0.202119f, 0.621340f, 0.582470f, 0.126700f, 0.652493f, 0.075992f, 0.100999f, 0.076819f,
0.025284f, 0.535498f, 0.321159f, 0.547032f, 0.042252f, 0.055964f, 0.051368f, 0.214569f,
0.171173f, 0.068723f, 0.155868f, 0.315531f, 0.373057f, 0.090525f, 0.074880f, 0.762884f,
0.097602f, 0.091214f, 0.191271f, 0.814145f, 0.364001f, 0.375110f, 0.335413f, 0.776141f,
0.336847f, 0.021910f, 0.342610f, 0.216313f, 0.024883f, 0.461198f, 0.331734f, 0.255358f,
0.047576f, 0.170723f, 0.174370f, 0.282355f, 0.026922f, 0.400698f, 0.028253f, 0.027585f,
0.126244f, 0.380872f, 0.172078f, 0.081861f, 0.299249f, 0.453514f, 0.559052f, 0.085891f,
0.174117f, 0.077363f, 0.349205f, 0.027823f, 0.148069f, 0.249145f, 0.076060f, 0.071037f,
0.032253f, 0.593962f, 0.494877f, 0.547328f, 0.372310f, 0.493528f, 0.478135f, 0.109682f,
0.117111f, 0.478561f, 0.015935f, 0.303751f, 0.041923f, 0.087090f, 0.030095f, 0.353678f,
0.081628f, 0.497504f, 0.069155f, 0.119034f, 0.136859f, 0.136486f, 0.350890f, 0.726643f
};
float upstream_grad_data_9[] = {
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f
};
float expected_grad_9[] = {
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f
};
Tensor t_softmax_output_99 = create_tensor(shape_99, softmax_output_data_99, false);
Tensor t_upstream_grad_99 = create_tensor(shape_99, upstream_grad_data_9, false);
Tensor t_expected_grad_99 = create_tensor(shape_99, expected_grad_9, false);
Tensor t_softmax_99 = nn_softmax(t_softmax_output_99, dim_99);
Tensor_backward(t_softmax_99, t_upstream_grad_99);
assert_equal(t_softmax_99.node->grad, t_expected_grad_99);
// --- Test Case 10: Shape=[2, 3, 4, 5], Dim=3 ---
TensorShape shape_1010 = { 2, 3, 4, 5 };
int dim_1010 = 3;
float softmax_output_data_1010[] = {
0.066401f, 0.158825f, 0.325174f, 0.193684f, 0.255916f, 0.211079f, 0.241442f, 0.225608f,
0.269596f, 0.052274f, 0.098790f, 0.040214f, 0.116746f, 0.541744f, 0.202506f, 0.135448f,
0.114021f, 0.116338f, 0.496727f, 0.137465f, 0.108314f, 0.121477f, 0.275101f, 0.061351f,
0.433757f, 0.047564f, 0.289791f, 0.223113f, 0.371671f, 0.067861f, 0.035685f, 0.048276f,
0.131985f, 0.378139f, 0.405915f, 0.070004f, 0.062254f, 0.583989f, 0.100436f, 0.183317f,
0.305591f, 0.033705f, 0.279647f, 0.223692f, 0.157365f, 0.109443f, 0.196573f, 0.185457f,
0.190898f, 0.317629f, 0.212507f, 0.346504f, 0.137915f, 0.219323f, 0.083751f, 0.187271f,
0.066957f, 0.069592f, 0.567064f, 0.109115f, 0.042001f, 0.062575f, 0.132275f, 0.720434f,
0.042715f, 0.376084f, 0.302346f, 0.042666f, 0.191719f, 0.087185f, 0.117984f, 0.152122f,
0.220522f, 0.292303f, 0.217070f, 0.194241f, 0.051929f, 0.007380f, 0.015124f, 0.731326f,
0.243276f, 0.393607f, 0.231850f, 0.077948f, 0.053319f, 0.202010f, 0.274588f, 0.056622f,
0.054226f, 0.412553f, 0.129608f, 0.203751f, 0.300948f, 0.043970f, 0.321724f, 0.038599f,
0.337018f, 0.276438f, 0.029686f, 0.318260f, 0.810219f, 0.025352f, 0.018812f, 0.045306f,
0.100312f, 0.231369f, 0.149545f, 0.032563f, 0.462638f, 0.123885f, 0.029724f, 0.117840f,
0.011977f, 0.047307f, 0.793153f, 0.020167f, 0.540207f, 0.078490f, 0.079111f, 0.282025f
};
float upstream_grad_data_10[] = {
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f
};
float expected_grad_10[] = {
-0.000000f, -0.000000f, -0.000000f, -0.000000f, -0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, -0.000000f, -0.000000f, -0.000000f, -0.000000f, -0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f,
0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f
};
Tensor t_softmax_output_1010 = create_tensor(shape_1010, softmax_output_data_1010, false);
Tensor t_upstream_grad_1010 = create_tensor(shape_1010, upstream_grad_data_10, false);
Tensor t_expected_grad_1010 = create_tensor(shape_1010, expected_grad_10, false);
Tensor t_softmax_1010 = nn_softmax(t_softmax_output_1010, dim_1010);
Tensor_backward(t_softmax_1010, t_upstream_grad_1010);
assert_equal(t_softmax_1010.node->grad, t_expected_grad_1010);
cten_end_malloc();
cten_finalize();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment