Created
September 21, 2023 11:44
-
-
Save fancyerii/d89d4d887b020eec1d756f606c3bc8bf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/llama_recipes/configs/datasets.py b/src/llama_recipes/configs/datasets.py | |
index 62230e4..3e4c76f 100644 | |
--- a/src/llama_recipes/configs/datasets.py | |
+++ b/src/llama_recipes/configs/datasets.py | |
@@ -15,8 +15,8 @@ class samsum_dataset: | |
@dataclass | |
class grammar_dataset: | |
dataset: str = "grammar_dataset" | |
- train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" | |
- test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv" | |
+ train_split: str = "src/llama_recipes/datasets2/grammar_dataset/gtrain_10k.csv" | |
+ test_split: str = "src/llama_recipes/datasets2/grammar_dataset/grammar_validation.csv" | |
input_length: int = 2048 | |
@@ -25,7 +25,7 @@ class alpaca_dataset: | |
dataset: str = "alpaca_dataset" | |
train_split: str = "train" | |
test_split: str = "val" | |
- data_path: str = "src/llama_recipes/datasets/alpaca_data.json" | |
+ data_path: str = "src/llama_recipes/datasets2/alpaca_data.json" | |
@dataclass | |
diff --git a/src/llama_recipes/datasets/__init__.py b/src/llama_recipes/datasets/__init__.py | |
deleted file mode 100644 | |
index 57d2376..0000000 | |
--- a/src/llama_recipes/datasets/__init__.py | |
+++ /dev/null | |
@@ -1,6 +0,0 @@ | |
-# Copyright (c) Meta Platforms, Inc. and affiliates. | |
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
- | |
-from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset | |
-from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset | |
-from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset | |
\ No newline at end of file | |
diff --git a/src/llama_recipes/datasets/alpaca_dataset.py b/src/llama_recipes/datasets/alpaca_dataset.py | |
deleted file mode 100644 | |
index 091aef9..0000000 | |
--- a/src/llama_recipes/datasets/alpaca_dataset.py | |
+++ /dev/null | |
@@ -1,78 +0,0 @@ | |
-# Copyright (c) Meta Platforms, Inc. and affiliates. | |
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
- | |
-# For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html | |
- | |
-import copy | |
-import json | |
- | |
-import torch | |
-from torch.utils.data import Dataset | |
- | |
- | |
-PROMPT_DICT = { | |
- "prompt_input": ( | |
- "Below is an instruction that describes a task, paired with an input that provides further context. " | |
- "Write a response that appropriately completes the request.\n\n" | |
- "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" | |
- ), | |
- "prompt_no_input": ( | |
- "Below is an instruction that describes a task. " | |
- "Write a response that appropriately completes the request.\n\n" | |
- "### Instruction:\n{instruction}\n\n### Response:" | |
- ), | |
-} | |
- | |
-class InstructionDataset(Dataset): | |
- def __init__(self, dataset_config, tokenizer, partition="train", max_words=30): | |
- self.ann = json.load(open(dataset_config.data_path)) | |
- if partition == "train": | |
- self.ann = self.ann | |
- else: | |
- self.ann = self.ann[:200] | |
- | |
- self.max_words = max_words | |
- # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model") | |
- self.tokenizer = tokenizer | |
- # self.tokenizer1 = tokenizer | |
- | |
- def __len__(self): | |
- return len(self.ann) | |
- | |
- def __getitem__(self, index): | |
- IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss | |
- | |
- | |
- ann = self.ann[index] | |
- if ann.get("input", "") == "": | |
- prompt = PROMPT_DICT["prompt_no_input"].format_map(ann) | |
- else: | |
- prompt = PROMPT_DICT["prompt_input"].format_map(ann) | |
- example = prompt + ann["output"] | |
- prompt = torch.tensor( | |
- self.tokenizer.encode(prompt), dtype=torch.int64 | |
- ) | |
- example = self.tokenizer.encode(example) | |
- example.append(self.tokenizer.eos_token_id) | |
- example = torch.tensor( | |
- example, dtype=torch.int64 | |
- ) | |
- padding = self.max_words - example.shape[0] | |
- if padding > 0: | |
- example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1)) | |
- elif padding < 0: | |
- example = example[: self.max_words] | |
- labels = copy.deepcopy(example) | |
- labels[: len(prompt)] = -1 | |
- example_mask = example.ge(0) | |
- label_mask = labels.ge(0) | |
- example[~example_mask] = 0 | |
- labels[~label_mask] = IGNORE_INDEX | |
- example_mask = example_mask.float() | |
- label_mask = label_mask.float() | |
- | |
- return { | |
- "input_ids": example, | |
- "labels": labels, | |
- "attention_mask":example_mask, | |
- } | |
diff --git a/src/llama_recipes/datasets/grammar_dataset/__init__.py b/src/llama_recipes/datasets/grammar_dataset/__init__.py | |
deleted file mode 100644 | |
index b193f67..0000000 | |
--- a/src/llama_recipes/datasets/grammar_dataset/__init__.py | |
+++ /dev/null | |
@@ -1,3 +0,0 @@ | |
-# Copyright (c) Meta Platforms, Inc. and affiliates. | |
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
- | |
diff --git a/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py b/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py | |
deleted file mode 100644 | |
index 47383c4..0000000 | |
--- a/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py | |
+++ /dev/null | |
@@ -1,85 +0,0 @@ | |
-# Copyright (c) Meta Platforms, Inc. and affiliates. | |
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
- | |
-# For dataset details visit: https://huggingface.co/datasets/jfleg | |
-# For download and preparation see: recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb | |
- | |
- | |
-from datasets import load_dataset | |
-from pathlib import Path | |
- | |
-from torch.utils.data import Dataset | |
- | |
-from llama_recipes.datasets.utils import ConcatDataset | |
- | |
- | |
-class grammar(Dataset): | |
- def __init__( | |
- self, | |
- tokenizer, | |
- csv_name=None, | |
- ): | |
- | |
- try: | |
- self.dataset = load_dataset( | |
- "csv", | |
- data_files={"train": [csv_name]}, # "eval": "grammar_validation.csv"}, | |
- delimiter=",", | |
- ) | |
- except Exception as e: | |
- print("Loading of grammar dataset failed! Please see recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb for details on how to download the dataset.") | |
- raise e | |
- | |
- # self.dataset = load_dataset("wikihow", "all", data_dir="data/", split=type_path) | |
- # if num_samples: | |
- # self.dataset = self.dataset.select(list(range(0, num_samples))) | |
- self.tokenizer = tokenizer | |
- self.print_text = False # print_text | |
- | |
- def __len__(self): | |
- return self.dataset["train"].shape[0] | |
- | |
- def convert_to_features(self, example_batch): | |
- | |
- # Create prompt and tokenize contexts and questions | |
- | |
- if self.print_text: | |
- print("Input Text: ", self.clean_text(example_batch["text"])) | |
- | |
- input_ = example_batch["input"] | |
- target_ = example_batch["target"] | |
- | |
- prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}" | |
- sample = self.tokenizer(prompt) | |
- | |
- return sample | |
- | |
- def __getitem__(self, index): | |
- sample = self.convert_to_features(self.dataset["train"][index]) | |
- source_ids = sample["input_ids"] | |
- | |
- src_mask = sample["attention_mask"] | |
- | |
- return { | |
- "input_ids": source_ids, | |
- "attention_mask": src_mask, | |
- "labels": source_ids.copy(), | |
- } | |
- | |
- | |
-def get_dataset( | |
- dataset_config, tokenizer, csv_name=None | |
-): | |
- """cover function for handling loading the working dataset""" | |
- """dataset loading""" | |
- if csv_name is None: | |
- currPath = Path.cwd() / "datasets_grammar" / "grammar_train.csv" | |
- print(f"Loading dataset {currPath}") | |
- csv_name = str(currPath) | |
- dataset = grammar( | |
- tokenizer=tokenizer, | |
- csv_name=csv_name, | |
- ) | |
- | |
- return ConcatDataset(dataset, chunk_size=dataset_config.input_length) | |
- | |
diff --git a/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb b/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb | |
deleted file mode 100644 | |
index ccbddca..0000000 | |
--- a/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb | |
+++ /dev/null | |
@@ -1,463 +0,0 @@ | |
-{ | |
- "cells": [ | |
- { | |
- "attachments": {}, | |
- "cell_type": "markdown", | |
- "metadata": {}, | |
- "source": [ | |
- "Copyright (c) Meta Platforms, Inc. and affiliates.\n", | |
- "This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.\n", | |
- "\n", | |
- "Use this notebook to pull in datasets and apply pre-processing. Most grammar datasets unfortunately require preprocessing before being usable in training. (example - jfleg has 4 targets per input, so we have to rematch as 1:1 pairings) " | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- "execution_count": 1, | |
- "metadata": {}, | |
- "outputs": [], | |
- | |
- "source": [ | |
- "import csv\n", | |
- "from datasets import load_metric, load_dataset\n", | |
- "from pathlib import Path" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- "execution_count": 2, | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "list_replacements = [\n", | |
- " (\" .\", \".\"), \n", | |
- " (\" ,\", \",\"),\n", | |
- " (\" '\", \"'\"),\n", | |
- " (\" ?\", \"?\"),\n", | |
- " (\" !\", \"!\"),\n", | |
- " (\" :\", \"!\"),\n", | |
- " (\" ;\", \"!\"),\n", | |
- " (\" n't\", \"n't\"),\n", | |
- " (\" v\", \"n't\"),\n", | |
- " (\"2 0 0 6\", \"2006\"),\n", | |
- " (\"5 5\", \"55\"),\n", | |
- " (\"4 0 0\", \"400\"),\n", | |
- " (\"1 7-5 0\", \"1750\"),\n", | |
- " (\"2 0 %\", \"20%\"),\n", | |
- " (\"5 0\", \"50\"),\n", | |
- " (\"1 2\", \"12\"),\n", | |
- " (\"1 0\", \"10\"),\n", | |
- " ('\" ballast water', '\"ballast water')\n", | |
- " ]" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- "execution_count": 3, | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "def correct_spacing(item):\n", | |
- " \"\"\" we iterate through the list of all replacements per each item in dataset\"\"\"\n", | |
- " for fix in list_replacements:\n", | |
- " item = item.replace(fix[0], fix[1])\n", | |
- " return item\n", | |
- "\n" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- "execution_count": 4, | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "def generate_csv(csv_path, dataset):\n", | |
- " \"\"\" apply spacing corrections and save out matched pairs to csv file as dataset\"\"\"\n", | |
- " with open(csv_path, 'w', newline='') as csvfile:\n", | |
- " writer = csv.writer(csvfile)\n", | |
- " writer.writerow([\"input\", \"target\"])\n", | |
- " for case in dataset:\n", | |
- " \t # Adding the t5 task indication prefix to input \n", | |
- | |
- " input_text = case[\"sentence\"]\n", | |
- | |
- " input_text = correct_spacing(input_text)\n", | |
- "\n", | |
- " for correction in case[\"corrections\"]:\n", | |
- " correction = correct_spacing(correction)\n", | |
- " # a few of the cases contain blank strings. \n", | |
- " if input_text and correction:\n", | |
- " writer.writerow([input_text, correction])" | |
- ] | |
- }, | |
- { | |
- "attachments": {}, | |
- "cell_type": "markdown", | |
- "metadata": {}, | |
- "source": [ | |
- "In Jfleg - validation will be used as 'train', test will be 'validation'" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 5, | |
- | |
- "metadata": {}, | |
- "outputs": [ | |
- { | |
- "name": "stderr", | |
- "output_type": "stream", | |
- "text": [ | |
- | |
- "Found cached dataset jfleg (/data/home/mreso/.cache/huggingface/datasets/jfleg/default/1.0.0/ed4ab2367351fe31949f48849ae6732b164f0d5ea6bb5d4357ff4293ac89511b)\n", | |
- "Found cached dataset jfleg (/data/home/mreso/.cache/huggingface/datasets/jfleg/default/1.0.0/ed4ab2367351fe31949f48849ae6732b164f0d5ea6bb5d4357ff4293ac89511b)\n" | |
- | |
- ] | |
- } | |
- ], | |
- "source": [ | |
- "train_dataset = load_dataset(\"jfleg\", split='validation[:]') \n", | |
- "eval_dataset = load_dataset(\"jfleg\", split='test[:]')\n" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 6, | |
- | |
- "metadata": {}, | |
- "outputs": [ | |
- { | |
- "name": "stdout", | |
- "output_type": "stream", | |
- "text": [ | |
- "Dataset({\n", | |
- " features: ['sentence', 'corrections'],\n", | |
- " num_rows: 755\n", | |
- "})\n", | |
- "Dataset({\n", | |
- " features: ['sentence', 'corrections'],\n", | |
- " num_rows: 748\n", | |
- "})\n" | |
- ] | |
- } | |
- ], | |
- "source": [ | |
- "print(train_dataset)\n", | |
- "print(eval_dataset)\n" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 7, | |
- | |
- "metadata": {}, | |
- "outputs": [ | |
- { | |
- "name": "stdout", | |
- "output_type": "stream", | |
- "text": [ | |
- "Students can focus on only a few subjects they are intwerested in and they will become an experts in those areas . \n", | |
- "['Students can focus on only a few subjects they are interested in and they will become experts in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become experts in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become an expert in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become an expert in those areas . ']\n" | |
- ] | |
- } | |
- ], | |
- "source": [ | |
- "print(train_dataset['sentence'][22])\n", | |
- "print(train_dataset['corrections'][22])" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 8, | |
- | |
- "metadata": {}, | |
- "outputs": [ | |
- { | |
- "data": { | |
- "text/plain": [ | |
- "'Students can focus on only a few subjects they are intwerested in and they will become an experts in those areas. '" | |
- ] | |
- }, | |
- | |
- "execution_count": 8, | |
- | |
- "metadata": {}, | |
- "output_type": "execute_result" | |
- } | |
- ], | |
- "source": [ | |
- "clean22 = correct_spacing(train_dataset['sentence'][22])\n", | |
- "clean22" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 9, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "jfleg_dir = Path.cwd()/'jfleg_dataset' # if you only use 'jfleg', hf will try and use that and complain\n", | |
- "jfleg_dir.mkdir(parents=True,exist_ok=True)\n", | |
- "c4_dir = Path.cwd()/'c4_dataset'\n", | |
- "c4_dir.mkdir(parents=True,exist_ok=True)" | |
- ] | |
- }, | |
- { | |
- "attachments": {}, | |
- "cell_type": "markdown", | |
- "metadata": {}, | |
- "source": [ | |
- "Process Jfleg data " | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 10, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "j_train_file = jfleg_dir/'jtrain.csv'\n", | |
- "j_eval_file = jfleg_dir/'jeval.csv'" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 11, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "generate_csv(j_train_file, train_dataset)" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 12, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "generate_csv(j_eval_file, eval_dataset)" | |
- ] | |
- }, | |
- { | |
- "attachments": {}, | |
- "cell_type": "markdown", | |
- "metadata": {}, | |
- "source": [ | |
- "Process C4_200M (!) - we'll pull 10K to start" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 13, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "c4_dataset = load_dataset(\"liweili/c4_200m\", streaming = True)" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 14, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "iterator = iter(c4_dataset['train'])" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 15, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "def c4_generate_csv(csv_path, iterator, num_examples):\n", | |
- " with open(csv_path, 'w', newline='') as csvfile:\n", | |
- " writer = csv.writer(csvfile)\n", | |
- " writer.writerow([\"input\", \"target\"])\n", | |
- " for i in range(0,num_examples):\n", | |
- " data = next(iterator)\n", | |
- | |
- " input_text = data[\"input\"]\n", | |
- | |
- " input_text = correct_spacing(input_text)\n", | |
- " correction = correct_spacing(data[\"output\"])\n", | |
- " if input_text and correction:\n", | |
- " writer.writerow([input_text, correction])" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 16, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "c4_dir = Path.cwd()/'c4_dataset'\n", | |
- "c4_dir.mkdir(parents=True,exist_ok=True)" | |
- ] | |
- }, | |
- { | |
- "attachments": {}, | |
- "cell_type": "markdown", | |
- "metadata": {}, | |
- "source": [ | |
- "You can modify the following to make the csv file with desired number of instances, here we go for 10k to make a quick test" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 17, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "c4_filename = c4_dir/'c4train_10k.csv'" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 18, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "c4_generate_csv(c4_filename, iterator, num_examples=10000)" | |
- ] | |
- }, | |
- { | |
- "attachments": {}, | |
- "cell_type": "markdown", | |
- "metadata": {}, | |
- "source": [ | |
- "Create a single training file by combining jtrain and c4train" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 19, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "merge_list = [j_train_file, c4_filename, ]" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 20, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "import pandas as pd" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 21, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "combined_csv = pd.concat([pd.read_csv(fn) for fn in merge_list])\n" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 22, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "merged_name = \"gtrain_10k.csv\"" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 23, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "combined_csv.to_csv(merged_name, index=False, encoding = 'utf-8-sig', )" | |
- ] | |
- }, | |
- { | |
- "cell_type": "code", | |
- | |
- "execution_count": 24, | |
- | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "eval_name = \"grammar_validation.csv\"" | |
- ] | |
- | |
- }, | |
- { | |
- "cell_type": "code", | |
- "execution_count": 25, | |
- "metadata": {}, | |
- "outputs": [], | |
- "source": [ | |
- "eval_csv = pd.read_csv(j_eval_file)\n", | |
- "eval_csv.to_csv(eval_name, index=False, encoding = 'utf-8-sig', )" | |
- ] | |
- | |
- } | |
- ], | |
- "metadata": { | |
- "interpreter": { | |
- "hash": "5b2c14c5f2a3b21e6c2412c8196f5145870350e81c0b737cae3e5c60eb1e1eac" | |
- }, | |
- "kernelspec": { | |
- | |
- "display_name": "Python 3 (ipykernel)", | |
- | |
- "language": "python", | |
- "name": "python3" | |
- }, | |
- "language_info": { | |
- "codemirror_mode": { | |
- "name": "ipython", | |
- "version": 3 | |
- }, | |
- "file_extension": ".py", | |
- "mimetype": "text/x-python", | |
- "name": "python", | |
- "nbconvert_exporter": "python", | |
- "pygments_lexer": "ipython3", | |
- "version": "3.10.11" | |
- | |
- } | |
- }, | |
- "nbformat": 4, | |
- "nbformat_minor": 4 | |
- | |
-} | |
diff --git a/src/llama_recipes/datasets/samsum_dataset.py b/src/llama_recipes/datasets/samsum_dataset.py | |
deleted file mode 100644 | |
index fd91782..0000000 | |
--- a/src/llama_recipes/datasets/samsum_dataset.py | |
+++ /dev/null | |
@@ -1,33 +0,0 @@ | |
-# Copyright (c) Meta Platforms, Inc. and affiliates. | |
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
- | |
-# For dataset details visit: https://huggingface.co/datasets/samsum | |
- | |
-import datasets | |
- | |
-from llama_recipes.datasets.utils import Concatenator | |
- | |
-def get_preprocessed_samsum(dataset_config, tokenizer, split): | |
- dataset = datasets.load_dataset("samsum", split=split) | |
- | |
- prompt = ( | |
- f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}" | |
- ) | |
- | |
- def apply_prompt_template(sample): | |
- return { | |
- "text": prompt.format( | |
- dialog=sample["dialogue"], | |
- summary=sample["summary"], | |
- eos_token=tokenizer.eos_token, | |
- ) | |
- } | |
- | |
- dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features)) | |
- | |
- dataset = dataset.map( | |
- lambda sample: tokenizer(sample["text"]), | |
- batched=True, | |
- remove_columns=list(dataset.features), | |
- ).map(Concatenator(), batched=True) | |
- return dataset | |
diff --git a/src/llama_recipes/datasets/utils.py b/src/llama_recipes/datasets/utils.py | |
deleted file mode 100644 | |
index 0a11d8c..0000000 | |
--- a/src/llama_recipes/datasets/utils.py | |
+++ /dev/null | |
@@ -1,66 +0,0 @@ | |
-# Copyright (c) Meta Platforms, Inc. and affiliates. | |
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
- | |
-from tqdm import tqdm | |
-from itertools import chain | |
- | |
-from torch.utils.data import Dataset | |
- | |
-class Concatenator(object): | |
- def __init__(self, chunk_size=2048): | |
- self.chunk_size=chunk_size | |
- self.residual = {"input_ids": [], "attention_mask": []} | |
- | |
- def __call__(self, batch): | |
- concatenated_samples = { | |
- k: v + list(chain(*batch[k])) for k, v in self.residual.items() | |
- } | |
- | |
- total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]]) | |
- | |
- if total_length >= self.chunk_size: | |
- chunk_num = total_length // self.chunk_size | |
- result = { | |
- k: [ | |
- v[i : i + self.chunk_size] | |
- for i in range(0, chunk_num * self.chunk_size, self.chunk_size) | |
- ] | |
- for k, v in concatenated_samples.items() | |
- } | |
- self.residual = { | |
- k: v[(chunk_num * self.chunk_size) :] | |
- for k, v in concatenated_samples.items() | |
- } | |
- else: | |
- result = concatenated_samples | |
- self.residual = {k: [] for k in concatenated_samples.keys()} | |
- | |
- result["labels"] = result["input_ids"].copy() | |
- | |
- return result | |
- | |
-class ConcatDataset(Dataset): | |
- def __init__(self, dataset, chunk_size=4096): | |
- self.dataset = dataset | |
- self.chunk_size = chunk_size | |
- | |
- self.samples = [] | |
- | |
- buffer = { | |
- "input_ids": [], | |
- "attention_mask": [], | |
- "labels": [], | |
- } | |
- | |
- for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True): | |
- buffer = {k: v + sample[k] for k,v in buffer.items()} | |
- | |
- while len(next(iter(buffer.values()))) > self.chunk_size: | |
- self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()}) | |
- buffer = {k: v[self.chunk_size:] for k,v in buffer.items()} | |
- | |
- def __getitem__(self, idx): | |
- return self.samples[idx] | |
- | |
- def __len__(self): | |
- return len(self.samples) | |
diff --git a/src/llama_recipes/utils/dataset_utils.py b/src/llama_recipes/utils/dataset_utils.py | |
index 6d5f02c..18955b6 100644 | |
--- a/src/llama_recipes/utils/dataset_utils.py | |
+++ b/src/llama_recipes/utils/dataset_utils.py | |
@@ -7,7 +7,7 @@ from pathlib import Path | |
import torch | |
-from llama_recipes.datasets import ( | |
+from llama_recipes.datasets2 import ( | |
get_grammar_dataset, | |
get_alpaca_dataset, | |
get_samsum_dataset, |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment