Created
April 24, 2023 08:33
-
-
Save cnh/1628fd64396372d12048a725f10178b4 to your computer and use it in GitHub Desktop.
text clustering using cohere's embed api
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 1. Loop through all the text files in the directory and store text contents in a giant list | |
import os, glob | |
path_files_texts = [] | |
path_to_dir_of_path_reports = "path/to/text/files" | |
# Loop through all the text files in the directory and store text contents in a giant list | |
for file in glob.glob(os.path.join(path_to_dir_of_path_reports, "*.txt")): | |
# Open each file and read its contents into a string | |
with open(file, "r" , encoding='cp1252') as f: | |
text_contents_of_file = f.read() | |
# Append the string to the list | |
path_files_texts.append(text_contents_of_file) | |
# 2. call the cohere embed api to generate embeddings | |
import cohere | |
apiKey = 'api_key' | |
co = cohere.Client(apiKey) | |
''' | |
embeds = co.embed(texts=path_files_texts, | |
model="small", | |
truncate="START").embeddings | |
''' | |
# 3. Plot the embeddings as a 2-d umap | |
#TO-DO : i am facing problems trying to install | |
#using pip3 install umap-learn | |
# the hn code has the following umap code | |
#reducer = umap.UMAP(n_neighbors=100) | |
#umap_embeds = reducer.fit_transform(embeds) | |
import umap | |
import pandas as pd | |
import altair as alt | |
reducer = umap.UMAP(n_neighbors=100) | |
umap_embeds = reducer.fit_transform(embeds) | |
df = pd.DataFrame(path_files_texts) | |
df['x'] = umap_embeds[:,0] | |
df['y'] = umap_embeds[:,1] | |
# Plot | |
chart = alt.Chart(df).mark_circle(size=60).encode( | |
x=#'x', | |
alt.X('x', | |
scale=alt.Scale(zero=False), | |
axis=alt.Axis(labels=False, ticks=False, domain=False) | |
), | |
y= | |
alt.Y('y', | |
scale=alt.Scale(zero=False), | |
axis=alt.Axis(labels=False, ticks=False, domain=False) | |
) | |
).configure(background="#FDF7F0" | |
).properties( | |
width=700, | |
height=400, | |
title='Pathology reports' | |
) | |
chart.interactive() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment