Last active
May 23, 2025 07:09
-
-
Save forresty/a33de55adc5a6352e59c841bc19d3469 to your computer and use it in GitHub Desktop.
arith.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Update 7: triple head loss, btw Claude 4 is vibe | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import math | |
import random | |
import copy | |
import json | |
import os | |
from typing import Optional, Tuple, Dict, List | |
from dataclasses import dataclass, asdict | |
from collections import deque, defaultdict | |
from tqdm import tqdm | |
import datetime | |
import re | |
# Handle optional dependencies gracefully | |
try: | |
import wandb | |
HAS_WANDB = True | |
except ImportError: | |
HAS_WANDB = False | |
# Mock wandb | |
class MockWandb: | |
def init(self, **kwargs): | |
pass | |
def log(self, data): | |
pass | |
def save(self, path): | |
pass | |
wandb = MockWandb() | |
# ============= CoT Regression Test Cases from original code ============= | |
COT_REGRESSION_TEST_CASES_HUMAN_FRIENDLY = """ | |
# sanity check | |
1 + 1 | |
= 2 | |
# negative result | |
5 - 10 | |
= -5 | |
# parenthesized | |
(1 + 2) + 3 | |
= 1 + 2 + 3 | |
= 3 + 3 | |
= 6 | |
# double parentheses | |
((1 + 2) + 3) | |
= ( 1 + 2 ) + 3 | |
= 1 + 2 + 3 | |
= 3 + 3 | |
= 6 | |
# complex ops | |
((1168 - 9099) + 68) - -90 | |
= ( 1168 - 9099 ) + 68 - -90 | |
= 1168 - 9099 + 68 - -90 | |
= -7931 + 68 - -90 | |
= -7863 - -90 | |
= -7773 | |
# de-parenthesize whenever possible | |
29 + (-1055133 - -6651) | |
= 29 + -1055133 - -6651 | |
= 29 + -1048482 | |
= -1048453 | |
# important example | |
29 - (-1055133 - -6651) | |
= 29 - -1048482 | |
= 1048511 | |
# add negative number | |
10 + -5 - 2 | |
= 5 - 2 | |
= 3 | |
# single number | |
123 | |
= 123 | |
# single negative number | |
-42 | |
= -42 | |
# single parenthesized number | |
(7) | |
= 7 | |
# minus negative number | |
10 - -5 | |
= 15 | |
# ensure step by step de-parenthesization | |
((((1+1)))) | |
= ( ( ( 1 + 1 ) ) ) | |
= ( ( 1 + 1 ) ) | |
= ( 1 + 1 ) | |
= 1 + 1 | |
= 2 | |
# unary minus | |
-(1 + 1) | |
= -2 | |
""" | |
def parse_regression_test_cases(test_cases_str): | |
"""Parse the regression test cases into structured format""" | |
lines = test_cases_str.strip().split("\n") | |
test_cases = [] | |
i = 0 | |
current_name = "Unnamed Test" | |
while i < len(lines): | |
line = lines[i].strip() | |
# Skip empty lines | |
if not line: | |
i += 1 | |
continue | |
# Found a comment (test name) | |
if line.startswith("#"): | |
current_name = line[1:].strip() | |
i += 1 | |
continue | |
# Found an input expression | |
if not line.startswith("="): | |
input_expr = line | |
expected_output = [] | |
i += 1 | |
# Collect expected output lines | |
while ( | |
i < len(lines) | |
and lines[i].strip() | |
and not lines[i].strip().startswith("#") | |
): | |
if lines[i].strip().startswith("="): | |
expected_output.append(lines[i].strip()) | |
i += 1 | |
if expected_output: | |
# Join expected output | |
expected_cot = "\n".join(expected_output) | |
# Ensure it ends with <eos> | |
if not expected_cot.endswith("<eos>"): | |
expected_cot += "<eos>" | |
test_cases.append( | |
{ | |
"input": input_expr, | |
"expected": expected_cot, | |
"name": current_name, | |
} | |
) | |
return test_cases | |
# ============= AST Node Creation Functions ============= | |
def make_number_node(value): | |
return {"type": "number", "value": value} | |
def make_operation_node(left, op_char, right): | |
return { | |
"type": "operation", | |
"left": left, | |
"op": op_char, | |
"right": right, | |
} | |
def make_paren_node(child_node): | |
return {"type": "paren", "child": child_node} | |
# ============= ArithmeticTokenizer from original code ============= | |
class ArithmeticTokenizer: | |
def __init__(self): | |
self.digits = list("0123456789") | |
self.special_tokens = ["<bos>", "<eos>", "<pad>", "+", "-", "=", "\n", "(", ")"] | |
self.vocab = self.special_tokens + self.digits | |
self.token_to_id = {token: idx for idx, token in enumerate(self.vocab)} | |
self.id_to_token = {idx: token for idx, token in enumerate(self.vocab)} | |
self.ignore_tokens = [" "] | |
# Special token IDs | |
self.bos_token_id = self.token_to_id["<bos>"] | |
self.eos_token_id = self.token_to_id["<eos>"] | |
self.pad_token_id = self.token_to_id["<pad>"] | |
def detokenize_compound_numbers(self, coarse_tokens): | |
fine_tokens = [] | |
for token in coarse_tokens: | |
if token in self.token_to_id: | |
fine_tokens.append(token) | |
else: | |
for char_token in list(token): | |
if char_token in self.token_to_id: | |
fine_tokens.append(char_token) | |
return fine_tokens | |
def tokenize(self, text): | |
sorted_special_tokens = sorted(self.special_tokens, key=len, reverse=True) | |
tokenized_output = [] | |
i = 0 | |
while i < len(text): | |
if i < len(text) and text[i] in self.ignore_tokens: | |
i += 1 | |
continue | |
matched_special_token = False | |
for special_token in sorted_special_tokens: | |
if text.startswith(special_token, i): | |
if special_token == "-": | |
# Check if this is a unary minus for a number | |
is_unary_minus_for_num = False | |
if i + 1 < len(text) and text[i + 1] in self.digits: | |
if not tokenized_output or tokenized_output[-1] in [ | |
"+", | |
"-", | |
"(", | |
"=", | |
]: | |
is_unary_minus_for_num = True | |
if is_unary_minus_for_num: | |
pass # Let number parsing logic handle it | |
else: | |
tokenized_output.append(special_token) | |
i += len(special_token) | |
matched_special_token = True | |
break | |
else: | |
tokenized_output.append(special_token) | |
i += len(special_token) | |
matched_special_token = True | |
break | |
if matched_special_token: | |
continue | |
# Try to match a number (potentially signed, potentially multi-digit) | |
if text[i] in self.digits or ( | |
text[i] == "-" and i + 1 < len(text) and text[i + 1] in self.digits | |
): | |
is_negative_prefix = False | |
if text[i] == "-": | |
if not tokenized_output or tokenized_output[-1] in [ | |
"+", | |
"-", | |
"(", | |
"=", | |
]: | |
is_negative_prefix = True | |
current_num_str = "" | |
if is_negative_prefix: | |
current_num_str += text[i] | |
i += 1 | |
start_num_idx = i | |
while i < len(text) and text[i] in self.digits: | |
i += 1 | |
num_str = text[start_num_idx:i] | |
if current_num_str and not num_str: | |
i = start_num_idx - 1 | |
tokenized_output.append(current_num_str) | |
continue | |
elif num_str: | |
tokenized_output.append(current_num_str + num_str) | |
continue | |
# Fallback: treat as single character | |
if not matched_special_token: | |
if text[i] in self.digits: | |
tokenized_output.append(text[i]) | |
i += 1 | |
else: | |
tokenized_output.append(text[i]) | |
i += 1 | |
return tokenized_output | |
def convert_tokens_to_ids(self, tokens): | |
return [self.token_to_id[token] for token in tokens] | |
def convert_ids_to_tokens(self, ids): | |
return [self.id_to_token[int(id)] for id in ids] | |
def get_place_value_positions(self, tokens): | |
place_values = [-1] * len(tokens) | |
current_num_start_idx = -1 | |
for i, token in enumerate(tokens): | |
if token in self.digits: | |
if current_num_start_idx == -1: | |
current_num_start_idx = i | |
if i == len(tokens) - 1 or tokens[i + 1] not in self.digits: | |
num_len = i - current_num_start_idx + 1 | |
for k in range(num_len): | |
place_values[current_num_start_idx + k] = num_len - 1 - k | |
current_num_start_idx = -1 | |
return place_values | |
def get_level_ids(self, tokens, is_target_sequence=False): | |
"""Assigns a structural level ID to each token.""" | |
if not tokens: | |
return [] | |
level_ids = [-1] * len(tokens) | |
# Simplified level assignment | |
paren_depth = 0 | |
for i, token in enumerate(tokens): | |
if token == "(": | |
level_ids[i] = paren_depth | |
paren_depth += 1 | |
elif token == ")": | |
paren_depth -= 1 | |
level_ids[i] = paren_depth | |
elif ( | |
token in ["+", "-"] | |
and i > 0 | |
and tokens[i - 1] not in ["+", "-", "(", "="] | |
): | |
level_ids[i] = paren_depth + 10 # Operations get higher level | |
else: | |
level_ids[i] = paren_depth + 20 # Numbers/other get even higher | |
return level_ids | |
# ============= Enhanced Parser from Diffusion Model ============= | |
class _TokenStream: | |
def __init__(self, tokens): | |
self.tokens = tokens | |
self.pos = 0 | |
def peek(self): | |
if self.pos < len(self.tokens): | |
return self.tokens[self.pos] | |
return None | |
def next(self): | |
token = self.peek() | |
if token is not None: | |
self.pos += 1 | |
return token | |
def is_eof(self): | |
return self.pos >= len(self.tokens) | |
def expect(self, expected_token): | |
token = self.next() | |
if token != expected_token: | |
context_start = max(0, self.pos - 5) | |
context_end = min(len(self.tokens), self.pos + 4) | |
context = ( | |
self.tokens[context_start : self.pos - 1] | |
+ [f">>'{token}'<<"] | |
+ self.tokens[self.pos : context_end] | |
) | |
raise ValueError( | |
f"Expected token '{expected_token}' but got '{token}' at pos {self.pos-1}. Context: {' '.join(context)}" | |
) | |
return token | |
def _parse_factor_expr(token_stream, tokenizer): | |
"""Parse factor expressions including unary minus""" | |
token = token_stream.peek() | |
if token == "-": | |
token_stream.next() # Consume '-' | |
operand_ast = _parse_factor_expr(token_stream, tokenizer) | |
if operand_ast["type"] == "number": | |
return make_number_node(-operand_ast["value"]) | |
elif ( | |
operand_ast["type"] == "paren" and operand_ast["child"]["type"] == "number" | |
): | |
return make_number_node(-operand_ast["child"]["value"]) | |
else: | |
return make_operation_node(make_number_node(0), "-", operand_ast) | |
else: | |
return _parse_atom_expr(token_stream, tokenizer) | |
def _parse_atom_expr(token_stream, tokenizer): | |
"""Parse atomic expressions (numbers or parenthesized expressions)""" | |
token = token_stream.peek() | |
if token is None: | |
raise ValueError("Unexpected EOF while parsing atom") | |
if token == "(": | |
token_stream.next() # consume '(' | |
inner_expr_ast = _parse_add_sub_expr(token_stream, tokenizer) | |
token_stream.expect(")") # consume ')' | |
return make_paren_node(inner_expr_ast) | |
# Check if the token is a number | |
is_number = False | |
if token[0] == "-": | |
if len(token) > 1 and all(c in tokenizer.digits for c in token[1:]): | |
is_number = True | |
elif all(c in tokenizer.digits for c in token): | |
is_number = True | |
if is_number: | |
return make_number_node(int(token_stream.next())) | |
else: | |
raise ValueError( | |
f"Unexpected token '{token}' at pos {token_stream.pos} when expecting an atom" | |
) | |
def _parse_add_sub_expr(token_stream, tokenizer): | |
"""Parse addition and subtraction expressions (left-associative)""" | |
lhs = _parse_factor_expr(token_stream, tokenizer) | |
while not token_stream.is_eof() and token_stream.peek() in ["+", "-"]: | |
operator = token_stream.next() | |
next_peek = token_stream.peek() | |
if next_peek is None: | |
raise ValueError(f"Unexpected EOF after operator '{operator}'") | |
rhs = _parse_factor_expr(token_stream, tokenizer) | |
lhs = make_operation_node(lhs, operator, rhs) | |
return lhs | |
def _parse_expression_string_to_ast(expression_string): | |
"""Parse an arithmetic expression string into an AST""" | |
tokenizer = ArithmeticTokenizer() | |
raw_tokens = tokenizer.tokenize(expression_string) | |
# Filter out newline tokens | |
tokens = [t for t in raw_tokens if t != "\n"] | |
if not tokens: | |
raise ValueError("Cannot parse an empty expression string.") | |
token_stream = _TokenStream(tokens) | |
ast_root = _parse_add_sub_expr(token_stream, tokenizer) | |
if not token_stream.is_eof(): | |
remaining_tokens = token_stream.tokens[token_stream.pos :] | |
raise ValueError( | |
f"Unexpected tokens remaining after parsing: {remaining_tokens}" | |
) | |
return ast_root | |
class ArithmeticParser: | |
"""Wrapper for the working parser implementation""" | |
def __init__(self, tokenizer): | |
self.tokenizer = tokenizer | |
def parse(self, expression): | |
"""Parse arithmetic expression into AST""" | |
# Clean expression | |
expression = expression.strip() | |
if expression.endswith("\n"): | |
expression = expression[:-1] | |
try: | |
return _parse_expression_string_to_ast(expression) | |
except Exception as e: | |
# Fallback to simple number | |
try: | |
value = int(expression) | |
return {"type": "number", "value": value} | |
except: | |
raise ValueError(f"Could not parse expression: {expression}") | |
# ============= Enhanced CoT Generation from Diffusion Model ============= | |
def generate_enhanced_cot_from_tree(root_node): | |
"""Generate CoT from AST with comprehensive rule handling""" | |
cot_lines = [] | |
current_ast = copy.deepcopy(root_node) | |
# Helper for ground truth numerical evaluation | |
def evaluate_ast_numerically(node_eval_numeric): | |
if node_eval_numeric["type"] == "number": | |
return node_eval_numeric["value"] | |
if node_eval_numeric["type"] == "paren": | |
return evaluate_ast_numerically(node_eval_numeric["child"]) | |
if node_eval_numeric["type"] == "operation": | |
left_val_num = evaluate_ast_numerically(node_eval_numeric["left"]) | |
right_val_num = evaluate_ast_numerically(node_eval_numeric["right"]) | |
op_char_eval_num = node_eval_numeric["op"] | |
if op_char_eval_num == "+": | |
return left_val_num + right_val_num | |
if op_char_eval_num == "-": | |
return left_val_num - right_val_num | |
raise ValueError(f"Unknown node type or op: {node_eval_numeric}") | |
# Helper to flatten AST to string parts | |
def flatten(node): | |
if node["type"] == "number": | |
return [str(node["value"])] | |
if node["type"] == "operation": | |
left_parts = flatten(node["left"]) | |
right_parts = flatten(node["right"]) | |
return left_parts + [node["op"]] + right_parts | |
if node["type"] == "paren": | |
return ["("] + flatten(node["child"]) + [")"] | |
return [] | |
# Main CoT generation logic | |
if current_ast["type"] == "number": | |
cot_lines.append(f"= {current_ast['value']}<eos>") | |
return "".join(cot_lines), evaluate_ast_numerically(root_node) | |
# Main simplification loop | |
while current_ast["type"] != "number": | |
ast_state_before = copy.deepcopy(current_ast) | |
action_taken = False | |
# Rule 1: If current_ast is Paren(child), simplify to child | |
if current_ast["type"] == "paren": | |
current_ast = current_ast["child"] | |
action_taken = True | |
# Rule 2: If current_ast is OperationNode | |
elif current_ast["type"] == "operation": | |
# Rule 2a: Critical evaluation for A - Paren(Op(N,N)) | |
is_critical_subtraction = False | |
if ( | |
current_ast["op"] == "-" | |
and current_ast["right"]["type"] == "paren" | |
and current_ast["right"]["child"]["type"] == "operation" | |
and current_ast["right"]["child"]["left"]["type"] == "number" | |
and current_ast["right"]["child"]["right"]["type"] == "number" | |
): | |
paren_op_nn = current_ast["right"] | |
op_node = paren_op_nn["child"] | |
left_val_rhs = op_node["left"]["value"] | |
right_val_rhs = op_node["right"]["value"] | |
op_rhs = op_node["op"] | |
result_rhs = ( | |
left_val_rhs + right_val_rhs | |
if op_rhs == "+" | |
else left_val_rhs - right_val_rhs | |
) | |
# Check if it's the unary minus case: 0 - Paren(Op(N,N)) | |
if ( | |
current_ast["left"]["type"] == "number" | |
and current_ast["left"]["value"] == 0 | |
): | |
current_ast = make_number_node(-result_rhs) | |
action_taken = True | |
is_critical_subtraction = True | |
else: | |
current_ast["right"] = make_number_node(result_rhs) | |
action_taken = True | |
is_critical_subtraction = True | |
# Rule 2b: Find and process innermost Op(N,N) | |
if not is_critical_subtraction: | |
# Rule 2b.1: If current_ast itself is Op(N,N) | |
if ( | |
current_ast["type"] == "operation" | |
and current_ast["left"]["type"] == "number" | |
and current_ast["right"]["type"] == "number" | |
): | |
left_val = current_ast["left"]["value"] | |
right_val = current_ast["right"]["value"] | |
op = current_ast["op"] | |
result_val = ( | |
left_val + right_val if op == "+" else left_val - right_val | |
) | |
current_ast = make_number_node(result_val) | |
action_taken = True | |
else: | |
# Rule 2b.2: Strip immediate children's parentheses | |
if current_ast["left"]["type"] == "paren": | |
current_ast["left"] = current_ast["left"]["child"] | |
action_taken = True | |
elif current_ast["right"]["type"] == "paren": | |
current_ast["right"] = current_ast["right"]["child"] | |
action_taken = True | |
# Rule 2b.3: Recursive simplification if no immediate action | |
if not action_taken: | |
# Helper 1: Strip innermost Paren(Op(N,N)) | |
def strip_innermost_paren(node): | |
nonlocal action_taken | |
if action_taken: | |
return node | |
if node["type"] == "paren": | |
if ( | |
node["child"]["type"] == "operation" | |
and node["child"]["left"]["type"] == "number" | |
and node["child"]["right"]["type"] == "number" | |
): | |
action_taken = True | |
return node["child"] | |
else: | |
modified_child = strip_innermost_paren( | |
node["child"] | |
) | |
if action_taken: | |
node["child"] = modified_child | |
return node | |
elif node["type"] == "operation": | |
original_left = node["left"] | |
node["left"] = strip_innermost_paren(original_left) | |
if action_taken: | |
return node | |
original_right = node["right"] | |
node["right"] = strip_innermost_paren(original_right) | |
return node | |
return node | |
# Helper 2: Evaluate innermost Op(N,N) | |
def eval_innermost_opnn(node): | |
nonlocal action_taken | |
if action_taken: | |
return node | |
if node["type"] == "paren": | |
original_child = node["child"] | |
node["child"] = eval_innermost_opnn(original_child) | |
return node | |
elif node["type"] == "operation": | |
if ( | |
node["left"]["type"] == "number" | |
and node["right"]["type"] == "number" | |
): | |
left_val = node["left"]["value"] | |
right_val = node["right"]["value"] | |
op = node["op"] | |
result_val = ( | |
left_val + right_val | |
if op == "+" | |
else left_val - right_val | |
) | |
action_taken = True | |
return make_number_node(result_val) | |
else: | |
original_left = node["left"] | |
node["left"] = eval_innermost_opnn(original_left) | |
if action_taken: | |
return node | |
original_right = node["right"] | |
node["right"] = eval_innermost_opnn(original_right) | |
return node | |
return node | |
# Apply helpers | |
if current_ast["type"] == "operation": | |
current_ast = strip_innermost_paren(current_ast) | |
if not action_taken and current_ast["type"] == "operation": | |
current_ast = eval_innermost_opnn(current_ast) | |
# Generate output line if action was taken | |
if action_taken: | |
is_final = current_ast["type"] == "number" | |
marker = "<eos>" if is_final else "\n" | |
current_line = f"= {' '.join(flatten(current_ast))}{marker}" | |
if not cot_lines: | |
cot_lines.append(current_line) | |
else: | |
# Avoid duplicates | |
last_content = cot_lines[-1].replace("<eos>", "").strip() | |
current_content = ( | |
current_line.replace("<eos>", "").replace("\n", "").strip() | |
) | |
if last_content != current_content: | |
cot_lines.append(current_line) | |
elif is_final and not cot_lines[-1].endswith("<eos>"): | |
cot_lines[-1] = current_line | |
# Break if no action taken | |
if not action_taken: | |
if current_ast["type"] != "number": | |
final_val = evaluate_ast_numerically(root_node) | |
final_line = f"= {final_val}<eos>" | |
if not cot_lines or cot_lines[-1].strip() != final_line.strip(): | |
cot_lines.append(final_line) | |
current_ast = make_number_node(final_val) | |
break | |
# Final check | |
if not cot_lines: | |
final_val = evaluate_ast_numerically(root_node) | |
cot_lines.append(f"= {final_val}<eos>") | |
elif cot_lines and not cot_lines[-1].endswith("<eos>"): | |
final_val = evaluate_ast_numerically(root_node) | |
if cot_lines[-1].strip().replace("\n", "") == f"= {final_val}": | |
cot_lines[-1] = f"= {final_val}<eos>" | |
else: | |
cot_lines.append(f"= {final_val}<eos>") | |
return "".join(cot_lines), evaluate_ast_numerically(root_node) | |
def validate_tokenizer(tokenizer, verbose=True): | |
"""Validate tokenizer with test cases""" | |
test_inputs = [ | |
"123", | |
"-42", | |
"1 + 2", | |
"(1 + 2)", | |
"10 - -5", | |
"((1 + 2) + 3)", | |
"-123 + 456", | |
] | |
if verbose: | |
print("\n" + "=" * 60) | |
print("Validating Tokenizer") | |
print("=" * 60) | |
all_pass = True | |
for test_input in test_inputs: | |
tokens = tokenizer.tokenize(test_input) | |
fine_tokens = tokenizer.detokenize_compound_numbers(tokens) | |
ids = tokenizer.convert_tokens_to_ids(fine_tokens) | |
place_values = tokenizer.get_place_value_positions(fine_tokens) | |
level_ids = tokenizer.get_level_ids(fine_tokens) | |
reconstructed = tokenizer.convert_ids_to_tokens(ids) | |
if verbose: | |
print(f"\nInput: '{test_input}'") | |
print(f"Tokens: {fine_tokens}") | |
print(f"IDs: {ids}") | |
print(f"Place values: {place_values}") | |
print(f"Level IDs: {level_ids}") | |
# Check reconstruction | |
reconstructed_str = "".join(reconstructed) | |
original_str = test_input.replace(" ", "") | |
if reconstructed_str == original_str: | |
if verbose: | |
print("✓ Reconstruction successful") | |
else: | |
if verbose: | |
print( | |
f"✗ Reconstruction failed: '{reconstructed_str}' != '{original_str}'" | |
) | |
all_pass = False | |
if verbose: | |
print("\n" + "=" * 60 + "\n") | |
return all_pass | |
def validate_cot_generation(tokenizer, parser=None, verbose=True): | |
"""Validate CoT generation against regression test cases using actual parser""" | |
test_cases = parse_regression_test_cases(COT_REGRESSION_TEST_CASES_HUMAN_FRIENDLY) | |
if verbose: | |
print("\n" + "=" * 60) | |
print("Validating CoT Generation Against Regression Tests") | |
print("=" * 60) | |
passed = 0 | |
failed = 0 | |
for i, test_case in enumerate(test_cases): | |
input_expr = test_case["input"].strip() | |
expected = test_case["expected"] | |
try: | |
if parser is not None: | |
# Use the actual parser and CoT generation | |
ast = parser.parse(input_expr) | |
generated_cot, _ = generate_enhanced_cot_from_tree(ast) | |
# Compare | |
if generated_cot.strip() == expected.strip(): | |
passed += 1 | |
if verbose: | |
print(f"✓ Test {i+1}: {test_case['name']}") | |
else: | |
failed += 1 | |
if verbose: | |
print(f"✗ Test {i+1}: {test_case['name']}") | |
print(f" Expected: {expected}") | |
print(f" Generated: {generated_cot}") | |
else: | |
# Fallback: assume success for now if no parser provided | |
passed += 1 | |
if verbose: | |
print(f"? Test {i+1}: {test_case['name']} (no parser - skipped)") | |
except Exception as e: | |
failed += 1 | |
if verbose: | |
print(f"✗ Test {i+1}: {test_case['name']} - ERROR: {e}") | |
if verbose: | |
print( | |
f"\nResults: {passed}/{len(test_cases)} passed ({passed/len(test_cases)*100:.1f}%)" | |
) | |
print("=" * 60 + "\n") | |
return passed, failed, passed / len(test_cases) if test_cases else 0.0 | |
# ============= Embedding modules ============= | |
class SinusoidalPositionEncoding(nn.Module): | |
"""Standard sinusoidal encoding for positions""" | |
def __init__(self, d_model, max_len=5000): | |
super().__init__() | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len).unsqueeze(1).float() | |
div_term = torch.exp( | |
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) | |
) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
self.register_buffer("pe", pe.unsqueeze(0)) | |
def forward(self, seq_len): | |
return self.pe[:, :seq_len] | |
class ConcatProjectEmbeddings(nn.Module): | |
"""Concatenate embeddings then project - preserves all information""" | |
def __init__(self, vocab_size, d_model, max_place_value=10, max_level_id=30): | |
super().__init__() | |
self.d_per_emb = d_model // 4 | |
self.token_embedding = nn.Embedding(vocab_size, self.d_per_emb) | |
self.position_encoding = SinusoidalPositionEncoding(self.d_per_emb) | |
self.place_value_embedding = nn.Embedding( | |
max_place_value + 2, self.d_per_emb | |
) # +2 for -1 and padding | |
self.level_id_embedding = nn.Embedding(max_level_id + 2, self.d_per_emb) | |
self.projection = nn.Sequential( | |
nn.Linear(d_model, d_model), | |
nn.LayerNorm(d_model), | |
nn.GELU(), | |
nn.Dropout(0.1), | |
) | |
self._init_weights() | |
def _init_weights(self): | |
std = 1.0 / math.sqrt(self.d_per_emb) | |
nn.init.normal_(self.token_embedding.weight, mean=0, std=std) | |
nn.init.normal_(self.place_value_embedding.weight, mean=0, std=std) | |
nn.init.normal_(self.level_id_embedding.weight, mean=0, std=std) | |
nn.init.xavier_uniform_(self.projection[0].weight) | |
nn.init.zeros_(self.projection[0].bias) | |
def forward(self, tokens, place_values, level_ids): | |
batch_size, seq_len = tokens.shape | |
token_emb = self.token_embedding(tokens) | |
pos_emb = self.position_encoding(seq_len).expand(batch_size, -1, -1) | |
# For place values, adjust -1 to 0 and shift others up | |
place_values_adjusted = (place_values + 1).clamp( | |
0, self.place_value_embedding.num_embeddings - 1 | |
) | |
pv_emb = self.place_value_embedding(place_values_adjusted) | |
level_adjusted = (level_ids + 1).clamp( | |
0, self.level_id_embedding.num_embeddings - 1 | |
) | |
level_emb = self.level_id_embedding(level_adjusted) | |
combined = torch.cat([token_emb, pos_emb, pv_emb, level_emb], dim=-1) | |
return self.projection(combined) | |
# ============= Curriculum components ============= | |
@dataclass | |
class DifficultyState: | |
"""Current difficulty parameters""" | |
max_digits: int = 1 | |
max_ops: int = 1 | |
parentheses_prob: float = 0.0 | |
negative_prob: float = 0.0 | |
nested_parentheses_depth: int = 0 | |
mixed_digit_sizes: bool = False | |
def complexity_score(self): | |
return ( | |
self.max_digits * 10 | |
+ self.max_ops * 5 | |
+ self.parentheses_prob * 3 | |
+ self.negative_prob * 2 | |
+ self.nested_parentheses_depth * 4 | |
+ (1 if self.mixed_digit_sizes else 0) | |
) | |
# ============= Main model ============= | |
class GrowableMultiTaskTransformer(nn.Module): | |
"""Transformer that can grow dynamically""" | |
def __init__( | |
self, | |
vocab_size, | |
d_model=128, | |
nhead=4, | |
num_layers=2, | |
dim_feedforward=512, | |
dropout=0.1, | |
max_place_value=10, | |
max_level_id=30, | |
): | |
super().__init__() | |
self.vocab_size = vocab_size | |
self.d_model = d_model | |
self.nhead = nhead | |
self.num_layers = num_layers | |
self.dim_feedforward = dim_feedforward | |
self.dropout = dropout | |
self.max_place_value = max_place_value | |
self.max_level_id = max_level_id | |
self.embeddings = ConcatProjectEmbeddings( | |
vocab_size, d_model, max_place_value, max_level_id | |
) | |
self._build_transformer() | |
self.token_head = nn.Linear(d_model, vocab_size) | |
self.place_value_head = nn.Linear(d_model, max_place_value + 2) | |
self.level_id_head = nn.Linear(d_model, max_level_id + 2) | |
self._init_heads() | |
def _build_transformer(self): | |
"""Build transformer layers""" | |
encoder_layer = nn.TransformerEncoderLayer( | |
d_model=self.d_model, | |
nhead=self.nhead, | |
dim_feedforward=self.dim_feedforward, | |
dropout=self.dropout, | |
activation="gelu", | |
batch_first=True, | |
norm_first=True, | |
) | |
self.transformer = nn.TransformerEncoder( | |
encoder_layer, num_layers=self.num_layers, enable_nested_tensor=False | |
) | |
def _init_heads(self): | |
for head in [self.token_head, self.place_value_head, self.level_id_head]: | |
nn.init.xavier_uniform_(head.weight) | |
nn.init.zeros_(head.bias) | |
def grow_model(self, new_config: Dict): | |
"""Grow model architecture""" | |
grown = False | |
if "num_layers" in new_config and new_config["num_layers"] > self.num_layers: | |
print(f"Growing layers: {self.num_layers} → {new_config['num_layers']}") | |
old_state = { | |
k: v.cpu().clone() for k, v in self.transformer.state_dict().items() | |
} | |
self.num_layers = new_config["num_layers"] | |
self._build_transformer() | |
new_state = self.transformer.state_dict() | |
for k, v in old_state.items(): | |
if k in new_state and new_state[k].shape == v.shape: | |
new_state[k] = v | |
self.transformer.load_state_dict(new_state) | |
grown = True | |
return grown | |
def create_causal_mask(self, seq_len, device): | |
"""Create causal attention mask""" | |
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1) | |
mask = mask.masked_fill(mask == 1, float("-inf")) | |
return mask | |
def forward( | |
self, | |
input_ids, | |
place_values, | |
level_ids, | |
attention_mask=None, | |
src_key_padding_mask=None, | |
): | |
"""Forward pass with proper masking""" | |
batch_size, seq_len = input_ids.shape | |
device = input_ids.device | |
embeddings = self.embeddings(input_ids, place_values, level_ids) | |
# Create causal mask (seq_len x seq_len) | |
causal_mask = self.create_causal_mask(seq_len, device) | |
# For transformer, we need src_key_padding_mask (batch_size x seq_len) | |
# True values are ignored, False values are attended to | |
if attention_mask is not None: | |
# attention_mask: 1 for real tokens, 0 for padding | |
# src_key_padding_mask: True for padding, False for real tokens | |
padding_mask = ~attention_mask.bool() # Invert: True for padding | |
else: | |
padding_mask = None | |
hidden_states = self.transformer( | |
embeddings, mask=causal_mask, src_key_padding_mask=padding_mask | |
) | |
token_logits = self.token_head(hidden_states) | |
pv_logits = self.place_value_head(hidden_states) | |
level_logits = self.level_id_head(hidden_states) | |
return token_logits, pv_logits, level_logits | |
def generate(self, input_text, tokenizer, max_length=100): | |
"""Generate CoT for given input""" | |
self.eval() | |
device = next(self.parameters()).device | |
# Tokenize input | |
tokens_prompt = tokenizer.tokenize(input_text) | |
fine_tokens_prompt = tokenizer.detokenize_compound_numbers(tokens_prompt) | |
input_ids_prompt = tokenizer.convert_tokens_to_ids(fine_tokens_prompt) | |
# Compute structure for prompt | |
place_values_prompt = tokenizer.get_place_value_positions(fine_tokens_prompt) | |
level_ids_prompt = tokenizer.get_level_ids(fine_tokens_prompt) | |
# Prepend BOS token and its features | |
current_input_ids_list = [tokenizer.bos_token_id] + input_ids_prompt | |
current_place_values_list = [-1] + place_values_prompt # BOS has neutral PV | |
current_level_ids_list = [-1] + level_ids_prompt # BOS has neutral Level ID | |
# Convert to tensors | |
generated_ids = torch.tensor([current_input_ids_list], device=device) | |
generated_pv = torch.tensor([current_place_values_list], device=device) | |
generated_level = torch.tensor([current_level_ids_list], device=device) | |
with torch.no_grad(): | |
for _ in range(max_length): | |
# Get predictions | |
token_logits, pv_logits, level_logits = self.forward( | |
generated_ids, generated_pv, generated_level | |
) | |
# Get next token | |
next_token = token_logits[0, -1].argmax() | |
next_pv = pv_logits[0, -1].argmax() - 1 | |
next_level = level_logits[0, -1].argmax() - 1 | |
# Append | |
generated_ids = torch.cat( | |
[generated_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=1 | |
) | |
generated_pv = torch.cat( | |
[generated_pv, next_pv.unsqueeze(0).unsqueeze(0)], dim=1 | |
) | |
generated_level = torch.cat( | |
[generated_level, next_level.unsqueeze(0).unsqueeze(0)], dim=1 | |
) | |
# Check for EOS | |
if next_token == tokenizer.eos_token_id: | |
break | |
# Convert to text, removing the initial BOS token from the output | |
output_token_ids_with_bos = generated_ids[0].tolist() | |
final_output_token_ids = [] | |
if output_token_ids_with_bos: | |
if output_token_ids_with_bos[0] == tokenizer.bos_token_id: | |
final_output_token_ids = output_token_ids_with_bos[1:] | |
else: | |
# Should not happen if BOS was prepended, but handle defensively | |
final_output_token_ids = output_token_ids_with_bos | |
generated_tokens = tokenizer.convert_ids_to_tokens(final_output_token_ids) | |
return "".join(generated_tokens) | |
# ============= Data preparation helper ============= | |
def prepare_batch_for_training( | |
batch: List[Dict], tokenizer: ArithmeticTokenizer, device="cuda" | |
): | |
"""Prepare batch for autoregressive training""" | |
all_input_ids = [] | |
all_target_ids = [] | |
all_input_pv = [] | |
all_target_pv = [] | |
all_input_level = [] | |
all_target_level = [] | |
attention_masks = [] | |
for item in batch: | |
# Combine input and target sequences | |
full_ids = item["input_ids"] + item["target_ids"] | |
full_pv = item["input_place_values"] + item["target_place_values"] | |
full_level = item["input_level_ids"] + item["target_level_ids"] | |
# For autoregressive training: predict position i from positions 0...i-1 | |
if len(full_ids) > 1: | |
input_ids = [tokenizer.bos_token_id] + full_ids[:-1] | |
target_ids = full_ids | |
input_pv = [-1] + full_pv[:-1] | |
target_pv = full_pv | |
input_level = [-1] + full_level[:-1] | |
target_level = full_level | |
mask = [1] * len(input_ids) | |
else: | |
input_ids = [tokenizer.bos_token_id] | |
target_ids = full_ids | |
input_pv = [-1] | |
target_pv = full_pv | |
input_level = [-1] | |
target_level = full_level | |
mask = [1] | |
all_input_ids.append(input_ids) | |
all_target_ids.append(target_ids) | |
all_input_pv.append(input_pv) | |
all_target_pv.append(target_pv) | |
all_input_level.append(input_level) | |
all_target_level.append(target_level) | |
attention_masks.append(mask) | |
# Find max length and pad | |
max_len = max(len(seq) for seq in all_input_ids) | |
for i in range(len(all_input_ids)): | |
pad_len = max_len - len(all_input_ids[i]) | |
all_input_ids[i] += [tokenizer.pad_token_id] * pad_len | |
all_target_ids[i] += [tokenizer.pad_token_id] * pad_len | |
all_input_pv[i] += [-1] * pad_len | |
all_target_pv[i] += [-1] * pad_len | |
all_input_level[i] += [-1] * pad_len | |
all_target_level[i] += [-1] * pad_len | |
attention_masks[i] += [0] * pad_len | |
return { | |
"input_ids": torch.tensor(all_input_ids, device=device), | |
"target_ids": torch.tensor(all_target_ids, device=device), | |
"input_place_values": torch.tensor(all_input_pv, device=device), | |
"target_place_values": torch.tensor(all_target_pv, device=device), | |
"input_level_ids": torch.tensor(all_input_level, device=device), | |
"target_level_ids": torch.tensor(all_target_level, device=device), | |
"attention_mask": torch.tensor( | |
attention_masks, device=device, dtype=torch.bool | |
), | |
} | |
# ============= Enhanced Data Generation ============= | |
def generate_enhanced_arithmetic_example(config: Dict, tokenizer, parser) -> Dict: | |
"""Enhanced example generation with better error handling""" | |
max_retries = 3 | |
for attempt in range(max_retries): | |
try: | |
# Generate base numbers | |
num_ops = random.randint(1, min(config.get("max_ops", 2), 3)) | |
numbers = [] | |
for i in range(num_ops + 1): | |
max_digits = config.get("max_digits", 2) | |
if config.get("mixed_digits", False) and random.random() < 0.3: | |
digits = random.randint(1, max(1, max_digits - 1)) | |
else: | |
digits = random.randint(1, max_digits) | |
if digits == 1: | |
num = random.randint(0, 9) | |
else: | |
num = random.randint( | |
10 ** (digits - 1), min(10**digits - 1, 9999) | |
) # Cap at 9999 | |
# Apply negatives more carefully | |
if ( | |
random.random() < config.get("negative_prob", 0) * 0.5 | |
): # Reduce negative probability | |
num = -num | |
numbers.append(num) | |
# Generate operations | |
ops = [] | |
for _ in range(num_ops): | |
ops.append(random.choice(["+", "-"])) | |
# Build simple expression string (avoid complex parentheses for now) | |
if len(numbers) == 2: | |
# Simple binary operation | |
expression = f"{numbers[0]} {ops[0]} {numbers[1]}" | |
else: | |
# Left-associative for multiple operations | |
expression = str(numbers[0]) | |
for i, op in enumerate(ops): | |
expression += f" {op} {numbers[i + 1]}" | |
# Add parentheses occasionally for simple cases | |
if ( | |
random.random() < config.get("parentheses_prob", 0) * 0.5 | |
and len(numbers) == 3 | |
and all(abs(n) < 100 for n in numbers) | |
): | |
# Simple case: (a op b) op c | |
expression = ( | |
f"({numbers[0]} {ops[0]} {numbers[1]}) {ops[1]} {numbers[2]}" | |
) | |
# Test parse and generate | |
ast = parser.parse(expression.strip()) | |
cot, result = generate_enhanced_cot_from_tree(ast) | |
# Validate result is reasonable | |
if abs(result) > 100000: # Skip very large results | |
continue | |
input_str = expression + "\n" | |
# Tokenize | |
input_tokens = tokenizer.tokenize(input_str) | |
input_fine = tokenizer.detokenize_compound_numbers(input_tokens) | |
input_ids = tokenizer.convert_tokens_to_ids(input_fine) | |
input_pv = tokenizer.get_place_value_positions(input_fine) | |
input_level = tokenizer.get_level_ids(input_fine) | |
target_tokens = tokenizer.tokenize(cot) | |
target_fine = tokenizer.detokenize_compound_numbers(target_tokens) | |
target_ids = tokenizer.convert_tokens_to_ids(target_fine) | |
target_pv = tokenizer.get_place_value_positions(target_fine) | |
target_level = tokenizer.get_level_ids(target_fine, is_target_sequence=True) | |
return { | |
"input": input_str, | |
"target": cot, | |
"result": result, | |
"input_ids": input_ids, | |
"target_ids": target_ids, | |
"input_place_values": input_pv, | |
"target_place_values": target_pv, | |
"input_level_ids": input_level, | |
"target_level_ids": target_level, | |
} | |
except Exception as e: | |
if attempt == max_retries - 1: | |
# Final fallback: very simple case | |
try: | |
a, b = random.randint(1, 9), random.randint(1, 9) | |
op = random.choice(["+", "-"]) | |
simple_expr = f"{a} {op} {b}" | |
ast = parser.parse(simple_expr) | |
cot, result = generate_enhanced_cot_from_tree(ast) | |
input_str = simple_expr + "\n" | |
input_tokens = tokenizer.tokenize(input_str) | |
input_fine = tokenizer.detokenize_compound_numbers(input_tokens) | |
return { | |
"input": input_str, | |
"target": cot, | |
"result": result, | |
"input_ids": tokenizer.convert_tokens_to_ids(input_fine), | |
"target_ids": tokenizer.convert_tokens_to_ids( | |
tokenizer.detokenize_compound_numbers( | |
tokenizer.tokenize(cot) | |
) | |
), | |
"input_place_values": tokenizer.get_place_value_positions( | |
input_fine | |
), | |
"target_place_values": tokenizer.get_place_value_positions( | |
tokenizer.detokenize_compound_numbers( | |
tokenizer.tokenize(cot) | |
) | |
), | |
"input_level_ids": tokenizer.get_level_ids(input_fine), | |
"target_level_ids": tokenizer.get_level_ids( | |
tokenizer.detokenize_compound_numbers( | |
tokenizer.tokenize(cot) | |
), | |
is_target_sequence=True, | |
), | |
} | |
except: | |
# Ultimate fallback with correct token IDs | |
tokenizer_temp = ArithmeticTokenizer() | |
simple_input = "1 + 1\n" | |
simple_target = "= 2<eos>" | |
input_tokens = tokenizer_temp.tokenize(simple_input) | |
input_fine = tokenizer_temp.detokenize_compound_numbers( | |
input_tokens | |
) | |
target_tokens = tokenizer_temp.tokenize(simple_target) | |
target_fine = tokenizer_temp.detokenize_compound_numbers( | |
target_tokens | |
) | |
return { | |
"input": simple_input, | |
"target": simple_target, | |
"result": 2, | |
"input_ids": tokenizer_temp.convert_tokens_to_ids(input_fine), | |
"target_ids": tokenizer_temp.convert_tokens_to_ids(target_fine), | |
"input_place_values": tokenizer_temp.get_place_value_positions( | |
input_fine | |
), | |
"target_place_values": tokenizer_temp.get_place_value_positions( | |
target_fine | |
), | |
"input_level_ids": tokenizer_temp.get_level_ids(input_fine), | |
"target_level_ids": tokenizer_temp.get_level_ids( | |
target_fine, is_target_sequence=True | |
), | |
} | |
continue | |
# Should never reach here, but just in case | |
return generate_enhanced_arithmetic_example( | |
{"max_digits": 1, "max_ops": 1, "parentheses_prob": 0, "negative_prob": 0}, | |
tokenizer, | |
parser, | |
) | |
# ============= Enhanced Learning Rate Scheduler ============= | |
class CurriculumAwareLRScheduler: | |
"""Learning rate scheduler that adapts to curriculum changes""" | |
def __init__(self, optimizer, base_lr=0.001, warmup_steps=1000): | |
self.optimizer = optimizer | |
self.base_lr = base_lr | |
self.warmup_steps = warmup_steps | |
self.step_count = 0 | |
self.last_difficulty_change = 0 | |
def step(self, difficulty_just_changed=False): | |
self.step_count += 1 | |
if difficulty_just_changed: | |
self.last_difficulty_change = self.step_count | |
# Warmup | |
if self.step_count <= self.warmup_steps: | |
lr = self.base_lr * (self.step_count / self.warmup_steps) | |
else: | |
# Cosine annealing with restarts after difficulty changes | |
steps_since_change = self.step_count - self.last_difficulty_change | |
cycle_length = 5000 | |
t = (steps_since_change % cycle_length) / cycle_length | |
lr = self.base_lr * 0.5 * (1 + math.cos(math.pi * t)) | |
# Boost LR temporarily after difficulty increase | |
if steps_since_change < 100: | |
lr *= 1.5 | |
for param_group in self.optimizer.param_groups: | |
param_group["lr"] = lr | |
return lr | |
# ============= Enhanced Trainer ============= | |
class EnhancedPerpetualTrainer: | |
"""Enhanced trainer optimized for RTX 5090 Blackwell architecture""" | |
def __init__(self, device="cuda", use_mixed_precision=True, use_fp8=False): | |
self.device = torch.device(device if torch.cuda.is_available() else "cpu") | |
self.use_mixed_precision = use_mixed_precision and torch.cuda.is_available() | |
self.use_fp8 = use_fp8 and torch.cuda.is_available() | |
# Detect GPU capabilities more accurately | |
self.is_blackwell = False | |
self.has_fp8 = False | |
self.tensor_core_gen = "Unknown" | |
if torch.cuda.is_available(): | |
gpu_name = torch.cuda.get_device_name() | |
compute_capability = torch.cuda.get_device_capability() | |
# Detect architecture and FP8 support | |
if "RTX 50" in gpu_name or "5090" in gpu_name: | |
self.is_blackwell = True | |
self.has_fp8 = True | |
self.tensor_core_gen = "5th Gen (Blackwell)" | |
print(f"🔥 Blackwell GPU detected: {gpu_name}") | |
print("💫 5th Gen Tensor Cores with enhanced FP8!") | |
elif "RTX 40" in gpu_name or "4090" in gpu_name or "4080" in gpu_name: | |
self.has_fp8 = True | |
self.tensor_core_gen = "4th Gen (Ada Lovelace)" | |
print(f"🚀 Ada Lovelace GPU detected: {gpu_name}") | |
print("⚡ 4th Gen Tensor Cores with FP8 support!") | |
elif "H100" in gpu_name: | |
self.has_fp8 = True | |
self.tensor_core_gen = "4th Gen (Hopper)" | |
print(f"🔥 Hopper GPU detected: {gpu_name}") | |
print("⚡ Native FP8 Tensor Cores!") | |
elif "RTX 30" in gpu_name or "3090" in gpu_name or "3080" in gpu_name: | |
self.tensor_core_gen = "3rd Gen (Ampere)" | |
print(f"🚀 Ampere GPU detected: {gpu_name}") | |
print("⚡ 3rd Gen Tensor Cores (FP16 optimized)") | |
elif "A100" in gpu_name: | |
self.has_fp8 = True # Software emulated | |
self.tensor_core_gen = "3rd Gen (Ampere)" | |
print(f"🚀 A100 detected: {gpu_name}") | |
print("⚡ 3rd Gen Tensor Cores with software FP8") | |
else: | |
print(f"🔥 GPU detected: {gpu_name}") | |
print( | |
f" Compute Capability: {compute_capability[0]}.{compute_capability[1]}" | |
) | |
if self.has_fp8: | |
print(f" ✅ FP8 Support Available ({self.tensor_core_gen})") | |
else: | |
print(f" 📊 FP16 Optimized ({self.tensor_core_gen})") | |
self.tokenizer = ArithmeticTokenizer() | |
self.parser = ArithmeticParser(self.tokenizer) | |
# Validate components | |
self._validate_components() | |
# Start small, grow smart - optimize for learning, not just GPU capacity | |
if self.is_blackwell: | |
# Blackwell: Start reasonable, room to grow HUGE | |
base_d_model = 256 # Start modest | |
base_layers = 3 # Start simple | |
max_batch_size = 128 # Can use big batches from start | |
self.growth_potential = "massive" # Can grow to 768d × 8 layers | |
print( | |
f"🚀 Blackwell: Starting smart (d_model={base_d_model}, layers={base_layers}) with massive growth potential" | |
) | |
elif self.has_fp8: | |
# FP8-capable: Start small, good growth room | |
base_d_model = 224 | |
base_layers = 3 | |
max_batch_size = 96 | |
self.growth_potential = "large" # Can grow to 512d × 6 layers | |
print( | |
f"🚀 FP8-capable: Starting smart (d_model={base_d_model}, layers={base_layers}) with large growth potential" | |
) | |
elif self.use_mixed_precision: | |
# Mixed precision: Conservative start | |
base_d_model = 192 | |
base_layers = 2 | |
max_batch_size = 64 | |
self.growth_potential = "medium" # Can grow to 384d × 4 layers | |
print( | |
f"🚀 Mixed precision: Starting conservative (d_model={base_d_model}, layers={base_layers})" | |
) | |
else: | |
# Standard: Very conservative | |
base_d_model = 128 | |
base_layers = 2 | |
max_batch_size = 32 | |
self.growth_potential = "limited" | |
print( | |
f"💻 Standard: Starting small (d_model={base_d_model}, layers={base_layers})" | |
) | |
self.max_batch_size = max_batch_size | |
self.model = GrowableMultiTaskTransformer( | |
vocab_size=len(self.tokenizer.vocab), | |
d_model=base_d_model, | |
nhead=max(4, base_d_model // 32), # Scale heads with model size | |
num_layers=base_layers, | |
dim_feedforward=base_d_model * 4, | |
).to(self.device) | |
# Store growth limits based on GPU capability | |
if self.is_blackwell: | |
self.max_d_model = 768 | |
self.max_layers = 8 | |
elif self.has_fp8: | |
self.max_d_model = 512 | |
self.max_layers = 6 | |
elif self.use_mixed_precision: | |
self.max_d_model = 384 | |
self.max_layers = 4 | |
else: | |
self.max_d_model = 256 | |
self.max_layers = 4 | |
# Enhanced curriculum - start MUCH simpler | |
self.difficulty = DifficultyState() | |
# Reset to absolute basics | |
self.difficulty.max_digits = 1 | |
self.difficulty.max_ops = 1 | |
self.difficulty.parentheses_prob = 0.0 # No parentheses at start | |
self.difficulty.negative_prob = 0.0 # No negatives at start | |
self.difficulty.mixed_digit_sizes = False | |
self.performance_history = deque(maxlen=200) | |
self.difficulty_history = [] | |
self.steps_since_change = 0 | |
self.global_step = 0 | |
self._last_eval_accuracy = 0.0 # Track evaluation accuracy | |
# Conservative optimizer settings - start stable | |
self.optimizer = torch.optim.AdamW( | |
self.model.parameters(), | |
lr=0.001, # Start with standard LR | |
weight_decay=0.01, | |
betas=(0.9, 0.95), # Conservative betas | |
) | |
self.lr_scheduler = CurriculumAwareLRScheduler(self.optimizer) | |
# Mixed precision scaler with Blackwell optimizations | |
if self.use_mixed_precision: | |
if self.is_blackwell: | |
# Optimized scaler settings for Blackwell | |
self.scaler = torch.cuda.amp.GradScaler( | |
init_scale=2.0**10, # Lower initial scale for stability | |
growth_factor=2.0, | |
backoff_factor=0.5, | |
growth_interval=1000, # Less frequent scaling updates | |
) | |
else: | |
self.scaler = torch.cuda.amp.GradScaler() | |
else: | |
self.scaler = None | |
# Enhanced tracking | |
self.best_accuracy = 0.0 | |
self.plateau_counter = 0 | |
self.regression_test_scores = deque(maxlen=10) | |
# Print smart starting configuration | |
print(f"📚 Smart Start Strategy:") | |
print(f" • Starting: {base_d_model}d × {base_layers} layers") | |
print(f" • Max growth: {self.max_d_model}d × {self.max_layers} layers") | |
print(f" • Growth potential: {self.growth_potential}") | |
print(f" • Max batch size: {max_batch_size}") | |
print(f" • Learning approach: Start small → Grow intelligently") | |
# Initialize wandb only if available | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
if HAS_WANDB: | |
wandb.init( | |
project="enhanced-perpetual-arithmetic", | |
name=( | |
f"blackwell_run_{timestamp}" | |
if self.is_blackwell | |
else f"enhanced_run_{timestamp}" | |
), | |
config={ | |
"initial_difficulty": asdict(self.difficulty), | |
"model_config": { | |
"d_model": self.model.d_model, | |
"nhead": self.model.nhead, | |
"num_layers": self.model.num_layers, | |
}, | |
"optimizer": "AdamW", | |
"lr_scheduler": "CurriculumAware", | |
"mixed_precision": self.use_mixed_precision, | |
"blackwell_optimized": self.is_blackwell, | |
"has_fp8": self.has_fp8, | |
"tensor_core_gen": self.tensor_core_gen, | |
"max_batch_size": self.max_batch_size, | |
}, | |
) | |
else: | |
print("Warning: wandb not available. Running without experiment tracking.") | |
self.use_mixed_precision = use_mixed_precision and torch.cuda.is_available() | |
def _validate_components(self): | |
"""Enhanced validation with parser testing""" | |
print("Validating enhanced components...") | |
# Validate tokenizer | |
tokenizer_valid = validate_tokenizer(self.tokenizer, verbose=True) | |
if not tokenizer_valid: | |
raise ValueError("Tokenizer validation failed!") | |
# Test parser on various expressions | |
test_expressions = [ | |
"1 + 2", | |
"(1 + 2) + 3", | |
"10 - -5", | |
"-42", | |
"((1 + 2) + 3)", | |
"123 + 456 - 789", | |
] | |
parser_valid = True | |
for expr in test_expressions: | |
try: | |
ast = self.parser.parse(expr) | |
cot, result = generate_enhanced_cot_from_tree(ast) | |
print(f"✓ Parsed '{expr}' → {result}") | |
except Exception as e: | |
print(f"✗ Failed to parse '{expr}': {e}") | |
parser_valid = False | |
if not parser_valid: | |
raise ValueError("Enhanced parser validation failed!") | |
# Run CoT regression tests (once at startup) | |
print("\n" + "=" * 60) | |
print("Running CoT Regression Tests (Parser Validation)") | |
print("=" * 60) | |
passed, failed, accuracy = validate_cot_generation( | |
self.tokenizer, self.parser, verbose=True | |
) | |
if accuracy < 1.0: # Require 100% accuracy | |
print(f"❌ CoT regression test accuracy is {accuracy:.1%}") | |
print(f"Failed {failed} out of {passed + failed} tests") | |
raise ValueError( | |
"CoT regression tests must pass with 100% accuracy before training!" | |
) | |
else: | |
print(f"✅ CoT regression tests passed with {accuracy:.1%} accuracy") | |
print("Enhanced component validation complete!\n") | |
def load_checkpoint(self, checkpoint_path): | |
"""Load training checkpoint""" | |
try: | |
checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
# Load model state | |
self.model.load_state_dict(checkpoint["model_state"]) | |
# Load optimizer state | |
self.optimizer.load_state_dict(checkpoint["optimizer_state"]) | |
# Load curriculum state | |
if "difficulty" in checkpoint: | |
difficulty_dict = checkpoint["difficulty"] | |
self.difficulty = DifficultyState(**difficulty_dict) | |
# Load training state | |
if "step" in checkpoint: | |
self.global_step = checkpoint["step"] | |
if "performance_history" in checkpoint: | |
self.performance_history = deque( | |
checkpoint["performance_history"], maxlen=200 | |
) | |
if "difficulty_history" in checkpoint: | |
self.difficulty_history = checkpoint["difficulty_history"] | |
# Reinitialize LR scheduler | |
self.lr_scheduler = CurriculumAwareLRScheduler(self.optimizer) | |
print(f"✅ Loaded checkpoint from {checkpoint_path}") | |
print(f"Resumed at step {self.global_step}") | |
print(f"Current difficulty: {asdict(self.difficulty)}") | |
return True | |
except Exception as e: | |
print(f"❌ Failed to load checkpoint: {e}") | |
return False | |
def evaluate_on_difficulty_level(self, difficulty_config, num_examples=50): | |
"""Evaluate model on specific difficulty level""" | |
self.model.eval() | |
correct = 0 | |
total = 0 | |
losses = [] | |
# Debug: show a few examples | |
debug_examples = min(3, num_examples) | |
with torch.no_grad(): | |
for i in range(num_examples): | |
try: | |
example = generate_enhanced_arithmetic_example( | |
difficulty_config, self.tokenizer, self.parser | |
) | |
# Quick generation test | |
generated_output_str = self.model.generate( | |
example["input"], self.tokenizer, max_length=100 | |
).strip() | |
# Debug output for first few examples | |
if i < debug_examples: | |
print(f"\n Example {i+1}:") | |
print(f" Input: {example['input'].strip()}") | |
print(f" Expected CoT (sample): {example['target'].strip()}") | |
print(f" Generated CoT: {generated_output_str}") | |
print(f" Expected numerical result: {example['result']}") | |
# Check if result is correctly in generated text | |
expected_numerical_result = str(example["result"]) | |
# Check for both patterns: "= {result}<eos>" and "={result}<eos>" | |
# This check focuses on the presence of the final correct answer string. | |
# For a more robust check, one might parse the last line. | |
# However, the model is trained to produce <eos> at the end of the CoT. | |
correct_pattern_with_space = f"= {expected_numerical_result}<eos>" | |
correct_pattern_no_space = f"={expected_numerical_result}<eos>" | |
is_match = False | |
if ( | |
correct_pattern_with_space in generated_output_str | |
or correct_pattern_no_space in generated_output_str | |
): | |
# Further check: ensure it's at the end of a line or the string | |
lines = generated_output_str.split("\n") | |
last_line = lines[-1].strip() | |
if ( | |
last_line == correct_pattern_with_space | |
or last_line == correct_pattern_no_space | |
): | |
is_match = True | |
if is_match: | |
correct += 1 | |
if i < debug_examples: | |
print(f" ✓ Correct!") | |
else: | |
if i < debug_examples: | |
print(f" ✗ Incorrect") | |
total += 1 | |
except Exception as e: | |
# Skip problematic examples | |
if i < debug_examples: | |
print(f"\n Example {i+1} generation/check failed: {e}") | |
continue | |
accuracy = correct / total if total > 0 else 0.0 | |
self.model.train() | |
return accuracy | |
def get_teacher_forcing_ratio(self): | |
"""No teacher forcing - return 0 always""" | |
return 0.0 | |
def smart_curriculum_update(self, current_accuracy): | |
"""Smarter curriculum adaptation""" | |
self.performance_history.append(current_accuracy) | |
self.steps_since_change += 1 | |
if len(self.performance_history) < 50: | |
return False | |
# Calculate trends | |
recent_performance = np.mean(list(self.performance_history)[-50:]) | |
older_performance = ( | |
np.mean(list(self.performance_history)[-100:-50]) | |
if len(self.performance_history) >= 100 | |
else recent_performance | |
) | |
trend = recent_performance - older_performance | |
variance = np.std(list(self.performance_history)[-50:]) | |
difficulty_changed = False | |
# CRITICAL: Don't increase difficulty if we haven't mastered the current level | |
# Check actual evaluation performance, not just training accuracy | |
if hasattr(self, "_last_eval_accuracy") and self._last_eval_accuracy < 0.85: | |
# If evaluation shows poor performance, decrease difficulty | |
if recent_performance < 0.85 and self.steps_since_change > 100: | |
print( | |
f"\n📉 Poor evaluation performance! Decreasing difficulty (train acc: {recent_performance:.3f}, eval acc: {self._last_eval_accuracy:.3f})" | |
) | |
self._decrease_difficulty() | |
difficulty_changed = True | |
else: | |
# Conditions for difficulty increase - much stricter | |
if ( | |
recent_performance > 0.95 | |
and variance < 0.03 | |
and self.steps_since_change > 500 | |
and trend >= -0.01 | |
): | |
print( | |
f"\n📈 Mastery detected! Increasing difficulty (acc: {recent_performance:.3f}, var: {variance:.3f})" | |
) | |
self._increase_difficulty() | |
difficulty_changed = True | |
# Conditions for difficulty decrease | |
elif recent_performance < 0.70 and self.steps_since_change > 100: | |
print( | |
f"\n📉 Struggling detected! Decreasing difficulty (acc: {recent_performance:.3f})" | |
) | |
self._decrease_difficulty() | |
difficulty_changed = True | |
# Plateau detection - but only if performance is good | |
elif ( | |
0.85 <= recent_performance <= 0.94 | |
and variance < 0.02 | |
and abs(trend) < 0.005 | |
and self.steps_since_change > 400 | |
): | |
print( | |
f"\n🔄 Plateau detected! Making lateral change (acc: {recent_performance:.3f})" | |
) | |
self._lateral_change() | |
difficulty_changed = True | |
if difficulty_changed: | |
self.steps_since_change = 0 | |
self.performance_history.clear() | |
self.difficulty_history.append(asdict(self.difficulty)) | |
return difficulty_changed | |
def _increase_difficulty(self): | |
"""Smarter difficulty increase""" | |
old_complexity = self.difficulty.complexity_score() | |
# Priority order for increasing difficulty | |
if self.difficulty.max_digits < 4: | |
self.difficulty.max_digits += 1 | |
elif self.difficulty.max_ops < 3: | |
self.difficulty.max_ops += 1 | |
elif self.difficulty.parentheses_prob < 0.4: | |
self.difficulty.parentheses_prob = min( | |
0.4, self.difficulty.parentheses_prob + 0.1 | |
) | |
elif self.difficulty.negative_prob < 0.3: | |
self.difficulty.negative_prob = min( | |
0.3, self.difficulty.negative_prob + 0.1 | |
) | |
elif not self.difficulty.mixed_digit_sizes: | |
self.difficulty.mixed_digit_sizes = True | |
else: | |
# Try model growth when curriculum gets complex | |
if ( | |
self.difficulty.complexity_score() > 50 | |
and self.model.num_layers < self.max_layers | |
): | |
self._try_model_growth() | |
new_complexity = self.difficulty.complexity_score() | |
print(f"Difficulty complexity: {old_complexity} → {new_complexity}") | |
def _decrease_difficulty(self): | |
"""Smarter difficulty decrease""" | |
if self.difficulty.mixed_digit_sizes: | |
self.difficulty.mixed_digit_sizes = False | |
elif self.difficulty.negative_prob > 0: | |
self.difficulty.negative_prob = max(0, self.difficulty.negative_prob - 0.1) | |
elif self.difficulty.parentheses_prob > 0: | |
self.difficulty.parentheses_prob = max( | |
0, self.difficulty.parentheses_prob - 0.1 | |
) | |
elif self.difficulty.max_ops > 1: | |
self.difficulty.max_ops = max(1, self.difficulty.max_ops - 1) | |
elif self.difficulty.max_digits > 1: | |
self.difficulty.max_digits = max(1, self.difficulty.max_digits - 1) | |
def _lateral_change(self): | |
"""Make lateral changes to explore different aspects""" | |
changes = [] | |
# Only make lateral changes if we're at a reasonable difficulty level | |
if self.difficulty.max_digits == 1 and self.difficulty.max_ops == 1: | |
# At the most basic level, don't add complexity | |
print("At basic level - no lateral changes needed") | |
return | |
# Try different combinations | |
if self.difficulty.parentheses_prob < 0.3 and self.difficulty.max_ops >= 2: | |
self.difficulty.parentheses_prob += 0.1 | |
changes.append("parentheses") | |
if self.difficulty.negative_prob < 0.2 and len(changes) == 0: | |
self.difficulty.negative_prob += 0.1 | |
changes.append("negatives") | |
if ( | |
not changes | |
and not self.difficulty.mixed_digit_sizes | |
and self.difficulty.max_digits >= 2 | |
): | |
self.difficulty.mixed_digit_sizes = True | |
changes.append("mixed_digits") | |
if changes: | |
print(f"Applied lateral changes: {changes}") | |
else: | |
print("No lateral changes available at current difficulty") | |
def _try_model_growth(self): | |
"""Enhanced model growth""" | |
if self.model.num_layers < 6: | |
new_config = {"num_layers": self.model.num_layers + 1} | |
if self.model.grow_model(new_config): | |
# Reinitialize optimizer and scheduler for grown model | |
if self.use_mixed_precision: | |
self.optimizer = torch.optim.AdamW( | |
self.model.parameters(), | |
lr=0.001, | |
weight_decay=0.01, | |
betas=(0.9, 0.95), | |
) | |
self.scaler = ( | |
torch.cuda.amp.GradScaler() | |
) # Reset scaler for new parameters | |
else: | |
self.optimizer = torch.optim.AdamW( | |
self.model.parameters(), | |
lr=0.001, | |
weight_decay=0.01, | |
betas=(0.9, 0.95), | |
) | |
self.lr_scheduler = CurriculumAwareLRScheduler(self.optimizer) | |
print(f"🚀 Model grown to {new_config['num_layers']} layers") | |
return True | |
return False | |
def enhanced_train_step(self, batch_size=32): | |
"""Enhanced training step optimized for RTX 5090 Blackwell""" | |
# Dynamic batch sizing based on GPU capability | |
if self.is_blackwell: | |
# Blackwell: Massive batches | |
batch_size = min(batch_size * 4, self.max_batch_size) | |
elif self.has_fp8: | |
# FP8-capable: Large batches | |
batch_size = min(batch_size * 3, self.max_batch_size) | |
elif self.use_mixed_precision: | |
# Standard mixed precision | |
batch_size = min(batch_size * 2, 64) | |
# else: keep original batch_size | |
# Generate batch with current difficulty | |
config = { | |
"max_digits": self.difficulty.max_digits, | |
"max_ops": self.difficulty.max_ops, | |
"parentheses_prob": self.difficulty.parentheses_prob, | |
"negative_prob": self.difficulty.negative_prob, | |
"mixed_digits": self.difficulty.mixed_digit_sizes, | |
} | |
batch = [] | |
for _ in range(batch_size): | |
try: | |
example = generate_enhanced_arithmetic_example( | |
config, self.tokenizer, self.parser | |
) | |
batch.append(example) | |
except: | |
simple_config = { | |
"max_digits": 1, | |
"max_ops": 1, | |
"parentheses_prob": 0, | |
"negative_prob": 0, | |
} | |
example = generate_enhanced_arithmetic_example( | |
simple_config, self.tokenizer, self.parser | |
) | |
batch.append(example) | |
# Store first example for logging | |
example_to_log_input = None | |
example_to_log_target = None | |
if batch: | |
example_to_log_input = batch[0]["input"] | |
example_to_log_target = batch[0]["target"] | |
# Prepare batch | |
batch_dict = prepare_batch_for_training(batch, self.tokenizer, self.device) | |
# NO TEACHER FORCING - just use the prepared inputs directly | |
input_ids = batch_dict["input_ids"] | |
input_pv = batch_dict["input_place_values"] | |
input_level = batch_dict["input_level_ids"] | |
# Forward pass with Blackwell optimizations | |
self.model.train() | |
if self.use_mixed_precision: | |
# GPU-optimized autocast settings | |
autocast_kwargs = {} | |
if self.is_blackwell: | |
# Blackwell: Enhanced FP8/FP16 mixing | |
autocast_kwargs = { | |
"dtype": torch.float16, | |
} | |
elif self.has_fp8: | |
# FP8-capable: Optimized for 4th gen Tensor Cores | |
autocast_kwargs = { | |
"dtype": torch.float16, | |
} | |
with torch.amp.autocast("cuda", **autocast_kwargs): | |
token_logits, pv_logits, level_logits = self.model( | |
input_ids, | |
input_pv, | |
input_level, | |
attention_mask=batch_dict["attention_mask"], | |
) | |
# Multi-task loss computation | |
token_loss = F.cross_entropy( | |
token_logits.reshape(-1, self.model.vocab_size), | |
batch_dict["target_ids"].reshape(-1), | |
ignore_index=self.tokenizer.pad_token_id, | |
label_smoothing=0.1, | |
) | |
# Auxiliary losses | |
pv_targets = (batch_dict["target_place_values"] + 1).clamp( | |
0, self.model.max_place_value + 1 | |
) | |
valid_pv_mask = (batch_dict["target_place_values"] >= 0) & ( | |
batch_dict["target_ids"] != self.tokenizer.pad_token_id | |
) | |
if valid_pv_mask.any(): | |
pv_loss = F.cross_entropy( | |
pv_logits.reshape(-1, self.model.max_place_value + 2)[ | |
valid_pv_mask.reshape(-1) | |
], | |
pv_targets.reshape(-1)[valid_pv_mask.reshape(-1)], | |
) | |
else: | |
pv_loss = torch.tensor(0.0, device=self.device) | |
level_targets = (batch_dict["target_level_ids"] + 1).clamp( | |
0, self.model.max_level_id + 1 | |
) | |
valid_level_mask = (batch_dict["target_level_ids"] >= 0) & ( | |
batch_dict["target_ids"] != self.tokenizer.pad_token_id | |
) | |
if valid_level_mask.any(): | |
level_loss = F.cross_entropy( | |
level_logits.reshape(-1, self.model.max_level_id + 2)[ | |
valid_level_mask.reshape(-1) | |
], | |
level_targets.reshape(-1)[valid_level_mask.reshape(-1)], | |
) | |
else: | |
level_loss = torch.tensor(0.0, device=self.device) | |
total_loss = token_loss + 0.1 * pv_loss + 0.1 * level_loss | |
# GPU-optimized backward pass | |
self.scaler.scale(total_loss).backward() | |
self.scaler.unscale_(self.optimizer) | |
# Gradient clipping based on model size | |
if self.is_blackwell: | |
clip_value = 2.0 # Higher for massive models | |
elif self.has_fp8: | |
clip_value = 1.5 # Medium for large models | |
else: | |
clip_value = 1.0 # Standard | |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_value) | |
self.scaler.step(self.optimizer) | |
self.scaler.update() | |
self.optimizer.zero_grad() | |
else: | |
# Standard precision path | |
token_logits, pv_logits, level_logits = self.model( | |
input_ids, | |
input_pv, | |
input_level, | |
attention_mask=batch_dict["attention_mask"], | |
) | |
token_loss = F.cross_entropy( | |
token_logits.reshape(-1, self.model.vocab_size), | |
batch_dict["target_ids"].reshape(-1), | |
ignore_index=self.tokenizer.pad_token_id, | |
label_smoothing=0.1, | |
) | |
pv_targets = (batch_dict["target_place_values"] + 1).clamp( | |
0, self.model.max_place_value + 1 | |
) | |
valid_pv_mask = (batch_dict["target_place_values"] >= 0) & ( | |
batch_dict["target_ids"] != self.tokenizer.pad_token_id | |
) | |
if valid_pv_mask.any(): | |
pv_loss = F.cross_entropy( | |
pv_logits.reshape(-1, self.model.max_place_value + 2)[ | |
valid_pv_mask.reshape(-1) | |
], | |
pv_targets.reshape(-1)[valid_pv_mask.reshape(-1)], | |
) | |
else: | |
pv_loss = torch.tensor(0.0, device=self.device) | |
level_targets = (batch_dict["target_level_ids"] + 1).clamp( | |
0, self.model.max_level_id + 1 | |
) | |
valid_level_mask = (batch_dict["target_level_ids"] >= 0) & ( | |
batch_dict["target_ids"] != self.tokenizer.pad_token_id | |
) | |
if valid_level_mask.any(): | |
level_loss = F.cross_entropy( | |
level_logits.reshape(-1, self.model.max_level_id + 2)[ | |
valid_level_mask.reshape(-1) | |
], | |
level_targets.reshape(-1)[valid_level_mask.reshape(-1)], | |
) | |
else: | |
level_loss = torch.tensor(0.0, device=self.device) | |
total_loss = token_loss + 0.1 * pv_loss + 0.1 * level_loss | |
total_loss.backward() | |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
# LR scheduler step | |
lr = self.lr_scheduler.step() | |
# Calculate accuracies | |
with torch.no_grad(): | |
mask = batch_dict["target_ids"] != self.tokenizer.pad_token_id | |
token_preds = token_logits.float().argmax(dim=-1) | |
token_acc = ( | |
(token_preds == batch_dict["target_ids"])[mask].float().mean().item() | |
) | |
if valid_pv_mask.any(): | |
pv_preds = pv_logits.float().argmax(dim=-1) - 1 | |
pv_acc = ( | |
(pv_preds == batch_dict["target_place_values"])[valid_pv_mask] | |
.float() | |
.mean() | |
.item() | |
) | |
else: | |
pv_acc = 1.0 | |
if valid_level_mask.any(): | |
level_preds = level_logits.float().argmax(dim=-1) - 1 | |
level_acc = ( | |
(level_preds == batch_dict["target_level_ids"])[valid_level_mask] | |
.float() | |
.mean() | |
.item() | |
) | |
else: | |
level_acc = 1.0 | |
return { | |
"loss": total_loss.item(), | |
"token_loss": token_loss.item(), | |
"pv_loss": pv_loss.item(), | |
"level_loss": level_loss.item(), | |
"token_acc": token_acc, | |
"pv_acc": pv_acc, | |
"level_acc": level_acc, | |
"teacher_forcing": 0.0, # Always 0 now | |
"learning_rate": lr, | |
"batch_size": batch_size, | |
"mixed_precision": self.use_mixed_precision, | |
"blackwell_optimized": self.is_blackwell, | |
"example_input_for_log": example_to_log_input, | |
"example_target_for_log": example_to_log_target, | |
} | |
def comprehensive_evaluation(self): | |
"""Run comprehensive evaluation across difficulty levels""" | |
print("\n" + "=" * 60) | |
print("Comprehensive Evaluation") | |
print("=" * 60) | |
# Test on multiple difficulty levels | |
test_configs = [ | |
{"max_digits": 1, "max_ops": 1, "parentheses_prob": 0, "negative_prob": 0}, | |
{ | |
"max_digits": 2, | |
"max_ops": 2, | |
"parentheses_prob": 0.2, | |
"negative_prob": 0.1, | |
}, | |
{ | |
"max_digits": 3, | |
"max_ops": 3, | |
"parentheses_prob": 0.4, | |
"negative_prob": 0.2, | |
}, | |
{ | |
"max_digits": 4, | |
"max_ops": 4, | |
"parentheses_prob": 0.6, | |
"negative_prob": 0.3, | |
}, | |
] | |
results = {} | |
for i, config in enumerate(test_configs): | |
acc = self.evaluate_on_difficulty_level(config, num_examples=30) | |
results[f"level_{i+1}"] = acc | |
print( | |
f"Level {i+1} (max_digits={config['max_digits']}, max_ops={config['max_ops']}): {acc:.3f}" | |
) | |
avg_acc = np.mean(list(results.values())) | |
print(f"Average accuracy: {avg_acc:.3f}") | |
print("=" * 60 + "\n") | |
# Store the evaluation accuracy for curriculum decisions | |
self._last_eval_accuracy = results.get( | |
"level_1", 0.0 | |
) # Focus on basic level performance | |
# Log to wandb | |
wandb.log( | |
{ | |
**{f"eval/{k}": v for k, v in results.items()}, | |
"eval/average": avg_acc, | |
"step": self.global_step, | |
} | |
) | |
return avg_acc, results | |
def train(self, max_steps=None): | |
"""Enhanced training loop""" | |
print("Starting enhanced perpetual training...") | |
print(f"Initial difficulty: {asdict(self.difficulty)}") | |
print(f"Model: {self.model.num_layers} layers, {self.model.d_model} d_model") | |
step = 0 | |
last_comprehensive_eval = 0 | |
while max_steps is None or step < max_steps: | |
# Training step | |
metrics = self.enhanced_train_step() | |
# Update curriculum | |
difficulty_changed = self.smart_curriculum_update(metrics["token_acc"]) | |
if difficulty_changed: | |
self.lr_scheduler.step(difficulty_just_changed=True) | |
# Logging | |
if step % 50 == 0: | |
recent_acc = ( | |
np.mean(list(self.performance_history)[-10:]) | |
if self.performance_history | |
else metrics["token_acc"] | |
) | |
wandb.log( | |
{ | |
"train/loss": metrics["loss"], | |
"train/token_loss": metrics["token_loss"], | |
"train/pv_loss": metrics["pv_loss"], | |
"train/level_loss": metrics["level_loss"], | |
"train/token_acc": metrics["token_acc"], | |
"train/pv_acc": metrics["pv_acc"], | |
"train/level_acc": metrics["level_acc"], | |
"train/recent_acc": recent_acc, | |
"train/teacher_forcing": metrics["teacher_forcing"], | |
"train/learning_rate": metrics["learning_rate"], | |
"train/batch_size": metrics["batch_size"], | |
"train/mixed_precision": metrics["mixed_precision"], | |
"curriculum/max_digits": self.difficulty.max_digits, | |
"curriculum/max_ops": self.difficulty.max_ops, | |
"curriculum/parentheses_prob": self.difficulty.parentheses_prob, | |
"curriculum/negative_prob": self.difficulty.negative_prob, | |
"curriculum/complexity": self.difficulty.complexity_score(), | |
"model/num_layers": self.model.num_layers, | |
"training/steps_since_change": self.steps_since_change, | |
"step": step, | |
} | |
) | |
if step % 200 == 0: | |
if metrics.get("blackwell_optimized", False): | |
mode_str = "BLACKWELL" | |
elif metrics.get("has_fp8", False): | |
mode_str = "FP8" | |
elif metrics["mixed_precision"]: | |
mode_str = "MP" | |
else: | |
mode_str = "FP32" | |
# Show current model size in logs | |
model_size = f"{self.model.d_model}d×{self.model.num_layers}L" | |
print( | |
f"Step {step}: Loss={metrics['loss']:.4f}, TokenAcc={metrics['token_acc']:.3f}, " | |
f"PVAcc={metrics['pv_acc']:.3f}, LevelAcc={metrics['level_acc']:.3f}, " | |
f"BS={metrics['batch_size']}, {mode_str}, {model_size}, " | |
f"LR={metrics['learning_rate']:.6f}, Complexity={self.difficulty.complexity_score():.1f}" | |
) | |
# Print a training example and model output | |
if ( | |
metrics["example_input_for_log"] | |
and metrics["example_target_for_log"] | |
): | |
print( | |
" ┌─ Training Example & Model Output (Current Difficulty) ─┐" | |
) | |
example_input = metrics["example_input_for_log"] | |
example_target = metrics["example_target_for_log"] | |
print(f" │ Input: {example_input.strip()}") | |
print(f" │ Expected: {example_target.strip()}") | |
self.model.eval() # Switch to eval mode for generation | |
with torch.no_grad(): | |
generated_output = self.model.generate( | |
example_input, self.tokenizer, max_length=100 | |
) | |
self.model.train() # Switch back to train mode | |
print(f" │ Generated:{generated_output.strip()}") | |
print( | |
" └────────────────────────────────────────────────────────┘" | |
) | |
# Comprehensive evaluation - more frequent at the beginning | |
eval_interval = 1000 if step < 10000 else 2000 | |
if step - last_comprehensive_eval >= eval_interval and step > 0: | |
avg_acc, eval_results = self.comprehensive_evaluation() | |
last_comprehensive_eval = step | |
# Track best performance | |
if avg_acc > self.best_accuracy: | |
self.best_accuracy = avg_acc | |
self.save_checkpoint(step, is_best=True) | |
# If we're doing poorly on basic level, force a curriculum reset | |
if ( | |
eval_results.get("level_1", 0) < 0.5 | |
and self.difficulty.complexity_score() > 10 | |
): | |
print( | |
"\n⚠️ CRITICAL: Basic arithmetic accuracy too low! Resetting to fundamentals." | |
) | |
self.difficulty = DifficultyState() | |
self.difficulty.max_digits = 1 | |
self.difficulty.max_ops = 1 | |
self.difficulty.parentheses_prob = 0.0 | |
self.difficulty.negative_prob = 0.0 | |
self.steps_since_change = 0 | |
self.performance_history.clear() | |
# Regular checkpoints | |
if step % 10000 == 0 and step > 0: | |
self.save_checkpoint(step) | |
step += 1 | |
self.global_step = step | |
def save_checkpoint(self, step, is_best=False): | |
"""Enhanced checkpoint saving""" | |
checkpoint = { | |
"model_state": self.model.state_dict(), | |
"optimizer_state": self.optimizer.state_dict(), | |
"difficulty": asdict(self.difficulty), | |
"difficulty_history": self.difficulty_history, | |
"step": step, | |
"performance_history": list(self.performance_history), | |
"best_accuracy": self.best_accuracy, | |
"mixed_precision": self.use_mixed_precision, | |
"model_config": { | |
"vocab_size": self.model.vocab_size, | |
"d_model": self.model.d_model, | |
"nhead": self.model.nhead, | |
"num_layers": self.model.num_layers, | |
"dim_feedforward": self.model.dim_feedforward, | |
}, | |
} | |
suffix = "_best" if is_best else "" | |
path = f"enhanced_checkpoint_step_{step}{suffix}.pt" | |
torch.save(checkpoint, path) | |
print(f"💾 Saved checkpoint: {path}") | |
# Upload to wandb if available | |
if HAS_WANDB: | |
wandb.save(path) | |
# ============= Interactive Testing Mode ============= | |
class InteractiveArithmeticREPL: | |
"""Interactive REPL for testing trained models""" | |
def __init__(self, model_path=None, device="cuda"): | |
self.device = torch.device(device if torch.cuda.is_available() else "cpu") | |
self.tokenizer = ArithmeticTokenizer() | |
self.parser = ArithmeticParser(self.tokenizer) | |
if model_path and os.path.exists(model_path): | |
self.load_model(model_path) | |
else: | |
print("No model loaded. Use 'load <path>' to load a checkpoint.") | |
self.model = None | |
def load_model(self, checkpoint_path): | |
"""Load model from checkpoint""" | |
try: | |
checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
model_config = checkpoint.get( | |
"model_config", | |
{ | |
"vocab_size": len(self.tokenizer.vocab), | |
"d_model": 256, | |
"nhead": 8, | |
"num_layers": 3, | |
"dim_feedforward": 1024, | |
}, | |
) | |
self.model = GrowableMultiTaskTransformer(**model_config).to(self.device) | |
self.model.load_state_dict(checkpoint["model_state"]) | |
self.model.eval() | |
print(f"✅ Loaded model from {checkpoint_path}") | |
print( | |
f"Model: {model_config['num_layers']} layers, {model_config['d_model']} d_model" | |
) | |
if "difficulty" in checkpoint: | |
print(f"Last difficulty: {checkpoint['difficulty']}") | |
if "step" in checkpoint: | |
print(f"Training step: {checkpoint['step']}") | |
except Exception as e: | |
print(f"❌ Failed to load model: {e}") | |
self.model = None | |
def generate_solution(self, expression, max_length=200): | |
"""Generate CoT solution for expression""" | |
if self.model is None: | |
print("❌ No model loaded!") | |
return None | |
try: | |
# Clean input | |
expression = expression.strip() | |
if not expression.endswith("\n"): | |
expression += "\n" | |
# Generate using model | |
with torch.no_grad(): | |
result = self.model.generate(expression, self.tokenizer, max_length) | |
# Extract just the CoT part (remove input echo) | |
if expression.strip() in result: | |
cot_part = result.split(expression.strip(), 1)[1] | |
else: | |
cot_part = result | |
return cot_part.strip() | |
except Exception as e: | |
print(f"❌ Generation failed: {e}") | |
return None | |
def generate_ground_truth(self, expression): | |
"""Generate ground truth CoT using parser""" | |
try: | |
ast = self.parser.parse(expression.strip()) | |
cot, final_result = generate_enhanced_cot_from_tree(ast) | |
return cot.strip(), final_result | |
except Exception as e: | |
print(f"❌ Parser failed: {e}") | |
return None, None | |
def test_expression(self, expression): | |
"""Test expression with both model and ground truth""" | |
print(f"\n🧮 Testing: {expression}") | |
print("=" * 50) | |
# Ground truth | |
gt_cot, gt_result = self.generate_ground_truth(expression) | |
if gt_cot: | |
print(f"📚 Ground Truth:") | |
print(f" {gt_cot}") | |
print(f" Final Result: {gt_result}") | |
# Model prediction | |
if self.model: | |
model_cot = self.generate_solution(expression) | |
if model_cot: | |
print(f"\n🤖 Model Prediction:") | |
print(f" {model_cot}") | |
# Check if correct | |
if gt_cot and model_cot == gt_cot: | |
print(" ✅ CORRECT!") | |
elif gt_cot: | |
print(" ❌ INCORRECT") | |
# Try to extract final number | |
try: | |
if "<eos>" in model_cot: | |
final_line = model_cot.split("<eos>")[0].split("\n")[-1] | |
if "=" in final_line: | |
predicted_result = final_line.split("=")[-1].strip() | |
print(f" Model Result: {predicted_result}") | |
if ( | |
gt_result is not None | |
and str(gt_result) == predicted_result | |
): | |
print(" ✅ Final answer is correct!") | |
elif gt_result is not None: | |
print( | |
f" ❌ Final answer wrong (expected {gt_result})" | |
) | |
except: | |
pass | |
print("=" * 50) | |
def run_test_suite(self): | |
"""Run a suite of test expressions""" | |
test_expressions = [ | |
"1 + 1", | |
"5 - 3", | |
"10 + -5", | |
"(2 + 3) + 4", | |
"10 - -5", | |
"123 + 456", | |
"((1 + 2) + 3)", | |
"29 - (-10 + 5)", | |
"-42", | |
"(7)", | |
] | |
print("\n🧪 Running Test Suite") | |
print("=" * 60) | |
correct = 0 | |
total = 0 | |
for expr in test_expressions: | |
try: | |
gt_cot, gt_result = self.generate_ground_truth(expr) | |
if self.model and gt_cot: | |
model_cot = self.generate_solution(expr) | |
if model_cot == gt_cot: | |
correct += 1 | |
print(f"✅ {expr}") | |
else: | |
print(f"❌ {expr}") | |
total += 1 | |
else: | |
print(f"? {expr} (skipped)") | |
except: | |
print(f"❌ {expr} (error)") | |
if total > 0: | |
print(f"\n📊 Results: {correct}/{total} correct ({correct/total*100:.1f}%)") | |
print("=" * 60) | |
def repl(self): | |
"""Interactive REPL mode""" | |
print("\n🎯 Interactive Arithmetic REPL") | |
print("Commands:") | |
print(" <expression> - Test arithmetic expression (e.g., '1 + 2')") | |
print(" load <path> - Load model checkpoint") | |
print(" test - Run test suite") | |
print(" gt <expression> - Show ground truth only") | |
print(" help - Show this help") | |
print(" exit/quit - Exit REPL") | |
print("-" * 50) | |
while True: | |
try: | |
user_input = input("\n➤ ").strip() | |
if not user_input: | |
continue | |
elif user_input.lower() in ["exit", "quit", "q"]: | |
print("👋 Goodbye!") | |
break | |
elif user_input.lower() in ["help", "h"]: | |
print( | |
"Commands: <expression>, load <path>, test, gt <expr>, help, exit" | |
) | |
elif user_input.lower() == "test": | |
self.run_test_suite() | |
elif user_input.lower().startswith("load "): | |
path = user_input[5:].strip() | |
self.load_model(path) | |
elif user_input.lower().startswith("gt "): | |
expr = user_input[3:].strip() | |
gt_cot, gt_result = self.generate_ground_truth(expr) | |
if gt_cot: | |
print(f"📚 Ground Truth for '{expr}':") | |
print(f" {gt_cot}") | |
print(f" Final Result: {gt_result}") | |
else: | |
print(f"❌ Could not parse '{expr}'") | |
else: | |
# Treat as arithmetic expression | |
self.test_expression(user_input) | |
except KeyboardInterrupt: | |
print("\n👋 Goodbye!") | |
break | |
except Exception as e: | |
print(f"❌ Error: {e}") | |
def single_expression_test(expression, model_path=None): | |
"""Quick single expression test""" | |
repl = InteractiveArithmeticREPL(model_path) | |
repl.test_expression(expression) | |
# ============= Example Usage ============= | |
# ============= Example Usage with Error Handling ============= | |
if __name__ == "__main__": | |
import sys | |
if len(sys.argv) > 1: | |
if sys.argv[1] == "repl": | |
# Interactive REPL mode | |
model_path = sys.argv[2] if len(sys.argv) > 2 else None | |
repl = InteractiveArithmeticREPL(model_path) | |
repl.repl() | |
elif sys.argv[1] == "test": | |
# Single expression test | |
if len(sys.argv) < 3: | |
print("Usage: python script.py test '<expression>' [model_path]") | |
sys.exit(1) | |
expression = sys.argv[2] | |
model_path = sys.argv[3] if len(sys.argv) > 3 else None | |
single_expression_test(expression, model_path) | |
elif sys.argv[1] == "train": | |
# Training mode with optional mixed precision flag | |
use_mp = "--mixed-precision" in sys.argv or "--mp" in sys.argv | |
no_mp = "--no-mixed-precision" in sys.argv or "--no-mp" in sys.argv | |
if no_mp: | |
use_mixed_precision = False | |
elif use_mp: | |
use_mixed_precision = True | |
else: | |
use_mixed_precision = True # Default to enabled | |
try: | |
trainer = EnhancedPerpetualTrainer( | |
use_mixed_precision=use_mixed_precision | |
) | |
# Parse max_steps from args (skip flags) | |
max_steps = None | |
for arg in sys.argv[2:]: | |
if not arg.startswith("--"): | |
try: | |
max_steps = int(arg) | |
break | |
except: | |
pass | |
trainer.train(max_steps=max_steps) | |
except Exception as e: | |
print(f"Training failed: {e}") | |
import traceback | |
traceback.print_exc() | |
print("Make sure all dependencies are installed:") | |
print("pip install torch tqdm numpy") | |
elif sys.argv[1] == "validate": | |
# Just run validation | |
use_mp = "--mixed-precision" in sys.argv or "--mp" in sys.argv | |
no_mp = "--no-mixed-precision" in sys.argv or "--no-mp" in sys.argv | |
if no_mp: | |
use_mixed_precision = False | |
elif use_mp: | |
use_mixed_precision = True | |
else: | |
use_mixed_precision = True | |
try: | |
trainer = EnhancedPerpetualTrainer( | |
use_mixed_precision=use_mixed_precision | |
) | |
print("✅ All components validated successfully!") | |
except Exception as e: | |
print(f"❌ Validation failed: {e}") | |
import traceback | |
traceback.print_exc() | |
else: | |
print("Usage:") | |
print( | |
" python script.py train [max_steps] [--mp/--no-mp] - Start training" | |
) | |
print( | |
" python script.py repl [model_path] - Interactive REPL" | |
) | |
print( | |
" python script.py test '<expr>' [model] - Test single expression" | |
) | |
print( | |
" python script.py validate [--mp/--no-mp] - Validate components only" | |
) | |
print("\nMixed Precision Options:") | |
print(" --mixed-precision, --mp - Enable mixed precision (default)") | |
print(" --no-mixed-precision, --no-mp - Disable mixed precision") | |
else: | |
# Default: show usage | |
print("Enhanced Perpetual Arithmetic Transformer") | |
print("Usage:") | |
print(" python script.py train [max_steps] [--mp/--no-mp] - Start training") | |
print(" python script.py repl [model_path] - Interactive REPL") | |
print( | |
" python script.py test '<expr>' [model] - Test single expression" | |
) | |
print( | |
" python script.py validate [--mp/--no-mp] - Validate components only" | |
) | |
print("\nMixed Precision Options:") | |
print(" --mixed-precision, --mp - Enable mixed precision (default)") | |
print(" --no-mixed-precision, --no-mp - Disable mixed precision") | |
print("\nExample:") | |
print(" python script.py validate") | |
print(" python script.py train 1000 --mp") | |
print(" python script.py train --no-mp") | |
print(" python script.py repl checkpoint.pt") | |
print(" python script.py test '(1 + 2) + 3'") | |
if not HAS_WANDB: | |
print( | |
"\nNote: wandb not installed. Install with 'pip install wandb' for experiment tracking." | |
) | |
# Show GPU info | |
if torch.cuda.is_available(): | |
gpu_name = torch.cuda.get_device_name() | |
print(f"\n🔥 GPU detected: {gpu_name}") | |
if "RTX 50" in gpu_name or "5090" in gpu_name: | |
print(" 🚀 BLACKWELL ARCHITECTURE DETECTED! 🚀") | |
print(" ⚡ 5th Gen Tensor Cores + 32GB GDDR7") | |
print(" 💫 Enabling maximum performance optimizations!") | |
print(" 🎯 Expected 2.5x speedup vs RTX 4090") | |
elif "RTX" in gpu_name or "A100" in gpu_name or "V100" in gpu_name: | |
print( | |
" ✅ Tensor Cores available - mixed precision will provide significant speedup!" | |
) | |
else: | |
print( | |
" ⚠️ Older GPU - mixed precision may provide memory savings but limited speedup" | |
) | |
else: | |
print( | |
"\n💻 No GPU detected - training will be slow, mixed precision disabled" | |
) |
Processing: '382108209482142142189-421382832321'
--- Python eval() Ground Truth: 382108209060759309868 ---
--- Generated CoT (Our System) ---
= 382108209060759309868<eos>
--- Our System Evaluated Result: 382108209060759309868 ---
(Our system's result MATCHES Python eval)
expr> 123290429049210392320-13218294829184921312+84284908021380213802183028302132132
Processing: '123290429049210392320-13218294829184921312+84284908021380213802183028302132132'
--- Python eval() Ground Truth: 84284908021380323874317248327603140 ---
--- Generated CoT (Our System) ---
= 110072134220025471008 + 84284908021380213802183028302132132
= 84284908021380323874317248327603140<eos>
--- Our System Evaluated Result: 84284908021380323874317248327603140 ---
(Our system's result MATCHES Python eval)
expr>
ok the current version of code is wrong.
I am testing removing position encoding entirely, use addition for sinusoidal mapped level ID and place value, remove embedding projection, teacher forcing training and predicting token_id / place value / level ID at the same time
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
<1 hour with RTX 5090: