Skip to content

Instantly share code, notes, and snippets.

@austin362667
Last active November 5, 2024 14:45
Show Gist options
  • Save austin362667/762acb712abeba8d425329b4bf0da55b to your computer and use it in GitHub Desktop.
Save austin362667/762acb712abeba8d425329b4bf0da55b to your computer and use it in GitHub Desktop.
LLM Sampling with UCB
import numpy as np
import torch
from typing import List, Dict, Optional
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM
class HuggingFaceLLM:
def __init__(self, model_name: str = "HuggingFaceTB/SmolLM2-135M-Instruct", device: str = 'mps'):
self.tokenizer = AutoTokenizer.from_pretrained(model_name, device_map=device)
self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
self.device = device
def get_logits(self, context_tokens: List[int]) -> np.ndarray:
"""Get logits for next token prediction"""
input_ids = torch.tensor([context_tokens]).to(self.device)
with torch.no_grad():
outputs = self.model(input_ids)
logits = outputs.logits[:, -1, :].cpu().numpy()[0]
return logits
def sample_token(self, logits: np.ndarray, temperature: float = 1.0) -> int:
"""Sample next token from logits"""
if temperature == 0:
return int(np.argmax(logits))
probs = np.exp(logits / temperature)
probs = probs / np.sum(probs)
return int(np.random.choice(len(probs), p=probs))
@dataclass
class MCTSNode:
token_id: int
parent: Optional['MCTSNode']
children: Dict[int, 'MCTSNode']
visits: int = 0
total_value: float = 0.0
prior_probability: float = 0.0
@property
def value(self) -> float:
return self.total_value / (self.visits + 1e-8)
def ucb_score(self, exploration_constant: float = 1.0) -> float:
if self.parent is None:
return 0.0
exploitation = self.value
exploration = exploration_constant * self.prior_probability * \
np.sqrt(self.parent.visits) / (1 + self.visits)
return exploitation + exploration
class MCTSLLM:
def __init__(self, model_name: str = "HuggingFaceTB/SmolLM2-135M-Instruct",
device: str = 'mps', exploration_constant: float = 1.0):
self.llm = HuggingFaceLLM(model_name, device)
self.exploration_constant = exploration_constant
self.root = MCTSNode(token_id=-1, parent=None, children={})
def select(self, node: MCTSNode) -> List[MCTSNode]:
path = []
while node.children:
node = max(node.children.values(),
key=lambda n: n.ucb_score(self.exploration_constant))
path.append(node)
return path
def expand(self, leaf: MCTSNode, context_tokens: List[int]) -> MCTSNode:
logits = self.llm.get_logits(context_tokens)
probs = np.exp(logits) / np.sum(np.exp(logits))
new_token = self.llm.sample_token(logits)
if new_token not in leaf.children:
leaf.children[new_token] = MCTSNode(
token_id=new_token,
parent=leaf,
children={},
prior_probability=probs[new_token]
)
return leaf.children[new_token]
def simulate(self, node: MCTSNode, context_tokens: List[int]) -> float:
current_tokens = context_tokens + [node.token_id]
value = 0.0
for _ in range(3):
logits = self.llm.get_logits(current_tokens)
next_token = self.llm.sample_token(logits)
current_tokens.append(next_token)
probs = np.exp(logits) / np.sum(np.exp(logits))
value += np.max(probs)
return value / 3.0
def backpropagate(self, path: List[MCTSNode], value: float):
for node in reversed(path):
node.visits += 1
node.total_value += value
def search(self, context_tokens: List[int], n_iterations: int = 10) -> int:
for _ in range(n_iterations):
path = self.select(self.root)
leaf = path[-1] if path else self.root
child = self.expand(leaf, context_tokens)
path.append(child)
value = self.simulate(child, context_tokens)
self.backpropagate(path, value)
best_child = max(self.root.children.values(), key=lambda n: n.visits)
return best_child.token_id
def generate_text(self, prompt: str, max_tokens: int = 100) -> str:
"""Generate text from a prompt using MCTS-guided sampling"""
initial_tokens = self.llm.tokenizer.encode(prompt)
current_tokens = initial_tokens.copy()
for _ in range(max_tokens):
next_token = self.search(current_tokens)
current_tokens.append(next_token)
if next_token == self.llm.tokenizer.eos_token_id:
break
self.root = MCTSNode(token_id=-1, parent=None, children={})
return self.llm.tokenizer.decode(current_tokens)
def generate_text_streaming(self, prompt: str, max_tokens: int = 100):
"""Generate text from a prompt using MCTS-guided sampling, streaming each token as it's generated."""
initial_tokens = self.llm.tokenizer.encode(prompt)
current_tokens = initial_tokens.copy()
for _ in range(max_tokens):
next_token = self.search(current_tokens)
token_text = self.llm.tokenizer.decode(next_token)
yield token_text
current_tokens.append(next_token)
if next_token == self.llm.tokenizer.eos_token_id:
break
self.root = MCTSNode(token_id=-1, parent=None, children={})
# Optional: Final yield for full generated text if desired
yield self.llm.tokenizer.decode(current_tokens)
def print_trie(node: MCTSNode, tokenizer, depth: int = 0):
"""Recursively print the structure of the MCTS tree (trie)"""
indent = " " * depth
token_text = tokenizer.decode([node.token_id]) if node.token_id != -1 else "<root>"
print(f"{indent}Token: {token_text} | ID: {node.token_id} | Visits: {node.visits} | Value: {node.value:.4f}")
for child in node.children.values():
print_trie(child, tokenizer, depth + 1)
@austin362667
Copy link
Author

austin362667 commented Nov 4, 2024

Usage

Create MCTS_LLM

mcts = MCTSLLM(
    model_name="HuggingFaceTB/SmolLM2-135M-Instruct",
    device='mps',
    exploration_constant=1.0
)
print_trie(mcts.root, mcts.llm.tokenizer) 
Token: <root> | ID: -1 | Visits: 0 | Value: 0.0000

Sampling *10

prompt = "Paris is the capital of"
for e in range(10):
    print(mcts.generate_text(prompt, max_tokens=10))
Paris is the capital of France but famous for its gastronomy, fashion, and
Paris is the capital of the City of Paris, France. Paris is known
Paris is the capital of a French ethnological colony. St Augustine has estates
Paris is the capital of France. This city also serves another important location -
Paris is the capital of France, and 4 crunches tower The
Paris is the capital of France, and its economic and cultural influence is significant
Paris is the capital of Li-Jiaoing, a province in present
Paris is the capital of 
	Poland - Szczecin

Paris is the capital of Maine-Yachts and Cannery Bays
Paris is the capital of the state and also a touristic place from which

@austin362667
Copy link
Author

Usage

Create MCTS_LLM

mcts = MCTSLLM(
    model_name="HuggingFaceTB/SmolLM2-135M-Instruct",
    device='mps',
    exploration_constant=0.01
)
print_trie(mcts.root, mcts.llm.tokenizer) 
Token: <root> | ID: -1 | Visits: 0 | Value: 0.0000

Sampling *10

prompt = "Paris is the capital of"
for e in range(10):
    print(mcts.generate_text(prompt, max_tokens=10))
Paris is the capital of France, known for its historic sites, art exhibitions
Paris is the capital of the French Republic, and it is a major center
Paris is the capital of France. It has a population of 62
Paris is the capital of the Commonwealth of Independent States in the Baltic region of
Paris is the capital of the Dominican Republic, known for its rich history,
Paris is the capital of Italy, and it is the largest city in southern
Paris is the capital of the province of Zinsburn and is the largest
Paris is the capital of France."

Should we clarify the sentence about
Paris is the capital of France, located in the south and center of the
Paris is the capital of 14 communes. Some of the cities

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment