Skip to content

Instantly share code, notes, and snippets.

@dmarx
Last active June 12, 2024 16:41
import numpy as np
from openai import OpenAI
import plotly
import plotly.graph_objs as go
import umap
url = "http://localhost:80"
client = OpenAI(
# This is the default and can be omitted
#api_key=os.environ.get("OPENAI_API_KEY"),
api_key="123",
base_url=url + "/v1"
)
def get_model_name():
response = client.models.list()
return response.to_dict()['data'][0]['id']
MODEL_NAME = get_model_name()
def generate(prompt,
model=MODEL_NAME,
max_tokens=1024,
temperature=0.1,
**kargs
):
if not model:
model = get_model_name()
completion = client.completions.create(
prompt=prompt,
model=model,
max_tokens=max_tokens,
temperature=temperature,
**kargs
)
response = completion.choices[0].text
response = response.strip()
return response
def embed(content,
model=MODEL_NAME,
**kargs
):
if not model:
model = get_model_name()
response = client.embeddings.create(
input=content,
model=model,
encoding_format='float',
**kargs
)
return response
for i, a in enumerate(articles):
#a['vect'] = model.encode([a['content']])
a['content']
if len(content) > 32768:
content = content[:32768] # would be nice if we could get the model's character limit from the API
a['vect'] = embed(content).data[0].embedding
if (i % 50) == 0:
print(f"{i}\t{a['metadata']['inferred_article_title']}")
X = np.array([np.array(a['vect']).ravel() for a in articles])
trans = umap.UMAP(n_neighbors=10, metric='cosine', n_components=3, random_state=42).fit(X)
xs = np.array([a['umap'][:,0] for a in articles]).ravel()
ys = np.array([a['umap'][:,1] for a in articles]).ravel()
zs = np.array([a['umap'][:,2] for a in articles]).ravel()
ts = [a['metadata'].get('inferred_article_title', '') for a in articles]
cs = [a['metadata']['create_time'] for a in articles]
scattered = go.Scatter3d(
x=xs,
y=ys,
z=zs,
text=ts,
hoverinfo='text',
marker={'size':2, 'color':cs, 'colorscale':'Spectral'},
line={'width':.5, 'color':cs, 'colorscale':'Spectral'},
)
fig = go.Figure(data=scattered)
fig.update_layout(showlegend=False, height=int(700),
scene=dict(
xaxis=dict(showbackground=False, visible=False),
yaxis=dict(showbackground=False, visible=False),
zaxis=dict(showbackground=False, visible=False),
))
fig.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment