Skip to content

Instantly share code, notes, and snippets.

@marduk191
Last active November 13, 2024 04:25
Show Gist options
  • Save marduk191/1252b3b7bf679441490cf86b0b144f59 to your computer and use it in GitHub Desktop.
Save marduk191/1252b3b7bf679441490cf86b0b144f59 to your computer and use it in GitHub Desktop.
Comfyui Tranformers node for internlm-xcomposer2-4khd-7b
import torch
import os
import folder_paths
from transformers import AutoModel, AutoTokenizer
class InternLMXComposer2:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = None
self.tokenizer = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"prompt": ("STRING", {"default": "Illustrate the fine details present in the image"}),
"hd_num": ("INT", {"default": 55, "min": 1, "max": 100}),
"num_beams": ("INT", {"default": 3, "min": 1, "max": 10}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "generate_description"
CATEGORY = "image/text"
def load_model(self):
if self.model is None:
print("Loading InternLM-XComposer2 model...")
self.model = AutoModel.from_pretrained(
'internlm/internlm-xcomposer2-4khd-7b',
torch_dtype=torch.bfloat16,
trust_remote_code=True
).to(self.device).eval()
self.tokenizer = AutoTokenizer.from_pretrained(
'internlm/internlm-xcomposer2-4khd-7b',
trust_remote_code=True
)
print("Model loaded successfully!")
def generate_description(self, image, prompt, hd_num=55, num_beams=3):
self.load_model()
import tempfile
import numpy as np
from PIL import Image
# Format the prompt with <ImageHere> token
formatted_prompt = f"<ImageHere>{prompt}"
# Convert tensor to numpy array
image_np = image.cpu().numpy()
# Handle different image formats
if len(image_np.shape) == 4:
image_np = image_np[0] # Remove batch dimension if present
# Convert to RGB if grayscale
if image_np.shape[0] == 1: # Grayscale
image_np = np.repeat(image_np, 3, axis=0)
# Ensure correct channel order and shape
if image_np.shape[0] in [3, 4]: # If channels first
image_np = np.transpose(image_np, (1, 2, 0))
# Handle alpha channel if present
if image_np.shape[-1] == 4:
image_np = image_np[..., :3]
# Scale values to 0-255 range if needed
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
else:
image_np = image_np.astype(np.uint8)
# Convert to PIL Image
pil_image = Image.fromarray(image_np)
# Ensure RGB mode
if pil_image.mode != 'RGB':
pil_image = pil_image.convert('RGB')
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
pil_image.save(tmp_file.name)
tmp_path = tmp_file.name
try:
with torch.cuda.amp.autocast():
response, _ = self.model.chat(
self.tokenizer,
query=formatted_prompt,
image=tmp_path,
hd_num=hd_num,
history=[],
do_sample=False,
num_beams=num_beams
)
finally:
# Clean up temporary file
os.unlink(tmp_path)
return (response,)
# Node class registration
NODE_CLASS_MAPPINGS = {
"InternLMXComposer2": InternLMXComposer2
}
NODE_DISPLAY_NAME_MAPPINGS = {
"InternLMXComposer2": "InternLM XComposer2"
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment