Created
March 1, 2025 23:37
-
-
Save shrimo/f5e0503e16d57cbbc308e0c8270ab9e2 to your computer and use it in GitHub Desktop.
TransformerChatbot
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import os | |
# Tokenization | |
def tokenize(sentence): | |
return sentence.lower().split() | |
# Vocabulary | |
training_data = [ | |
("hello", "hi"), | |
("how are you", "i am fine"), | |
("what is your name", "i am a chatbot"), | |
("bye", "goodbye"), | |
] | |
all_words = list(set(word for pair in training_data for sentence in pair for word in tokenize(sentence))) | |
word_to_idx = {word: idx for idx, word in enumerate(all_words)} | |
idx_to_word = {idx: word for word, idx in word_to_idx.items()} | |
# Model Hyperparameters | |
d_model = 16 | |
num_heads = 2 | |
num_layers = 2 | |
hidden_dim = 32 | |
vocab_size = len(word_to_idx) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Transformer Model | |
class TransformerChatbot(nn.Module): | |
def __init__(self, vocab_size, d_model, num_heads, num_layers, hidden_dim): | |
super().__init__() | |
self.embedding = nn.Embedding(vocab_size, d_model) | |
self.transformer = nn.Transformer( | |
d_model=d_model, nhead=num_heads, num_encoder_layers=num_layers, | |
num_decoder_layers=num_layers, dim_feedforward=hidden_dim, batch_first=True | |
) | |
self.fc = nn.Linear(d_model, vocab_size) | |
def forward(self, src, tgt): | |
src_emb = self.embedding(src) | |
tgt_emb = self.embedding(tgt) | |
transformer_out = self.transformer(src_emb, tgt_emb) | |
return self.fc(transformer_out) | |
# Create model | |
model = TransformerChatbot(vocab_size, d_model, num_heads, num_layers, hidden_dim).to(device) | |
optimizer = optim.Adam(model.parameters(), lr=0.001) | |
criterion = nn.CrossEntropyLoss() | |
MODEL_PATH = "chatbot_model.pth" | |
# Load model if exists | |
if os.path.exists(MODEL_PATH): | |
checkpoint = torch.load(MODEL_PATH, map_location=device) | |
model.load_state_dict(checkpoint["model_state"]) | |
optimizer.load_state_dict(checkpoint["optimizer_state"]) | |
print("✅ Model loaded from checkpoint!") | |
else: | |
print("⚡ Training new model...") | |
# Prepare Data | |
def prepare_batch(input_text, target_text): | |
"""Convert text to tensors.""" | |
input_indices = torch.tensor([[word_to_idx[word] for word in tokenize(input_text)]]).to(device) | |
target_indices = torch.tensor([[word_to_idx[word] for word in tokenize(target_text)]]).to(device) | |
return input_indices, target_indices | |
# Training Loop | |
epochs = 500 | |
for epoch in range(epochs): | |
total_loss = 0 | |
for input_text, target_text in training_data: | |
src, tgt = prepare_batch(input_text, target_text) | |
optimizer.zero_grad() | |
output = model(src, tgt) | |
loss = criterion(output.view(-1, vocab_size), tgt.view(-1)) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
if epoch % 50 == 0: | |
print(f"Epoch {epoch}, Loss: {total_loss:.4f}") | |
# Save the trained model | |
torch.save({"model_state": model.state_dict(), "optimizer_state": optimizer.state_dict()}, MODEL_PATH) | |
print("✅ Model saved!") | |
# Inference | |
def generate_response(input_text): | |
model.eval() # Set the model to evaluation mode | |
with torch.no_grad(): # Disable gradient calculation during inference | |
# Convert the input text to a tensor of indices | |
src = torch.tensor([[word_to_idx.get(word, 0) for word in tokenize(input_text)]]).to(device) | |
tgt = torch.zeros_like(src) # Placeholder for the target sequence | |
output = model(src, tgt) # Pass the input through the model | |
# Get the index with the highest probability | |
predicted_idx = output.argmax(dim=-1).squeeze() | |
# If predicted_idx is a tensor, convert it to a list | |
if isinstance(predicted_idx, torch.Tensor): # if it's a tensor | |
predicted_idx = predicted_idx.tolist() | |
# If it's still an int, convert it into a list | |
if isinstance(predicted_idx, int): # if it's an int | |
predicted_idx = [predicted_idx] | |
# Join the predicted indices into words and return the response | |
return " ".join(idx_to_word[idx] for idx in predicted_idx if idx in idx_to_word) | |
# Chatbot Loop | |
print("Chatbot: Hello! Type something to chat. Type 'exit' to stop.") | |
while True: | |
user_input = input("You: ").strip().lower() | |
if user_input == "exit": | |
print("Chatbot: Goodbye!") | |
break | |
response = generate_response(user_input) | |
print(f"Chatbot: {response}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment