Skip to content

Instantly share code, notes, and snippets.

View cgarciae's full-sized avatar

Cristian Garcia cgarciae

View GitHub Profile
@cgarciae
cgarciae / fasthtml_and_fastapi_integration.py
Last active August 3, 2024 15:04
FastHTMLResponse class for FastAPI
from typing import Iterable
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
import fasthtml.common as fasthtml
from fasthtml.components import *
app = FastAPI()
class FastHTMLResponse(HTMLResponse):
def render(self, content) -> bytes:
@cgarciae
cgarciae / print.txt
Last active August 25, 2021 20:01
Module print
MyModule:
b: list
- Initializer
- Parameter((5, 13), float32)
a: dict
mlps: list
- MLP:
linear1: Linear
w: Parameter((2, 3), float32)
b: Parameter((3,), float32)
@cgarciae
cgarciae / resnet_random.py
Created November 16, 2020 03:26
memory leak
# PARAMETERS
MODEL = "ResNet50"
OUTPUT_DIRECTORY = "models/resnet50"
EPOCHS = 90
BATCH_SIZE = 6
IMAGE_SIZE = 224 # image size in pixels
DATASET = "imagenet2012:5.1.*" # TFDS dataset name and version
DTYPE = "float16" # float16 for mixed_precision or float32 for normal mode
LEARNING_RATE = 0.1 * BATCH_SIZE / 256.0
MOMENTUM = 0.9
model = elegy.Model(
module=MixtureModel(k=k),
loss=MixtureNLL(),
optimizer=optax.adam(3e-4),
)
model.summary(X_train[:batch_size], depth=1)
model.fit(
x=X_train,
class MixtureNLL(elegy.Loss):
def call(self, y_true, y_pred):
y, probs = y_pred
y_true = jnp.broadcast_to(y_true, (y_true.shape[0], y.shape[1]))
return -safe_log(
jnp.sum(
probs
* jax.scipy.stats.norm.pdf(y_true, loc=y[..., 0], scale=y[..., 1]),
axis=1,
class MixtureModel(elegy.Module):
def __init__(self, k: int):
super().__init__()
self.k = k
def call(self, x):
x = elegy.nn.Linear(64, name="backbone")(x)
x = jax.nn.relu(x)
import numpy as np
import jax
import jax.numpy as jnp
import elegy
import optax
import numpy as np
import jax
import jax.numpy as jnp
import elegy
import optax
@cgarciae
cgarciae / haiku-accuracy.py
Last active July 8, 2020 01:04
Simple cumulative accuracy metric using Haiku and Jax.
class Accuracy(hk.Module):
def __init__(self, y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray:
total = hk.get_state(
"total", shape=[], dtype=jnp.float32, init=hk.initializers.Constant(0)
)
count = hk.get_state(
"count", shape=[], dtype=jnp.int64, init=hk.initializers.Constant(0)
)
from aiohttp import ClientSession, TCPConnector
import asyncio
import sys
import pypeln as pl
limit = 1000
urls = ("http://localhost:8080/{}".format(i) for i in range(int(sys.argv[1])))