Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Created March 20, 2025 12:09
Show Gist options
  • Save vwxyzjn/96189b726eae2afa8e07f24f82eb10f0 to your computer and use it in GitHub Desktop.
Save vwxyzjn/96189b726eae2afa8e07f24f82eb10f0 to your computer and use it in GitHub Desktop.
from collections import deque
import queue
import time
import numpy as np
import ray
from vllm import SamplingParams, LLM
import wandb
from open_instruct.dataset_transformation import TokenizerConfig, get_cached_dataset_rlvr
from open_instruct.vllm_utils3 import create_vllm_engines
from transformers import HfArgumentParser
from dataclasses import asdict, dataclass
from typing import Literal, NamedTuple, Optional
from rich.pretty import pprint
"""
python scripts/benchmark_batch_generate.py \
--generate_method chunked \
--max_tokens 2048 \
--max_chunk_size 1024 \
--num_prompts 100 \
--model_name_or_path HuggingFaceTB/SmolLM2-135M-Instruct \
--revision main \
--num_engines 8 \
--debug True
python scripts/benchmark_batch_generate.py \
--generate_method v2 \
--max_tokens 2048 \
--max_chunk_size 10 \
--num_prompts 100 \
--model_name_or_path HuggingFaceTB/SmolLM2-135M-Instruct \
--revision main \
--num_engines 2 \
--debug True
python scripts/benchmark_batch_generate2.py \
--generate_method v3 \
--max_tokens 2048 \
--max_chunk_size 512 \
--num_prompts 100 \
--max_batch_size 50 \
--model_name_or_path HuggingFaceTB/SmolLM2-135M-Instruct \
--revision main \
--num_engines 2 \
--eager True
# ----------------------------------------------------------------------------
# benchmark for 8 H100s
# chunked
python scripts/benchmark_batch_generate.py \
--generate_method chunked \
--max_tokens 8192 \
--max_chunk_size 1024 \
--num_prompts 100 \
--model_name_or_path Qwen/Qwen2.5-7B \
--revision main \
--num_engines 8 \
# naive
python scripts/benchmark_batch_generate.py \
--generate_method naive \
--max_tokens 8192 \
--num_prompts 100 \
--model_name_or_path Qwen/Qwen2.5-7B \
--revision main \
--num_engines 8 \
# 32B chunked
python scripts/benchmark_batch_generate.py \
--generate_method chunked \
--max_tokens 8192 \
--max_chunk_size 1024 \
--num_prompts 100 \
--model_name_or_path Qwen/Qwen2.5-32B \
--revision main \
--num_engines 4 \
--tensor_parallel_size 2 \
# 32B naive
python scripts/benchmark_batch_generate.py \
--generate_method naive \
--max_tokens 8192 \
--num_prompts 100 \
--model_name_or_path Qwen/Qwen2.5-32B \
--revision main \
--num_engines 4 \
--tensor_parallel_size 2 \
"""
@dataclass
class Args:
generate_method: Literal["chunked", "naive", "v2", "v3"] = "chunked"
max_chunk_size: int = 1000
num_prompts: int = 100
max_tokens: int = 8192
num_engines: int = 2
tensor_parallel_size: int = 1
max_batch_size: int = 100
debug: bool = False
eager: bool = False
def generate_with_engines(prompts: list[list[int]], sampling_params: SamplingParams, vllm_engines: list[LLM]):
# Split queries between engines
queries_per_engine = (len(prompts) + len(vllm_engines) - 1) // len(vllm_engines)
split_queries = [prompts[i : i + queries_per_engine] for i in range(0, len(prompts), queries_per_engine)]
# Generate responses in parallel across engines
start_time = time.time()
futures = [
vllm_engine.generate.remote(sampling_params=sampling_params, prompt_token_ids=queries, use_tqdm=False)
for vllm_engine, queries in zip(vllm_engines, split_queries)
]
future_execution_times = [None] * len(futures)
num_done_futures = 0
while num_done_futures < len(futures):
for i in range(len(futures)):
if future_execution_times[i] is not None:
continue
is_ready = len(ray.wait([futures[i]], timeout=0.01)[0]) > 0
if is_ready:
future_execution_times[i] = time.time() - start_time
num_done_futures += 1
print(f"🔥 {future_execution_times=}")
print(f"🔥 discrepency: {max(future_execution_times) - min(future_execution_times)=}")
# Gather all responses
all_outputs = ray.get(futures)
response_ids = []
finish_reasons = [] # either "stop" or "length"
for outputs in all_outputs:
response_ids.extend([list(out.token_ids) for output in outputs for out in output.outputs])
finish_reasons.extend([out.finish_reason for output in outputs for out in output.outputs])
return response_ids, finish_reasons
def chunked_generate_with_engines(prompts: list[list[int]], sampling_params: SamplingParams, max_chunk_size: int, vllm_engines: list[LLM]):
max_new_tokens = sampling_params.max_tokens
chunked_sampling_params = SamplingParams(
max_tokens=max_chunk_size,
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
top_k=sampling_params.top_k,
repetition_penalty=sampling_params.repetition_penalty,
)
prompt_and_responses = [prompt.copy() for prompt in prompts]
idxs = list(range(len(prompt_and_responses)))
dones = [0] * len(prompt_and_responses)
finish_reasons = ["length"] * len(prompt_and_responses)
max_iterations = max_new_tokens // max_chunk_size # don't do ceil div here because we don't want to over generate
for _ in range(max_iterations):
if all(dones):
break
not_done_idxs = [i for i in idxs if dones[i] == 0]
cur_prompt_and_responses = [prompt_and_responses[i] for i in not_done_idxs]
samples_per_engine = (len(cur_prompt_and_responses) + len(vllm_engines) - 1) // len(vllm_engines)
print(f"🔥 {samples_per_engine=}")
split_prompt_and_responses = [cur_prompt_and_responses[i : i + samples_per_engine] for i in range(0, len(cur_prompt_and_responses), samples_per_engine)]
futures = [
vllm_engine.generate.remote(sampling_params=chunked_sampling_params, prompt_token_ids=queries, use_tqdm=True)
for vllm_engine, queries in zip(vllm_engines, split_prompt_and_responses)
]
all_outputs = ray.get(futures)
for i, outputs in enumerate(all_outputs):
for j, output in enumerate(outputs):
seq_idx = not_done_idxs[i*samples_per_engine + j]
out = output.outputs[0] # we assume num_samples_per_prompt_rollout == 1
prompt_and_responses[seq_idx].extend(list(out.token_ids))
if out.finish_reason == "stop":
dones[seq_idx] = 1
finish_reasons[seq_idx] = out.finish_reason
response_ids = [prompt_and_response[len(prompt):] for prompt, prompt_and_response in zip(prompts, prompt_and_responses)]
return response_ids, finish_reasons
@dataclass
class PromptAndResponse:
id: int
response_length: int
prompt_and_response: list[int]
finish_reason: Optional[str] = None
def chunked_generate_with_engines_v2(prompts: list[list[int]], sampling_params: SamplingParams, max_chunk_size: int, max_batch_size: int, vllm_engines: list[LLM]):
"""split the generations in to batch sizes and chunks"""
max_new_tokens = sampling_params.max_tokens
chunked_sampling_params = SamplingParams(
max_tokens=max_chunk_size,
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
top_k=sampling_params.top_k,
repetition_penalty=sampling_params.repetition_penalty,
)
unfinished_prompt_and_responses = deque([PromptAndResponse(i, 0, prompt.copy()) for i, prompt in enumerate(prompts)])
finished_prompt_and_responses = []
busy_engines = [False] * len(vllm_engines)
while len(finished_prompt_and_responses) < len(prompts):
# check if there is anything in the output queue of each engine
for i, vllm_engine in enumerate(vllm_engines):
try:
outputs = ray.get(vllm_engine.output_queue_get.remote(block=True, timeout=0.01))
busy_engines[i] = False
for prompt_and_response in outputs:
if prompt_and_response.response_length + max_chunk_size > max_new_tokens:
prompt_and_response.finish_reason = "length"
finished_prompt_and_responses.append(prompt_and_response)
elif prompt_and_response.finish_reason == "stop":
finished_prompt_and_responses.append(prompt_and_response)
else:
unfinished_prompt_and_responses.append(prompt_and_response)
except queue.Empty:
if not busy_engines[i]:
per_engine_data = []
for j in range(max_batch_size):
if unfinished_prompt_and_responses:
per_engine_data.append(unfinished_prompt_and_responses.popleft())
else:
break
if per_engine_data:
print(f"🔥 putting {i=}")
vllm_engine.input_queue_put.remote((chunked_sampling_params, per_engine_data))
busy_engines[i] = True
continue
assert len(finished_prompt_and_responses) == len(prompts)
response_ids = [prompt_and_response.prompt_and_response[len(prompt):] for prompt, prompt_and_response in zip(prompts, finished_prompt_and_responses)]
finish_reasons = [prompt_and_response.finish_reason for prompt_and_response in finished_prompt_and_responses]
return response_ids, finish_reasons
def generate_with_engines_v3(prompts: list[list[int]], sampling_params: SamplingParams, max_batch_size: int, vllm_engines: list[LLM]):
"""split the generations in to batch sizes and chunks"""
total_prompts = len(prompts)
finished_prompt_and_responses = []
prompts_q = deque([(i, prompt.copy()) for i, prompt in enumerate(prompts)])
# prompts = deque([prompt.copy() for i, prompt in enumerate(prompts)])
while len(finished_prompt_and_responses) < total_prompts:
if len(prompts_q) > 0:
for i, vllm_engine in enumerate(vllm_engines):
num_unfinished_requests = ray.get(vllm_engine.get_num_unfinished_requests.remote()) # , timeout=0.5 # this should error out if this takes too long
if num_unfinished_requests < max_batch_size:
cur_prompts = []
for _ in range(max_batch_size - num_unfinished_requests):
if prompts_q:
cur_prompts.append(prompts_q.popleft())
else:
break
if len(cur_prompts) > 0:
vllm_engine.input_queue_put.remote((cur_prompts, sampling_params))
for i, vllm_engine in enumerate(vllm_engines):
try:
outputs = ray.get(vllm_engine.output_queue_get.remote(block=True, timeout=0.01))
for output in outputs:
finished_prompt_and_responses.append(output)
except queue.Empty:
continue
assert len(finished_prompt_and_responses) == total_prompts
finished_prompt_and_responses = sorted(finished_prompt_and_responses, key=lambda x: int(x.request_id))
response_ids = [list(prompt_and_response.outputs[0].token_ids)[len(prompt):] for prompt, prompt_and_response in zip(prompts, finished_prompt_and_responses)]
finish_reasons = [prompt_and_response.outputs[0].finish_reason for prompt_and_response in finished_prompt_and_responses]
return response_ids, finish_reasons
def main(args: Args, tc: TokenizerConfig):
all_configs = {}
all_configs.update(**asdict(args), **asdict(tc))
run = wandb.init(
project="open_instruct_internal",
name=f"{args.generate_method}_{args.max_chunk_size}_{args.max_batch_size}",
config=all_configs,
)
tokenizer = tc.tokenizer
vllm_engines = create_vllm_engines(
num_engines=args.num_engines,
tensor_parallel_size=args.tensor_parallel_size,
enforce_eager=args.eager,
pretrain=tc.model_name_or_path,
revision=tc.revision,
seed=42,
enable_prefix_caching=False,
max_model_len=args.max_tokens,
)
ray.get([vllm_engine.ping.remote() for vllm_engine in vllm_engines])
train_dataset = get_cached_dataset_rlvr(
dataset_mixer_list=["ai2-adapt-dev/rlvr_open_reasoner_math", "1.0"],
dataset_mixer_list_splits=["train"],
tc=tc,
max_prompt_token_length=1024,
max_token_length=2048,
hf_entity="allenai",
)
prompts = train_dataset[:args.num_prompts]["input_ids_prompt"]
sampling_params = SamplingParams(
max_tokens=args.max_tokens,
temperature=0.7,
)
if args.debug:
prompts = [
[{"role": "user", "content": "Could you explain the PPO algorithm?"}],
[{"role": "user", "content": "What is nuclear physics?"}]
]
prompts = [tokenizer.apply_chat_template(item, add_generation_prompt=True) for item in prompts]
start_time = time.time()
if args.generate_method == "chunked":
response_ids, finish_reasons = chunked_generate_with_engines(prompts, sampling_params, max_chunk_size=args.max_chunk_size, vllm_engines=vllm_engines)
elif args.generate_method == "naive":
response_ids, finish_reasons = generate_with_engines(prompts, sampling_params, vllm_engines)
elif args.generate_method == "v2":
response_ids, finish_reasons = chunked_generate_with_engines_v2(prompts, sampling_params, max_chunk_size=args.max_chunk_size, max_batch_size=args.max_batch_size, vllm_engines=vllm_engines)
elif args.generate_method == "v3":
response_ids, finish_reasons = generate_with_engines_v3(prompts, sampling_params, max_batch_size=args.max_batch_size, vllm_engines=vllm_engines)
else:
raise ValueError(f"Invalid generate method: {args.generate_method}")
end_time = time.time()
total_time = end_time - start_time
generated_tokens = sum([len(item) for item in response_ids])
run.log(
{
"total_time": total_time,
"generated_tokens": generated_tokens,
"output_tokens_per_second": generated_tokens / total_time,
}
)
pprint([args, tc])
print(f"🔥 Generation: {total_time:.2f} seconds")
print(f"🔥 {generated_tokens=}")
print(f"🔥 output tokens per second: {generated_tokens / total_time:.2f}")
if __name__ == "__main__":
main(*HfArgumentParser((Args, TokenizerConfig)).parse_args_into_dataclasses())
# this file deals with dataset pre-processing before training
# 1. PPO (prompt)
# 2. SFT (prompt + demonstration), there is also packing.
# 3. ✅ RM / DPO (chosen and rejected)
# 4. ✅ Visualization of length distributions?
# 5. ✅ Filter?
# 6. ✅ dataset_num_proc
# 7. ✅ check EOS token
# 8. dataset mixer?
# 9. ✅ pretty print that show tokenization?
# 10. ✅ hashable tokneization?
# 11. inputs / labels / attention_mask
# 12. ✅ always set a `tokenizer.pad_token_id`?
# 13. a new DataCollatorForLanguageModeling?
# 14. ✅ `add_bos_token` and `add_eos_token`? E.g., LLAMA models
# 15. ✅ generate properties: has eos_token, bos_token (through chat template)
# ✅ get tokenizer revision
# ✅ get dataset revision
# create a cached tokenized dataset, with tokenized revision, dataset revision, tokenization function name.
# too many names related to "maximum length":
# * `max_seq_length` in SFT
# * `max_length`, `max_target_length` in RM / DPO,
# * `max_prompt_length` in DPO
# TODO: note that tokenizer doesn't change but model name does change. Should be mindful of this.
"""
This file contains the utility to transform and cache datasets with different configurations.
The main things we are looking for are:
* handle dataset mixing
* handle different tokenization functions
* **cache** the tokenized dataset so we don't have to re-tokenize every time
* This is especially important when we have 405B SFT models: 32 nodes are just spending like
5 minutes to tokenize the dataset. This translates to 32 * 5 * 8 = 1280 minutes = 21 hours of
wasted H100 time.
* Sometimes we also launch on places that don't have a shared cache (e.g., GCP), so we would
download individual datasets 32 times, and wait for concatenation and tokenization (actually
twice because the `with accelerator.main_process_first()` function assumes a shared cache)
"""
import copy
import hashlib
import json
import multiprocessing
import os
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional
import torch
import transformers
from datasets import Dataset, concatenate_datasets, load_dataset
from huggingface_hub import HfApi, ModelCard, revision_exists
from rich.console import Console
from rich.text import Text
from transformers import (
AutoConfig,
AutoTokenizer,
GPTNeoXTokenizerFast,
LlamaTokenizer,
LlamaTokenizerFast,
PreTrainedTokenizer,
)
from transformers.utils.hub import cached_file, extract_commit_hash
# ----------------------------------------------------------------------------
# Utilities
def get_commit_hash(model_name_or_path: str, revision: str, filename: str = "config.json", repo_type: str = "model"):
file = cached_file(model_name_or_path, revision=revision, filename=filename, repo_type=repo_type)
commit_hash = extract_commit_hash(file, None)
return commit_hash
# Performance tuning. Some rough numbers:
APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU = 400
FILTER_EXAMPLE_PER_SECOND_PER_CPU = 1130
def get_num_proc(dataset_len: int, num_available_cpus: int, example_per_second_per_cpu) -> int:
num_required_cpus = max(1, dataset_len // example_per_second_per_cpu)
return min(num_required_cpus, num_available_cpus)
COLORS = ["on red", "on green", "on blue", "on yellow", "on magenta"]
def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer):
i = 0
console = Console()
rich_text = Text()
for i, token in enumerate(tokens):
color = COLORS[i % len(COLORS)]
decoded_token = tokenizer.decode(token)
rich_text.append(f"{decoded_token}", style=color)
console.print(rich_text)
# ----------------------------------------------------------------------------
# Tokenization
# Chat templates
# flake8: noqa
# note we added `{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}`
# because we want the template to not output eos_token if `add_generation_prompt=True`
CHAT_TEMPLATES = {
"simple_concat_with_space": (
"{% for message in messages %}"
"{{ ' ' if not loop.first else '' }}"
"{{ message['content'] }}"
"{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}"
"{% endfor %}"
),
"simple_concat_with_new_line": (
"{% for message in messages %}"
"{{ '\n' if not loop.first else '' }}"
"{{ message['content'] }}"
"{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}"
"{% endfor %}"
),
"simple_chat": (
"{% for message in messages %}"
"{{ '\n\n' if not loop.first else '' }}"
"{{ message['role'].capitalize() + ': ' + message['content'] }}"
"{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}"
"{% endfor %}"
),
"assistant_message_only": (
"{% for message in messages %}"
"{% if message['role'] == 'assistant' %}"
"{{ message['content'] }}"
"{% endif %}"
"{% endfor %}"
),
"zephyr": (
"{% for message in messages %}"
"{% if message['role'] == 'user' %}"
"{{ '<|user|>\n' + message['content'] + eos_token + '\n' }}"
"{% elif message['role'] == 'system' %}"
"{{ '<|system|>\n' + message['content'] + eos_token + '\n' }}"
"{% elif message['role'] == 'assistant' %}"
"{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}"
"{% endif %}"
"{% if loop.last and add_generation_prompt %}"
"{{ '<|assistant|>\n' }}"
"{% endif %}"
"{% endfor %}"
),
"tulu": (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"{{ '<|system|>\n' + message['content'] + '\n' }}"
"{% elif message['role'] == 'user' %}"
"{{ '<|user|>\n' + message['content'] + '\n' }}"
"{% elif message['role'] == 'assistant' %}"
"{% if not loop.last %}"
"{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}"
"{% else %}"
"{{ '<|assistant|>\n' + message['content'] + eos_token }}"
"{% endif %}"
"{% endif %}"
"{% if loop.last and add_generation_prompt %}"
"{{ '<|assistant|>\n' }}"
"{% endif %}"
"{% endfor %}"
),
# template is taken from https://arxiv.org/abs/2501.12948.
"r1_simple_chat": (
"A conversation between User and Assistant. "
"The user asks a question, and the Assistant solves it. "
"The assistant first thinks about the reasoning process in "
"the mind and then provides the user with the answer. "
"The reasoning process and answer are enclosed within <think> </think> "
"and <answer> </answer> tags, respectively, "
"i.e., <think> reasoning process here </think> "
"<answer> answer here </answer>."
"\n\n"
"{% for message in messages %}"
"{{ '\n\n' if not loop.first else '' }}"
"{{ message['role'].capitalize() + ': ' + message['content'] + '\n' }}"
"{% if loop.last and add_generation_prompt %}"
"{{ 'Assistant:' }}"
"{% endif %}"
"{% endfor %}"
),
"r1_simple_chat_postpend_think": (
"A conversation between User and Assistant. "
"The user asks a question, and the Assistant solves it. "
"The assistant first thinks about the reasoning process in "
"the mind and then provides the user with the answer. "
"The reasoning process and answer are enclosed within <think> </think> "
"and <answer> </answer> tags, respectively, "
"i.e., <think> reasoning process here </think> "
"<answer> answer here </answer>."
"\n\n"
"{% for message in messages %}"
"{{ '\n\n' if not loop.first else '' }}"
"{{ message['role'].capitalize() + ': ' + message['content'] + '\n' }}"
"{% if loop.last and add_generation_prompt %}"
"{{ 'Assistant: <think>' }}"
"{% endif %}"
"{% endfor %}"
),
"qwen_countdown": (
"<|im_start|>system\n"
"Please reason step by step, and put your final answer within <answer> </answer> tags."
"The goal here is to reach a target number by combining integers using basic arithmetic operations. "
"Write your thoughts in <think> </think> tags. "
"The answer is a series of arithmetic operations (+, -, *, /) that results in the target number. "
"Write the final answer within <answer> </answer> tags. "
"Make sure each step in the final answer is written as <answer> (number1 [+-*/] number2) [+-*/] number3 </answer>. "
"The answer should be a valid mathematical expression using only the given numbers, NOT the target number."
"<|im_end|>\n\n"
"{% for message in messages %}"
"<|im_start|>{{ message['role'] }}\n{{ message['content'] }}<|im_end|>\n"
"{% endfor %}\n\n"
"{% if add_generation_prompt %}"
"<|im_start|>assistant\nLet me solve this step by step.\n<think>"
"{% endif %}"
),
}
# flake8: noqa
def get_tokenizer_simple_v1(tc: "TokenizerConfig"):
tokenizer = AutoTokenizer.from_pretrained(
tc.model_name_or_path,
revision=tc.revision,
trust_remote_code=tc.trust_remote_code,
use_fast=tc.use_fast,
)
return tokenizer
def get_tokenizer_tulu_v1(tc: "TokenizerConfig"):
tokenizer = AutoTokenizer.from_pretrained(
tc.model_name_or_path,
revision=tc.revision,
trust_remote_code=tc.trust_remote_code,
use_fast=tc.use_fast,
)
# no default pad token for llama!
# here we add all special tokens again, because the default ones are not in the special_tokens_map
# only add if the pad token is not present already.
if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast):
num_added_tokens = tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
}
)
assert num_added_tokens in [
0,
1,
], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present."
elif isinstance(tokenizer, GPTNeoXTokenizerFast):
# OLMo newer models use this tokenizer
if tokenizer.bos_token is None:
tokenizer.bos_token = tokenizer.eos_token
assert tc.add_bos, "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence."
# else, pythia / other models
else:
num_added_tokens = tokenizer.add_special_tokens(
{
"pad_token": "<pad>",
}
)
assert (
num_added_tokens <= 1
), "GPTNeoXTokenizer should only add one special token - the pad_token (or no tokens if already set in SFT)."
# NOTE: (Costa) I just commented the `OPTForCausalLM` because we are not likely to use it.
# elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM):
# num_added_tokens = tokenizer.add_special_tokens({"unk_token": "<unk>"})
elif isinstance(tokenizer, transformers.PreTrainedTokenizerFast) and tokenizer.pad_token is None:
num_added_tokens = tokenizer.add_special_tokens({"pad_token": "<pad>"})
assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one."
# set the tokenizer chat template to the training format
# this will be used for encoding the training examples
# and saved together with the tokenizer to be used later.
if tc.chat_template_name in CHAT_TEMPLATES:
tokenizer.chat_template = CHAT_TEMPLATES[tc.chat_template_name]
else:
try:
tokenizer.chat_template = AutoTokenizer.from_pretrained(tc.model_name_or_path).chat_template
except Exception:
raise ValueError(f"Could not find chat template for {tc.model_name_or_path}.")
if tc.add_bos:
if tokenizer.chat_template.startswith("{{ bos_token }}") or (
tokenizer.bos_token is not None and tokenizer.chat_template.startswith(tokenizer.bos_token)
):
raise ValueError(
"You specified add_bos=True, but the chat template already has a bos_token at the beginning."
)
# also add bos in the chat template if not already there
tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template
return tokenizer
def get_tokenizer_tulu_v2_1(tc: "TokenizerConfig"):
tokenizer = AutoTokenizer.from_pretrained(
tc.model_name_or_path,
revision=tc.revision,
trust_remote_code=tc.trust_remote_code,
use_fast=tc.use_fast,
)
# no default pad token for llama!
# here we add all special tokens again, because the default ones are not in the special_tokens_map
# only add if the pad token is not present already, or if the current one is set to eos_token_id.
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == tokenizer.eos_token_id:
if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast):
num_added_tokens = tokenizer.add_special_tokens({"pad_token": "<pad>"})
assert num_added_tokens in [
0,
1,
], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present."
elif isinstance(tokenizer, GPTNeoXTokenizerFast):
# OLMo newer models use this tokenizer
if tokenizer.bos_token is None:
tokenizer.bos_token = tokenizer.eos_token
assert (
tc.add_bos
), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence."
# else, pythia / other models
else:
num_added_tokens = tokenizer.add_special_tokens(
{
"pad_token": "<pad>",
}
)
assert (
num_added_tokens <= 1
), "GPTNeoXTokenizer should only add one special token - the pad_token (or no tokens if already set in SFT)."
# NOTE: (Costa) I just commented the `OPTForCausalLM` because we are not likely to use it.
# elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM):
# num_added_tokens = tokenizer.add_special_tokens({"unk_token": "<unk>"})
elif isinstance(tokenizer, transformers.PreTrainedTokenizerFast):
num_added_tokens = tokenizer.add_special_tokens({"pad_token": "<pad>"})
assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one."
assert (
tokenizer.pad_token_id != tokenizer.eos_token_id
), "pad token and eos token matching causes issues in our setup."
# set the tokenizer chat template to the training format
# this will be used for encoding the training examples
# and saved together with the tokenizer to be used later.
if tc.chat_template_name in CHAT_TEMPLATES:
tokenizer.chat_template = CHAT_TEMPLATES[tc.chat_template_name]
else:
try:
tokenizer.chat_template = AutoTokenizer.from_pretrained(tc.model_name_or_path).chat_template
except Exception:
raise ValueError(f"Could not find chat template for {tc.model_name_or_path}.")
if tc.add_bos:
if tokenizer.chat_template.startswith("{{ bos_token }}") or (
tokenizer.bos_token is not None and tokenizer.chat_template.startswith(tokenizer.bos_token)
):
raise ValueError(
"You specified add_bos=True, but the chat template already has a bos_token at the beginning."
)
# also add bos in the chat template if not already there
tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template
return tokenizer
def get_tokenizer_tulu_v2_2(tc: "TokenizerConfig"):
config = AutoConfig.from_pretrained(tc.model_name_or_path, revision=tc.revision)
# @vwxyzjn: "olmo" handles both `olmo2` and `olmoe`.
if "olmo" in config.model_type:
assert tc.add_bos, "For OLMo, you must run with `--add_bos`."
assert tc.use_fast, "For OLMo, you must use fast tokenizer."
tokenizer = AutoTokenizer.from_pretrained(
tc.model_name_or_path,
revision=tc.revision,
trust_remote_code=tc.trust_remote_code,
use_fast=tc.use_fast,
)
# no default pad token for llama!
# here we add all special tokens again, because the default ones are not in the special_tokens_map
# only add if the pad token is not present already, or if the current one is set to eos_token_id.
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == tokenizer.eos_token_id:
if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast):
num_added_tokens = tokenizer.add_special_tokens({"pad_token": "<pad>"})
assert num_added_tokens in [
0,
1,
], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present."
elif isinstance(tokenizer, GPTNeoXTokenizerFast):
# OLMo newer models use this tokenizer
if tokenizer.bos_token is None:
tokenizer.bos_token = tokenizer.eos_token
assert (
tc.add_bos
), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence."
# else, pythia / other models
else:
num_added_tokens = tokenizer.add_special_tokens(
{
"pad_token": "<pad>",
}
)
assert (
num_added_tokens <= 1
), "GPTNeoXTokenizer should only add one special token - the pad_token (or no tokens if already set in SFT)."
# NOTE: (Costa) I just commented the `OPTForCausalLM` because we are not likely to use it.
# elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM):
# num_added_tokens = tokenizer.add_special_tokens({"unk_token": "<unk>"})
elif isinstance(tokenizer, transformers.PreTrainedTokenizerFast):
num_added_tokens = tokenizer.add_special_tokens({"pad_token": "<pad>"})
assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one."
assert (
tokenizer.pad_token_id != tokenizer.eos_token_id
), "pad token and eos token matching causes issues in our setup."
# set the tokenizer chat template to the training format
# this will be used for encoding the training examples
# and saved together with the tokenizer to be used later.
if tc.chat_template_name in CHAT_TEMPLATES:
tokenizer.chat_template = CHAT_TEMPLATES[tc.chat_template_name]
else:
try:
tokenizer.chat_template = AutoTokenizer.from_pretrained(tc.model_name_or_path).chat_template
except Exception:
raise ValueError(f"Could not find chat template for {tc.model_name_or_path}.")
if tc.add_bos:
if tokenizer.chat_template.startswith("{{ bos_token }}") or (
tokenizer.bos_token is not None and tokenizer.chat_template.startswith(tokenizer.bos_token)
):
raise ValueError(
"You specified add_bos=True, but the chat template already has a bos_token at the beginning."
)
# also add bos in the chat template if not already there
tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template
return tokenizer
def get_tokenizer_tulu_v2_3(tc: "TokenizerConfig"):
config = AutoConfig.from_pretrained(tc.model_name_or_path, revision=tc.revision)
# @vwxyzjn: "olmo" handles both `olmo2` and `olmoe`.
if "olmo" in config.model_type:
assert tc.add_bos, "For OLMo, you must run with `--add_bos`."
assert tc.use_fast, "For OLMo, you must use fast tokenizer."
tokenizer = AutoTokenizer.from_pretrained(
tc.model_name_or_path,
revision=tc.revision,
trust_remote_code=tc.trust_remote_code,
use_fast=tc.use_fast,
)
# no default pad token for llama!
# here we add all special tokens again, because the default ones are not in the special_tokens_map
# only add if the pad token is not present already, or if the current one is set to eos_token_id.
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == tokenizer.eos_token_id:
if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, LlamaTokenizerFast):
num_added_tokens = tokenizer.add_special_tokens({"pad_token": "<pad>"})
assert num_added_tokens in [
0,
1,
], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present."
elif isinstance(tokenizer, GPTNeoXTokenizerFast):
# OLMo newer models use this tokenizer
if tokenizer.bos_token is None:
tokenizer.bos_token = tokenizer.eos_token
assert (
tc.add_bos
), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence."
# else, pythia / other models
else:
num_added_tokens = tokenizer.add_special_tokens(
{
"pad_token": "<pad>",
}
)
assert (
num_added_tokens <= 1
), "GPTNeoXTokenizer should only add one special token - the pad_token (or no tokens if already set in SFT)."
# NOTE: (Costa) I just commented the `OPTForCausalLM` because we are not likely to use it.
# elif isinstance(tokenizer, GPT2Tokenizer) and isinstance(model, OPTForCausalLM):
# num_added_tokens = tokenizer.add_special_tokens({"unk_token": "<unk>"})
elif isinstance(tokenizer, transformers.PreTrainedTokenizerFast):
num_added_tokens = tokenizer.add_special_tokens({"pad_token": "<pad>"})
assert num_added_tokens == 1, "We detected no padding token but add_special_tokens did not add one."
assert (
tokenizer.pad_token_id != tokenizer.eos_token_id
), "pad token and eos token matching causes issues in our setup."
# set the tokenizer chat template to the training format
# this will be used for encoding the training examples
# and saved together with the tokenizer to be used later.
if tc.chat_template_name in CHAT_TEMPLATES:
tokenizer.chat_template = CHAT_TEMPLATES[tc.chat_template_name]
elif tc.chat_template_name == "default":
try:
tokenizer.chat_template = AutoTokenizer.from_pretrained(tc.model_name_or_path).chat_template
except Exception:
raise ValueError(f"Could not find chat template for the original model.")
else:
raise ValueError(f"Could not find chat template for {tc.model_name_or_path}.")
if tc.add_bos:
if tokenizer.chat_template.startswith("{{ bos_token }}") or (
tokenizer.bos_token is not None and tokenizer.chat_template.startswith(tokenizer.bos_token)
):
raise ValueError(
"You specified add_bos=True, but the chat template already has a bos_token at the beginning."
)
# also add bos in the chat template if not already there
tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template
return tokenizer
GET_TOKENIZER_FN = {
"get_tokenizer_simple_v1": get_tokenizer_simple_v1,
"get_tokenizer_tulu_v1": get_tokenizer_tulu_v1, # old version, see https://github.com/allenai/open-instruct/pull/570
"get_tokenizer_tulu_v2_1": get_tokenizer_tulu_v2_1,
"get_tokenizer_tulu_v2_2": get_tokenizer_tulu_v2_2,
"get_tokenizer_tulu_v2_3": get_tokenizer_tulu_v2_3,
}
@dataclass
class TokenizerConfig:
model_name_or_path: str
revision: str
trust_remote_code: bool = True
use_fast: bool = True
chat_template_name: Optional[str] = None # TODO: should I give an option to force override?
add_bos: bool = False
get_tokenizer_fn: str = "get_tokenizer_tulu_v2_3"
# for tracking purposes
tokenizer_commit_hash: Optional[str] = None
def __post_init__(self):
self.tokenizer_commit_hash = get_commit_hash(
self.model_name_or_path, self.revision, filename="tokenizer_config.json"
)
self.tokenizer = GET_TOKENIZER_FN[self.get_tokenizer_fn](self)
# TODO: for testing, we should load the tokenizer from the sft / dpo / rl and make sure they are all the same.
# ----------------------------------------------------------------------------
# Dataset Transformation
# SFT dataset
DEFAULT_SFT_MESSAGES_KEY = "messages"
INPUT_IDS_KEY = "input_ids"
ATTENTION_MASK_KEY = "attention_mask"
LABELS_KEY = "labels"
TOKENIZED_SFT_DATASET_KEYS = [
INPUT_IDS_KEY,
ATTENTION_MASK_KEY,
LABELS_KEY,
]
# Preference dataset
# NOTE (Costa): the `INPUT_IDS_PROMPT_KEY` is just for visualization purposes only
# also we don't really need `CHOSEN_ATTENTION_MASK_KEY` and `REJECTED_ATTENTION_MASK_KEY`
# since we are always padding from the right with a collator; however they might become
# more useful if we want to do some sort of packing in the future. The nice thing is
# that the tokenization logic would work for both DPO and RM training.
DEFAULT_CHOSEN_KEY = "chosen"
DEFAULT_REJECTED_KEY = "rejected"
CHOSEN_INPUT_IDS_KEY = "chosen_input_ids"
CHOSEN_ATTENTION_MASK_KEY = "chosen_attention_mask"
CHOSEN_LABELS_KEY = "chosen_labels"
REJECTED_INPUT_IDS_KEY = "rejected_input_ids"
REJECTED_ATTENTION_MASK_KEY = "rejected_attention_mask"
REJECTED_LABELS_KEY = "rejected_labels"
INPUT_IDS_PROMPT_KEY = "input_ids_prompt"
ATTENTION_MASK_PROMPT_KEY = "attention_mask_prompt"
GROUND_TRUTHS_KEY = "ground_truth"
DATASET_SOURCE_KEY = "dataset"
TOKENIZED_PREFERENCE_DATASET_KEYS = [
CHOSEN_INPUT_IDS_KEY,
CHOSEN_LABELS_KEY,
CHOSEN_ATTENTION_MASK_KEY,
REJECTED_INPUT_IDS_KEY,
REJECTED_LABELS_KEY,
REJECTED_ATTENTION_MASK_KEY,
]
# TODO: allow passing in sft_message key, so we can train on "chosen" of pref dataset.
def sft_tokenize_v1(
row: Dict[str, Any], tokenizer: PreTrainedTokenizer, sft_messages_key: str = DEFAULT_SFT_MESSAGES_KEY
):
if len(row[sft_messages_key]) == 1:
prompt = row[sft_messages_key]
else:
prompt = row[sft_messages_key][:-1]
row[INPUT_IDS_PROMPT_KEY] = tokenizer.apply_chat_template(
prompt,
add_generation_prompt=True,
)
row[INPUT_IDS_KEY] = tokenizer.apply_chat_template(row[sft_messages_key])
row[ATTENTION_MASK_KEY] = [1] * len(row[INPUT_IDS_KEY])
labels = copy.deepcopy(row[INPUT_IDS_KEY])
row[LABELS_KEY] = labels
return row
def sft_tokenize_mask_out_prompt_v1(
row: Dict[str, Any], tokenizer: PreTrainedTokenizer, sft_messages_key: str = DEFAULT_SFT_MESSAGES_KEY
):
"""mask out the prompt tokens by manipulating labels"""
if len(row[sft_messages_key]) == 1:
prompt = row[sft_messages_key]
else:
prompt = row[sft_messages_key][:-1]
row[INPUT_IDS_PROMPT_KEY] = tokenizer.apply_chat_template(
prompt,
add_generation_prompt=True,
)
row[INPUT_IDS_KEY] = tokenizer.apply_chat_template(row[sft_messages_key])
row[ATTENTION_MASK_KEY] = [1] * len(row[INPUT_IDS_KEY])
labels = copy.deepcopy(row[INPUT_IDS_KEY])
labels[: len(row[INPUT_IDS_PROMPT_KEY])] = [-100] * len(row[INPUT_IDS_PROMPT_KEY])
row[LABELS_KEY] = labels
return row
def sft_filter_v1(
row: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
max_prompt_token_length: Optional[int] = None,
max_token_length: Optional[int] = None,
need_contain_labels: bool = True,
):
max_prompt_token_length_ok = True
if max_prompt_token_length is not None:
max_prompt_token_length_ok = len(row[INPUT_IDS_PROMPT_KEY]) <= max_prompt_token_length
max_token_length_ok = True
if max_token_length is not None:
max_token_length_ok = len(row[INPUT_IDS_KEY]) <= max_token_length
contain_some_labels = any(x != -100 for x in row[LABELS_KEY])
return max_prompt_token_length_ok and max_token_length_ok and (contain_some_labels or not need_contain_labels)
def sft_tulu_tokenize_and_truncate_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int):
"""taken directly from https://github.com/allenai/open-instruct/blob/ba11286e5b9eb00d4ce5b40ef4cac1389888416a/open_instruct/finetune.py#L385"""
messages = row["messages"]
if len(messages) == 0:
raise ValueError("messages field is empty.")
input_ids = tokenizer.apply_chat_template(
conversation=messages,
tokenize=True,
return_tensors="pt",
padding=False,
truncation=True,
max_length=max_seq_length,
add_generation_prompt=False,
)
labels = input_ids.clone()
# mask the non-assistant part for avoiding loss
for message_idx, message in enumerate(messages):
if message["role"] != "assistant":
# we calculate the start index of this non-assistant message
if message_idx == 0:
message_start_idx = 0
else:
message_start_idx = tokenizer.apply_chat_template(
conversation=messages[:message_idx], # here marks the end of the previous messages
tokenize=True,
return_tensors="pt",
padding=False,
truncation=True,
max_length=max_seq_length,
add_generation_prompt=False,
).shape[1]
# next, we calculate the end index of this non-assistant message
if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant":
# for intermediate messages that follow with an assistant message, we need to
# set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss
# (e.g., `<|assistant|>`)
message_end_idx = tokenizer.apply_chat_template(
conversation=messages[: message_idx + 1],
tokenize=True,
return_tensors="pt",
padding=False,
truncation=True,
max_length=max_seq_length,
add_generation_prompt=True,
).shape[1]
else:
# for the last message or the message that doesn't follow with an assistant message,
# we don't need to add the assistant generation prefix
message_end_idx = tokenizer.apply_chat_template(
conversation=messages[: message_idx + 1],
tokenize=True,
return_tensors="pt",
padding=False,
truncation=True,
max_length=max_seq_length,
add_generation_prompt=False,
).shape[1]
# set the label to -100 for the non-assistant part
labels[:, message_start_idx:message_end_idx] = -100
if max_seq_length and message_end_idx >= max_seq_length:
break
attention_mask = torch.ones_like(input_ids)
row[INPUT_IDS_KEY] = input_ids.flatten()
row[LABELS_KEY] = labels.flatten()
row[ATTENTION_MASK_KEY] = attention_mask.flatten()
return row
def sft_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer):
return any(x != -100 for x in row[LABELS_KEY])
def preference_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer):
# Extract prompt (all messages except the last one)
prompt = row["chosen"][:-1]
# Tokenize prompt
row[INPUT_IDS_PROMPT_KEY] = tokenizer.apply_chat_template(
prompt,
add_generation_prompt=True,
)
row[ATTENTION_MASK_PROMPT_KEY] = [1] * len(row[INPUT_IDS_PROMPT_KEY])
# Tokenize chosen completion
row[CHOSEN_INPUT_IDS_KEY] = tokenizer.apply_chat_template(row["chosen"])
row[CHOSEN_ATTENTION_MASK_KEY] = [1] * len(row[CHOSEN_INPUT_IDS_KEY])
# Tokenize rejected completion
row[REJECTED_INPUT_IDS_KEY] = tokenizer.apply_chat_template(row["rejected"])
row[REJECTED_ATTENTION_MASK_KEY] = [1] * len(row[REJECTED_INPUT_IDS_KEY])
return row
def preference_filter_v1(
row: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
max_prompt_token_length: Optional[int] = None,
max_token_length: Optional[int] = None,
):
# Check prompt length if specified
if max_prompt_token_length is not None:
if len(row[INPUT_IDS_PROMPT_KEY]) > max_prompt_token_length:
return False
# Check total sequence lengths if specified
if max_token_length is not None:
if len(row[CHOSEN_INPUT_IDS_KEY]) > max_token_length:
return False
if len(row[REJECTED_INPUT_IDS_KEY]) > max_token_length:
return False
return True
def preference_tulu_tokenize_and_truncate_v1(
row: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
max_seq_length: int,
chosen_key: str = DEFAULT_CHOSEN_KEY,
rejected_key: str = DEFAULT_REJECTED_KEY,
):
"""
Here we assume each example has a rejected and chosen field, both of which are a list of messages.
Each message is a dict with 'role' and 'content' fields.
We assume only the last message is different, and the prompt is contained in the list of messages.
"""
chosen_messages = row[chosen_key]
rejected_messages = row[rejected_key]
if len(chosen_messages) == 0:
raise ValueError("chosen messages field is empty.")
if len(rejected_messages) == 0:
raise ValueError("rejected messages field is empty.")
chosen_encoded = sft_tulu_tokenize_and_truncate_v1(
{DEFAULT_SFT_MESSAGES_KEY: chosen_messages}, tokenizer, max_seq_length
)
rejected_encoded = sft_tulu_tokenize_and_truncate_v1(
{DEFAULT_SFT_MESSAGES_KEY: rejected_messages}, tokenizer, max_seq_length
)
return {
CHOSEN_INPUT_IDS_KEY: chosen_encoded["input_ids"],
CHOSEN_LABELS_KEY: chosen_encoded["labels"],
CHOSEN_ATTENTION_MASK_KEY: chosen_encoded["attention_mask"],
REJECTED_INPUT_IDS_KEY: rejected_encoded["input_ids"],
REJECTED_LABELS_KEY: rejected_encoded["labels"],
REJECTED_ATTENTION_MASK_KEY: rejected_encoded["attention_mask"],
}
def preference_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer):
return any(x != -100 for x in row[CHOSEN_LABELS_KEY]) and any(x != -100 for x in row[REJECTED_LABELS_KEY])
def rlvr_tokenize_v1(
row: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
sft_messages_key: str = DEFAULT_SFT_MESSAGES_KEY,
ground_truths_key: str = GROUND_TRUTHS_KEY,
dataset_source_key: str = DATASET_SOURCE_KEY,
):
if len(row[sft_messages_key]) == 1:
prompt = row[sft_messages_key]
else:
prompt = row[sft_messages_key][:-1]
row[INPUT_IDS_PROMPT_KEY] = tokenizer.apply_chat_template(
prompt,
add_generation_prompt=True,
)
row[INPUT_IDS_KEY] = tokenizer.apply_chat_template(row[sft_messages_key])
row[ATTENTION_MASK_KEY] = [1] * len(row[INPUT_IDS_KEY])
labels = copy.deepcopy(row[INPUT_IDS_KEY])
row[LABELS_KEY] = labels
row[GROUND_TRUTHS_KEY] = row[ground_truths_key]
row[DATASET_SOURCE_KEY] = row[dataset_source_key]
return row
def rlvr_filter_v1(
row: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
need_contain_labels: bool = True,
max_prompt_token_length: Optional[int] = None,
max_token_length: Optional[int] = None,
):
max_prompt_token_length_ok = True
if max_prompt_token_length is not None:
max_prompt_token_length_ok = len(row[INPUT_IDS_PROMPT_KEY]) <= max_prompt_token_length
max_token_length_ok = True
if max_token_length is not None:
max_token_length_ok = len(row[INPUT_IDS_KEY]) <= max_token_length
contain_some_labels = any(x != -100 for x in row[LABELS_KEY])
return max_prompt_token_length_ok and max_token_length_ok and (contain_some_labels or not need_contain_labels)
TRANSFORM_FNS = {
"sft_tokenize_v1": (sft_tokenize_v1, "map"),
"sft_tokenize_mask_out_prompt_v1": (sft_tokenize_mask_out_prompt_v1, "map"),
"sft_filter_v1": (sft_filter_v1, "filter"),
"sft_tulu_tokenize_and_truncate_v1": (sft_tulu_tokenize_and_truncate_v1, "map"),
"sft_tulu_filter_v1": (sft_tulu_filter_v1, "filter"),
"preference_tokenize_v1": (preference_tokenize_v1, "map"),
"preference_filter_v1": (preference_filter_v1, "filter"),
"preference_tulu_tokenize_and_truncate_v1": (preference_tulu_tokenize_and_truncate_v1, "map"),
"preference_tulu_filter_v1": (preference_tulu_filter_v1, "filter"),
"rlvr_tokenize_v1": (rlvr_tokenize_v1, "map"),
"rlvr_filter_v1": (rlvr_filter_v1, "filter"),
}
# ----------------------------------------------------------------------------
# Dataset Configuration and Caching
@dataclass
class DatasetConfig:
dataset_name: str
dataset_split: str
dataset_revision: str
dataset_range: Optional[int] = None
transform_fn: List[str] = field(default_factory=list)
transform_fn_args: Dict[str, Dict[str, Any]] = field(default_factory=dict)
# for tracking purposes
dataset_commit_hash: Optional[str] = None
def __post_init__(self):
# if the file exists locally, use the local file
if os.path.exists(self.dataset_name) and self.dataset_name.endswith(".jsonl"):
assert self.dataset_split == "train", "Only train split is supported for local jsonl files."
self.dataset = load_dataset(
"json",
data_files=self.dataset_name,
split=self.dataset_split,
)
else:
# commit hash only works for hf datasets
self.dataset_commit_hash = get_commit_hash(
self.dataset_name, self.dataset_revision, "README.md", "dataset"
)
self.dataset = load_dataset(
self.dataset_name,
split=self.dataset_split,
revision=self.dataset_revision,
)
if self.dataset_range is None:
dataset_range = len(self.dataset)
self.update_range(dataset_range)
def update_range(self, dataset_range: int):
self.dataset_range = dataset_range
if self.dataset_range > len(self.dataset):
raise ValueError("Dataset range exceeds dataset length")
self.dataset = self.dataset.select(range(self.dataset_range))
def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig):
# beaker specific logic; we may get assigned 15.5 CPU, so we convert it to float then int
num_proc = int(float(os.environ.get("BEAKER_ASSIGNED_CPU_COUNT", multiprocessing.cpu_count())))
tokenizer = tc.tokenizer
dataset = dc.dataset
for fn_name in dc.transform_fn:
fn, fn_type = TRANSFORM_FNS[fn_name]
# always pass in tokenizer and other args if needed
fn_kwargs = {"tokenizer": tokenizer}
target_columns = dataset.column_names
if fn_name in dc.transform_fn_args:
target_columns = dc.transform_fn_args[fn_name].pop("target_columns", dataset.column_names)
fn_kwargs.update(dc.transform_fn_args[fn_name])
# perform the transformation
if fn_type == "map":
dataset = dataset.map(
fn,
fn_kwargs=fn_kwargs,
remove_columns=[col for col in dataset.column_names if col not in target_columns],
num_proc=get_num_proc(len(dataset), num_proc, APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU),
)
elif fn_type == "filter":
dataset = dataset.filter(
fn,
fn_kwargs=fn_kwargs,
num_proc=get_num_proc(len(dataset), num_proc, FILTER_EXAMPLE_PER_SECOND_PER_CPU),
)
# NOTE: elif we can implement packing here to create a packed SFT dataset. Low priority for now.
else:
raise ValueError(f"Unknown transform function type: {fn_type}")
if len(dataset) == 0:
raise ValueError("No examples left after transformation")
return dataset
class DatasetTransformationCache:
def __init__(self, hf_entity: Optional[str] = None):
self.hf_entity = hf_entity or HfApi().whoami()["name"]
def compute_config_hash(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> str:
"""Compute a deterministic hash of both configs for caching."""
dc_dicts = [{k: v for k, v in asdict(dc).items() if v is not None} for dc in dcs]
tc_dict = {k: v for k, v in asdict(tc).items() if v is not None}
combined_dict = {"dataset_configs": dc_dicts, "tokenizer_config": tc_dict}
config_str = json.dumps(combined_dict, sort_keys=True)
return hashlib.sha256(config_str.encode()).hexdigest()[:10]
def load_or_transform_dataset(self, dcs: List[DatasetConfig], tc: TokenizerConfig) -> Dataset:
"""Load dataset from cache if it exists, otherwise transform and cache it."""
config_hash = self.compute_config_hash(dcs, tc)
repo_name = f"{self.hf_entity}/dataset-mix-cached"
# NOTE: the cached dataset is always train split
DEFAULT_SPLIT_FOR_CACHED_DATASET = "train"
# Check if the revision exists
if revision_exists(repo_name, config_hash, repo_type="dataset"):
print(f"✅ Found cached dataset at https://huggingface.co/datasets/{repo_name}/tree/{config_hash}")
# Use the split from the first dataset config as default
return load_dataset(repo_name, split=DEFAULT_SPLIT_FOR_CACHED_DATASET, revision=config_hash)
print(f"Cache not found, transforming datasets...")
# Transform each dataset
transformed_datasets = []
for dc in dcs:
dataset = get_dataset_v1(dc, tc)
transformed_datasets.append(dataset)
# Combine datasets
combined_dataset = concatenate_datasets(transformed_datasets)
# Push to hub with config hash as revision
combined_dataset.push_to_hub(
repo_name,
private=True,
revision=config_hash,
commit_message=f"Cache combined dataset with configs hash: {config_hash}",
)
print(f"🚀 Pushed transformed dataset to https://huggingface.co/datasets/{repo_name}/tree/{config_hash}")
model_card = ModelCard(
f"""\
---
tags: [open-instruct]
---
# Cached Tokenized Datasets
## Summary
This is a cached dataset produced by https://github.com/allenai/open-instruct
## Configuration
`TokenizerConfig`:
```json
{json.dumps(asdict(tc), indent=2)}
```
`List[DatasetConfig]`:
```json
{json.dumps([asdict(dc) for dc in dcs], indent=2)}
```
"""
)
model_card.push_to_hub(repo_name, repo_type="dataset", revision=config_hash)
# NOTE: Load the dataset again to make sure it's downloaded to the HF cache
print(f"✅ Found cached dataset at https://huggingface.co/datasets/{repo_name}/tree/{config_hash}")
return load_dataset(repo_name, split=DEFAULT_SPLIT_FOR_CACHED_DATASET, revision=config_hash)
def get_cached_dataset(dcs: List[DatasetConfig], tc: TokenizerConfig, hf_entity: Optional[str] = None) -> Dataset:
cache = DatasetTransformationCache(hf_entity=hf_entity)
return cache.load_or_transform_dataset(dcs, tc)
def get_cached_dataset_tulu_sft(
dataset_mixer_list: List[str],
tc: TokenizerConfig,
max_seq_length: int,
hf_entity: Optional[str] = None,
) -> Dataset:
dcs = []
assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}"
for i in range(0, len(dataset_mixer_list), 2):
dataset_name = dataset_mixer_list[i]
frac_or_num_samples = dataset_mixer_list[i + 1]
if "." in frac_or_num_samples:
frac_or_num_samples = float(frac_or_num_samples)
else:
frac_or_num_samples = int(frac_or_num_samples)
dataset_config = DatasetConfig(
dataset_name=dataset_name,
dataset_split="train",
dataset_revision="main",
transform_fn=["sft_tulu_tokenize_and_truncate_v1", "sft_tulu_filter_v1"],
transform_fn_args={
"sft_tulu_tokenize_and_truncate_v1": {
"max_seq_length": max_seq_length,
"target_columns": TOKENIZED_SFT_DATASET_KEYS,
}
},
)
if frac_or_num_samples > 1.0:
new_range = int(frac_or_num_samples)
else:
new_range = int(frac_or_num_samples * len(dataset_config.dataset))
dataset_config.update_range(new_range)
dcs.append(dataset_config)
cache = DatasetTransformationCache(hf_entity=hf_entity)
return cache.load_or_transform_dataset(dcs, tc)
def get_cached_dataset_tulu_preference(
dataset_mixer_list: List[str], tc: TokenizerConfig, max_seq_length: int, hf_entity: Optional[str] = None
) -> Dataset:
dcs = []
assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}"
for i in range(0, len(dataset_mixer_list), 2):
dataset_name = dataset_mixer_list[i]
frac_or_num_samples = dataset_mixer_list[i + 1]
if "." in frac_or_num_samples:
frac_or_num_samples = float(frac_or_num_samples)
else:
frac_or_num_samples = int(frac_or_num_samples)
dataset_config = DatasetConfig(
dataset_name=dataset_name,
dataset_split="train",
dataset_revision="main",
transform_fn=["preference_tulu_tokenize_and_truncate_v1", "preference_tulu_filter_v1"],
transform_fn_args={
"preference_tulu_tokenize_and_truncate_v1": {
"max_seq_length": max_seq_length,
"target_columns": TOKENIZED_PREFERENCE_DATASET_KEYS,
}
},
)
if frac_or_num_samples > 1.0:
new_range = int(frac_or_num_samples)
else:
new_range = int(frac_or_num_samples * len(dataset_config.dataset))
dataset_config.update_range(new_range)
dcs.append(dataset_config)
cache = DatasetTransformationCache(hf_entity=hf_entity)
return cache.load_or_transform_dataset(dcs, tc)
def get_cached_dataset_rlvr(
dataset_mixer_list: List[str],
dataset_mixer_list_splits: List[str],
tc: TokenizerConfig,
max_token_length: Optional[int] = None,
max_prompt_token_length: Optional[int] = None,
hf_entity: Optional[str] = None,
) -> Dataset:
if len(dataset_mixer_list_splits) == 1:
print("by default, we will use the same split for all datasets")
dataset_mixer_list_splits = [dataset_mixer_list_splits[0]] * len(dataset_mixer_list)
dcs = []
assert len(dataset_mixer_list) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer_list}"
for i in range(0, len(dataset_mixer_list), 2):
dataset_name = dataset_mixer_list[i]
frac_or_num_samples = dataset_mixer_list[i + 1]
if "." in frac_or_num_samples:
frac_or_num_samples = float(frac_or_num_samples)
else:
frac_or_num_samples = int(frac_or_num_samples)
dataset_config = DatasetConfig(
dataset_name=dataset_name,
dataset_split=dataset_mixer_list_splits[i],
dataset_revision="main",
transform_fn=["rlvr_tokenize_v1", "rlvr_filter_v1"],
transform_fn_args={
"rlvr_filter_v1": {
"max_token_length": max_token_length,
"max_prompt_token_length": max_prompt_token_length,
}
},
)
if frac_or_num_samples > 1.0:
new_range = int(frac_or_num_samples)
else:
new_range = int(frac_or_num_samples * len(dataset_config.dataset))
dataset_config.update_range(new_range)
dcs.append(dataset_config)
cache = DatasetTransformationCache(hf_entity=hf_entity)
return cache.load_or_transform_dataset(dcs, tc)
def test_sft_dpo_same_tokenizer():
base_to_sft_tc = TokenizerConfig(
model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu"
)
sft_to_dpo_tc = TokenizerConfig(
model_name_or_path="allenai/Llama-3.1-Tulu-3-8B-SFT", revision="main", chat_template_name="tulu"
)
dpo_to_rl_tc = TokenizerConfig(
model_name_or_path="allenai/Llama-3.1-Tulu-3-8B-DPO", revision="main", chat_template_name="tulu"
)
def equal_tokenizer(tc1, tc2):
tok1 = tc1.tokenizer
tok2 = tc2.tokenizer
assert tok1.vocab_size == tok2.vocab_size, "Vocab size should be the same"
assert tok1.model_max_length == tok2.model_max_length, "Model max length should be the same"
assert tok1.is_fast == tok2.is_fast, "is_fast should be the same"
assert tok1.padding_side == tok2.padding_side, "padding_side should be the same"
assert tok1.truncation_side == tok2.truncation_side, "truncation_side should be the same"
assert (
tok1.clean_up_tokenization_spaces == tok2.clean_up_tokenization_spaces
), "clean_up_tokenization_spaces should be the same"
assert tok1.added_tokens_decoder == tok2.added_tokens_decoder, "added_tokens_decoder should be the same"
equal_tokenizer(base_to_sft_tc, sft_to_dpo_tc)
equal_tokenizer(sft_to_dpo_tc, dpo_to_rl_tc)
equal_tokenizer(base_to_sft_tc, dpo_to_rl_tc)
def test_config_hash_different():
"""Test that different configurations produce different hashes."""
tc = TokenizerConfig(model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu")
dcs1 = [
DatasetConfig(
dataset_name="allenai/tulu-3-sft-personas-algebra",
dataset_split="train",
dataset_revision="main",
transform_fn=["sft_tokenize_v1"],
transform_fn_args={},
)
]
dcs2 = [
DatasetConfig(
dataset_name="allenai/tulu-3-sft-personas-algebra",
dataset_split="train",
dataset_revision="main",
transform_fn=["sft_tokenize_mask_out_prompt_v1"],
transform_fn_args={},
)
]
cache = DatasetTransformationCache()
hash1 = cache.compute_config_hash(dcs1, tc)
hash2 = cache.compute_config_hash(dcs2, tc)
assert hash1 != hash2, "Different configs should have different hashes"
def test_sft_dataset_caching():
"""Test caching functionality for SFT datasets."""
tc = TokenizerConfig(model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu")
dcs = [
DatasetConfig(
dataset_name="allenai/tulu-3-sft-personas-algebra",
dataset_split="train",
dataset_revision="main",
transform_fn=["sft_tokenize_v1"],
transform_fn_args={},
),
DatasetConfig(
dataset_name="allenai/tulu-3-hard-coded-10x",
dataset_split="train",
dataset_revision="main",
transform_fn=["sft_tokenize_v1"],
transform_fn_args={},
),
]
# First transformation should cache
dataset1 = get_cached_dataset(dcs, tc)
# Second load should use cache
dataset1_cached = get_cached_dataset(dcs, tc)
# Verify the datasets are the same
assert len(dataset1) == len(dataset1_cached), "Cached dataset should have same length"
def test_sft_different_transform():
"""Test different transform functions produce different cached datasets."""
tc = TokenizerConfig(model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu")
dcs = [
DatasetConfig(
dataset_name="allenai/tulu-3-sft-personas-algebra",
dataset_split="train",
dataset_revision="main",
transform_fn=["sft_tokenize_mask_out_prompt_v1"],
transform_fn_args={},
),
DatasetConfig(
dataset_name="allenai/tulu-3-hard-coded-10x",
dataset_split="train",
dataset_revision="main",
transform_fn=["sft_tokenize_mask_out_prompt_v1"],
transform_fn_args={},
),
]
dataset = get_cached_dataset(dcs, tc)
assert dataset is not None, "Should successfully create dataset with different transform"
def test_sft_filter():
"""Test different transform functions produce different cached datasets."""
tc = TokenizerConfig(model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu")
ARBITRARY_MAX_LENGTH = 1000
dcs = [
DatasetConfig(
dataset_name="allenai/tulu-3-sft-personas-algebra",
dataset_split="train",
dataset_revision="main",
transform_fn=["sft_tokenize_v1", "sft_filter_v1"], # First tokenize, then filter
transform_fn_args={
"sft_filter_v1": {
"max_token_length": ARBITRARY_MAX_LENGTH # Filter to sequences <= ARBITRARY_MAX_LENGTH tokens
}
},
)
]
filtered_dataset = get_cached_dataset(dcs, tc)
# Verify that all sequences are <= ARBITRARY_MAX_LENGTH tokens
max_length = max(len(example[INPUT_IDS_KEY]) for example in filtered_dataset)
assert max_length <= ARBITRARY_MAX_LENGTH, f"Found sequence with length {max_length} > {ARBITRARY_MAX_LENGTH}"
print("Filter test passed! Max sequence length:", max_length)
print("All tests passed!")
assert filtered_dataset is not None, "Should successfully create dataset with different transform"
def test_preference_dataset():
"""Test caching functionality for preference datasets."""
tc = TokenizerConfig(model_name_or_path="meta-llama/Llama-3.1-8B", revision="main", chat_template_name="tulu")
dcs_pref = [
DatasetConfig(
dataset_name="allenai/tulu-3-pref-personas-instruction-following",
dataset_split="train",
dataset_revision="main",
transform_fn=["preference_tokenize_v1"],
transform_fn_args={},
),
DatasetConfig(
dataset_name="allenai/tulu-3-wildchat-reused-on-policy-70b",
dataset_split="train",
dataset_revision="main",
transform_fn=["preference_tokenize_v1"],
transform_fn_args={},
),
]
dataset_pref = get_cached_dataset(dcs_pref, tc)
assert dataset_pref is not None, "Should successfully create preference dataset"
if __name__ == "__main__":
test_sft_dpo_same_tokenizer()
test_config_hash_different()
test_sft_dataset_caching()
test_sft_different_transform()
test_preference_dataset()
test_sft_filter()
print("All tests passed!")
# Taken and modified from https://github.com/huggingface/trl
# Copyright 2024 The AllenAI Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file is copied from https://github.com/OpenRLHF/OpenRLHF"""
from datetime import timedelta
import queue
import threading
from typing import Any, Optional, Union
import ray
import torch
import torch.distributed
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
Store,
_new_process_group_helper,
_world,
default_pg_timeout,
rendezvous,
)
# monkey patch for olmo2 32B
from vllm.model_executor.models.olmo2 import AttentionMetadata, Olmo2Attention
from vllm.worker.worker import Worker
def custom_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
# q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
Olmo2Attention.forward = custom_forward
# Copy from pytorch to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
def init_process_group(
backend: Union[str, Backend] = None,
init_method: Optional[str] = None,
timeout: Optional[timedelta] = None,
world_size: int = -1,
rank: int = -1,
store: Optional[Store] = None,
group_name: str = None,
pg_options: Optional[Any] = None,
):
assert (store is None) or (init_method is None), "Cannot specify both init_method and store."
if store is not None:
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"
if backend:
backend = Backend(backend)
else:
backend = Backend("undefined")
if timeout is None:
timeout = default_pg_timeout
# backward compatible API
if store is None:
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
store = PrefixStore(group_name, store)
pg, _ = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
group_name=group_name,
pg_options=pg_options,
timeout=timeout,
)
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
return pg
class WorkerWrap(Worker):
def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl"):
"""Init torch process group for model weights update"""
assert torch.distributed.is_initialized(), "default torch process group must be initialized"
assert group_name != "", "group name must not be empty"
rank = torch.distributed.get_rank() + rank_offset
self._model_update_group = init_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=rank,
group_name=group_name,
)
print(
f"init_process_group: master_address={master_address}, master_port={master_port}, ",
f"rank={rank}, world_size={world_size}, group_name={group_name}",
)
def update_weight(self, name, dtype, shape, empty_cache=False):
"""Broadcast weight to all vllm workers from source rank 0 (actor model)"""
# print(f"update_weight: {name}, dtype: {dtype}, shape: {shape}, rank: {torch.distributed.get_rank()}, world_size: {torch.distributed.get_world_size()}")
# if torch.distributed.get_rank() == 0:
# print(f"update weight: {name}, dtype: {dtype}, shape: {shape}")
assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}"
weight = torch.empty(shape, dtype=dtype, device="cuda")
torch.distributed.broadcast(weight, 0, group=self._model_update_group)
self.model_runner.model.load_weights(weights=[(name, weight)])
del weight
# TODO: should we empty cache if all weights have updated?
# if empty_cache:
# torch.cuda.empty_cache()
@ray.remote
class LLMRayActor:
def __init__(self, *args, **kwargs):
import vllm
self.__version__ = vllm.__version__
assert self.__version__ >= "0.4.1", "OpenRLHF only supports vLLM >= 0.4.1"
self.use_gpu_executor = kwargs["tensor_parallel_size"] == 1
# See https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py
if self.use_gpu_executor:
vllm.worker.worker.Worker = WorkerWrap
else:
# RayGPUExecutor
# See the patch https://github.com/vllm-project/vllm/commit/479d69fad0538f04cb22bf13e76ff91cfeb8a4e5
kwargs["worker_use_ray"] = True
if vllm.__version__ > "0.4.1":
RayWorkerWrapperPath = vllm.executor.ray_utils
else:
RayWorkerWrapperPath = vllm.engine.ray_utils
# patch for newer vllm from openrlhf:
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/trainer/ray/vllm_engine.py#L40
if vllm.__version__ > "0.6.4.post1":
# https://github.com/vllm-project/vllm/pull/10555
kwargs["worker_cls"] = "open_instruct.vllm_utils2.WorkerWrap"
else:
RayWorkerWrapperPath = vllm.executor.ray_utils
class RayWorkerWrapper(RayWorkerWrapperPath.RayWorkerWrapper):
def __init__(self, *args, **kwargs) -> None:
kwargs["worker_module_name"] = "open_instruct.vllm_utils2"
kwargs["worker_class_name"] = "WorkerWrap"
super().__init__(*args, **kwargs)
RayWorkerWrapperPath.RayWorkerWrapper = RayWorkerWrapper
self.llm = vllm.LLM(*args, **kwargs)
self.input_queue = queue.Queue()
self.output_queue = queue.Queue()
self.busy_generate_th = threading.Thread(target=self.busy_generate)
self.busy_generate_th.start()
def ping(self):
return "pong"
def generate(self, *args, **kwargs):
return self.llm.generate(*args, **kwargs)
def output_queue_get(self, *args, **kwargs):
return self.output_queue.get(*args, **kwargs)
def input_queue_put(self, *args, **kwargs):
return self.input_queue.put(*args, **kwargs)
def busy_generate(self):
while True:
item = self.input_queue.get()
sampling_params, per_engine_data = item
prompt_token_ids = [prompt_and_response.prompt_and_response for prompt_and_response in per_engine_data]
outputs = self.llm.generate(
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
)
for i in range(len(per_engine_data)):
prompt_and_response = per_engine_data[i]
tokens = outputs[i].outputs[0].token_ids
finish_reason = outputs[i].outputs[0].finish_reason
prompt_and_response.prompt_and_response.extend(list(tokens))
prompt_and_response.response_length += len(tokens)
if finish_reason == "stop":
prompt_and_response.finish_reason = finish_reason
self.output_queue.put(per_engine_data)
def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend):
if self.use_gpu_executor:
return self.llm.llm_engine.model_executor.driver_worker.init_process_group(
master_address, master_port, rank_offset, world_size, group_name, backend
)
else:
return self.llm.llm_engine.model_executor._run_workers(
"init_process_group", master_address, master_port, rank_offset, world_size, group_name, backend
)
def update_weight(self, name, dtype, shape, empty_cache=False):
self.stop_remote_worker_execution_loop()
if self.use_gpu_executor:
return self.llm.llm_engine.model_executor.driver_worker.update_weight(name, dtype, shape, empty_cache)
else:
return self.llm.llm_engine.model_executor._run_workers("update_weight", name, dtype, shape, empty_cache)
def stop_remote_worker_execution_loop(self):
# Fix error for using 2 communication group
# https://github.com/vllm-project/vllm/commit/eb6d3c264d0cd8e44dec16bca7947fbe96415ce9#diff-e1ad69e38e033accddfa5480ec808c4740eb39244d1ef51cc3407e20dde8cfd4
if self.__version__ > "0.4.2":
self.llm.llm_engine.model_executor.stop_remote_worker_execution_loop()
def create_vllm_engines(
num_engines: int,
tensor_parallel_size: int,
enforce_eager: bool,
pretrain: str,
revision: str,
seed: int,
enable_prefix_caching: bool,
max_model_len: int,
vllm_gpu_memory_utilization: float = 0.9,
single_gpu_mode: bool = False,
pg: Optional[ray.util.placement_group] = None,
):
vllm_engines = []
for i in range(num_engines):
# When tensor_parallel_size=1, vLLM init model in LLMEngine directly, assign 1 GPU for it.
num_gpus = int(tensor_parallel_size == 1)
scheduling_strategy = None
if pg is not None:
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group_capture_child_tasks=True
)
elif tensor_parallel_size > 1:
bundles = [{"GPU": 1, "CPU": 4}] * tensor_parallel_size
pg = placement_group(bundles)
ray.get(pg.ready())
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=0
)
print(f"vllm: {num_gpus=}, {num_engines=}")
vllm_engines.append(
LLMRayActor.options(
num_cpus=4,
num_gpus=0.48 if single_gpu_mode else num_gpus,
scheduling_strategy=scheduling_strategy,
).remote(
pretrain,
revision=revision,
tokenizer_revision=revision,
trust_remote_code=True,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
dtype="bfloat16",
seed=seed + i,
enable_prefix_caching=enable_prefix_caching,
max_model_len=max_model_len,
gpu_memory_utilization=vllm_gpu_memory_utilization,
)
)
return vllm_engines
if __name__ == "__main__":
llm = LLMRayActor.remote("Qwen/Qwen2.5-7B", tensor_parallel_size=2)
output = ray.get(llm.generate.remote("San Franciso is a"))
print(f"output: {output}")
# Taken and modified from https://github.com/huggingface/trl
# Copyright 2024 The AllenAI Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file is copied from https://github.com/OpenRLHF/OpenRLHF"""
from datetime import timedelta
import queue
import threading
import time
from typing import Any, Optional, Union
import ray
import torch
import torch.distributed
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
Store,
_new_process_group_helper,
_world,
default_pg_timeout,
rendezvous,
)
# monkey patch for olmo2 32B
from vllm.model_executor.models.olmo2 import AttentionMetadata, Olmo2Attention
from vllm.worker.worker import Worker
from vllm.utils import Counter
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput,
ScoringRequestOutput)
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
def custom_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
# q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
Olmo2Attention.forward = custom_forward
# Copy from pytorch to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
def init_process_group(
backend: Union[str, Backend] = None,
init_method: Optional[str] = None,
timeout: Optional[timedelta] = None,
world_size: int = -1,
rank: int = -1,
store: Optional[Store] = None,
group_name: str = None,
pg_options: Optional[Any] = None,
):
assert (store is None) or (init_method is None), "Cannot specify both init_method and store."
if store is not None:
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"
if backend:
backend = Backend(backend)
else:
backend = Backend("undefined")
if timeout is None:
timeout = default_pg_timeout
# backward compatible API
if store is None:
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
store = PrefixStore(group_name, store)
pg, _ = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
group_name=group_name,
pg_options=pg_options,
timeout=timeout,
)
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
return pg
class WorkerWrap(Worker):
def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl"):
"""Init torch process group for model weights update"""
assert torch.distributed.is_initialized(), "default torch process group must be initialized"
assert group_name != "", "group name must not be empty"
rank = torch.distributed.get_rank() + rank_offset
self._model_update_group = init_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=rank,
group_name=group_name,
)
print(
f"init_process_group: master_address={master_address}, master_port={master_port}, ",
f"rank={rank}, world_size={world_size}, group_name={group_name}",
)
def update_weight(self, name, dtype, shape, empty_cache=False):
"""Broadcast weight to all vllm workers from source rank 0 (actor model)"""
# print(f"update_weight: {name}, dtype: {dtype}, shape: {shape}, rank: {torch.distributed.get_rank()}, world_size: {torch.distributed.get_world_size()}")
# if torch.distributed.get_rank() == 0:
# print(f"update weight: {name}, dtype: {dtype}, shape: {shape}")
assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}"
weight = torch.empty(shape, dtype=dtype, device="cuda")
torch.distributed.broadcast(weight, 0, group=self._model_update_group)
self.model_runner.model.load_weights(weights=[(name, weight)])
del weight
# TODO: should we empty cache if all weights have updated?
# if empty_cache:
# torch.cuda.empty_cache()
@ray.remote
class LLMRayActor:
def __init__(self, *args, **kwargs):
import vllm
self.__version__ = vllm.__version__
assert self.__version__ >= "0.4.1", "OpenRLHF only supports vLLM >= 0.4.1"
self.use_gpu_executor = kwargs["tensor_parallel_size"] == 1
# See https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py
if self.use_gpu_executor:
vllm.worker.worker.Worker = WorkerWrap
else:
# RayGPUExecutor
# See the patch https://github.com/vllm-project/vllm/commit/479d69fad0538f04cb22bf13e76ff91cfeb8a4e5
kwargs["worker_use_ray"] = True
if vllm.__version__ > "0.4.1":
RayWorkerWrapperPath = vllm.executor.ray_utils
else:
RayWorkerWrapperPath = vllm.engine.ray_utils
# patch for newer vllm from openrlhf:
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/trainer/ray/vllm_engine.py#L40
if vllm.__version__ > "0.6.4.post1":
# https://github.com/vllm-project/vllm/pull/10555
kwargs["worker_cls"] = "open_instruct.vllm_utils2.WorkerWrap"
else:
RayWorkerWrapperPath = vllm.executor.ray_utils
class RayWorkerWrapper(RayWorkerWrapperPath.RayWorkerWrapper):
def __init__(self, *args, **kwargs) -> None:
kwargs["worker_module_name"] = "open_instruct.vllm_utils2"
kwargs["worker_class_name"] = "WorkerWrap"
super().__init__(*args, **kwargs)
RayWorkerWrapperPath.RayWorkerWrapper = RayWorkerWrapper
self.llm = vllm.LLM(*args, **kwargs)
self.llm_engine = self.llm.llm_engine
self.request_counter = Counter()
self.input_queue = queue.Queue()
self.output_queue = queue.Queue()
self.busy_generate_th = threading.Thread(target=self.busy_generate)
self.busy_generate_th.start()
def ping(self):
return "pong"
def generate(self, *args, **kwargs):
return self.llm.generate(*args, **kwargs)
def output_queue_get(self, *args, **kwargs):
return self.output_queue.get(*args, **kwargs)
def input_queue_put(self, *args, **kwargs):
return self.input_queue.put(*args, **kwargs)
def get_num_unfinished_requests(self):
return self.llm_engine.get_num_unfinished_requests()
def busy_generate(self):
while True:
outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
try:
cur_prompts, sampling_params = self.input_queue.get_nowait()
for prompt in cur_prompts:
request_id, prompt = prompt
self.llm_engine.add_request(
request_id=request_id,
prompt=TokensPrompt(prompt_token_ids=prompt),
params=sampling_params
)
except queue.Empty:
pass
if self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if len(outputs) > 0:
self.output_queue.put(outputs)
else:
time.sleep(0.01) # to avoid busy-waiting
def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend):
if self.use_gpu_executor:
return self.llm.llm_engine.model_executor.driver_worker.init_process_group(
master_address, master_port, rank_offset, world_size, group_name, backend
)
else:
return self.llm.llm_engine.model_executor._run_workers(
"init_process_group", master_address, master_port, rank_offset, world_size, group_name, backend
)
def update_weight(self, name, dtype, shape, empty_cache=False):
self.stop_remote_worker_execution_loop()
if self.use_gpu_executor:
return self.llm.llm_engine.model_executor.driver_worker.update_weight(name, dtype, shape, empty_cache)
else:
return self.llm.llm_engine.model_executor._run_workers("update_weight", name, dtype, shape, empty_cache)
def stop_remote_worker_execution_loop(self):
# Fix error for using 2 communication group
# https://github.com/vllm-project/vllm/commit/eb6d3c264d0cd8e44dec16bca7947fbe96415ce9#diff-e1ad69e38e033accddfa5480ec808c4740eb39244d1ef51cc3407e20dde8cfd4
if self.__version__ > "0.4.2":
self.llm.llm_engine.model_executor.stop_remote_worker_execution_loop()
def create_vllm_engines(
num_engines: int,
tensor_parallel_size: int,
enforce_eager: bool,
pretrain: str,
revision: str,
seed: int,
enable_prefix_caching: bool,
max_model_len: int,
vllm_gpu_memory_utilization: float = 0.9,
single_gpu_mode: bool = False,
pg: Optional[ray.util.placement_group] = None,
):
vllm_engines = []
for i in range(num_engines):
# When tensor_parallel_size=1, vLLM init model in LLMEngine directly, assign 1 GPU for it.
num_gpus = int(tensor_parallel_size == 1)
scheduling_strategy = None
if pg is not None:
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group_capture_child_tasks=True
)
elif tensor_parallel_size > 1:
bundles = [{"GPU": 1, "CPU": 4}] * tensor_parallel_size
pg = placement_group(bundles)
ray.get(pg.ready())
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=0
)
print(f"vllm: {num_gpus=}, {num_engines=}")
vllm_engines.append(
LLMRayActor.options(
num_cpus=4,
num_gpus=0.48 if single_gpu_mode else num_gpus,
scheduling_strategy=scheduling_strategy,
).remote(
pretrain,
revision=revision,
tokenizer_revision=revision,
trust_remote_code=True,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
dtype="bfloat16",
seed=seed + i,
enable_prefix_caching=enable_prefix_caching,
max_model_len=max_model_len,
gpu_memory_utilization=vllm_gpu_memory_utilization,
)
)
return vllm_engines
if __name__ == "__main__":
llm = LLMRayActor.remote("Qwen/Qwen2.5-7B", tensor_parallel_size=2)
output = ray.get(llm.generate.remote("San Franciso is a"))
print(f"output: {output}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment