#!/usr/bin/env ts-node

/**
 * Used this Huggingface space to create a JSON formatted file
 * https://huggingface.co/spaces/Xenova/whisper-speaker-diarization
 *
 * https://huggingface.co/onnx-community/whisper-base_timestamped
 * https://huggingface.co/onnx-community/pyannote-segmentation-3.0
 * https://huggingface.co/pyannote/segmentation-3.0
 */
import * as fs from 'fs';
import * as path from 'path';

/**
 * Converts a time (in seconds, possibly with decimals) to an SRT timestamp.
 * Example: 3661.235 -> "01:01:01,235"
 */
function formatTimestamp(seconds: number): string {
  const h = Math.floor(seconds / 3600);
  const m = Math.floor((seconds % 3600) / 60);
  const s = Math.floor(seconds % 60);
  const ms = Math.floor((seconds - Math.floor(seconds)) * 1000);
  return `${String(h).padStart(2, '0')}:${String(m).padStart(2, '0')}:${String(s).padStart(2, '0')},${String(ms).padStart(3, '0')}`;
}

/**
 * Merges an array of chunks (each having a text property) into a single string.
 * It avoids inserting extra spaces before punctuation.
 */
function mergeChunksSimple(chunks: any[]): string {
  let text = "";
  for (const chunk of chunks) {
    const trimmed = chunk.text.trim();
    if (!trimmed) continue;
    if (text === "") {
      text = trimmed;
    } else {
      // Append without a space if the token starts with punctuation.
      if (/^[,\.!?;:]/.test(trimmed)) {
        text += trimmed;
      } else {
        text += " " + trimmed;
      }
    }
  }
  return text;
}

/**
 * Assign speaker labels to chunks using segment information.
 * Each segment object has an `id` and a `label` (the speaker name).
 * For each chunk, if its timestamp falls within a segment’s boundaries,
 * we assign the segment’s label to the chunk.
 */
function assignSpeakers(chunks: any[], segments: any[]): void {
  segments.sort((a, b) => a.start - b.start);
  for (const chunk of chunks) {
    if (!chunk.speaker || chunk.speaker === "NO_SPEAKER") {
      for (const seg of segments) {
        if (chunk.timestamp && chunk.timestamp.length >= 2) {
          if (chunk.timestamp[0] >= seg.start && chunk.timestamp[1] <= seg.end) {
            // Use the segment's label as the speaker name.
            chunk.speaker = seg.label;
            break;
          }
        }
      }
    }
  }
}

/**
 * Groups all chunks into SRT blocks.
 *
 * The grouping algorithm walks through sorted chunks and groups them together if:
 * - The gap between the current chunk and the previous chunk is less than gapThreshold.
 * - They share the same speaker (if available).
 * - And the total duration from the first chunk of the group to the current chunk
 *   does not exceed maxBlockDuration.
 *
 * When any condition fails, a new group is started.
 */
function groupAllChunks(chunks: any[], maxBlockDuration: number = 3.0, gapThreshold: number = 0.5): any[] {
  let groups: any[] = [];
  if (chunks.length === 0) return groups;

  // Ensure chunks are sorted by start time.
  chunks.sort((a: any, b: any) => a.timestamp[0] - b.timestamp[0]);

  // Start the first group.
  let currentGroup = {
    speaker: chunks[0].speaker || "NO_SPEAKER",
    start: chunks[0].timestamp[0],
    end: chunks[0].timestamp[1],
    chunks: [chunks[0]]
  };

  for (let i = 1; i < chunks.length; i++) {
    let chunk = chunks[i];
    let gap = chunk.timestamp[0] - currentGroup.end;
    // When comparing speakers, if the current chunk is NO_SPEAKER, ignore the difference.
    let speakerMismatch = (chunk.speaker && chunk.speaker !== currentGroup.speaker && chunk.speaker !== "NO_SPEAKER" && currentGroup.speaker !== "NO_SPEAKER");

    if (speakerMismatch ||
      gap > gapThreshold ||
      ((chunk.timestamp[1] - currentGroup.start) > maxBlockDuration)) {
      groups.push(currentGroup);
      currentGroup = {
        speaker: chunk.speaker || "NO_SPEAKER",
        start: chunk.timestamp[0],
        end: chunk.timestamp[1],
        chunks: [chunk]
      };
    } else {
      currentGroup.chunks.push(chunk);
      currentGroup.end = chunk.timestamp[1];
      // If the current group is NO_SPEAKER and the new chunk has a valid speaker,
      // update the group speaker.
      if (currentGroup.speaker === "NO_SPEAKER" && chunk.speaker && chunk.speaker !== "NO_SPEAKER") {
        currentGroup.speaker = chunk.speaker;
      }
    }
  }
  groups.push(currentGroup);
  return groups;
}

/**
 * Merges groups whose speaker is "NO_SPEAKER" with an adjacent group that has a valid speaker.
 * It checks if the gap between the groups is less than gapThreshold.
 */
function mergeNoSpeakerGroups(groups: any[], gapThreshold: number = 0.5): any[] {
  let merged: any[] = [];
  for (let i = 0; i < groups.length; i++) {
    let current = groups[i];
    if (current.speaker === "NO_SPEAKER") {
      // Try to merge with the previous group if it exists and has a valid speaker.
      if (merged.length > 0) {
        let prev = merged[merged.length - 1];
        if (prev.speaker !== "NO_SPEAKER" && (current.start - prev.end) <= gapThreshold) {
          // Merge current group into previous.
          prev.end = current.end;
          prev.chunks = prev.chunks.concat(current.chunks);
          continue;
        }
      }
      // Otherwise, if there's a next group with a valid speaker, merge current into next.
      if (i < groups.length - 1 && groups[i + 1].speaker !== "NO_SPEAKER" && (groups[i + 1].start - current.end) <= gapThreshold) {
        groups[i + 1].chunks = current.chunks.concat(groups[i + 1].chunks);
        groups[i + 1].start = current.start;
        continue;
      }
    }
    merged.push(current);
  }
  return merged;
}

/**
 * Merges the segments and chunks from the transcript JSON.
 *
 * If segment info exists, it is used to assign speaker labels to chunks.
 * Then all chunks are grouped into SRT blocks based on time gap and block duration.
 * Finally, any groups labeled as NO_SPEAKER are merged into adjacent groups.
 */
function mergeSegmentsAndChunks(data: any, maxBlockDuration: number = 3.0, gapThreshold: number = 0.5): any[] {
  let mergedSegments: any[] = [];

  if (data.segments && Array.isArray(data.segments)) {
    // Use segment info to assign speakers.
    if (data.chunks && Array.isArray(data.chunks)) {
      assignSpeakers(data.chunks, data.segments);
    }
  }

  if (data.chunks && Array.isArray(data.chunks)) {
    mergedSegments = groupAllChunks(data.chunks, maxBlockDuration, gapThreshold);
    // Merge groups with NO_SPEAKER into the nearest valid speaker group.
    mergedSegments = mergeNoSpeakerGroups(mergedSegments, gapThreshold);
  } else {
    throw new Error('Invalid JSON format: missing "chunks" array.');
  }

  return mergedSegments;
}

/**
 * Converts the provided JSON transcript (merging segments and chunks) to SRT content.
 */
function convertToSrt(data: any): string {
  let segments = mergeSegmentsAndChunks(data);
  // Sort segments by start time.
  segments.sort((a, b) => a.start - b.start);
  let srtOutput = '';
  segments.forEach((seg, index) => {
    const startTime = formatTimestamp(seg.start);
    const endTime = formatTimestamp(seg.end);
    // Prepend the speaker label if available.
    const speakerText = seg.speaker && seg.speaker !== "NO_SPEAKER" ? `${seg.speaker}: ` : '';
    const mergedText = mergeChunksSimple(seg.chunks);

    srtOutput += `${index + 1}\n`;
    srtOutput += `${startTime} --> ${endTime}\n`;
    srtOutput += `${speakerText}${mergedText.trim()}\n\n`;
  });
  return srtOutput;
}

/**
 * CLI handling; I kinda wanted to use `command` but also wanted this to be a plain thing
 */
if (process.argv.length < 3) {
  console.error('Usage: ts-node convert.ts <input.json> [output.srt]');
  process.exit(1);
}

const inputFilePath = process.argv[2];
const outputFilePath = process.argv[3] || path.basename(inputFilePath, path.extname(inputFilePath)) + '.srt';

let jsonData: any;
try {
  const rawData = fs.readFileSync(inputFilePath, 'utf8');
  jsonData = JSON.parse(rawData);
} catch (error) {
  console.error('Error reading or parsing the input JSON file:', error);
  process.exit(1);
}

let srtContent: string;
try {
  srtContent = convertToSrt(jsonData);
} catch (error) {
  console.error('Error converting JSON to SRT:', error);
  process.exit(1);
}

try {
  fs.writeFileSync(outputFilePath, srtContent, 'utf8');
  console.log(`Conversion complete! Output written to ${outputFilePath}`);
} catch (error) {
  console.error('Error writing the SRT file:', error);
  process.exit(1);
}