Last active
June 11, 2025 05:58
-
-
Save hamees-sayed/08e9439c74a29d6654a9fefad9bcdc59 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from datasets.arrow_writer import ArrowWriter | |
from tqdm import tqdm | |
import datasets | |
import json | |
import re | |
import os | |
import csv | |
import string | |
import jieba | |
jieba.initialize() | |
duration_path = "data/duration.json" | |
csv_path = "/home/hamees/asr_benchmark/combined_data_langs_hi_en.csv" | |
arrow_file_path = "data/raw.arrow" | |
metadata_output = "data/metadata.csv" | |
def convert_char_to_pinyin(text_list, language_list=None, suffix_languages=None, polyphone=True): | |
if suffix_languages is None: | |
suffix_languages = ["fr", "de", "pl", "es", "it", "nl"] | |
if language_list is not None: | |
final_text_list = [] | |
custom_trans = str.maketrans( | |
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} | |
) | |
def is_chinese(c): | |
return "\u3100" <= c <= "\u9fff" | |
def is_punctuation_or_space(c): | |
return c in string.punctuation or c.isspace() | |
for text, lang in zip(text_list, language_list): | |
char_list = [] | |
text = text.translate(custom_trans) | |
for seg in jieba.cut(text): | |
for c in seg: | |
if is_punctuation_or_space(c): | |
# Leave out punctuation and spaces | |
char_list.append(c) | |
elif lang in suffix_languages: | |
# Add suffix for specified languages | |
char_list.append(f"{c}_{lang}") | |
else: | |
# No suffix for other languages | |
char_list.append(c) | |
final_text_list.append(char_list) | |
if 'hi' in language_list or 'en' in language_list: | |
for sublist in final_text_list: | |
for i in range(len(sublist)): | |
if sublist[i] == "ळ": | |
sublist[i] = "ल" | |
if 'mr' in language_list: | |
for sublist in final_text_list: | |
for i in range(len(sublist)): | |
if sublist[i] == "ळ": | |
sublist[i] = "ळ_mr" | |
if sublist[i] == "ं": | |
sublist[i] = "ं_mr" | |
if sublist[i] == "च": | |
sublist[i] = "च_mr" | |
if sublist[i] == "ज": | |
sublist[i] = "ज_mr" | |
return final_text_list | |
else: | |
final_text_list = [] | |
custom_trans = str.maketrans( | |
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} | |
) | |
for text in text_list: | |
char_list = [] | |
text = text.translate(custom_trans) | |
for seg in jieba.cut(text): | |
for c in seg: | |
if is_punctuation_or_space(c): | |
# Leave out punctuation and spaces | |
char_list.append(c) | |
else: | |
# No suffix if no language list is provided | |
char_list.append(c) | |
final_text_list.append(char_list) | |
return final_text_list | |
def clear_text(text): | |
text = text.replace("\n", " ") | |
text = re.sub(r'\s+', ' ', text) | |
text = text.strip('"') | |
text = text.strip("'") | |
return text.strip() | |
# Load the CSV file | |
df = pd.read_csv(csv_path, dtype={ | |
"audio_path": "string", | |
"duration": "float16", | |
"text": "string", | |
"language": "string" | |
}) | |
# Preprocess the text | |
tqdm.pandas() | |
df["text"] = df["text"].progress_apply(lambda x: clear_text(x)) | |
df["text"] = df.progress_apply(lambda x: convert_char_to_pinyin([x['text']], [x['language']])[0], axis=1) | |
print("Text preprocessing done!") | |
def convert_metadata(input_csv, output_csv): | |
with open(input_csv, 'r', newline='', encoding='utf-8') as infile, \ | |
open(output_csv, 'w', newline='', encoding='utf-8') as outfile: | |
reader = csv.DictReader(infile) | |
writer = csv.writer(outfile, delimiter='|') | |
for row in tqdm(reader): | |
audio_path = row.get('audio_path', '').strip() | |
text = row.get('text', '').replace('\n', ' ').strip() # Remove newlines inside text | |
# language = row.get('language', '').strip() | |
# text = "".join(convert_char_to_pinyin([text])[0]) | |
writer.writerow([audio_path, text]) # Correctly write rows | |
print(f"Metadata file created: {output_csv}") | |
# Write to Arrow file with a progress bar | |
with ArrowWriter(path=arrow_file_path, writer_batch_size=45) as writer: | |
for _, row in tqdm(df.iterrows(), total=len(df), desc="Writing to Arrow"): | |
if ''.join(row['text']) and row['duration'] >= 0.3 and row['duration'] <= 30 and os.path.exists(row['audio_path']): | |
writer.write({ | |
"audio_path": row['audio_path'], | |
"text": row['text'], | |
"duration": row['duration'], | |
}) | |
else: | |
print(f"Error: {row['audio_path']}") | |
ds = datasets.Dataset.from_file(arrow_file_path) | |
duration_list = [] | |
for i in tqdm(ds): | |
duration_list.append(i['duration']) | |
with open(duration_path, "w") as f: | |
json.dump({"duration": duration_list}, f, ensure_ascii=False) | |
convert_metadata(csv_path, metadata_output) | |
print(f"Arrow file successfully written to {arrow_file_path}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment