Created
October 20, 2022 04:14
-
-
Save cycyyy/db2cd361bb96f275fca4ad11595060e6 to your computer and use it in GitHub Desktop.
tiling_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from jax import random | |
from neural_tangents import stax | |
random_key = random.PRNGKey(42) | |
SAMPLE_SIZE = 100 | |
BATCH_SIZE = 25 | |
def get_mlp_kernel_fn(): | |
init_fn, apply_fn, kernel_fn = stax.serial( | |
stax.Dense(512), stax.Relu(), | |
stax.Dense(512), stax.Relu(), | |
stax.Dense(1) | |
) | |
return kernel_fn | |
def get_conv_kernel_fn(): | |
init_fn, apply_fn, kernel_fn = stax.serial( | |
stax.Conv(128, (3, 3)), | |
stax.Relu(), | |
stax.Conv(256, (3, 3)), | |
stax.Relu(), | |
stax.Conv(512, (3, 3)), | |
stax.Flatten(), | |
stax.Dense(2) | |
) | |
return kernel_fn | |
def get_kernel(x1, x2, kernel_fn): | |
return kernel_fn(x1, x2, ('nngp', 'ntk')) | |
def get_kernel_batch(x1, x2, kernel_fn, x1_batch, x2_batch): | |
if len(x1) % x1_batch != 0 or len(x2) % x2_batch != 0: | |
raise NotImplementedError( | |
"Not support sample batch size x1:%d-%d x2:%d-%d" % (len(x1), x1_batch, len(x2), x2_batch)) | |
kernel_nngp = np.zeros((len(x1), len(x2))) | |
kernel_ntk = np.zeros((len(x1), len(x2))) | |
for i in range(0, len(x1), x1_batch): | |
for j in range(0, len(x2), x2_batch): | |
x1_tile = x1[i:i + x1_batch] | |
x2_tile = x2[j:j + x2_batch] | |
nngp_tile, ntk_tile = kernel_fn(x1_tile, x2_tile, ('nngp', 'ntk')) | |
kernel_nngp[i:i + x1_batch, j:j + x2_batch] = nngp_tile | |
kernel_ntk[i:i + x1_batch, j:j + x2_batch] = ntk_tile | |
return kernel_nngp, kernel_ntk | |
def test_mlp(): | |
print("test mlp") | |
samples = random.normal(random_key, (SAMPLE_SIZE, 4)) | |
kernel_fn = get_mlp_kernel_fn() | |
nngp_gt, ntk_gt = get_kernel(samples, samples, kernel_fn) | |
nngp_tile, ntk_tile = get_kernel_batch( | |
samples, samples, kernel_fn, BATCH_SIZE, BATCH_SIZE) | |
assert(np.allclose(nngp_gt, nngp_tile)) | |
assert(np.allclose(ntk_gt, ntk_tile)) | |
def test_conv(): | |
print("test conv") | |
samples = random.normal(random_key, (SAMPLE_SIZE, 16, 16, 3)) | |
kernel_fn = get_conv_kernel_fn() | |
nngp_gt, ntk_gt = get_kernel(samples, samples, kernel_fn) | |
nngp_tile, ntk_tile = get_kernel_batch( | |
samples, samples, kernel_fn, BATCH_SIZE, BATCH_SIZE) | |
assert(np.allclose(nngp_gt, nngp_tile)) | |
assert(np.allclose(ntk_gt, ntk_tile)) | |
test_mlp() | |
test_conv() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment