Last active
February 23, 2022 11:58
-
-
Save MattFaus/35d773390d64997f8a09 to your computer and use it in GitHub Desktop.
Merge-reads several sorted .csv files stored on Google Cloud Storage.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SortedGcsCsvShardFileMergeReader(object): | |
"""Merges several sorted .csv files stored on GCS. | |
This class is both an iterator and a context manager. | |
Let's say there are 2 .csv files stored on GCS, with contents like: | |
/bucket/file_1.csv: | |
[0, "Matt"], | |
[0, "Sam"], | |
[2, "Dude"], | |
/bucket/file_2.csv: | |
[0, "Bakery"], | |
[2, "Francisco"], | |
[3, "Matt"], | |
These files are already sorted by our key_columns = [0]. We want to read | |
all of the rows with the same key_column value at once, regardless of | |
which files those rows reside in. This class does that by opening | |
handles to all of the files and picking off the top rows from each of the | |
files as long as they share the same key. | |
For example: | |
merge_reader = SortedGcsCsvShardFileMergeReader("/bucket", "file.*", [0]) | |
with merge_reader: | |
for row in merge_reader: | |
# Returns rows in totally-sorted order, like: | |
# [0, "Matt"] (from file_1.csv) | |
# [0, "Sam"] (from file_1.csv) | |
# [0, "Bakery"] (from file_2.csv) | |
# [2, "Dude"] (from file_1.csv) | |
# [2, "Francisco"] (from file_2.csv) | |
# [3, "Matt"] (from file_2.csv) | |
The merge columns must be comparable, so that this class can return the | |
results in totally sorted order. | |
NOTE: All shards must have at least one row. | |
To do this, we build up a somewhat complex instance object to keep track | |
of the shards and their current statuses. self.shard_files has this format: | |
{ | |
"shard_file_path_1": { | |
"gcs_file": handle to the gcs file stream | |
"csv_reader": csv_reader reading from gcs_file | |
"head_key": the key tuple of the head | |
"head": the most recently read from the csv_reader | |
"rows_returned": a running count of the rows returned | |
} | |
... | |
"shard_file_path_2" : {} | |
... | |
"shard_file_path_n" : {} | |
} | |
The self.shard_files object is pruned as the shards are exhausted. | |
""" | |
def __init__(self, input_bucket, input_pattern, merge_columns): | |
"""Constructor. | |
Arguments: | |
input_bucket - The bucket to read from. | |
input_pattern - The file pattern to read from. | |
merge_columns - The columns used for determining row merging. | |
""" | |
shard_paths = get_shard_files( | |
input_bucket, input_pattern, full_path=True) | |
if len(shard_paths) == 0: | |
raise ValueError("Could not find any shard files.") | |
logging.info("Merge-reading: %s", shard_paths) | |
self.shard_files = { | |
shard_path: {} for shard_path in shard_paths | |
} | |
self.merge_columns = merge_columns | |
self.current_key = None | |
self.current_shard_path = None | |
def __enter__(self): | |
"""Open handles to all of the shard files, read the first row""" | |
retry_params = gcs.RetryParams(urlfetch_timeout=60, | |
max_retry_period=60 * 60.0) | |
for shard_path in self.shard_files.keys(): | |
gcs_file = gcs.open(shard_path, "r", retry_params=retry_params) | |
csv_reader = csv.reader(gcs_file) | |
head = csv_reader.next() # Assumes there is at least 1 row | |
self.shard_files[shard_path] = { | |
"gcs_file": gcs_file, | |
"csv_reader": csv_reader, | |
"head": head, | |
"head_key": self._get_key(head), | |
"rows_returned": 0, | |
} | |
def __iter__(self): | |
return self | |
def _get_key(self, row): | |
return tuple([v for i, v in enumerate(row) if i in self.merge_columns]) | |
def _advance_shard(self, shard_path): | |
"""Update the shard's head values, return the current head.""" | |
# Save the head, to return later | |
metadata = self.shard_files[shard_path] | |
row = metadata["head"] | |
try: | |
new_head = metadata["csv_reader"].next() | |
metadata["head"] = new_head | |
metadata["head_key"] = self._get_key(new_head) | |
metadata["rows_returned"] += 1 | |
except StopIteration: | |
self._close_shard(shard_path) | |
self.current_shard_path = None | |
return row | |
def _find_next_key(self): | |
"""Find the next key to start merge reading. | |
We must always choose the next "lowest" key value to be the next key | |
to read. Not all shards have all keys, so we must do this to ensure | |
that we do not mis-order the rows in the final output. | |
""" | |
lowest_key_value = None | |
lowest_shard_path = None | |
for path, metadata in self.shard_files.iteritems(): | |
if (metadata["head_key"] < lowest_key_value | |
or lowest_key_value is None): | |
lowest_key_value = metadata["head_key"] | |
lowest_shard_path = path | |
return lowest_key_value, lowest_shard_path | |
def next(self): | |
# We've exhausted all rows from all shards | |
if len(self.shard_files) == 0: | |
raise StopIteration | |
# This happens at the very beginning, or after exhausting a shard | |
if self.current_shard_path is None: | |
self.current_shard_path = self.shard_files.keys()[0] | |
# This happens at the very beginning, or after exhausting a key | |
if self.current_key is None: | |
self.current_key, self.current_shard_path = self._find_next_key() | |
# If the current shard has more, just return that | |
if (self.shard_files[self.current_shard_path]["head_key"] | |
== self.current_key): | |
return self._advance_shard(self.current_shard_path) | |
# Iterate over all shard_files | |
for path, metadata in self.shard_files.iteritems(): | |
if metadata["head_key"] == self.current_key: | |
self.current_shard_path = path | |
return self._advance_shard(path) | |
# We didn't find any rows for current_key, so start over | |
self.current_key = None | |
return self.next() | |
def _close_shard(self, shard_path): | |
"""Close the shard, remove it from the shard_files collection.""" | |
if shard_path not in self.shard_files: | |
return | |
metadata = self.shard_files[shard_path] | |
logging.info( | |
"Closing shard after reading %d rows. %d shards remain. %s", | |
metadata["rows_returned"], len(self.shard_files) - 1, shard_path) | |
try: | |
metadata["gcs_file"].close() | |
except Exception: | |
logging.exception("Ignoring exception from %s", shard_path) | |
del self.shard_files[shard_path] | |
def __exit__(self, exception_type, exception_value, exception_traceback): | |
"""Closes all shards.""" | |
shard_paths = self.shard_files.keys() | |
for path in shard_paths: | |
self._close_shard(path) | |
# TODO(mattfaus): Re-raise any exception passed here? | |
def get_shard_files(bucket, filename_prefix, full_path=False): | |
"""Find files in a bucket, matching a filename prefix.""" | |
if not bucket.startswith("/"): | |
bucket = "/%s" % bucket | |
if bucket.endswith("/"): | |
bucket = bucket[:-1] | |
retry_params = gcs.RetryParams(urlfetch_timeout=60, | |
max_retry_period=60 * 30.0) | |
filename_prefix = bucket + "/" + filename_prefix | |
shard_files = [] | |
for file in gcs.listbucket(filename_prefix, retry_params=retry_params): | |
path = file.filename | |
if not full_path: | |
# Remove the "/<bucket>" + "/" prefix | |
path = path[len(bucket) + 1:] | |
shard_files.append(path) | |
# Sort for deterministic ordering | |
shard_files.sort() | |
return shard_files | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class GcsTest(gae_model.GAEModelTestCase): | |
def setUp(self): | |
super(GcsTest, self).setUp() | |
# Initialize stubs necessary for testing the GCS client | |
self.testbed.init_app_identity_stub() | |
self.testbed.init_urlfetch_stub() | |
self.testbed.init_blobstore_stub() | |
class MergeReaderTest(GcsTest): | |
"""Verifies SortedGcsCsvShardFileMergeReader functionality.""" | |
data = [ | |
[0, "Matt"], | |
[0, "Ben"], | |
[1, "Sam"], | |
[2, "Bam"], | |
[2, "Bam"], | |
[3, "Matt"], | |
] | |
bucket = "/bucket" | |
path_pattern = "/bucket/csv_shard_%d" | |
filename_prefix = "csv_shard_" | |
def _write_shards(self, num_shards, num_rows=None): | |
"""Writes test shards to GCS. | |
Arguments: | |
num_shards - The number of shards to create. | |
num_rows - An array of the number of rows that each shard should | |
contain. If None, all shards will contain all rows. Max value | |
is len(self.data). You can use this to make some shards have | |
less data than others. | |
""" | |
if num_rows is None: | |
num_rows = [len(self.data)] * num_shards | |
for shard in range(num_shards): | |
path = self.path_pattern % shard | |
with gcs.open(path, 'w') as gcs_file: | |
csv_writer = csv.writer(gcs_file) | |
for row in self.data[:num_rows[shard]]: | |
csv_writer.writerow(row) | |
def _get_key(self, row, key_columns): | |
return tuple([v for i, v in enumerate(row) if i in key_columns]) | |
def _verify_merge_reader(self, merge_reader, key_columns): | |
seen_keys = set() | |
prev_key = None | |
with merge_reader: | |
for row in merge_reader: | |
cur_key = self._get_key(row, key_columns) | |
if cur_key != prev_key: | |
self.assertNotIn(cur_key, seen_keys) | |
seen_keys.add(cur_key) | |
prev_key = cur_key | |
def test_same_length_shards(self): | |
self._write_shards(3) | |
key_columns = [0] | |
merge_reader = pipelines_util.SortedGcsCsvShardFileMergeReader( | |
self.bucket, self.filename_prefix, key_columns) | |
self._verify_merge_reader(merge_reader, key_columns) | |
def test_different_length_shards(self): | |
self._write_shards(3, num_rows=[ | |
len(self.data), | |
len(self.data) - 1, | |
len(self.data) - 2]) | |
key_columns = [0] | |
merge_reader = pipelines_util.SortedGcsCsvShardFileMergeReader( | |
self.bucket, self.filename_prefix, key_columns) | |
self._verify_merge_reader(merge_reader, key_columns) | |
def test_complex_key(self): | |
self._write_shards(3, num_rows=[ | |
len(self.data), | |
len(self.data) - 1, | |
len(self.data) - 2]) | |
key_columns = [0, 1] | |
merge_reader = pipelines_util.SortedGcsCsvShardFileMergeReader( | |
self.bucket, self.filename_prefix, key_columns) | |
self._verify_merge_reader(merge_reader, key_columns) | |
def test_skip_sorting(self): | |
# Files being sorted internally does not mean that each file contains | |
# all keys. The merge reader should be robust to this and skip keys | |
# until it finds one that matches. | |
shards = [ | |
[ | |
[5, 'hi'], | |
[5, 'hi'], | |
[6, 'hi'], | |
], | |
[ | |
[0, 'bye'], | |
[2, 'bye'], | |
], | |
[ | |
[2, 'fie'], | |
[3, 'fie'], | |
], | |
[ | |
[6, 'fie'], | |
[7, 'fie'], | |
], | |
[ | |
[1, 'fie'], | |
[7, 'fie'], | |
] | |
] | |
for idx, shard in enumerate(shards): | |
path = self.path_pattern % idx | |
with gcs.open(path, 'w') as gcs_file: | |
csv_writer = csv.writer(gcs_file) | |
for row in shard: | |
csv_writer.writerow(row) | |
key_columns = [0] | |
merge_reader = pipelines_util.SortedGcsCsvShardFileMergeReader( | |
self.bucket, self.filename_prefix, key_columns) | |
self._verify_merge_reader(merge_reader, key_columns) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment