Skip to content

Instantly share code, notes, and snippets.

@shrimo
Created March 1, 2025 23:37
Show Gist options
  • Save shrimo/f5e0503e16d57cbbc308e0c8270ab9e2 to your computer and use it in GitHub Desktop.
Save shrimo/f5e0503e16d57cbbc308e0c8270ab9e2 to your computer and use it in GitHub Desktop.
TransformerChatbot
import torch
import torch.nn as nn
import torch.optim as optim
import os
# Tokenization
def tokenize(sentence):
return sentence.lower().split()
# Vocabulary
training_data = [
("hello", "hi"),
("how are you", "i am fine"),
("what is your name", "i am a chatbot"),
("bye", "goodbye"),
]
all_words = list(set(word for pair in training_data for sentence in pair for word in tokenize(sentence)))
word_to_idx = {word: idx for idx, word in enumerate(all_words)}
idx_to_word = {idx: word for word, idx in word_to_idx.items()}
# Model Hyperparameters
d_model = 16
num_heads = 2
num_layers = 2
hidden_dim = 32
vocab_size = len(word_to_idx)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Transformer Model
class TransformerChatbot(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, num_layers, hidden_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.transformer = nn.Transformer(
d_model=d_model, nhead=num_heads, num_encoder_layers=num_layers,
num_decoder_layers=num_layers, dim_feedforward=hidden_dim, batch_first=True
)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, src, tgt):
src_emb = self.embedding(src)
tgt_emb = self.embedding(tgt)
transformer_out = self.transformer(src_emb, tgt_emb)
return self.fc(transformer_out)
# Create model
model = TransformerChatbot(vocab_size, d_model, num_heads, num_layers, hidden_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
MODEL_PATH = "chatbot_model.pth"
# Load model if exists
if os.path.exists(MODEL_PATH):
checkpoint = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state"])
print("✅ Model loaded from checkpoint!")
else:
print("⚡ Training new model...")
# Prepare Data
def prepare_batch(input_text, target_text):
"""Convert text to tensors."""
input_indices = torch.tensor([[word_to_idx[word] for word in tokenize(input_text)]]).to(device)
target_indices = torch.tensor([[word_to_idx[word] for word in tokenize(target_text)]]).to(device)
return input_indices, target_indices
# Training Loop
epochs = 500
for epoch in range(epochs):
total_loss = 0
for input_text, target_text in training_data:
src, tgt = prepare_batch(input_text, target_text)
optimizer.zero_grad()
output = model(src, tgt)
loss = criterion(output.view(-1, vocab_size), tgt.view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
if epoch % 50 == 0:
print(f"Epoch {epoch}, Loss: {total_loss:.4f}")
# Save the trained model
torch.save({"model_state": model.state_dict(), "optimizer_state": optimizer.state_dict()}, MODEL_PATH)
print("✅ Model saved!")
# Inference
def generate_response(input_text):
model.eval() # Set the model to evaluation mode
with torch.no_grad(): # Disable gradient calculation during inference
# Convert the input text to a tensor of indices
src = torch.tensor([[word_to_idx.get(word, 0) for word in tokenize(input_text)]]).to(device)
tgt = torch.zeros_like(src) # Placeholder for the target sequence
output = model(src, tgt) # Pass the input through the model
# Get the index with the highest probability
predicted_idx = output.argmax(dim=-1).squeeze()
# If predicted_idx is a tensor, convert it to a list
if isinstance(predicted_idx, torch.Tensor): # if it's a tensor
predicted_idx = predicted_idx.tolist()
# If it's still an int, convert it into a list
if isinstance(predicted_idx, int): # if it's an int
predicted_idx = [predicted_idx]
# Join the predicted indices into words and return the response
return " ".join(idx_to_word[idx] for idx in predicted_idx if idx in idx_to_word)
# Chatbot Loop
print("Chatbot: Hello! Type something to chat. Type 'exit' to stop.")
while True:
user_input = input("You: ").strip().lower()
if user_input == "exit":
print("Chatbot: Goodbye!")
break
response = generate_response(user_input)
print(f"Chatbot: {response}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment