Last active
January 19, 2024 02:44
Efficient Batching v2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#This method deducts from the list sent in (splitting the records between sample and remainder). | |
#Always 100% full of data until no more samples can be extracted where an empty sample along with the remainder are returned [where the remainder is to be folded into a new iteration] | |
# Function to find the combination of values that adds up to the target sum | |
def find_combination_to_sum(counts, target): | |
#print("Target inside function (find_combination_to_sum):", target) | |
values = [] | |
for val, count in counts.items(): | |
#print(f"Value (val): {val}, Type: {type(val)}") | |
#print(f"Count: {count}, Type: {type(count)}") | |
#print(f"Target // val: {target // val}, Type of target // val: {type(target // val)}") | |
values.extend([val] * min(count, target // val)) | |
# Initialize the DP table | |
n = len(values) | |
dp = [[False] * (target + 1) for _ in range(n + 1)] | |
# Base case: target sum 0 is always achievable (by choosing nothing) | |
for i in range(n + 1): | |
dp[i][0] = True | |
# Build the DP table | |
for i in range(1, n + 1): | |
for j in range(1, target + 1): | |
dp[i][j] = dp[i - 1][j] | |
if values[i - 1] <= j: | |
dp[i][j] |= dp[i - 1][j - values[i - 1]] | |
# Check if the target sum is possible | |
if not dp[n][target]: | |
return None | |
# Trace back the solution | |
result = [] | |
i, j = n, target | |
while i > 0 and j > 0: | |
if dp[i][j] != dp[i - 1][j]: | |
result.append(values[i - 1]) | |
j -= values[i - 1] | |
i -= 1 | |
return result | |
def sample_and_remove(combination, records): | |
# Group records by their length | |
grouped_records = defaultdict(list) | |
for record in records: | |
grouped_records[len(record)].append(record) | |
sampled_records = [] | |
if(combination): | |
for lens_size in combination: | |
# Check if there are enough records of this lens size | |
if grouped_records[lens_size]: | |
# Sample one record of this lens size | |
sample = random.sample(grouped_records[lens_size], 1)[0] | |
# Add to sampled records | |
sampled_records.append(sample) | |
# Remove this record from the grouped records | |
grouped_records[lens_size].remove(sample) | |
# Flatten the grouped records back to a single list | |
modified_records = [item for sublist in grouped_records.values() for item in sublist] | |
return sampled_records, modified_records | |
else: | |
return [], records | |
def create_batches_v2(records, block_size, num_batches): | |
#print("block_size in create_batches_v2:", block_size) | |
#print("num_batches in create_batches_v2:", num_batches) | |
samples = [] | |
modified_records = records.copy() | |
for r in range(0, num_batches): | |
sample, modified_records = retrieve_sample(modified_records, block_size, num_batches) | |
if(len(sample)==0): | |
return [], records | |
else: | |
samples.append(sample) | |
if(len(samples)<num_batches): | |
return [], records | |
else: | |
return samples, modified_records | |
def retrieve_sample(records, block_size, num_batches): | |
#print("block_size in retrieve_sample:", block_size) | |
lens = [len(s) for s in records] | |
# Assuming 'lens' is a list containing your data | |
grouped = pd.DataFrame(lens, columns=['lens']).groupby('lens').size() | |
# Convert to dictionary | |
counts_dict = grouped.to_dict() | |
combination = find_combination_to_sum(counts_dict, block_size) | |
sample, records = sample_and_remove(combination, records) | |
return sample, records | |
train_tokenized = [record + [tokenizer.eos_token_id] for record in train_tokenized if len(record) + 1 <= args.block_size] | |
val_tokenized = [record + [tokenizer.eos_token_id] for record in val_tokenized if len(record) + 1 <= args.block_size] | |
print(args.block_size) | |
sampled_train, remainder = create_batches_v2(train_tokenized, args.block_size, args.batch_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment