|
#!/usr/bin/env python3 |
|
""" |
|
Universal LoRA Training Dataset Generator |
|
|
|
Creates style-transfer LoRA training datasets from any collection of images. |
|
Generates subject-only captions (no style references) for consistent style learning. |
|
|
|
Supports any artistic style: anime, photorealistic, watercolor, oil painting, etc. |
|
""" |
|
|
|
import json |
|
import os |
|
import requests |
|
import base64 |
|
import argparse |
|
import shutil |
|
from pathlib import Path |
|
from typing import List, Dict, Optional |
|
import time |
|
from PIL import Image |
|
|
|
class UniversalLoRAGenerator: |
|
def __init__(self, |
|
source_dir: str, |
|
output_dir: str = "lora_dataset", |
|
lm_studio_url: str = "http://localhost:1234/v1/chat/completions", |
|
model: str = "gemma-3n-e2b-it-mlx"): |
|
|
|
self.source_dir = Path(source_dir) |
|
self.output_dir = Path(output_dir) |
|
self.lm_studio_url = lm_studio_url |
|
self.model = model |
|
|
|
# Supported image formats |
|
self.image_extensions = {'.png', '.jpg', '.jpeg', '.webp', '.bmp', '.tiff'} |
|
|
|
# Create output directory |
|
self.output_dir.mkdir(exist_ok=True) |
|
|
|
print(f"π― Universal LoRA Dataset Generator") |
|
print(f"π Source: {self.source_dir}") |
|
print(f"π Output: {self.output_dir}") |
|
|
|
def find_images(self) -> List[Path]: |
|
"""Find all supported image files in source directory""" |
|
images = [] |
|
|
|
for ext in self.image_extensions: |
|
images.extend(self.source_dir.glob(f"*{ext}")) |
|
images.extend(self.source_dir.glob(f"*{ext.upper()}")) |
|
|
|
# Sort for consistent processing |
|
images.sort() |
|
|
|
print(f"πΌοΈ Found {len(images)} images in {self.source_dir}") |
|
return images |
|
|
|
def generate_subject_caption(self, image_path: Path) -> Optional[str]: |
|
"""Generate style-agnostic caption focusing only on subject matter""" |
|
|
|
try: |
|
# Convert image to base64 |
|
with open(image_path, 'rb') as f: |
|
image_b64 = base64.b64encode(f.read()).decode() |
|
|
|
prompt = """Describe what you see in this image. Focus ONLY on the subject matter, objects, people, actions, and setting. |
|
|
|
DO NOT mention: |
|
- Artistic style (drawing, painting, illustration, anime, cartoon, etc.) |
|
- Visual techniques (shading, line art, colors, etc.) |
|
- Art medium (watercolor, oil paint, digital art, etc.) |
|
- Quality descriptions (detailed, realistic, stylized, etc.) |
|
|
|
DO describe: |
|
- What objects/people/animals are present |
|
- What they are doing |
|
- The setting/environment |
|
- Clothing, poses, expressions |
|
- Physical relationships between elements |
|
|
|
Examples: |
|
β
Good: "A woman with long hair sitting at a wooden table with a coffee cup" |
|
β Bad: "A beautiful anime-style illustration of a woman sitting at a table" |
|
|
|
β
Good: "Two cats playing in a garden with flowers and a wooden fence" |
|
β Bad: "A watercolor painting showing two cats in a pastoral scene" |
|
|
|
Keep it concise (1-2 sentences). Focus purely on content, not aesthetics.""" |
|
|
|
response = requests.post( |
|
self.lm_studio_url, |
|
json={ |
|
"model": self.model, |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": prompt}, |
|
{ |
|
"type": "image_url", |
|
"image_url": {"url": f"data:image/png;base64,{image_b64}"} |
|
} |
|
] |
|
} |
|
], |
|
"temperature": 0.2, |
|
"max_tokens": 150 |
|
}, |
|
timeout=45 |
|
) |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
caption = result['choices'][0]['message']['content'].strip() |
|
|
|
# Clean up style references that might slip through |
|
return self.clean_style_references(caption) |
|
else: |
|
print(f"β API error {response.status_code} for {image_path.name}") |
|
return None |
|
|
|
except Exception as e: |
|
print(f"β Error processing {image_path.name}: {e}") |
|
return None |
|
|
|
def clean_style_references(self, caption: str) -> str: |
|
"""Remove any style references from caption""" |
|
|
|
style_indicators = [ |
|
# Art styles |
|
'illustration', 'drawing', 'painting', 'artwork', 'sketch', |
|
'anime', 'manga', 'cartoon', 'comic', 'watercolor', 'oil painting', |
|
'digital art', 'concept art', 'fan art', |
|
|
|
# Visual qualities |
|
'realistic', 'stylized', 'detailed', 'vibrant', 'colorful', |
|
'black and white', 'monochrome', 'vintage', 'modern', |
|
|
|
# Techniques |
|
'line art', 'shading', 'rendering', 'brush strokes', |
|
'pen and ink', 'pencil', 'charcoal', |
|
|
|
# Descriptive words often used for art |
|
'depicts', 'shows', 'portrays', 'rendered', 'artistic', |
|
'beautiful', 'stunning', 'gorgeous', 'lovely', |
|
|
|
# Medium references |
|
'canvas', 'paper', 'digital', 'traditional' |
|
] |
|
|
|
# Split into sentences and filter |
|
sentences = [s.strip() for s in caption.split('.') if s.strip()] |
|
clean_sentences = [] |
|
|
|
for sentence in sentences: |
|
sentence_lower = sentence.lower() |
|
|
|
# Skip sentences with style indicators |
|
has_style_ref = any(indicator in sentence_lower for indicator in style_indicators) |
|
|
|
if not has_style_ref: |
|
clean_sentences.append(sentence) |
|
|
|
if clean_sentences: |
|
result = '. '.join(clean_sentences) |
|
if not result.endswith('.'): |
|
result += '.' |
|
return result |
|
else: |
|
# Fallback - try to extract just core nouns/subjects |
|
words = caption.split() |
|
clean_words = [] |
|
skip_next = False |
|
|
|
for i, word in enumerate(words): |
|
if skip_next: |
|
skip_next = False |
|
continue |
|
|
|
word_lower = word.lower().strip('.,!?') |
|
|
|
# Skip style-related words |
|
if any(indicator in word_lower for indicator in style_indicators): |
|
# Also skip the next word if it might be related |
|
if i < len(words) - 1 and words[i+1].lower() in ['of', 'showing', 'with']: |
|
skip_next = True |
|
continue |
|
|
|
clean_words.append(word) |
|
|
|
if clean_words: |
|
return ' '.join(clean_words) |
|
else: |
|
return caption # Return original as last resort |
|
|
|
def convert_and_copy_image(self, source_path: Path, target_name: str) -> bool: |
|
"""Convert image to PNG and copy to dataset directory""" |
|
|
|
try: |
|
target_path = self.output_dir / f"{target_name}.png" |
|
|
|
# If already PNG, just copy |
|
if source_path.suffix.lower() == '.png': |
|
shutil.copy2(source_path, target_path) |
|
return True |
|
|
|
# Convert to PNG |
|
with Image.open(source_path) as img: |
|
# Convert to RGB if necessary (for JPEG compatibility) |
|
if img.mode in ('RGBA', 'LA'): |
|
# Keep alpha channel |
|
img.save(target_path, 'PNG') |
|
elif img.mode != 'RGB': |
|
img = img.convert('RGB') |
|
img.save(target_path, 'PNG') |
|
else: |
|
img.save(target_path, 'PNG') |
|
|
|
return True |
|
|
|
except Exception as e: |
|
print(f"β Failed to convert {source_path.name}: {e}") |
|
return False |
|
|
|
def process_images(self, |
|
max_images: Optional[int] = None, |
|
skip_existing: bool = True, |
|
delay: float = 1.0) -> Dict: |
|
"""Process all images to create LoRA dataset""" |
|
|
|
images = self.find_images() |
|
|
|
if max_images: |
|
images = images[:max_images] |
|
print(f"π― Processing first {max_images} images") |
|
|
|
if not images: |
|
print("β No images found!") |
|
return {} |
|
|
|
print(f"π Starting caption generation...") |
|
print("=" * 60) |
|
|
|
results = { |
|
'successful': 0, |
|
'failed': 0, |
|
'skipped': 0, |
|
'processed_files': [], |
|
'failed_files': [], |
|
'config': { |
|
'source_dir': str(self.source_dir), |
|
'output_dir': str(self.output_dir), |
|
'total_images': len(images), |
|
'max_images': max_images, |
|
'delay': delay |
|
} |
|
} |
|
|
|
for i, image_path in enumerate(images, 1): |
|
# Create clean filename for dataset |
|
clean_name = self.get_clean_filename(image_path) |
|
|
|
caption_file = self.output_dir / f"{clean_name}.txt" |
|
image_file = self.output_dir / f"{clean_name}.png" |
|
|
|
print(f"[{i}/{len(images)}] Processing {image_path.name}...") |
|
|
|
# Skip if already exists |
|
if skip_existing and caption_file.exists() and image_file.exists(): |
|
print(f" βοΈ Already exists, skipping") |
|
results['skipped'] += 1 |
|
continue |
|
|
|
# Generate caption |
|
caption = self.generate_subject_caption(image_path) |
|
|
|
if caption: |
|
# Copy/convert image |
|
if self.convert_and_copy_image(image_path, clean_name): |
|
# Save caption |
|
with open(caption_file, 'w', encoding='utf-8') as f: |
|
f.write(caption) |
|
|
|
print(f" β
Success: {caption[:50]}...") |
|
results['successful'] += 1 |
|
results['processed_files'].append({ |
|
'original': str(image_path), |
|
'dataset_name': clean_name, |
|
'caption': caption |
|
}) |
|
else: |
|
print(f" β Failed to copy image") |
|
results['failed'] += 1 |
|
results['failed_files'].append(str(image_path)) |
|
else: |
|
print(f" β Failed to generate caption") |
|
results['failed'] += 1 |
|
results['failed_files'].append(str(image_path)) |
|
|
|
# Rate limiting |
|
if delay > 0: |
|
time.sleep(delay) |
|
|
|
return results |
|
|
|
def get_clean_filename(self, image_path: Path) -> str: |
|
"""Generate clean filename for dataset""" |
|
# Remove extension and clean up filename |
|
name = image_path.stem |
|
|
|
# Replace problematic characters |
|
clean_name = "".join(c if c.isalnum() or c in '-_' else '_' for c in name) |
|
|
|
# Remove multiple underscores |
|
while '__' in clean_name: |
|
clean_name = clean_name.replace('__', '_') |
|
|
|
# Remove leading/trailing underscores |
|
clean_name = clean_name.strip('_') |
|
|
|
return clean_name |
|
|
|
def create_training_config(self, |
|
style_prompt: str, |
|
style_name: str, |
|
results: Dict) -> None: |
|
"""Create configuration files for LoRA training""" |
|
|
|
# Create style prompt file |
|
with open(self.output_dir / "style_prompt.txt", 'w') as f: |
|
f.write(style_prompt) |
|
|
|
# Create detailed metadata |
|
metadata = { |
|
'dataset_info': { |
|
'name': f"{style_name} LoRA Training Dataset", |
|
'style': style_name, |
|
'style_prompt': style_prompt, |
|
'created': time.strftime('%Y-%m-%d %H:%M:%S'), |
|
'generator': 'Universal LoRA Dataset Generator v1.0' |
|
}, |
|
'statistics': { |
|
'successful': results['successful'], |
|
'failed': results['failed'], |
|
'skipped': results['skipped'], |
|
'success_rate': f"{results['successful']/(results['successful']+results['failed'])*100:.1f}%" if results['successful']+results['failed'] > 0 else "0%" |
|
}, |
|
'training_setup': { |
|
'image_format': 'PNG', |
|
'caption_format': 'TXT (subject-only descriptions)', |
|
'style_learning': 'Apply style_prompt to any subject from captions' |
|
}, |
|
'files': results['processed_files'] |
|
} |
|
|
|
with open(self.output_dir / "metadata.json", 'w') as f: |
|
json.dump(metadata, f, indent=2, ensure_ascii=False) |
|
|
|
# Create README |
|
readme_content = f"""# {style_name} LoRA Training Dataset |
|
|
|
## π Dataset Overview |
|
- **Style**: {style_name} |
|
- **Images**: {results['successful']} PNG files |
|
- **Captions**: {results['successful']} TXT files (subject-only) |
|
- **Generated**: {time.strftime('%Y-%m-%d %H:%M:%S')} |
|
|
|
## π― Training Objective |
|
Train a LoRA model to apply **{style_name}** style to any subject matter. |
|
|
|
## π File Structure |
|
``` |
|
{self.output_dir.name}/ |
|
βββ [filename].png # Training images |
|
βββ [filename].txt # Subject captions (no style refs) |
|
βββ style_prompt.txt # Style prompt for training |
|
βββ metadata.json # Detailed dataset info |
|
βββ README.md # This file |
|
``` |
|
|
|
## π¨ Style Prompt |
|
``` |
|
{style_prompt} |
|
``` |
|
|
|
## π Training Method |
|
1. **Input**: Subject description from .txt files |
|
2. **Target**: Apply {style_name} style to that subject |
|
3. **Result**: LoRA that can stylize any content |
|
|
|
## π Dataset Statistics |
|
- β
Successful: {results['successful']} |
|
- β Failed: {results['failed']} |
|
- βοΈ Skipped: {results['skipped']} |
|
- π Success Rate: {metadata['statistics']['success_rate']} |
|
|
|
## π‘ Usage |
|
Use with your preferred LoRA training tool: |
|
- Pair each .png with its matching .txt caption |
|
- Use style_prompt.txt content for style conditioning |
|
- Train for style transfer to any subject matter |
|
|
|
Generated by Universal LoRA Dataset Generator |
|
""" |
|
|
|
with open(self.output_dir / "README.md", 'w') as f: |
|
f.write(readme_content) |
|
|
|
def validate_captions(self) -> Dict: |
|
"""Validate generated captions for style references""" |
|
|
|
print(f"\nπ Validating captions for style references...") |
|
|
|
style_indicators = ['drawing', 'painting', 'illustration', 'anime', 'artwork', 'sketch'] |
|
|
|
caption_files = list(self.output_dir.glob("*.txt")) |
|
caption_files = [f for f in caption_files if f.name not in ['style_prompt.txt', 'README.md']] |
|
|
|
issues = [] |
|
clean_count = 0 |
|
|
|
for caption_file in caption_files: |
|
with open(caption_file, 'r', encoding='utf-8') as f: |
|
caption = f.read().strip().lower() |
|
|
|
found_issues = [word for word in style_indicators if word in caption] |
|
|
|
if found_issues: |
|
issues.append({ |
|
'file': caption_file.name, |
|
'caption': caption[:60] + "..." if len(caption) > 60 else caption, |
|
'style_words': found_issues |
|
}) |
|
else: |
|
clean_count += 1 |
|
|
|
validation_results = { |
|
'clean_captions': clean_count, |
|
'problematic_captions': len(issues), |
|
'issues': issues, |
|
'success_rate': f"{clean_count/(clean_count+len(issues))*100:.1f}%" if clean_count+len(issues) > 0 else "0%" |
|
} |
|
|
|
if issues: |
|
print(f"β οΈ Found {len(issues)} captions with potential style references:") |
|
for issue in issues[:5]: # Show first 5 |
|
print(f" β’ {issue['file']}: {issue['style_words']}") |
|
print(f" π {issue['caption']}") |
|
else: |
|
print("β
All captions are clean - no style references found!") |
|
|
|
print(f"π Validation: {clean_count} clean, {len(issues)} problematic ({validation_results['success_rate']} success)") |
|
|
|
return validation_results |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description="Universal LoRA Training Dataset Generator", |
|
epilog=""" |
|
Examples: |
|
%(prog)s anime_images/ --style "anime, manga style" --name "Anime" |
|
%(prog)s portraits/ --style "oil painting, renaissance portrait" --name "Renaissance" --max 50 |
|
%(prog)s landscapes/ --style "watercolor painting, impressionist" --name "Impressionist" |
|
""", |
|
formatter_class=argparse.RawDescriptionHelpFormatter |
|
) |
|
|
|
parser.add_argument("source_dir", help="Directory containing source images") |
|
parser.add_argument("--output", "-o", default="lora_dataset", help="Output directory (default: lora_dataset)") |
|
parser.add_argument("--style", "-s", required=True, help="Style prompt for training (e.g. 'anime, manga style')") |
|
parser.add_argument("--name", "-n", required=True, help="Style name for documentation (e.g. 'Anime')") |
|
parser.add_argument("--max", "-m", type=int, help="Maximum number of images to process") |
|
parser.add_argument("--delay", "-d", type=float, default=1.0, help="Delay between API calls in seconds (default: 1.0)") |
|
parser.add_argument("--no-skip", action="store_true", help="Don't skip existing files") |
|
parser.add_argument("--lm-studio", default="http://localhost:1234/v1/chat/completions", help="LM Studio API URL") |
|
parser.add_argument("--model", default="gemma-3n-e2b-it-mlx", help="LLM model name") |
|
|
|
args = parser.parse_args() |
|
|
|
# Validate source directory |
|
if not os.path.exists(args.source_dir): |
|
print(f"β Source directory not found: {args.source_dir}") |
|
return |
|
|
|
# Test LM Studio connection |
|
print("π§ Testing LM Studio connection...") |
|
try: |
|
test_response = requests.get(f"{args.lm_studio.replace('/chat/completions', '/models')}", timeout=5) |
|
if test_response.status_code != 200: |
|
print(f"β LM Studio not responding at {args.lm_studio}") |
|
print("π‘ Please start LM Studio and load your model") |
|
return |
|
print("β
LM Studio connected") |
|
except requests.exceptions.RequestException: |
|
print(f"β Cannot connect to LM Studio at {args.lm_studio}") |
|
print("π‘ Please start LM Studio and load your model") |
|
return |
|
|
|
# Create generator |
|
generator = UniversalLoRAGenerator( |
|
source_dir=args.source_dir, |
|
output_dir=args.output, |
|
lm_studio_url=args.lm_studio, |
|
model=args.model |
|
) |
|
|
|
# Process images |
|
results = generator.process_images( |
|
max_images=args.max, |
|
skip_existing=not args.no_skip, |
|
delay=args.delay |
|
) |
|
|
|
# Create training configuration |
|
generator.create_training_config( |
|
style_prompt=args.style, |
|
style_name=args.name, |
|
results=results |
|
) |
|
|
|
# Validate captions |
|
validation = generator.validate_captions() |
|
|
|
# Final summary |
|
print("\n" + "=" * 60) |
|
print(f"π Dataset generation complete!") |
|
print(f"π Dataset location: {args.output}/") |
|
print(f"πΌοΈ Images processed: {results['successful']}") |
|
print(f"π Captions generated: {results['successful']}") |
|
print(f"π Success rate: {results['successful']/(results['successful']+results['failed'])*100:.1f}%") |
|
print(f"π Caption validation: {validation['success_rate']} clean") |
|
|
|
print(f"\nπ‘ Next steps:") |
|
print(f"1. Review generated dataset in {args.output}/") |
|
print(f"2. Manually check any problematic captions") |
|
print(f"3. Train LoRA using your preferred tool") |
|
print(f"4. Test style transfer on new subjects") |
|
|
|
if __name__ == "__main__": |
|
main() |