Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save mtreviso/b618b499bc6de0414a3e11157e91cf02 to your computer and use it in GitHub Desktop.
Save mtreviso/b618b499bc6de0414a3e11157e91cf02 to your computer and use it in GitHub Desktop.
Explaining and Correcting xCOMET error spans with xTower
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "4f1acd3f-b8fb-4c3d-bebb-31e86ee6c7a4",
"metadata": {},
"source": [
"# Explaining and Correcting xCOMET error spans with xTower"
]
},
{
"cell_type": "markdown",
"id": "1b93122b-9bec-400e-8fda-b70353a506ff",
"metadata": {},
"source": [
"### Load xCOMET"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7738279e-071c-49f2-979d-c89b42651239",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "644c655ae81a4ca8bbe5b5f5242adf21",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 5 files: 0%| | 0/5 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Encoder model frozen.\n",
"/home/mtreviso/qenle-eval/env/lib/python3.10/site-packages/pytorch_lightning/core/saving.py:188: Found keys that are not in the model state dict but in the checkpoint: ['encoder.model.embeddings.position_ids']\n"
]
}
],
"source": [
"from comet import download_model, load_from_checkpoint\n",
"\n",
"# Choose your model from Hugging Face Hub\n",
"model_path = download_model(\"Unbabel/XCOMET-XL\")\n",
"\n",
"# Load the model checkpoint:\n",
"model = load_from_checkpoint(model_path)"
]
},
{
"cell_type": "markdown",
"id": "d3147822-2f41-4b4c-ac13-95cc5a4d0fef",
"metadata": {},
"source": [
"### Get error spans"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4d5df25f-be6a-409d-9e70-724b20b76055",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/mtreviso/qenle-eval/env/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3 /home/mtreviso/qenle-eval/env/lib/python3.10/site-p ...\n",
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"Predicting DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.56it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.9822099208831787, 0.9599897861480713, 0.9070618152618408]\n",
"0.9497538407643636\n",
"[[{'text': 'my food', 'confidence': 0.4160953164100647, 'severity': 'minor', 'start': 13, 'end': 21}], [{'text': 'you send it for', 'confidence': 0.40004390478134155, 'severity': 'minor', 'start': 3, 'end': 19}], [{'text': 'trugen Lawinenschilder', 'confidence': 0.4170106053352356, 'severity': 'minor', 'start': 4, 'end': 27}]]\n"
]
}
],
"source": [
"# Data must be in the following format\n",
"# Note that \"refs\" can be removed for referenceless experiments\n",
"data = [\n",
" {\n",
" \"src\": \"10 到 15 分钟可以送到吗\",\n",
" \"mt\": \"Can I receive my food in 10 to 15 minutes?\",\n",
" \"ref\": \"Can it be delivered between 10 to 15 minutes?\"\n",
" },\n",
" {\n",
" \"src\": \"Pode ser entregue dentro de 10 a 15 minutos?\",\n",
" \"mt\": \"Can you send it for 10 to 15 minutes?\",\n",
" \"ref\": \"Can it be delivered between 10 to 15 minutes?\"\n",
" },\n",
" {\n",
" \"src\": \"All were wearing avalanche beacons\",\n",
" \"mt\": \"Alle trugen Lawinenschilder\",\n",
" \"ref\": \"Alle trugen Lawinensuchgeräte\"\n",
" }\n",
"]\n",
"# Call predict method:\n",
"model_output = model.predict(data, batch_size=8, gpus=1)\n",
"\n",
"print(model_output.scores) # sentence-level scores\n",
"print(model_output.system_score) # system-level score\n",
"print(model_output.metadata.error_spans) # detected error spans"
]
},
{
"cell_type": "markdown",
"id": "bacd4377-ef67-4248-94b2-85baccb2124e",
"metadata": {},
"source": [
"### Prepare prompts for xTower"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8db3a12d-494a-4690-8a28-3126029b281d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|im_start|>user\n",
"You are provided with a Source, Translation, Translation quality analysis, and Translation quality score (weak, moderate, good, excellent, best). The Translation quality analysis contain a translation with marked error spans with different levels of severity (minor or major). Additionally, we may provide a **reference translation**. Given this information, generate an explanation for each error and a fully correct translation.\n",
"\n",
"Portuguese source: Pode ser entregue dentro de 10 a 15 minutos?\n",
"English translation: Can you send it for 10 to 15 minutes?\n",
"English reference: Can it be delivered between 10 to 15 minutes?\n",
"Translation quality analysis: Can <error1 severity='minor'>you send it for</error1> 10 to 15 minutes?\n",
"Translation quality score: excellent<|im_end|>\n",
"<|im_start|>assistant\n",
"\n"
]
}
],
"source": [
"def get_discrete_quality_score(score):\n",
" \"\"\"\n",
" Discretizes a given quality score into categories.\n",
"\n",
" Args:\n",
" score (float): The quality score to be discretized.\n",
"\n",
" Returns:\n",
" str: The discrete quality category ('weak', 'moderate', 'good', 'excellent', 'best').\n",
" \"\"\"\n",
" if score < 0.6:\n",
" return 'weak'\n",
" elif score < 0.8:\n",
" return 'moderate'\n",
" elif score < 0.94:\n",
" return 'good'\n",
" elif score < 0.98:\n",
" return 'excellent'\n",
" else:\n",
" return 'best'\n",
"\n",
"\n",
"def annotate_translation_with_error_spans(translation, error_spans):\n",
" \"\"\"\n",
" Annotates a translation string with error spans.\n",
"\n",
" Args:\n",
" translation (str): The translation text to be annotated.\n",
" error_spans (list of dict): A list of error spans, where each span is a dictionary \n",
" with 'start', 'end', and 'severity' keys.\n",
"\n",
" Returns:\n",
" str: The annotated translation text with error tags.\n",
" \"\"\"\n",
" annotated_translation = str(translation)\n",
" error_spans = list(sorted(error_spans, key=lambda x: x['start']))\n",
" # Iterate over the error spans in reverse order\n",
" for i, span in enumerate(error_spans[::-1]):\n",
" error_id = len(error_spans) - i # Assign a unique error ID based on the reverse index\n",
" start, end, severity = span['start'], span['end'], span['severity'].lower()\n",
" # Insert error tags around the specified span in the translation\n",
" annotated_translation = (\n",
" annotated_translation[:start].strip() +\n",
" f\" <error{error_id} severity='{severity}'>\" +\n",
" annotated_translation[start:end].strip() +\n",
" f\"</error{error_id}> \" +\n",
" annotated_translation[end:]\n",
" )\n",
" # Trim potential double spaces around error tags\n",
" return annotated_translation.replace(' <', ' <').replace('> ', '> ').strip()\n",
"\n",
"\n",
"def create_prompt(sample, src_lang, mt_lang):\n",
" \"\"\"\n",
" Creates a prompt for translation quality assessment.\n",
"\n",
" Args:\n",
" sample (dict): A dictionary containing the translation data.\n",
" src_lang (str): The source language (e.g., \"English).\n",
" mt_lang (str): The machine translation language (e.g., \"German).\n",
"\n",
" Returns:\n",
" str: The generated prompt for translation quality assessment.\n",
" \"\"\"\n",
" prompt = \"<|im_start|>user\"\n",
" prompt += \"\\n\"\n",
" prompt += \"You are provided with a Source, Translation, Translation quality analysis, and Translation quality score (weak, moderate, good, excellent, best). \"\n",
" prompt += \"The Translation quality analysis contain a translation with marked error spans with different levels of severity (minor or major). \"\n",
" prompt += \"Additionally, we may provide a **reference translation**. \"\n",
" prompt += \"Given this information, generate an explanation for each error and a fully correct translation.\"\n",
" prompt += \"\\n\\n\"\n",
" prompt += f\"{src_lang} source: {sample['src']}\"\n",
" prompt += \"\\n\"\n",
" prompt += f\"{mt_lang} translation: {sample['mt']}\"\n",
" prompt += \"\\n\"\n",
" if 'ref' in sample.keys():\n",
" prompt += f\"{mt_lang} reference: {sample['ref']}\"\n",
" prompt += \"\\n\"\n",
" prompt += f\"Translation quality analysis: {sample['annotated_mt']}\"\n",
" prompt += \"\\n\"\n",
" prompt += f\"Translation quality score: {sample['discrete_score']}\"\n",
" prompt += \"<|im_end|>\\n<|im_start|>assistant\\n\"\n",
" return prompt\n",
"\n",
"\n",
"# annotate the Portuguese-English sample with error spans \n",
"sample = data[1]\n",
"score = model_output.scores[1]\n",
"error_spans = model_output.metadata.error_spans[1]\n",
"sample['annotated_mt'] = annotate_translation_with_error_spans(sample['mt'], error_spans)\n",
"sample['discrete_score'] = get_discrete_quality_score(score)\n",
"\n",
"# create prompts\n",
"prompt = create_prompt(sample, src_lang='Portuguese', mt_lang='English')\n",
"\n",
"print(prompt)"
]
},
{
"cell_type": "markdown",
"id": "cf489a97-48f2-4144-85f9-152b6ae95f44",
"metadata": {},
"source": [
"### Prompt xTower\n",
"\n",
"1. Using VLLM (recommended)\n",
"2. Using huggingface generate"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "77ea08b8-4ac6-4a28-9325-a7a7cde5f0f9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 06-28 15:20:43 llm_engine.py:161] Initializing an LLM engine (v0.5.0.post1) with config: model='sardinelab/xTower13B', speculative_config=None, tokenizer='sardinelab/xTower13B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=sardinelab/xTower13B)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO 06-28 15:20:44 weight_utils.py:218] Using model weights format ['*.safetensors']\n",
"INFO 06-28 15:20:49 model_runner.py:160] Loading model weights took 24.2869 GB\n",
"INFO 06-28 15:20:50 gpu_executor.py:83] # GPU blocks: 1473, # CPU blocks: 327\n"
]
}
],
"source": [
"from vllm import LLM\n",
"from transformers import pipeline\n",
"\n",
"def load_xtower(generate_lib='vllm'):\n",
" \"\"\"\n",
" Loads the xTower model using either VLLM (recommended) or Huggingface pipeline.\n",
"\n",
" Args:\n",
" generate_lib (str): The library to use for loading the model ('vllm' or 'huggingface').\n",
"\n",
" Returns:\n",
" object: The loaded language model.\n",
" \"\"\"\n",
" if generate_lib == 'vllm':\n",
" llm = LLM(\n",
" model=\"sardinelab/xTower13B\", \n",
" tensor_parallel_size=1, \n",
" enforce_eager=True\n",
" )\n",
" \n",
" else:\n",
" llm = pipeline(\n",
" \"text-generation\", \n",
" model=\"sardinelab/xTower13B\", \n",
" device_map=\"auto\"\n",
" )\n",
" \n",
" return llm\n",
"\n",
"xtower_llm = load_xtower(generate_lib='vllm')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "cc0cd1c9-0306-4884-ae20-7ff65e3aa07b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00, 5.66s/it, est. speed input: 34.49 toks/s, output: 24.23 toks/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Explanation for error1: The phrase \"you send it for\" is a direct translation that doesn't accurately convey the intended meaning of the original Portuguese sentence. The original sentence is asking if the delivery can be made within a time frame of 10 to 15 minutes, not asking someone to send something for a duration of 10 to 15 minutes. The English reference translation \"Can it be delivered between 10 to 15 minutes?\" is more accurate and clearly conveys the intended meaning of delivery within a specific time frame.\n",
"Translation correction: Can it be delivered within 10 to 15 minutes?\n",
"---\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"def prompt_xtower(prompts, llm):\n",
" \"\"\"\n",
" Generates outputs based on the provided prompts using the selected library.\n",
"\n",
" Args:\n",
" prompts (list of str): A list of prompts to generate responses for.\n",
" llm (object): The loaded language model.\n",
"\n",
" Returns:\n",
" list of str: The generated outputs for each prompt.\n",
" \"\"\"\n",
" if isinstance(llm, LLM):\n",
" from vllm import SamplingParams\n",
" sampling_params = SamplingParams(temperature=0, max_tokens=1024, stop=[\"</s>\"])\n",
" responses = llm.generate(prompts, sampling_params)\n",
" outputs = [response.outputs[0].text.strip() for response in responses]\n",
" \n",
" else:\n",
" outputs = llm(prompts, max_new_tokens=1024, do_sample=False)\n",
" \n",
" return outputs\n",
"\n",
"outputs = prompt_xtower(prompts=[prompt], llm=xtower_llm)\n",
"for output in outputs:\n",
" print(output)\n",
" print('---')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec3884db-e5d6-4b52-a7eb-d0f06437574d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment