Skip to content

Instantly share code, notes, and snippets.

@ArthurZucker
ArthurZucker / sam_split.py
Created October 5, 2024 12:10
Split q k v in sam
from transformers.models.sam.modeling_sam import SamVisionAttention, SamModel, SamVisionLayer
from transformers import SamProcessor
import torch.nn as nn
import torch
from transformers.models.sam import modeling_sam
from PIL import Image
import requests
import matplotlib.pyplot as plt
import numpy as np
from transformers import pipeline
@ArthurZucker
ArthurZucker / whisper_compile.py
Created June 7, 2024 08:39
Whisper static cache
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, AutoProcessor, StaticCache
import torch
import torch._dynamo.config
import torch._inductor.config
import time
from tqdm import tqdm
import logging
torch._inductor.config.coordinate_descent_tuning = True
@ArthurZucker
ArthurZucker / mamba_peft.py
Created March 7, 2024 09:32
Mamba peft finetuning
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
@ArthurZucker
ArthurZucker / generate_loop.py
Created March 1, 2024 03:03
I don't pass the positions so prompts have the same shape
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache
import torch
from typing import Optional
import time
import os
os.environ["TOKENIZERS_PARALLELISM"] = "1"
device = "cuda:1"
torch.set_float32_matmul_precision('high')
@ArthurZucker
ArthurZucker / static_kv_cache.py
Last active October 21, 2024 02:08
simple static kv cache script
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache
import torch
from typing import Optional
device = "cuda"
# Copied from the gpt-fast repo
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
@ArthurZucker
ArthurZucker / best_benchmark.py
Last active October 21, 2024 02:08
benchmark
FRANCE_ARTICLE = """<s>Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. \"One can hear cries of 'My God' in several languages,\" Par
@ArthurZucker
ArthurZucker / hf_compiled.py
Created February 8, 2024 06:18
Transformers with torch compile
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache, set_seed
torch.set_printoptions(linewidth=400)
attn_implementation = "sdpa"
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>")
model = AutoModelForCausalLM.from_pretrained("NousResearch/Llama-2-7b-chat-hf",torch_dtype=torch.bfloat16,attn_implementation=attn_implementation).to("cuda:1")
inputs = tokenizer(
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
@ArthurZucker
ArthurZucker / bench_fa2.py
Created February 8, 2024 03:32
bench_fa2.py
import torch
import os
import argparse
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import seaborn as sns
def get_parser():
@ArthurZucker
ArthurZucker / convert_marians.bash
Created November 10, 2023 15:05
Script to automatically convert and upload marian models, checking new results vs previous
#!/bin/bash
# conda create -n 4.29 python==3.9
# source activate 4.29
# pip install transformers==4.29.2
# pip install torch accelerate sentencepiece tokenizers colorama sacremoses googletrans
# conda create -n 4.34 python==3.9
# source activate 4.34
# pip install transformers==4.34
# pip install torch accelerate sentencepiece tokenizers colorama sacremoses googletrans
@ArthurZucker
ArthurZucker / nllb_moe_gist.py
Last active April 28, 2023 12:51
nllb_moe_gist
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-moe-54b")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-moe-54b", device_map = "auto")
>>> article = "Previously, Ring's CEO, Jamie Siminoff, remarked the company started when his doorbell wasn't audible from his shop in his garage."
>>> inputs = tokenizer(article, return_tensors="pt")
>>> translated_tokens = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["fra_Latn"], max_length=50)
>>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]