Deprecated, just use https://gist.github.com/rom1504/0b846b2dc64c5e0604e1d532c09cbff6 with pyspark, it's better and faster
Last active
July 20, 2021 19:19
-
-
Save rom1504/360aa910e9867776a3293cb18dde3fa3 to your computer and use it in GitHub Desktop.
cah_stats_dask
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Compute some stats on cah collection | |
First get the files with: | |
lynx -dump -hiddenlinks=listonly -nonumbers http://the-eye.eu/eleuther_staging/cah/ | grep cah | grep .csv > cah.csv | |
aria2c --dir=shards --auto-file-renaming=false --continue=true -i cah.csv -x 16 -s 16 -j 100 | |
Takes a few minutes to run | |
Then pip install pandas dask distributed pyarrow bokeh | |
Then run this file. It also takes a few minutes | |
''' | |
if __name__ == '__main__': | |
from dask.distributed import Client | |
import dask.dataframe as dd | |
# You can open http://localhost:8787/status to follow progress on the dask operations | |
client = Client() | |
import numpy as np | |
import io | |
import pandas as pd | |
schema = { 'SAMPLE_ID': int, | |
'PATH': str, | |
'URL': str, | |
'TEXT': str, | |
'HEIGHT': int, | |
'WIDTH': int, | |
'LICENSE': str, | |
'similarity': float, | |
'NSFW': str} | |
cols = ['SAMPLE_ID', 'PATH', 'URL', 'TEXT', 'HEIGHT', 'WIDTH', 'LICENSE', 'similarity', 'NSFW'] | |
def read_data(filename): | |
dtypes = schema | |
try: | |
df = pd.read_csv(filename, sep="|", dtype=dtypes, usecols=cols) | |
except: | |
try: | |
df = pd.read_csv(filename, sep="|", dtype=dtypes, usecols=['SAMPLE_ID', 'PATH', 'URL', 'TEXT', 'HEIGHT', 'WIDTH', 'LICENSE', 'similarity']) | |
except: | |
df = pd.read_csv(filename, sep="|", dtype=dtypes, usecols=['SAMPLE_ID', 'PATH', 'URL', 'TEXT', 'HEIGHT', 'WIDTH', 'LICENSE']) | |
if 'NSFS' not in df.columns: | |
df['NSFW'] = "unknown" | |
if 'similarity' not in df.columns: | |
df['similarity'] = 0.0 | |
df = df[cols] | |
return df | |
read_data = delayed(read_data) | |
from glob import glob | |
files = glob('shards/*.csv') | |
df = [read_data(file) for file in files] | |
df = dd.from_delayed(df, meta=schema) | |
df.to_parquet('parquet_shards_fulll', engine='pyarrow', schema="infer") | |
cah = dd.read_parquet("parquet_shards_fulll", engine='pyarrow') | |
# start computing stats | |
a[a.similarity > 0.4].shape[0].compute() | |
print("Size of collection", a.shape[0].compute()) | |
print("Number of uniques", a.drop_duplicates().shape[0].compute()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ == '__main__': | |
from dask.distributed import Client | |
import dask.dataframe as dd | |
import dask.delayed as delayed | |
# You can open http://localhost:8787/status to follow progress on the dask operations | |
client = Client() | |
import numpy as np | |
import io | |
import pandas as pd | |
schema = { 'SAMPLE_ID': int, | |
'URL': str, | |
'TEXT': str, | |
} | |
cols = ['SAMPLE_ID', 'URL', 'TEXT'] | |
def read_data(filename): | |
try: | |
df = pd.read_csv(filename, sep="|", usecols=cols, dtype=schema) | |
except: | |
try: | |
df = pd.read_csv(filename, sep="|", usecols=cols, error_bad_lines=False, dtype=schema) | |
except: | |
df = pd.DataFrame(columns=cols) | |
df = df.astype(schema) | |
return df | |
read_data = delayed(read_data) | |
from glob import glob | |
files = glob('/media/hd/cah/drive/*.csv') + glob('/media/hd/cah/theeye/output/cah/**/*.csv') | |
df = [read_data(file) for file in files] | |
df = dd.from_delayed(df, meta=schema) | |
df = df.repartition(100) | |
df.to_parquet('the_good', engine='pyarrow', schema="infer") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment