Created
April 4, 2025 09:32
-
-
Save tmalsburg/26f039e5ff05b94d9197b462a4e51b97 to your computer and use it in GitHub Desktop.
OLMp patch for llm_generate.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/llm_generate.py b/llm_generate.py | |
index 061cc99..73fb208 100755 | |
--- a/llm_generate.py | |
+++ b/llm_generate.py | |
@@ -9,6 +9,8 @@ | |
# "protobuf>=6.30.1", | |
# "spacy>=3.8.4", | |
# "ftfy>=6.3.1", | |
+# "ai2-olmo", | |
+# "datasets", | |
# ] | |
# /// | |
@@ -63,15 +65,20 @@ import numpy as np | |
from transformers import AutoTokenizer | |
import torch.nn.functional as F | |
-exec(f"from transformers import {model_class}") | |
-model_class = eval(model_class) | |
- | |
# Special case for GPT1 (with AutoTokenizer, surprisal calculations | |
# wouldn't work): | |
if model=="openai-community/openai-gpt": | |
+ exec(f"from transformers import {model_class}") | |
+ model_class = eval(model_class) | |
from transformers import OpenAIGPTTokenizer | |
tokenizer = OpenAIGPTTokenizer.from_pretrained(model) | |
+elif model=="allenai/OLMo-1B": | |
+ from hf_olmo import OLMoForCausalLM, OLMoTokenizerFast | |
+ model_class = eval(model_class) | |
+ tokenizer = OLMoTokenizerFast.from_pretrained(model) | |
else: | |
+ exec(f"from transformers import {model_class}") | |
+ model_class = eval(model_class) | |
tokenizer = AutoTokenizer.from_pretrained(model) | |
model = model_class.from_pretrained(model) | |
diff --git a/models.py b/models.py | |
index c185d41..2220990 100644 | |
--- a/models.py | |
+++ b/models.py | |
@@ -13,4 +13,5 @@ models = { | |
"xglm-2.9B": ("facebook/xglm-2.9B", "XGLMForCausalLM"), | |
"german-gpt2": ("dbmdz/german-gpt2", "AutoModelForCausalLM"), | |
"german-gpt2-larger": ("stefan-it/german-gpt2-larger", "AutoModelForCausalLM"), | |
+ "OLMo-1B": ("allenai/OLMo-1B", "OLMoForCausalLM"), | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment