Created
April 4, 2025 09:32
-
-
Save tmalsburg/52c860d468258afc5b8f4f3f2b7bbd1e to your computer and use it in GitHub Desktop.
OLMp patch for llm_generate.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/llm_generate.py b/llm_generate.py | |
index 061cc99..73fb208 100755 | |
--- a/llm_generate.py | |
+++ b/llm_generate.py | |
@@ -9,6 +9,8 @@ | |
# "protobuf>=6.30.1", | |
# "spacy>=3.8.4", | |
# "ftfy>=6.3.1", | |
+# "ai2-olmo", | |
+# "datasets", | |
# ] | |
# /// | |
@@ -63,15 +65,20 @@ import numpy as np | |
from transformers import AutoTokenizer | |
import torch.nn.functional as F | |
-exec(f"from transformers import {model_class}") | |
-model_class = eval(model_class) | |
- | |
# Special case for GPT1 (with AutoTokenizer, surprisal calculations | |
# wouldn't work): | |
if model=="openai-community/openai-gpt": | |
+ exec(f"from transformers import {model_class}") | |
+ model_class = eval(model_class) | |
from transformers import OpenAIGPTTokenizer | |
tokenizer = OpenAIGPTTokenizer.from_pretrained(model) | |
+elif model=="allenai/OLMo-1B": | |
+ from hf_olmo import OLMoForCausalLM, OLMoTokenizerFast | |
+ model_class = eval(model_class) | |
+ tokenizer = OLMoTokenizerFast.from_pretrained(model) | |
else: | |
+ exec(f"from transformers import {model_class}") | |
+ model_class = eval(model_class) | |
tokenizer = AutoTokenizer.from_pretrained(model) | |
model = model_class.from_pretrained(model) | |
diff --git a/models.py b/models.py | |
index c185d41..2220990 100644 | |
--- a/models.py | |
+++ b/models.py | |
@@ -13,4 +13,5 @@ models = { | |
"xglm-2.9B": ("facebook/xglm-2.9B", "XGLMForCausalLM"), | |
"german-gpt2": ("dbmdz/german-gpt2", "AutoModelForCausalLM"), | |
"german-gpt2-larger": ("stefan-it/german-gpt2-larger", "AutoModelForCausalLM"), | |
+ "OLMo-1B": ("allenai/OLMo-1B", "OLMoForCausalLM"), | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment