-
-
Save HughPH/c3f2ba64e5dff6a5f8b1fb6549a2272a to your computer and use it in GitHub Desktop.
Run HuggingFace converted GPT-J-6B checkpoint using FastAPI and Ngrok on local GPU (3090 or Titan)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# So you want to run GPT-J-6B using HuggingFace+FastAPI on a local rig (3090 or TITAN) ... tricky. | |
# special help from the Kolob Colab server https://colab.research.google.com/drive/1VFh5DOkCJjWIrQ6eB82lxGKKPgXmsO5D?usp=sharing#scrollTo=iCHgJvfL4alW | |
# Conversion to HF format (12.6GB tar image) found at https://drive.google.com/u/0/uc?id=1NXP75l1Xa5s9K18yf3qLoZcR6p4Wced1&export=download | |
# Uses GDOWN to get the image | |
# You will need 26 GB of space, 12+GB for the tar and 12+GB expanded (you can nuke the tar after expansion) | |
# HPPH: Not sure where you'll find this file, the links I found didn't work and the GDOWN was returning unauthorised errors. Maybe I'll make it a torrent. | |
# HPPH: I also dumped the kobold endpoint. And added one for getting token counts so you can prune your prompt if necessary. | |
# HPPH: And finally... Now the prompt goes in the POST body, which simplifies matters significantly. | |
# Near Simplest Language model API, with room to expand! | |
# runs GPT-J-6B on 3090 and TITAN and servers it using FastAPI | |
# change "seq" (which is the context size) to adjust footprint | |
# | |
# JAX-based | |
# seq vram usage | |
# 512 14.7G | |
# 900 15.3G | |
# | |
# HF-based | |
# seq vram usage | |
# 512 15.6 G | |
# 900 --.- G | |
# | |
# uses FastAPI, so install that | |
# https://fastapi.tiangolo.com/tutorial/ | |
# pip install fastapi | |
# pip install uvicorn[standard] | |
# pip install git+https://github.com/finetuneanon/transformers@gpt-neo-localattention3 | |
# pip install termcolor | |
# pip install gdown | |
# gdown --id 1NXP75l1Xa5s9K18yf3qLoZcR6p4Wced1 --output ../j6b_ckpt.tar | |
# (resutls 12.6GB [18:19], 11.4MB/s] | |
# HPPH: I removed the dependency on ngrok | |
# note: for my setup I needed to perform symlink suggested ny myjr52 in https://github.com/google/jax/issues/5231 | |
# HPPH: I did not need to do this. | |
# https://pytorch.org/get-started/previous-versions/ | |
# for cuda 10.1 | |
# pip install torch==1.8.1+cu101 torchvision==0.9.1+cu101 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html | |
### HPPH: The below didn't work for me, couldn't find any such torch version. I used cu111 | |
# for cuda 11.2 | |
# pip install torch==1.8.1+cu112 torchvision==0.9.1+cu112 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html | |
# conda install python-multipart | |
# -------------------------------------- | |
# chek pyngrok — https://github.com/alexdlaird/pyngrok | |
# install | |
# pip install pyngrok | |
# | |
# Set up your ngrok Authtoken | |
# ngrok authtoken xxxxxxxxxxxxx | |
# GO: local execution | |
# XLA_PYTHON_CLIENT_PREALLOCATE=false XLA_PYTHON_CLIENT_ALLOCATOR=platform CUDA_VISIBLE_DEVICES=0 python3 jserv_hf_fast.py | |
# When done try | |
# http://localhost:8051/docs#/default/read_completions_engines_completions_post | |
# now you are in FastAPI + EleutherAI land | |
# note: needs async on the read_completions otherwise jax gets upset | |
# REMEMBER: adjust the location of the checkpoint image TAR_PATH | |
# | |
# Using plain HF instead of Jax so can comment out JAX related for this install | |
# ----------------------------------------- | |
# # uses https://github.com/kingoflolz/mesh-transformer-jax | |
# # so install jax on your system so recommend you get it working with your GPU first | |
# # !apt install zstd | |
# | |
# # | |
# # the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory | |
# # wget https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd | |
# # tar -I zstd -xf step_383500_slim.tar.zstd | |
# # git clone https://github.com/kingoflolz/mesh-transformer-jax.git | |
# # pip install -r mesh-transformer-jax/requirements.txt | |
# # jax 0.2.12 is required due to a regression with xmap in 0.2.13 | |
# # pip install mesh-transformer-jax/ jax==0.2.12 | |
# # I have cuda 10.1 and python 3.9 so had to update | |
# # pip3 install --upgrade "https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.66+cuda101-cp39-none-manylinux2010_x86_64.whl" | |
# ----------------------------------------- | |
# | |
# Started 2021-06-19 (USA Juneteenth) and released to freedom under MIT | |
# Continued 2021-09-09 (September the 9th) HPPH | |
from termcolor import colored | |
import json | |
import torch | |
import requests | |
import subprocess | |
import tarfile | |
import io | |
import os | |
import re | |
import time | |
import pickle | |
from threading import Timer | |
from typing import Optional | |
from typing import Dict | |
from fastapi import FastAPI, Request, Body | |
import uvicorn | |
import threading | |
import numpy as np | |
import torch | |
import requests | |
from PIL import Image | |
from CLIP import clip | |
from antarcticcaptions import model | |
from antarcticcaptions import utils | |
import argparse | |
import transformers | |
from transformers import GPTNeoForCausalLM, AutoConfig, AutoTokenizer, GPT2Tokenizer | |
print(colored("Server Initialization ...", "magenta")) | |
gptmodel = None | |
tokenizer = None | |
# ------------------------------------------ | |
# REMEMBER: Change these settings to local values | |
active_model = '' | |
runtime_gpu = "cuda:0" | |
TAR_PATH = "model/" | |
check_point_dir = "model/j6b_ckpt" | |
SERVER_PORT = 9995 | |
# ----------------------------------------- | |
# https://stackoverflow.com/questions/48152674/how-to-check-if-pytorch-is-using-the-gpu | |
report_color = "green" | |
if (not torch.cuda.is_available()): report_color = "red" | |
print(colored(" torch.cuda.is_available() = " + str(torch.cuda.is_available()), report_color)) | |
print(colored(" torch.cuda.current_device() = " + str(torch.cuda.current_device()), report_color)) | |
print(colored(" torch.cuda.device_count() = " + str(torch.cuda.device_count()), report_color)) | |
print(colored(" torch.cuda.get_device_name(0) = " + str(torch.cuda.get_device_name()), report_color)) | |
print(colored(" Mem Allocated:{}GB".format(round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)), report_color)) | |
print(colored(" Mem Cached: {}GB".format(round(torch.cuda.memory_reserved(0) / 1024 ** 3, 1)), report_color)) | |
# Set path to tar file and unpack it | |
model_on_drive = TAR_PATH + "j6b_ckpt.tar" | |
print(colored("Checking j6b_ckpt ...", "magenta")) | |
print(colored(" TAR_PATH ={}".format(TAR_PATH), "green")) | |
print(colored(" check_point_dir ={}".format(check_point_dir), "green")) | |
print(colored(" model_on_drive ={}".format(model_on_drive), "green")) | |
if (not os.path.isdir(check_point_dir)): | |
print(colored("Unpacking tar file, please wait...", "magenta")) | |
tar = tarfile.open(model_on_drive, "r") | |
tar.extractall() | |
tar.close() | |
else: | |
print(colored("Expanded Checkpoint directory found", "green")) | |
# Initialize the model | |
print(colored("Initializing model, please wait...", "magenta")) | |
config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B") | |
config.num_layers = 28 | |
config.attention_layers = ["global"] * 28 | |
config.attention_types = [["global"], 28] | |
config.num_heads = 16 | |
config.hidden_size = 256 * config.num_heads | |
config.vocab_size = 50400 | |
config.rotary = True | |
config.rotary_dim = 64 | |
config.jax = True | |
config.output_hidden_states = True | |
try: | |
from collections.abc import MutableMapping | |
except ImportError: | |
from collections import MutableMapping | |
from pathlib import Path | |
class Checkpoint(MutableMapping): | |
def __init__(self, chkpt_dir, device="cpu"): | |
self.device = device | |
self.chkpt_dir = Path(chkpt_dir) | |
self.checkpoint = torch.load(str(chkpt_dir / Path("m.pt"))) | |
def __len__(self): | |
return len(self.checkpoint) | |
def __getitem__(self, key): | |
path = self.chkpt_dir / Path(self.checkpoint[key]).name | |
return torch.load(str(path), map_location=self.device) | |
def __setitem__(self, key, value): | |
return | |
def __delitem__(self, key): | |
return | |
def keys(self): | |
return self.checkpoint.keys() | |
def __iter__(self): | |
for key in self.checkpoint: | |
yield (key, self.__getitem__(key)) | |
def __copy__(self): | |
return Checkpoint(self.chkpt_dir, device=self.device) | |
def copy(self): | |
return Checkpoint(self.chkpt_dir, device=self.device) | |
print(colored("loading GPT Neo 2.7B tokenizer.", "magenta")) | |
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B") | |
vocab = tokenizer.get_vocab() | |
vocab_keys = vocab.keys() | |
find_keys = lambda char: [key for key in vocab_keys if key.find(char) != -1] | |
bad_words = [] | |
bad_words_ids = [] | |
bad_words.extend(find_keys("[")) | |
bad_words.extend(find_keys(" [")) | |
bad_words.extend(find_keys("<|endoftext|>")) | |
for key in bad_words: | |
bad_id = vocab[key] | |
bad_words_ids.append([bad_id]) | |
print(colored("loading GPTNeoForCausalLM.from_pretrained", "magenta")) | |
print(colored(" loading from {}".format(check_point_dir), "green")) | |
gptmodel = GPTNeoForCausalLM.from_pretrained(pretrained_model_name_or_path=None, config=config, state_dict=Checkpoint(check_point_dir)) | |
print(colored(" move to GPU", "magenta")) | |
gptmodel.to(runtime_gpu) | |
print(colored(" >>>> DONE! <<<<", "green")) | |
app = FastAPI() | |
@app.route("/") | |
def home(): | |
return "<h1>EleutherAI J6B Service Running!</h1>" | |
from pydantic import BaseModel | |
class Prompt(BaseModel): | |
prompt: Optional[str] = None | |
@app.post("/engines/count_tokens") | |
async def count_tokens(promptobj: Prompt): | |
text = str(promptobj.prompt) | |
tokens = tokenizer(text, return_tensors="pt") | |
ids = tokens.input_ids.to(runtime_gpu) | |
return ids.shape[1] | |
@app.post("/engines/completions") | |
async def read_completions( | |
prompt: Optional[str] = None, | |
max_tokens: Optional[int] = 16, | |
temperature: Optional[float] = 1.0, | |
top_p: Optional[float] = 1.0, | |
top_k: Optional[int] = 40, | |
presence_penalty: Optional[float] = 0.0001, | |
repetition_penalty: Optional[float] = 1.0000, | |
promptobj: Optional[Prompt] = None, | |
request: Request = None | |
): | |
global active_model, gptmodel, tokenizer | |
response = {} | |
response['params'] = dict(request.query_params) | |
print(response) | |
if not prompt: | |
prompt = promptobj.prompt | |
text = str(prompt) | |
text = text.replace("|", "\r\n") | |
prompt_len = len(text) | |
tokens = tokenizer(text, return_tensors="pt") | |
ids = tokens.input_ids.to(runtime_gpu) | |
max_length = max_tokens + ids.shape[1] | |
start = time.time() | |
gen_tokens = gptmodel.generate( | |
ids, | |
do_sample=True, | |
min_length=max_length, | |
max_length=max_length, | |
temperature=temperature, | |
use_cache=True, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
presence_penalty=presence_penalty, | |
no_repeat_ngram_size=6, | |
max_time=60, | |
num_beams=1, | |
eos_token_id=eos_token, | |
#length_penalty=0.3, | |
early_stopping=True, | |
output_hidden_states=True, | |
pad_token_id=0 | |
) | |
last_prompt = text | |
choices = [] | |
gen_text = '' | |
for i, out_seq in enumerate(gen_tokens): | |
choice = {} | |
choice['prompt'] = last_prompt | |
choice['text'] = tokenizer.decode(out_seq, skip_special_tokens=True) | |
choice['index'] = i | |
choices.append(choice) | |
print("GenText[{}]:{}".format(i, choice['text'])) | |
gen_text = gen_text + choice['text'] | |
last_prompt = text | |
fin = time.time() | |
elapsed = fin - start | |
cps = (len(gen_text) - prompt_len) / elapsed | |
print("elapsed:{} len:{} cps:{}".format(elapsed, len(gen_text), cps)) | |
response['choices'] = choices | |
return (response) | |
print(colored("Model startup complete! Starting web service....", "green")) | |
print(colored("Ready to Serve!", "green")) | |
uvicorn.run(app, host="0.0.0.0", port=SERVER_PORT) | |
print(colored("Toodlepip!", "green")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment