Last active
October 2, 2023 17:45
-
-
Save Delaunay/8c866a81cd696ca4cc01df26d6849764 to your computer and use it in GitHub Desktop.
Run a LLAMA2 inference server on slurm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import subprocess | |
import random | |
import openai | |
def get_slurm_job_by_name(name): | |
command = ["squeue", "-h", f"--name={name}", "--format=\"%A %j %T %P %U %k %N\""] | |
output = subprocess.check_output(command, text=True) | |
jobs = [] | |
for line in output.splitlines(): | |
job_id, job_name, status, partition, user, comment, nodes = line.split(' ') | |
data = dict() | |
if comment != "(null)": | |
items = comment.split('|') | |
for kv in items: | |
try: | |
k, v = kv.split('=', maxsplit=1) | |
data[k] = v | |
except: | |
pass | |
jobs.append({ | |
"job_id":job_id, | |
"job_name":job_name, | |
"status":status, | |
"partition":partition, | |
"user":user, | |
"comment": data, | |
"nodes": nodes | |
}) | |
return jobs | |
def find_suitable_inference_server(jobs, model): | |
selected = [] | |
def is_shared(job): | |
return job["comment"].get("shared", 'y') == 'y' | |
def is_running(job): | |
return job['status'] == "RUNNING" | |
def has_model(job, model): | |
if model is None: | |
return True | |
return job['comment']['model'] == model | |
def select(job): | |
selected.append({ | |
"model": job['comment']["model"], | |
"host": job["comment"]["host"], | |
"port": job["comment"]["port"], | |
}) | |
for job in jobs: | |
if is_shared(job) and is_running(job): | |
if has_model(job, model): | |
select(job) | |
return selected | |
def get_inference_server(model=None): | |
jobs = get_slurm_job_by_name('inference_server_SHARED.sh') | |
servers = find_suitable_inference_server(jobs, model) | |
try: | |
return random.choice(servers) | |
except IndexError: | |
return None | |
def get_endpoint(model): | |
server = get_inference_server(model) | |
return f"http://{server['host']}:{server['port']}/v1" | |
model = "/network/weights//llama.var/llama2//Llama-2-7b-chat-hf" | |
# Modify OpenAI's API key and API base to use vLLM's API server. | |
openai.api_key = "EMPTY" | |
openai.api_base = get_endpoint(model) | |
completion = openai.Completion.create( | |
model=model, | |
prompt="San Francisco is a" | |
) | |
print("Completion result:", completion) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment