Last active
May 21, 2023 17:51
-
-
Save adivekar-utexas/b9db9b46c0649ae310889ae3e4e914d9 to your computer and use it in GitHub Desktop.
Download and split REALNEWS into multiple small parquet files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
REALNEWS is a big dataset of several million news articles obtained from Common Crawl. | |
It was used to train the Grover news generation language model. | |
Details here: https://arxiv.org/abs/1905.12616 | |
In this script, we download it following instructions from https://github.com/rowanz/grover/tree/master/realnews | |
(please make sure to fill in the survey in the link above!) | |
After downloading, the file is a .tar.gz containing an enormous .jsonl file. | |
To split it into multiple small .parquet files, I've written the script below. | |
You should get | |
""" | |
## Required dependencies: | |
## pip install tqdm orjsonl pandas numpy gsutil | |
from typing import * | |
import time, traceback, random, sys | |
import math, gc | |
from datetime import datetime | |
from math import inf | |
from threading import Semaphore | |
import multiprocessing as mp | |
from concurrent.futures._base import Future | |
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, wait as wait_future | |
from concurrent.futures.thread import BrokenThreadPool | |
from concurrent.futures.process import BrokenProcessPool | |
import numpy as np | |
import pandas as pd | |
import orjsonl as oj | |
from tqdm import tqdm | |
def concurrent(max_active_threads: int = 10, max_calls_per_second: float = inf): | |
""" | |
Decorator which runs function calls concurrently via multithreading. | |
When decorating an IO-bound function with @concurrent(MAX_THREADS), and then invoking the function | |
N times in a loop, it will run min(MAX_THREADS, N) invocations of the function concurrently. | |
For example, if your function calls another service, and you must invoke the function N times, decorating with | |
@concurrent(3) ensures that you only have 3 concurrent function-calls at a time, meaning you only make | |
3 concurrent requests at a time. This reduces the number of connections you are making to the downstream service. | |
As this uses multi-threading and not multi-processing, it is suitable for IO-heavy functions, not CPU-heavy. | |
Each call to the decorated function returns a future. Calling .result() on that future will return the value. | |
Generally, you should call the decorated function N times in a loop, and store the futures in a list/dict. Then, | |
call .result() on all the futures, saving the results in a new list/dict. Each .result() call is synchronous, so the | |
order of items is maintained between the lists. When doing this, at most min(MAX_THREADS, N) function calls will be | |
running concurrently. | |
Note that if the function calls throws an exception, then calling .result() will raise the exception in the | |
orchestrating code. If multiple function calls raise an exception, the one on which .result() was called first will | |
throw the exception to the orchestrating code. You should add try-catch logic inside your decorated function to | |
ensure exceptions are handled. | |
Note that decorated function `a` can call another decorated function `b` without issues; it is upto the function A | |
to determine whether to call .result() on the futures it gets from `b`, or return the future to its own invoker. | |
`max_calls_per_second` controls the rate at which we can call the function. This is particularly important for | |
functions which execute quickly: e.g. suppose the decorated function calls a downstream service, and we allow a | |
maximum concurrency of 5. If each function call takes 100ms, then we end up making 1000/100*5 = 50 calls to the | |
downstream service each second. We thus should pass `max_calls_per_second` to restrict this to a smaller value. | |
:param max_active_threads: the max number of threads which can be running the function at one time. This is thus | |
them max concurrency factor. | |
:param max_calls_per_second: controls the rate at which we can call the function. | |
:return: N/A, this is a decorator. | |
""" | |
## Refs: | |
## 1. ThreadPoolExecutor: docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Executor.submit | |
## 2. Decorators: www.datacamp.com/community/tutorials/decorators-python | |
## 3. Semaphores: www.geeksforgeeks.org/synchronization-by-using-semaphore-in-python/ | |
## 4. Overall code: https://gist.github.com/gregburek/1441055#gistcomment-1294264 | |
def decorator(function): | |
## Each decorated function gets its own executor and semaphore. These are defined at the function-level, so | |
## if you write two decorated functions `def say_hi` and `def say_bye`, they each gets a separate executor and | |
## semaphore. Then, if you invoke `say_hi` 30 times and `say_bye` 20 times, all 30 calls to say_hi will use the | |
## same executor and semaphore, and all 20 `say_bye` will use a different executor and semaphore. The value of | |
## `max_active_threads` will determine how many function calls actually run concurrently, e.g. if say_hi has | |
## max_active_threads=5, then the 30 calls will run 5 at a time (this is enforced by the semaphore). | |
executor = ThreadPoolExecutor(max_workers=max_active_threads) | |
semaphore = Semaphore(max_active_threads) | |
## The minimum time between invocations. | |
min_time_interval_between_calls = 1 / max_calls_per_second | |
## This only stores a single value, but it must be a list (mutable) for Python's function scoping to work. | |
time_last_called = [0.0] | |
def wrapper(*args, **kwargs) -> Future: | |
semaphore.acquire() | |
time_elapsed_since_last_called = time.time() - time_last_called[0] | |
time_to_wait_before_next_call = max(0.0, min_time_interval_between_calls - time_elapsed_since_last_called) | |
time.sleep(time_to_wait_before_next_call) | |
def run_function(*args, **kwargs): | |
try: | |
result = function(*args, **kwargs) | |
finally: | |
semaphore.release() ## If the function call throws an exception, release the semaphore. | |
return result | |
time_last_called[0] = time.time() | |
return executor.submit(run_function, *args, **kwargs) ## return a future | |
return wrapper | |
return decorator | |
def get_num_zeros_to_pad(max_i: int) -> int: | |
assert isinstance(max_i, int) and max_i >= 1 | |
num_zeros = math.ceil(math.log10(max_i)) ## Ref: https://stackoverflow.com/a/51837162/4900327 | |
if max_i == 10 ** num_zeros: ## If it is a power of 10 | |
num_zeros += 1 | |
return num_zeros | |
def pad_zeros(i: int, max_i: int = None) -> str: | |
assert isinstance(i, int) and i >= 0 | |
if max_i is None: | |
return str(i) | |
assert isinstance(max_i, int) and max_i >= i | |
num_zeros: int = get_num_zeros_to_pad(max_i) | |
return f'{i:0{num_zeros}}' | |
print(""" | |
___ _ _ _ _ _ | |
| \ _____ __ ___ _ | |___ __ _ __| | __ _ _ _ __| | _ _ _ _ __(_)_ __ _ _ ___ __ _| |_ _ _____ __ _____ | |
| |) / _ \ V V / ' \| / _ \/ _` / _` | / _` | ' \/ _` | | || | ' \|_ / | '_ \ | '_/ -_) _` | | ' \/ -_) V V (_-< | |
|___/\___/\_/\_/|_||_|_\___/\__,_\__,_| \__,_|_||_\__,_| \_,_|_||_/__|_| .__/ |_| \___\__,_|_|_||_\___|\_/\_//__/ | |
|_| | |
""") | |
## Download and unzip realnews: | |
! gsutil cp gs://grover-models/realnews.tar.gz ~/data/realnews/ | |
! pigz -dc realnews.tar.gz | pv | tar xf - | |
! mkdir -p ~/data/realnews/realnews/realnews-splits/ | |
get_fpath = lambda file_num: f'~/data/realnews/realnews/realnews-splits/realnews-part-{pad_zeros(file_num+1, 1_000_000)}.parquet' | |
@concurrent(max_active_threads=64) ## 64 threads will write to dist/remote storage in parallel | |
def save_df(buf, fpath): | |
pd.DataFrame(buf).to_parquet(fpath) | |
print(f'Saved file to: {fpath}') | |
realnews_path = '~/data/realnews/realnews/realnews.jsonl' | |
nrows = int(1e5) ## Number of rows per file | |
row_buf = [] | |
file_num = 0 | |
for row in tqdm(oj.stream(realnews_path), smoothing=0.1, ncols=100, unit='rows'): | |
row_buf.append(row) | |
if len(row_buf) % nrows == 0: | |
save_df(row_buf, get_fpath(file_num)) | |
row_buf = [] | |
file_num += 1 | |
save_df(row_buf, get_fpath(file_num)) ## Save the final file |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment