Skip to content

Instantly share code, notes, and snippets.

@webeng
Created December 12, 2023 23:44
Show Gist options
  • Save webeng/7c3ef17f7e982c27bfbf7bdc751ba3bf to your computer and use it in GitHub Desktop.
Save webeng/7c3ef17f7e982c27bfbf7bdc751ba3bf to your computer and use it in GitHub Desktop.
from transformers import FuyuForCausalLM, AutoTokenizer, FuyuProcessor, FuyuImageProcessor
from PIL import Image
import torch
import argparse
parser = argparse.ArgumentParser()
# parser.add_argument("--image_path", type=str, default="amazon_screenshot.png")
parser.add_argument("--text_prompt", type=str, default="Generate a coco-style caption.\n")
args = parser.parse_args()
text_prompt = args.text_prompt
assert text_prompt is not None
# load model, tokenizer, and processor
pretrained_path = "adept/fuyu-8b"
tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
image_processor = FuyuImageProcessor()
processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer)
model = FuyuForCausalLM.from_pretrained(pretrained_path, device_map="cuda:0", torch_dtype=torch.float16)
# test inference
# text_prompt = "Generate a coco-style caption.\n"
image_path = "amazon_screenshot (1).jpg" # https://huggingface.co/adept-hf-collab/fuyu-8b/blob/main/bus.png
image_pil = Image.open(image_path)
# Ensure the image is in RGB format
if image_pil.mode != "RGB":
image_pil = image_pil.convert("RGB")
# Convert image to numpy array
import numpy as np
image_np = np.array(image_pil)
# Ensure the image is in CHW format
if image_np.shape[2] == 3: # Check if it's HWC
image_np = np.transpose(image_np, (2, 0, 1))
# Convert back to PIL Image for processing (if required)
image_pil = Image.fromarray(image_np.transpose(1, 2, 0))
model_inputs = processor(text=text_prompt, images=[image_pil], device="cuda:0")
for k, v in model_inputs.items():
model_inputs[k] = v.to("cuda:0")
max_new_tokens = 1024
generation_output = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
generation_text = processor.batch_decode(generation_output[:, -max_new_tokens:], skip_special_tokens=True)
print(generation_text)
# assert generation_text == ['A bus parked on the side of a road.']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment