Created
April 15, 2024 02:12
-
-
Save mahm/f8f5a14838bbcf44f378290fb4d21999 to your computer and use it in GitHub Desktop.
More Agents Is All You Need
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import openai | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModel | |
# OpenAI APIキーを設定 | |
openai.api_key = "your_api_key" | |
def get_completion(prompt, model="gpt-3.5-turbo"): | |
messages = [{"role": "user", "content": prompt}] | |
response = openai.ChatCompletion.create( | |
model=model, | |
messages=messages, | |
temperature=0, | |
) | |
return response.choices[0].message["content"] | |
def cosine_similarity(a, b): | |
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) | |
def sampling_and_voting(prompt, num_samples, embedding_model="sentence-transformers/all-MiniLM-L6-v2"): | |
# サンプリングフェーズ | |
samples = [get_completion(prompt) for _ in range(num_samples)] | |
# 埋め込みモデルの読み込み | |
tokenizer = AutoTokenizer.from_pretrained(embedding_model) | |
model = AutoModel.from_pretrained(embedding_model) | |
# サンプルの埋め込み表現を計算 | |
embeddings = [] | |
for sample in samples: | |
inputs = tokenizer(sample, padding=True, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embeddings.append(outputs.last_hidden_state[0, 0].numpy()) | |
# 類似度の計算と投票 | |
similarity_scores = [0] * num_samples | |
for i in range(num_samples): | |
for j in range(num_samples): | |
if i != j: | |
similarity_scores[i] += cosine_similarity(embeddings[i], embeddings[j]) | |
most_similar_index = np.argmax(similarity_scores) | |
return samples[most_similar_index] | |
# 使用例 | |
prompt = "What is the capital of France?" | |
num_samples = 5 | |
answer = sampling_and_voting(prompt, num_samples) | |
print(f"Question: {prompt}") | |
print(f"Answer: {answer}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment