Created
December 12, 2023 23:44
-
-
Save webeng/7c3ef17f7e982c27bfbf7bdc751ba3bf to your computer and use it in GitHub Desktop.
Code used for experimenting with Adept's FUYU model in https://joanfihu.wordpress.com/2023/10/19/evaluating-adepts-fuyu-model-for-ui-navigation/
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import FuyuForCausalLM, AutoTokenizer, FuyuProcessor, FuyuImageProcessor | |
from PIL import Image | |
import torch | |
import argparse | |
parser = argparse.ArgumentParser() | |
# parser.add_argument("--image_path", type=str, default="amazon_screenshot.png") | |
parser.add_argument("--text_prompt", type=str, default="Generate a coco-style caption.\n") | |
args = parser.parse_args() | |
text_prompt = args.text_prompt | |
assert text_prompt is not None | |
# load model, tokenizer, and processor | |
pretrained_path = "adept/fuyu-8b" | |
tokenizer = AutoTokenizer.from_pretrained(pretrained_path) | |
image_processor = FuyuImageProcessor() | |
processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer) | |
model = FuyuForCausalLM.from_pretrained(pretrained_path, device_map="cuda:0", torch_dtype=torch.float16) | |
# test inference | |
# text_prompt = "Generate a coco-style caption.\n" | |
image_path = "amazon_screenshot (1).jpg" # https://huggingface.co/adept-hf-collab/fuyu-8b/blob/main/bus.png | |
image_pil = Image.open(image_path) | |
# Ensure the image is in RGB format | |
if image_pil.mode != "RGB": | |
image_pil = image_pil.convert("RGB") | |
# Convert image to numpy array | |
import numpy as np | |
image_np = np.array(image_pil) | |
# Ensure the image is in CHW format | |
if image_np.shape[2] == 3: # Check if it's HWC | |
image_np = np.transpose(image_np, (2, 0, 1)) | |
# Convert back to PIL Image for processing (if required) | |
image_pil = Image.fromarray(image_np.transpose(1, 2, 0)) | |
model_inputs = processor(text=text_prompt, images=[image_pil], device="cuda:0") | |
for k, v in model_inputs.items(): | |
model_inputs[k] = v.to("cuda:0") | |
max_new_tokens = 1024 | |
generation_output = model.generate(**model_inputs, max_new_tokens=max_new_tokens) | |
generation_text = processor.batch_decode(generation_output[:, -max_new_tokens:], skip_special_tokens=True) | |
print(generation_text) | |
# assert generation_text == ['A bus parked on the side of a road.'] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment