Skip to content

Instantly share code, notes, and snippets.

@0wwafa
Last active May 28, 2024 12:18
Show Gist options
  • Save 0wwafa/dd034bbb7248e53d39a07c0233bea293 to your computer and use it in GitHub Desktop.
Save 0wwafa/dd034bbb7248e53d39a07c0233bea293 to your computer and use it in GitHub Desktop.
aya.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"0272ba7f31a2441ab1cb5b8f77dbaacb": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_d1bb171ddebd4f4bbeb4ed5d4b8b7076",
"IPY_MODEL_33b4fc55703746778511265e28160837",
"IPY_MODEL_7548c151f8764276ad7951e2ac80d981"
],
"layout": "IPY_MODEL_d972c72fef7c45998469550318661e71"
}
},
"d1bb171ddebd4f4bbeb4ed5d4b8b7076": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_2811b7c68a7b4c95b91bd5690cf06577",
"placeholder": "​",
"style": "IPY_MODEL_a33ccfdb735948e98a19d901d8091319",
"value": "Loading checkpoint shards: 100%"
}
},
"33b4fc55703746778511265e28160837": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_c1103244cec74a299265729e630faffd",
"max": 4,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_340941cfc49e4ab983b73fb48c30dfe8",
"value": 4
}
},
"7548c151f8764276ad7951e2ac80d981": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_8bb42aa84f4b4a9ab6417aed92132063",
"placeholder": "​",
"style": "IPY_MODEL_b0cf428afc21468caeb664428627aaf6",
"value": " 4/4 [00:11<00:00,  2.57s/it]"
}
},
"d972c72fef7c45998469550318661e71": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"2811b7c68a7b4c95b91bd5690cf06577": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"a33ccfdb735948e98a19d901d8091319": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"c1103244cec74a299265729e630faffd": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"340941cfc49e4ab983b73fb48c30dfe8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"8bb42aa84f4b4a9ab6417aed92132063": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b0cf428afc21468caeb664428627aaf6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/0wwafa/dd034bbb7248e53d39a07c0233bea293/aya.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!pip install -U bitsandbytes transformers peft accelerate trl datasets sentencepiece wandb\n",
"!pip install flash-attn --no-build-isolation"
],
"metadata": {
"id": "tg1moVggj5sk",
"collapsed": true,
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "418229fd-c995-496e-b6e3-4c0a7ac57ff5"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting bitsandbytes\n",
" Downloading bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl (119.8 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m119.8/119.8 MB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.41.0)\n",
"Collecting transformers\n",
" Downloading transformers-4.41.1-py3-none-any.whl (9.1 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.1/9.1 MB\u001b[0m \u001b[31m46.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting peft\n",
" Downloading peft-0.11.1-py3-none-any.whl (251 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.6/251.6 kB\u001b[0m \u001b[31m21.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting accelerate\n",
" Downloading accelerate-0.30.1-py3-none-any.whl (302 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m302.6/302.6 kB\u001b[0m \u001b[31m19.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting trl\n",
" Downloading trl-0.8.6-py3-none-any.whl (245 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m245.2/245.2 kB\u001b[0m \u001b[31m15.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting datasets\n",
" Downloading datasets-2.19.1-py3-none-any.whl (542 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m542.0/542.0 kB\u001b[0m \u001b[31m22.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.1.99)\n",
"Collecting sentencepiece\n",
" Downloading sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m20.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting wandb\n",
" Downloading wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.7 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.7/6.7 MB\u001b[0m \u001b[31m22.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (2.3.0+cu121)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (1.25.2)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.14.0)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.23.1)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.12.25)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
"Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n",
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.3)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.4)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n",
"Collecting tyro>=0.5.11 (from trl)\n",
" Downloading tyro-0.8.4-py3-none-any.whl (102 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m102.4/102.4 kB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (14.0.2)\n",
"Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n",
"Collecting dill<0.3.9,>=0.3.0 (from datasets)\n",
" Downloading dill-0.3.8-py3-none-any.whl (116 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m12.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.0.3)\n",
"Collecting xxhash (from datasets)\n",
" Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting multiprocess (from datasets)\n",
" Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m9.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: fsspec[http]<=2024.3.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.5)\n",
"Requirement already satisfied: click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.7)\n",
"Collecting docker-pycreds>=0.4.0 (from wandb)\n",
" Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n",
"Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)\n",
" Downloading GitPython-3.1.43-py3-none-any.whl (207 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.3/207.3 kB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: platformdirs in /usr/local/lib/python3.10/dist-packages (from wandb) (4.2.2)\n",
"Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n",
"Collecting sentry-sdk>=1.0.0 (from wandb)\n",
" Downloading sentry_sdk-2.3.1-py2.py3-none-any.whl (289 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m289.0/289.0 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting setproctitle (from wandb)\n",
" Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (67.7.2)\n",
"Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
"Collecting gitdb<5,>=4.0.1 (from gitpython!=3.1.29,>=1.0.0->wandb)\n",
" Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.0->transformers) (4.11.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.7)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.2.2)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (3.3)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (3.1.4)\n",
"Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->bitsandbytes)\n",
" Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
"Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->bitsandbytes)\n",
" Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
"Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->bitsandbytes)\n",
" Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
"Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->bitsandbytes)\n",
" Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n",
"Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->bitsandbytes)\n",
" Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
"Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->bitsandbytes)\n",
" Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
"Collecting nvidia-curand-cu12==10.3.2.106 (from torch->bitsandbytes)\n",
" Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
"Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch->bitsandbytes)\n",
" Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
"Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch->bitsandbytes)\n",
" Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"MODEL_NAME = \"CohereForAI/aya-23-8b\"\n",
"\n",
"# you may want to change the following parameters depending on your GPU configuration\n",
"\n",
"# free T4 instance\n",
"QUANTIZE_4BIT = True\n",
"USE_GRAD_CHECKPOINTING = True\n",
"TRAIN_BATCH_SIZE = 2\n",
"TRAIN_MAX_SEQ_LENGTH = 512\n",
"USE_FLASH_ATTENTION = False\n",
"GRAD_ACC_STEPS = 16\n",
"\n",
"# equivalent A100 setting\n",
"# QUANTIZE_4BIT = True\n",
"# USE_GRAD_CHECKPOINTING = True\n",
"# TRAIN_BATCH_SIZE = 16\n",
"# TRAIN_MAX_SEQ_LENGTH = 512\n",
"# USE_FLASH_ATTENTION = True\n",
"# GRAD_ACC_STEPS = 2"
],
"metadata": {
"id": "Izn6BYEYw4um"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging\n",
"from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model\n",
"import os,torch\n",
"import bitsandbytes as bnb\n",
"from datasets import load_dataset\n",
"from trl import SFTTrainer\n",
"from datasets import Dataset\n",
"import pyarrow as pa\n",
"import pyarrow.dataset as ds\n",
"import pandas as pd\n",
"import re\n",
"import wandb"
],
"metadata": {
"id": "wMs9uNDMHL6R"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Load Model\n",
"quantization_config = None\n",
"if QUANTIZE_4BIT:\n",
" quantization_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
" )\n",
"\n",
"attn_implementation = None\n",
"if USE_FLASH_ATTENTION:\n",
" attn_implementation=\"flash_attention_2\"\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" MODEL_NAME,\n",
" quantization_config=quantization_config,\n",
" attn_implementation=attn_implementation,\n",
" torch_dtype=torch.bfloat16,\n",
" device_map=\"auto\",\n",
" )"
],
"metadata": {
"id": "d9a23_jiC-qG"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Load tokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)"
],
"metadata": {
"id": "YuqAA8GhYSdO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_message_format(prompts):\n",
" messages = []\n",
"\n",
" for p in prompts:\n",
" messages.append(\n",
" [{\"role\": \"user\", \"content\": p}]\n",
" )\n",
"\n",
" return messages\n",
"\n",
"def generate_aya_23(\n",
" prompts,\n",
" model,\n",
" temperature=0.3,\n",
" top_p=0.75,\n",
" top_k=0,\n",
" max_new_tokens=1024\n",
" ):\n",
"\n",
" messages = get_message_format(prompts)\n",
"\n",
" input_ids = tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize=True,\n",
" add_generation_prompt=True,\n",
" padding=True,\n",
" return_tensors=\"pt\",\n",
" )\n",
" input_ids = input_ids.to(model.device)\n",
" prompt_padded_len = len(input_ids[0])\n",
"\n",
" gen_tokens = model.generate(\n",
" input_ids,\n",
" temperature=temperature,\n",
" top_p=top_p,\n",
" top_k=top_k,\n",
" max_new_tokens=max_new_tokens,\n",
" do_sample=True,\n",
" )\n",
"\n",
" # get only generated tokens\n",
" gen_tokens = [\n",
" gt[prompt_padded_len:] for gt in gen_tokens\n",
" ]\n",
"\n",
" gen_text = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)\n",
" return gen_text"
],
"metadata": {
"id": "s75a8Vda-eqx"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Test generations on langauges in Aya 23 set\n",
"prompts = [\n",
" \"Write a list of three fruits and tell me about each of them\", # English\n",
" \"Viết danh sách ba loại trái cây và kể cho tôi nghe về từng loại trái cây đó\", # Vietnamese\n",
" \"3 つの果物のリストを書いて、それぞれについて教えてください\", # Japanese\n",
" \"Üç meyveden oluşan bir liste yazın ve bana her birini anlatın\", # Turkish\n",
" \"Scrivi una lista di 3 frutti e parlami di ognuno di essi.\" # Italian\n",
"]\n",
"\n",
"generations = generate_aya_23(prompts, model)\n",
"\n",
"for p, g in zip(prompts, generations):\n",
" print(\n",
" \"PROMPT\", p ,\"RESPONSE\", g, \"\\n\", sep=\"\\n\"\n",
" )"
],
"metadata": {
"id": "4l12EC7q-h3I"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Load MEISD dataset from ZeroWv\n",
"dataset = load_dataset(\"ZeroWw/MEISD\")\n",
"dataset = dataset.filter(lambda example: example['TV Series']=='HIMYM')\n",
"\n",
"def formatting_prompts_func(example):\n",
" output_texts = []\n",
" for i in range(len(example['Utterances'])):\n",
" text = f\"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{example['Utterances'][i]}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{example['emotion'][i]}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{example['emotion2'][i]}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{example['intensity'][i]}\"\n",
" output_texts.append(text)\n",
" return output_texts\n",
"\n",
"def formatting_prompts_func_old(example):\n",
" output_texts = []\n",
" for i in range(len(example['Utterances'])):\n",
" text = f\"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{example['Utterances'][i]}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{example['emotion'][i]}\"\n",
" output_texts.append(text)\n",
" return output_texts"
],
"metadata": {
"id": "CHXm3Io5zCrk"
},
"execution_count": 38,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Training Arguments\n",
"training_arguments = TrainingArguments(\n",
" output_dir=\"results\",\n",
" num_train_epochs=2, # Keep this low for quick training\n",
" per_device_train_batch_size=TRAIN_BATCH_SIZE, # Lower batch size reduces memory usage\n",
" gradient_accumulation_steps=GRAD_ACC_STEPS, # Lower steps reduces memory usage\n",
" gradient_checkpointing=USE_GRAD_CHECKPOINTING, # Disable to reduce memory usage\n",
" optim=\"paged_adamw_32bit\", # Use a more memory-efficient optimizer\n",
" save_steps=100, # Save less frequently to reduce I/O operations\n",
" logging_steps=20, # Log less frequently to reduce I/O operations\n",
" learning_rate=1e-3, # This is usually a good starting point\n",
" weight_decay=0.0, # Disable weight decay to speed up training\n",
" fp16=False, # Enable mixed precision training to reduce memory usage and speed up training\n",
" bf16=True, # Disable bfloat16 unless you're on specific hardware that supports it\n",
" warmup_ratio=0.1, # A higher warmup ratio can sometimes help with training stability\n",
" group_by_length=True, # This can speed up training by reducing the amount of padding needed\n",
" lr_scheduler_type=\"linear\", # A linear scheduler usually works well in practice\n",
" report_to=\"none\" # Disable reporting to save resources\n",
")\n",
"training_arguments_old = TrainingArguments(\n",
" output_dir=\"results\",\n",
" num_train_epochs=3,\n",
" per_device_train_batch_size=TRAIN_BATCH_SIZE,\n",
" gradient_accumulation_steps=GRAD_ACC_STEPS,\n",
" gradient_checkpointing=USE_GRAD_CHECKPOINTING,\n",
" optim=\"paged_adamw_32bit\",\n",
" save_steps=50,\n",
" logging_steps=10,\n",
" learning_rate=1e-2,\n",
" weight_decay=0.01,\n",
" fp16=False,\n",
" bf16=True,\n",
" warmup_ratio=0.05,\n",
" group_by_length=True,\n",
" lr_scheduler_type=\"constant\",\n",
" report_to=\"none\"\n",
")\n",
"\n",
"peft_config = LoraConfig(\n",
" lora_alpha=32,\n",
" r=32,\n",
" bias=\"none\",\n",
" task_type=\"CAUSAL_LM\",\n",
" target_modules=[\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\"]\n",
")\n",
"\n",
"trainer = SFTTrainer(\n",
" model=model,\n",
" train_dataset=dataset[\"train\"],\n",
" peft_config=peft_config,\n",
" max_seq_length=TRAIN_MAX_SEQ_LENGTH,\n",
" tokenizer=tokenizer,\n",
" args=training_arguments,\n",
" formatting_func=formatting_prompts_func\n",
")"
],
"metadata": {
"id": "A9OdyDDEy7rM",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "59025a17-08c9-4f84-8fc3-6aca606f72fe"
},
"execution_count": 43,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:318: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.\n",
" warnings.warn(\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"trainer.train()"
],
"metadata": {
"id": "9BvK-3eYiwhx",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 147
},
"outputId": "76ac8ea6-a418-4752-ef81-ba30c6aae267"
},
"execution_count": 1,
"outputs": [
{
"output_type": "error",
"ename": "NameError",
"evalue": "name 'trainer' is not defined",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-1-3435b262f1ae>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mNameError\u001b[0m: name 'trainer' is not defined"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Save the model to disk\n",
"trainer.model.save_pretrained(save_directory='aya-qlora')\n",
"model.config.use_cache = True\n",
"model.eval()"
],
"metadata": {
"id": "X3Lqfwo-8CCG"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Test Bengali inference on loaded fine-tuned model\n",
"\n",
"# Load Model and LoRA Adapter\n",
"quantization_config = None\n",
"if QUANTIZE_4BIT:\n",
" quantization_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
" )\n",
"\n",
"attn_implementation = None\n",
"if USE_FLASH_ATTENTION:\n",
" attn_implementation=\"flash_attention_2\"\n",
"\n",
"loaded_model = AutoModelForCausalLM.from_pretrained(\n",
" MODEL_NAME,\n",
" quantization_config=quantization_config,\n",
" attn_implementation=attn_implementation,\n",
" torch_dtype=torch.bfloat16,\n",
" device_map=\"auto\",\n",
" )\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
"loaded_model.load_adapter(\"aya-qlora\")\n",
"\n",
"\n",
"prompts = [\n",
" 'Translate from English to Bengali: \"Rates are competitive, almost always the best in the market\"'\n",
"]\n",
"\n",
"generations = generate_aya_23(prompts, loaded_model)\n",
"\n",
"for p, g in zip(prompts, generations):\n",
" print(\n",
" \"PROMPT\", p ,\"RESPONSE\", g, \"\\n\", sep=\"\\n\"\n",
" )"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 174,
"referenced_widgets": [
"0272ba7f31a2441ab1cb5b8f77dbaacb",
"d1bb171ddebd4f4bbeb4ed5d4b8b7076",
"33b4fc55703746778511265e28160837",
"7548c151f8764276ad7951e2ac80d981",
"d972c72fef7c45998469550318661e71",
"2811b7c68a7b4c95b91bd5690cf06577",
"a33ccfdb735948e98a19d901d8091319",
"c1103244cec74a299265729e630faffd",
"340941cfc49e4ab983b73fb48c30dfe8",
"8bb42aa84f4b4a9ab6417aed92132063",
"b0cf428afc21468caeb664428627aaf6"
]
},
"id": "w5HGIJtRJN-y",
"outputId": "441193fe-89fa-40ad-8585-d1f2dcf124e5"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "0272ba7f31a2441ab1cb5b8f77dbaacb"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"PROMPT\n",
"Translate from English to Bengali: \"Rates are competitive, almost always the best in the market\"\n",
"RESPONSE\n",
"\"দরগুলি প্রতিযোগিতামূলক, প্রায় সবসময় বাজারে সেরা\"\n",
"\n",
"\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment