Created
March 18, 2026 18:10
-
-
Save tdhopper/803f541b9e789664c543a8525beb1531 to your computer and use it in GitHub Desktop.
Straive Batch 2: BQ → JSONL generation script (regenerate_batch_jsonl_v0v1.py)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """Generate straive batch 2 JSONL with V0/V1 caption strategy. | |
| For ALL track/prompt pairs from BigQuery: | |
| - Deterministically pick 2 of {raw, v0, v1} per track/prompt pair | |
| - Generate 4 seeds per pair → 8 rows per pair | |
| - Record chosen versions in metadata | |
| Outputs two files: | |
| - {output}_ready.jsonl: rows where target_text is known | |
| (all raw rows + v0/v1 rows for styles with spotbot captions) | |
| - {output}_pending.jsonl: rows where target_text is pending | |
| (v0/v1 rows for styles without spotbot captions yet) | |
| Usage: | |
| python -m diffusify.scripts.regenerate_batch_jsonl_v0v1 \ | |
| --bq-table fan-audio-2.dpo_samples.straive_batch_2_track_prompt_pairs \ | |
| --sample 8000 \ | |
| --output straive_batch_2_generation | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import hashlib | |
| import json | |
| import logging | |
| import random | |
| import sys | |
| from typing import Any | |
| from google.cloud import bigquery | |
| from diffusify.track_sets.spotbot_playlists import ( | |
| SPOTBOT_PLAYLISTS, | |
| SPOTBOT_STYLE_CAPTION_VARIANTS, | |
| ) | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| DEFAULT_BQ_TABLE = "fan-audio-2.dpo_samples.straive_batch_2_track_prompt_pairs" | |
| PROMPT_VERSIONS = ["raw", "v0", "v1"] | |
| def _build_normalized_lookup() -> dict[str, str]: | |
| """Build a mapping from normalized style names to their SPOTBOT_PLAYLISTS keys. | |
| Handles mismatches like 'hard-rock' (batch) vs 'Hard Rock' (spotbot). | |
| """ | |
| lookup: dict[str, str] = {} | |
| for key in SPOTBOT_PLAYLISTS: | |
| normalized = key.lower().replace(" ", "-") | |
| lookup[normalized] = key | |
| return lookup | |
| _NORMALIZED_SPOTBOT_KEYS = _build_normalized_lookup() | |
| def _resolve_spotbot_key(style: str) -> str | None: | |
| """Resolve a batch-style name to a SPOTBOT_PLAYLISTS key, handling naming differences.""" | |
| if style in SPOTBOT_PLAYLISTS: | |
| return style | |
| normalized = style.lower().replace(" ", "-") | |
| return _NORMALIZED_SPOTBOT_KEYS.get(normalized) | |
| def get_captions(style: str) -> dict[str, str | None]: | |
| """Get raw, V0, and V1 captions for a style. V0/V1 are None if not available.""" | |
| spotbot_key = _resolve_spotbot_key(style) | |
| v0: str | None = None | |
| v1: str | None = None | |
| if spotbot_key is not None: | |
| playlist = SPOTBOT_PLAYLISTS[spotbot_key] | |
| v0 = playlist.caption or None | |
| v1 = SPOTBOT_STYLE_CAPTION_VARIANTS.get("accessible", {}).get(spotbot_key) | |
| return {"raw": style, "v0": v0, "v1": v1} | |
| def pick_two_versions(track_uri: str, prompt: str) -> list[str]: | |
| """Deterministically pick 2 of 3 prompt versions for a track/prompt pair.""" | |
| h = hashlib.sha256(f"{track_uri}:{prompt}:version_pick".encode()).hexdigest() | |
| rng = random.Random(int(h[:8], 16)) | |
| return sorted(rng.sample(PROMPT_VERSIONS, 2)) | |
| def generate_deterministic_seeds(track_uri: str, prompt: str, n: int) -> list[int]: | |
| """Generate N deterministic seeds from a track_uri + prompt pair.""" | |
| seeds = [] | |
| for i in range(n): | |
| h = hashlib.sha256(f"{track_uri}:{prompt}:{i}".encode()).hexdigest() | |
| seeds.append(int(h[:8], 16)) | |
| return seeds | |
| def fetch_data_from_bigquery(table: str) -> list[dict[str, Any]]: | |
| """Fetch all rows from the BigQuery table.""" | |
| client = bigquery.Client(project="fan-audio-2") | |
| query = f"SELECT * FROM `{table}`" | |
| logger.info(f"Fetching data from BigQuery table: {table}") | |
| results = list(client.query(query).result()) | |
| logger.info(f"Fetched {len(results)} rows") | |
| return [dict(row) for row in results] | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description="Generate batch JSONL with V0/V1 captions") | |
| parser.add_argument( | |
| "--output", | |
| "-o", | |
| required=True, | |
| help="Output prefix (produces {output}_ready.jsonl and {output}_pending.jsonl)", | |
| ) | |
| parser.add_argument( | |
| "--bq-table", default=DEFAULT_BQ_TABLE, help=f"BigQuery table (default: {DEFAULT_BQ_TABLE})" | |
| ) | |
| parser.add_argument("--sample", type=int, help="Randomly sample N rows from BigQuery results") | |
| parser.add_argument("--sample-seed", type=int, default=42, help="Seed for sampling (default: 42)") | |
| parser.add_argument("--num-seeds", type=int, default=4, help="Seeds per pair (default: 4)") | |
| parser.add_argument("--batch", default="straive_batch_2", help="Batch name in metadata") | |
| args = parser.parse_args() | |
| # Fetch from BQ | |
| bq_rows = fetch_data_from_bigquery(args.bq_table) | |
| # Sample | |
| if args.sample and args.sample < len(bq_rows): | |
| rng = random.Random(args.sample_seed) | |
| bq_rows = rng.sample(bq_rows, args.sample) | |
| logger.info(f"Sampled {len(bq_rows)} rows (seed={args.sample_seed})") | |
| unique_styles = sorted(set(row["prompt"] for row in bq_rows)) | |
| covered = [s for s in unique_styles if get_captions(s)["v0"] is not None] | |
| uncovered = [s for s in unique_styles if get_captions(s)["v0"] is None] | |
| logger.info( | |
| f"{len(unique_styles)} unique styles: {len(covered)} covered, {len(uncovered)} uncovered" | |
| ) | |
| # Track stats | |
| version_counts: dict[str, int] = {"raw": 0, "v0": 0, "v1": 0} | |
| ready_rows: list[dict[str, Any]] = [] | |
| pending_rows: list[dict[str, Any]] = [] | |
| for row in bq_rows: | |
| raw_prompt = row["prompt"] | |
| track_uri = row["track_uri"] | |
| captions = get_captions(raw_prompt) | |
| versions = pick_two_versions(track_uri, raw_prompt) | |
| seeds = generate_deterministic_seeds(track_uri, raw_prompt, args.num_seeds) | |
| base_metadata = { | |
| "artist_name": row.get("artist_name", ""), | |
| "artist_genre": row.get("artist_genre", ""), | |
| "prompt_type": row.get("prompt_type", ""), | |
| "prompt_category": row.get("prompt_category", ""), | |
| "raw_prompt": raw_prompt, | |
| "batch": args.batch, | |
| "chosen_versions": versions, | |
| } | |
| for version in versions: | |
| version_counts[version] += 1 | |
| target_text = captions[version] | |
| has_caption = target_text is not None | |
| for seed in seeds: | |
| metadata = { | |
| **base_metadata, | |
| "prompt_version": version, | |
| "has_caption": has_caption, | |
| } | |
| if version != "raw" and target_text is not None: | |
| metadata["expanded_prompt"] = target_text | |
| jsonl_row = { | |
| "source_track_uri": track_uri, | |
| "target_text": target_text, | |
| "manual_seed": seed, | |
| "metadata": metadata, | |
| } | |
| if has_caption: | |
| ready_rows.append(jsonl_row) | |
| else: | |
| pending_rows.append(jsonl_row) | |
| # Write split files | |
| ready_path = f"{args.output}_ready.jsonl" | |
| pending_path = f"{args.output}_pending.jsonl" | |
| with open(ready_path, "w") as f: | |
| for row in ready_rows: | |
| f.write(json.dumps(row) + "\n") | |
| with open(pending_path, "w") as f: | |
| for row in pending_rows: | |
| f.write(json.dumps(row) + "\n") | |
| logger.info(f"Wrote {len(ready_rows)} rows to {ready_path}") | |
| logger.info(f"Wrote {len(pending_rows)} rows to {pending_path}") | |
| logger.info( | |
| f"Version distribution across pairs: " | |
| f"raw={version_counts['raw']}, v0={version_counts['v0']}, v1={version_counts['v1']}" | |
| ) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment