Created
November 23, 2022 15:41
-
-
Save Narsil/4f5b088f4dd23200d16dd2cc575fdc16 to your computer and use it in GitHub Desktop.
Few methods on using datasets + pipelines.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import pipeline | |
from datasets import load_dataset | |
import datetime | |
import torch | |
pipe = pipeline("automatic-speech-recognition", model="hf-internal-testing/tiny-random-wav2vec2", device=0) | |
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:10]") | |
filenames = [item["audio"]["path"] for item in dataset] | |
for item in dataset["audio"]: | |
# Warmup the dataset ?? | |
pass | |
def data(): | |
for item in dataset["audio"]: | |
yield item | |
# Make sure cuda init is not benchmarked | |
torch.zeros((2, 2)).cuda() | |
def method1(): | |
output = [] | |
for out in pipe(data()): | |
output.append(out["text"]) | |
# print(out) | |
pass | |
def predict(batch): | |
audios = batch["audio"] | |
# hacky renaming | |
audios = [{"raw": sample["array"], "sampling_rate": sample["sampling_rate"]} for sample in audios] | |
predictions = pipe(audios) | |
# unpack and index predictions (List[Dict]) | |
batch["predictions"] = [pred["text"] for pred in predictions] | |
return batch | |
def method2(): | |
dataset.map( | |
predict, | |
batched=True, | |
batch_size=2, | |
remove_columns=["file", "audio", "text", "speaker_id", "chapter_id", "id"], | |
) | |
def method3(): | |
output = [] | |
for out in pipe(filenames): | |
output.append(out["text"]) | |
# print(out) | |
pass | |
method1() | |
start = datetime.datetime.now() | |
method1() | |
print("Method 1 (pipe)", datetime.datetime.now() - start) | |
start = datetime.datetime.now() | |
method2() | |
print("Method 2 (dataset)", datetime.datetime.now() - start) | |
start = datetime.datetime.now() | |
method3() | |
print("Method 3 (raw file)", datetime.datetime.now() - start) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment