Skip to content

Instantly share code, notes, and snippets.

View cgarciae's full-sized avatar

Cristian Garcia cgarciae

View GitHub Profile
import numpy as np
import jax
import jax.numpy as jnp
import elegy
import optax
import numpy as np
import jax
import jax.numpy as jnp
import elegy
import optax
class MixtureModel(elegy.Module):
def __init__(self, k: int):
super().__init__()
self.k = k
def call(self, x):
x = elegy.nn.Linear(64, name="backbone")(x)
x = jax.nn.relu(x)
class MixtureNLL(elegy.Loss):
def call(self, y_true, y_pred):
y, probs = y_pred
y_true = jnp.broadcast_to(y_true, (y_true.shape[0], y.shape[1]))
return -safe_log(
jnp.sum(
probs
* jax.scipy.stats.norm.pdf(y_true, loc=y[..., 0], scale=y[..., 1]),
axis=1,
model = elegy.Model(
module=MixtureModel(k=k),
loss=MixtureNLL(),
optimizer=optax.adam(3e-4),
)
model.summary(X_train[:batch_size], depth=1)
model.fit(
x=X_train,
@cgarciae
cgarciae / resnet_random.py
Created November 16, 2020 03:26
memory leak
# PARAMETERS
MODEL = "ResNet50"
OUTPUT_DIRECTORY = "models/resnet50"
EPOCHS = 90
BATCH_SIZE = 6
IMAGE_SIZE = 224 # image size in pixels
DATASET = "imagenet2012:5.1.*" # TFDS dataset name and version
DTYPE = "float16" # float16 for mixed_precision or float32 for normal mode
LEARNING_RATE = 0.1 * BATCH_SIZE / 256.0
MOMENTUM = 0.9
@cgarciae
cgarciae / print.txt
Last active August 25, 2021 20:01
Module print
MyModule:
b: list
- Initializer
- Parameter((5, 13), float32)
a: dict
mlps: list
- MLP:
linear1: Linear
w: Parameter((2, 3), float32)
b: Parameter((3,), float32)
@cgarciae
cgarciae / fasthtml_and_fastapi_integration.py
Last active August 3, 2024 15:04
FastHTMLResponse class for FastAPI
from typing import Iterable
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
import fasthtml.common as fasthtml
from fasthtml.components import *
app = FastAPI()
class FastHTMLResponse(HTMLResponse):
def render(self, content) -> bytes: