Last active
August 9, 2025 13:52
-
-
Save kinoc/f3225092092e07b843e3a2798f7b3986 to your computer and use it in GitHub Desktop.
Simplest FastAPI endpoint for EleutherAI GPT-J-6B
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Near Simplest Language model API, with room to expand! | |
# runs GPT-J-6B on 3090 and TITAN and servers it using FastAPI | |
# change "seq" (which is the context size) to adjust footprint | |
# | |
# seq vram usage | |
# 512 14.7G | |
# 900 15.3G | |
# uses FastAPI, so install that | |
# https://fastapi.tiangolo.com/tutorial/ | |
# pip install fastapi | |
# pip install uvicorn[standard] | |
# uses https://github.com/kingoflolz/mesh-transformer-jax | |
# so install jax on your system so recommend you get it working with your GPU first | |
# !apt install zstd | |
# the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory | |
# wget https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd | |
# tar -I zstd -xf step_383500_slim.tar.zstd | |
# git clone https://github.com/kingoflolz/mesh-transformer-jax.git | |
# pip install -r mesh-transformer-jax/requirements.txt | |
# jax 0.2.12 is required due to a regression with xmap in 0.2.13 | |
# pip install mesh-transformer-jax/ jax==0.2.12 | |
# I have cuda 10.1 and python 3.9 so had to update | |
# pip3 install --upgrade "https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.66+cuda101-cp39-none-manylinux2010_x86_64.whl" | |
# GO: local execution | |
# XLA_PYTHON_CLIENT_PREALLOCATE=false XLA_PYTHON_CLIENT_ALLOCATOR=platform CUDA_VISIBLE_DEVICES=0 python3 jserv.py | |
# When done try | |
# http://localhost:8000/docs#/default/read_completions_engines_completions_post | |
# now you are in FastAPI + EleutherAI land | |
# note: needs async on the read_completions otherwise jax gets upset | |
# remember to adjust the location of the checkpoint image | |
import argparse | |
import time | |
from typing import Optional | |
from typing import Dict | |
from fastapi import FastAPI | |
import uvicorn | |
import os | |
import requests | |
import threading | |
import jax | |
from jax.experimental import maps | |
from jax.config import config | |
import numpy as np | |
import optax | |
import transformers | |
from mesh_transformer.checkpoint import read_ckpt | |
from mesh_transformer.sampling import nucleaus_sample | |
from mesh_transformer.transformer_shard import CausalTransformer | |
app = FastAPI() | |
params = { | |
"layers": 28, | |
"d_model": 4096, | |
"n_heads": 16, | |
"n_vocab": 50400, | |
"norm": "layernorm", | |
"pe": "rotary", | |
"pe_rotary_dims": 64, | |
"early_cast": True, | |
"seq": 768, | |
"cores_per_replica": 1, | |
"per_replica_batch": 1, | |
} | |
#>> INFO <<: adjust the location of the checkpoint image | |
check_point_dir="../step_383500/" | |
per_replica_batch = params["per_replica_batch"] | |
cores_per_replica = params["cores_per_replica"] | |
seq = params["seq"] | |
params["sampler"] = nucleaus_sample | |
# here we "remove" the optimizer parameters from the model (as we don't need them for inference) | |
params["optimizer"] = optax.scale(0) | |
print("jax.device_count ",jax.device_count()) | |
print("jax.devices ",jax.devices()) | |
print("cores_per_replica ",cores_per_replica) | |
mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica) | |
#devices = np.array(jax.devices()).reshape(mesh_shape) | |
devices = np.array([jax.devices()[0]]).reshape((1, 1)) | |
maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp'))) | |
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') | |
total_batch = per_replica_batch * jax.device_count() // cores_per_replica | |
print("CausalTransformer") | |
network = CausalTransformer(params) | |
#here we load a checkpoint which was written with 8 shards into 1 shard | |
print("read_ckpt") | |
network.state = read_ckpt(network.state, check_point_dir,8,shards_out=cores_per_replica) | |
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) | |
#move the state to CPU/system memory so it's not duplicated by xmap | |
network.state = jax.device_put(network.state, jax.devices("cpu")[0]) | |
def infer(context,top_k=40, top_p=0.9, temp=1.0, gen_len=512): | |
global network | |
start = time.time() | |
tokens = tokenizer.encode(context) | |
provided_ctx = len(tokens) | |
pad_amount = seq - provided_ctx | |
padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32) | |
batched_tokens = np.array([padded_tokens] * total_batch) | |
length = np.ones(total_batch, dtype=np.uint32) * len(tokens) | |
start = time.time() | |
#output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp}) | |
#output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(per_replica_batch) * top_p, "temp": np.ones(per_replica_batch) * temp}) | |
output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(per_replica_batch) * top_p, "top_k": top_k is not None and (np.ones(per_replica_batch, dtype=np.int32) * top_k) or None, "temp": np.ones(per_replica_batch) * temp}) | |
samples = [] | |
decoded_tokens = output[1][0] | |
for o in decoded_tokens[:, :, 0]: | |
samples.append(tokenizer.decode(o)) | |
print(f"completion done in {time.time() - start:06}s") | |
return samples | |
def recursive_infer(initial_context, current_context=None, top_k=40, top_p=0.9, temp=1.0, gen_len=512, depth=0, max_depth=5,recursive_refresh=0): | |
lcc=0 | |
if current_context : | |
lcc = len(current_context) | |
print ("recursive_infer:{} {} {} {}".format(len(initial_context),lcc,depth,max_depth)) | |
c='' | |
if not current_context : | |
c = initial_context | |
else: | |
if (recursive_refresh == 1): | |
c= initial_context + "\r\n ... \r\n" | |
c = c + current_context | |
print ("cc:{}".format(c)) | |
i = infer(c, top_k, top_p, temp, gen_len)[0] | |
#yield i[len(c):] | |
yield i | |
if depth >= max_depth: return | |
yield from recursive_infer(initial_context, i,top_k, top_p, temp, gen_len, depth+1, max_depth) | |
print("PRETEST") | |
#warms up the processing on startup | |
pre_prompt = "I am the EleutherAI / GPT-J-6B based AI language model server. I will" | |
print (pre_prompt) | |
print(infer(pre_prompt)[0]) | |
print("SERVER SERVING") | |
@app.post("/engines/completions") | |
async def read_completions( | |
#engine_id:str, | |
prompt:Optional[str] = None, | |
max_tokens: Optional[int]=16, | |
temperature: Optional[float]=1.0, | |
top_p:Optional[float]=1.0, | |
top_k:Optional[int]=40, | |
n:Optional[int]=1, | |
stream:Optional[bool]=False, | |
logprobs:Optional[int]=None, | |
echo:Optional[bool]=False, | |
stop:Optional[list]=None, | |
presence_penalty:Optional[float]=0.0001, | |
frequency_penalty:Optional[float]=0.0001, | |
best_of:Optional[int]=1, | |
recursive_depth:Optional[int]=0, | |
recursive_refresh:Optional[int]=0, | |
logit_bias:Optional[Dict[str,float]]=None | |
): | |
text = str(prompt) | |
text = text.replace("|","\r\n") | |
prompt_len = len(text) | |
#ids = tokenizer(text, return_tensors="pt").input_ids.to("cuda") | |
tokens = tokenizer.encode(text) | |
max_length = max_tokens + len(tokens) | |
do_sample=True | |
use_cache=True | |
start = time.time() | |
num_return_sequences=n | |
num_beams = n | |
num_beam_groups=n | |
mydata = threading.local() | |
mydata.env=None | |
if (recursive_depth== 0): | |
gtext= infer(context=text, top_p=top_p,top_k=top_k, temp=temperature, gen_len=max_length) | |
else: | |
gtext = recursive_infer(initial_context=text,current_context=None, top_p=top_p,top_k=top_k, temp=temperature, gen_len=max_length, depth=0, max_depth = recursive_depth,recursive_refresh=recursive_refresh) | |
last_prompt=text | |
choices=[] | |
gen_text='' | |
for i,out_seq in enumerate(gtext): | |
choice={} | |
choice['prompt']=last_prompt | |
choice['text']=out_seq | |
choice['index']=i | |
choice['logprobs']=None | |
choice['finish_reason']='length' | |
choices.append(choice) | |
print("GenText[{}]:{}".format(i,choice['text'])) | |
gen_text = gen_text + choice['text'] | |
if (recursive_depth==0): | |
last_prompt = text | |
else: | |
last_prompt = out_seq | |
if (recursive_refresh==1): | |
last_prompt = text +"\r\n ... \r\n"+out_seq | |
#gen_text = tokenizer.batch_decode(gen_tokens)[0] | |
fin = time.time() | |
elapsed = fin - start | |
cps = (len(gen_text)-prompt_len) / elapsed | |
print("elapsed:{} len:{} cps:{}".format(elapsed,len(gen_text),cps)) | |
response={} | |
response['id']='' | |
response['object']='text_completion' | |
response['created']='' | |
response['model']= 'GPT-J-6B' #args.model | |
response['choices']=choices | |
return(response) | |
#if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |
print ("Happy Service!") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello,
This is more of an informational post for folks trying to run this on Windows. I just wanted anyone who is using Windows to know that it is actually possible to get this going with pretty solid results. Thanks to those who put it together.
I have a RTX 3090 and was trying to get this to work on Windows. I installed all the dependencies (Jaxlib first) using a specifically created Anaconda Python 3.9 environment (Conda is 3.84 by default), along with the latest version of CUDA (11.4 at this writing) from Nvidia.
I decoded the pretrained model using PeaZip and three separate extractions rather than the command prompt Linux method.
Jaxlib was a particular challenge when working through the dependencies, but I managed to find a compiled wheel here that supported GPU acceleration (I used 1.68):
https://github.com/erwincoumans/jax/tags
Numpy threw an error on the next run of the script, but I got around it by upgrading Numpy.
The Nvidia lib files were not originally located by Jax, but I fixed this by copying them to one of the designated (in the error message) search locations, the "CUDA_V11.0" directory on my D: drive.
I had to lower the seq to 512 to get it to run without throwing "IMAGE_REL_AMD64_ADDR32NB relocation requires an ordered section layout" as an error.
I had to hardcode the pretrained model directory in order for it to be found on Windows. It was important to use forward slashes rather than back slashes in the path.
After that, the model ran and I was able to travel to :
http://localhost:8000/docs#/default/read_completions_engines_completions_post
and get it going.