Skip to content

Instantly share code, notes, and snippets.

@tdhopper
Created March 18, 2026 18:10
Show Gist options
  • Select an option

  • Save tdhopper/803f541b9e789664c543a8525beb1531 to your computer and use it in GitHub Desktop.

Select an option

Save tdhopper/803f541b9e789664c543a8525beb1531 to your computer and use it in GitHub Desktop.
Straive Batch 2: BQ → JSONL generation script (regenerate_batch_jsonl_v0v1.py)
#!/usr/bin/env python3
"""Generate straive batch 2 JSONL with V0/V1 caption strategy.
For ALL track/prompt pairs from BigQuery:
- Deterministically pick 2 of {raw, v0, v1} per track/prompt pair
- Generate 4 seeds per pair → 8 rows per pair
- Record chosen versions in metadata
Outputs two files:
- {output}_ready.jsonl: rows where target_text is known
(all raw rows + v0/v1 rows for styles with spotbot captions)
- {output}_pending.jsonl: rows where target_text is pending
(v0/v1 rows for styles without spotbot captions yet)
Usage:
python -m diffusify.scripts.regenerate_batch_jsonl_v0v1 \
--bq-table fan-audio-2.dpo_samples.straive_batch_2_track_prompt_pairs \
--sample 8000 \
--output straive_batch_2_generation
"""
from __future__ import annotations
import argparse
import hashlib
import json
import logging
import random
import sys
from typing import Any
from google.cloud import bigquery
from diffusify.track_sets.spotbot_playlists import (
SPOTBOT_PLAYLISTS,
SPOTBOT_STYLE_CAPTION_VARIANTS,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
DEFAULT_BQ_TABLE = "fan-audio-2.dpo_samples.straive_batch_2_track_prompt_pairs"
PROMPT_VERSIONS = ["raw", "v0", "v1"]
def _build_normalized_lookup() -> dict[str, str]:
"""Build a mapping from normalized style names to their SPOTBOT_PLAYLISTS keys.
Handles mismatches like 'hard-rock' (batch) vs 'Hard Rock' (spotbot).
"""
lookup: dict[str, str] = {}
for key in SPOTBOT_PLAYLISTS:
normalized = key.lower().replace(" ", "-")
lookup[normalized] = key
return lookup
_NORMALIZED_SPOTBOT_KEYS = _build_normalized_lookup()
def _resolve_spotbot_key(style: str) -> str | None:
"""Resolve a batch-style name to a SPOTBOT_PLAYLISTS key, handling naming differences."""
if style in SPOTBOT_PLAYLISTS:
return style
normalized = style.lower().replace(" ", "-")
return _NORMALIZED_SPOTBOT_KEYS.get(normalized)
def get_captions(style: str) -> dict[str, str | None]:
"""Get raw, V0, and V1 captions for a style. V0/V1 are None if not available."""
spotbot_key = _resolve_spotbot_key(style)
v0: str | None = None
v1: str | None = None
if spotbot_key is not None:
playlist = SPOTBOT_PLAYLISTS[spotbot_key]
v0 = playlist.caption or None
v1 = SPOTBOT_STYLE_CAPTION_VARIANTS.get("accessible", {}).get(spotbot_key)
return {"raw": style, "v0": v0, "v1": v1}
def pick_two_versions(track_uri: str, prompt: str) -> list[str]:
"""Deterministically pick 2 of 3 prompt versions for a track/prompt pair."""
h = hashlib.sha256(f"{track_uri}:{prompt}:version_pick".encode()).hexdigest()
rng = random.Random(int(h[:8], 16))
return sorted(rng.sample(PROMPT_VERSIONS, 2))
def generate_deterministic_seeds(track_uri: str, prompt: str, n: int) -> list[int]:
"""Generate N deterministic seeds from a track_uri + prompt pair."""
seeds = []
for i in range(n):
h = hashlib.sha256(f"{track_uri}:{prompt}:{i}".encode()).hexdigest()
seeds.append(int(h[:8], 16))
return seeds
def fetch_data_from_bigquery(table: str) -> list[dict[str, Any]]:
"""Fetch all rows from the BigQuery table."""
client = bigquery.Client(project="fan-audio-2")
query = f"SELECT * FROM `{table}`"
logger.info(f"Fetching data from BigQuery table: {table}")
results = list(client.query(query).result())
logger.info(f"Fetched {len(results)} rows")
return [dict(row) for row in results]
def main() -> int:
parser = argparse.ArgumentParser(description="Generate batch JSONL with V0/V1 captions")
parser.add_argument(
"--output",
"-o",
required=True,
help="Output prefix (produces {output}_ready.jsonl and {output}_pending.jsonl)",
)
parser.add_argument(
"--bq-table", default=DEFAULT_BQ_TABLE, help=f"BigQuery table (default: {DEFAULT_BQ_TABLE})"
)
parser.add_argument("--sample", type=int, help="Randomly sample N rows from BigQuery results")
parser.add_argument("--sample-seed", type=int, default=42, help="Seed for sampling (default: 42)")
parser.add_argument("--num-seeds", type=int, default=4, help="Seeds per pair (default: 4)")
parser.add_argument("--batch", default="straive_batch_2", help="Batch name in metadata")
args = parser.parse_args()
# Fetch from BQ
bq_rows = fetch_data_from_bigquery(args.bq_table)
# Sample
if args.sample and args.sample < len(bq_rows):
rng = random.Random(args.sample_seed)
bq_rows = rng.sample(bq_rows, args.sample)
logger.info(f"Sampled {len(bq_rows)} rows (seed={args.sample_seed})")
unique_styles = sorted(set(row["prompt"] for row in bq_rows))
covered = [s for s in unique_styles if get_captions(s)["v0"] is not None]
uncovered = [s for s in unique_styles if get_captions(s)["v0"] is None]
logger.info(
f"{len(unique_styles)} unique styles: {len(covered)} covered, {len(uncovered)} uncovered"
)
# Track stats
version_counts: dict[str, int] = {"raw": 0, "v0": 0, "v1": 0}
ready_rows: list[dict[str, Any]] = []
pending_rows: list[dict[str, Any]] = []
for row in bq_rows:
raw_prompt = row["prompt"]
track_uri = row["track_uri"]
captions = get_captions(raw_prompt)
versions = pick_two_versions(track_uri, raw_prompt)
seeds = generate_deterministic_seeds(track_uri, raw_prompt, args.num_seeds)
base_metadata = {
"artist_name": row.get("artist_name", ""),
"artist_genre": row.get("artist_genre", ""),
"prompt_type": row.get("prompt_type", ""),
"prompt_category": row.get("prompt_category", ""),
"raw_prompt": raw_prompt,
"batch": args.batch,
"chosen_versions": versions,
}
for version in versions:
version_counts[version] += 1
target_text = captions[version]
has_caption = target_text is not None
for seed in seeds:
metadata = {
**base_metadata,
"prompt_version": version,
"has_caption": has_caption,
}
if version != "raw" and target_text is not None:
metadata["expanded_prompt"] = target_text
jsonl_row = {
"source_track_uri": track_uri,
"target_text": target_text,
"manual_seed": seed,
"metadata": metadata,
}
if has_caption:
ready_rows.append(jsonl_row)
else:
pending_rows.append(jsonl_row)
# Write split files
ready_path = f"{args.output}_ready.jsonl"
pending_path = f"{args.output}_pending.jsonl"
with open(ready_path, "w") as f:
for row in ready_rows:
f.write(json.dumps(row) + "\n")
with open(pending_path, "w") as f:
for row in pending_rows:
f.write(json.dumps(row) + "\n")
logger.info(f"Wrote {len(ready_rows)} rows to {ready_path}")
logger.info(f"Wrote {len(pending_rows)} rows to {pending_path}")
logger.info(
f"Version distribution across pairs: "
f"raw={version_counts['raw']}, v0={version_counts['v0']}, v1={version_counts['v1']}"
)
return 0
if __name__ == "__main__":
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment