Skip to content

Instantly share code, notes, and snippets.

@mutukrish
Last active October 16, 2024 05:11
Show Gist options
  • Save mutukrish/63099d34969393cfbbbf246b3e284f79 to your computer and use it in GitHub Desktop.
Save mutukrish/63099d34969393cfbbbf246b3e284f79 to your computer and use it in GitHub Desktop.
Index recommendation using pure Aggregation pipeline stage

How It Works:

  • Aggregation Stages: The system looks through each stage of an aggregation pipeline ($match, $sort, $group, $lookup) and generates appropriate indexes.
  • Index Key Creation: The indexes are created based on the fields used in each stage
  • Logical and Comparison Operators: It handles $and, $or, $eq, $gt, and similar logical and comparison operators to determine which fields should be indexed.
  • Aliases Handling: The code applies any field aliases created by $project or $addFields stages and adjusts the index accordingly.
class IndexRecommender:
LOGIC_OPS = ["$or", "$and", "$nor", "$not"]
COMPARISON_OPS = ['$cmp', '$eq', '$gt', '$gte', '$lt', '$lte', '$ne']
def __init__(self):
pass
def generate(self, aggregations):
result = {}
for aggregation in aggregations:
result[aggregation['name']] = self.generate_indexes(aggregation)
return result
def generate_indexes(self, aggregation):
indexes = []
operator_seq = [self.get_aggregation_stage_operator(stage) for stage in aggregation['pipeline']]
base_collection = aggregation['collection']
step_num = 0
for stage in aggregation['pipeline']:
operator = self.get_aggregation_stage_operator(stage)
name = f"{aggregation['name']}_{base_collection}_{operator.replace('$', '')}_{step_num}"
index = self.get_stage_index(operator, stage, step_num, operator_seq, aggregation, name)
if index:
indexes.append(index)
step_num += 1
return indexes
def get_stage_index(self, operator, stage, order, sequence, aggregation, name):
if operator == '$match':
return self.get_match_index(operator, stage, order, sequence, aggregation, name)
elif operator == '$sort':
return self.get_sort_index(operator, stage, order, sequence, aggregation, name)
elif operator == '$group':
return self.get_group_index(operator, stage, order, sequence, aggregation, name)
elif operator == '$lookup':
return self.get_lookup_index(operator, stage, order, sequence, aggregation, name)
return None
def get_match_index(self, operator, stage, order, sequence, aggregation, name):
if order > 0:
valid_stages = ["$project", "$match", "$unset", "$set", "$sort", "$addFields"]
for previous in sequence[:order]:
if previous not in valid_stages:
return None
element_dict = stage[operator]
index_key = {}
for key, value in element_dict.items():
if not key.startswith('$'):
if not self.is_field_array(value):
index_key[key] = 1
elif key in self.LOGIC_OPS:
logic_index = self.process_logical_operator(value)
for field in logic_index:
index_key[field] = 1
elif key == '$expr':
expr_index = self.process_expr_operator(value)
for field in expr_index:
index_key[field] = 1
added_fields = self.get_added_fields(aggregation['pipeline'][:order])
index_key = self.apply_aliases(added_fields, index_key)
if index_key:
return {"name": name, "key": index_key, "collection": aggregation['collection'], "operator": "$match", "order": order}
return None
def get_sort_index(self, operator, stage, order, sequence, aggregation, name):
if order > 0:
invalid_stages = ["$project", "$unwind", "$group"]
for previous in sequence[:order]:
if previous in invalid_stages:
return None
element_dict = stage[operator]
index_key = {}
for key, value in element_dict.items():
if not key.startswith('$') and (value == 1 or value == -1):
index_key[key] = value
added_fields = self.get_added_fields(aggregation['pipeline'][:order])
index_key = self.apply_aliases(added_fields, index_key)
if index_key:
return {"name": name, "key": index_key, "collection": aggregation['collection'], "operator": "$sort", "order": order}
return None
def get_group_index(self, operator, stage, order, sequence, aggregation, name):
previous_sort_order = 0
agg_order = 0
invalid_stages = ["$group", "$project", "$unwind"]
for previous in sequence[:order]:
if previous in invalid_stages:
return None
if previous == "$sort":
previous_sort_order = agg_order
agg_order += 1
element_dict = stage[operator]
first_op_fields = {}
for key, element in element_dict.items():
if key != "_id":
for group_op, field in element.items():
if group_op != "$first":
return None
else:
first_op_fields[field.replace('$', '')] = 1
if '$sort' not in sequence[:order]:
group_id_fields = list(element_dict["_id"].values())
if len(group_id_fields) < 2:
id_field = group_id_fields[0].replace('$', '')
index_key = {id_field: 1, **first_op_fields}
added_fields = self.get_added_fields(aggregation['pipeline'][:order])
index_key = self.apply_aliases(added_fields, index_key)
return {"name": name, "key": index_key, "collection": aggregation['collection'], "operator": "$group", "order": order}
else:
previous_sort_dict = aggregation['pipeline'][previous_sort_order]["$sort"]
valid_sort_fields = {}
for i, group_field in enumerate(list(element_dict["_id"].values())):
id_field = group_field.replace('$', '')
if i < len(previous_sort_dict):
sort_field = list(previous_sort_dict.keys())[i]
if id_field == sort_field:
valid_sort_fields[sort_field] = previous_sort_dict[sort_field]
else:
return None
else:
valid_sort_fields[id_field] = 1
index_key = {**valid_sort_fields, **first_op_fields}
added_fields = self.get_added_fields(aggregation['pipeline'][:order])
index_key = self.apply_aliases(added_fields, index_key)
return {"name": name, "key": index_key, "collection": aggregation['collection'], "operator": "$group", "order": order}
return None
def get_lookup_index(self, operator, stage, order, sequence, aggregation, name):
lookup_pipeline = stage[operator].get("pipeline")
if lookup_pipeline:
return self.get_pipeline_lookup_index(operator, stage, order, sequence, aggregation, name)
else:
return self.get_basic_lookup_index(operator, stage, order, sequence, aggregation, name)
def get_basic_lookup_index(self, operator, stage, order, sequence, aggregation, name):
foreign_field = stage[operator]["foreignField"]
if foreign_field == "_id":
return None
return {"name": name, "key": {foreign_field: 1}, "collection": stage[operator]["from"], "operator": "$lookup", "order": order}
def get_pipeline_lookup_index(self, operator, stage, order, sequence, aggregation, name):
foreign_collection = stage[operator]["from"]
pipeline = stage[operator]["pipeline"]
index_key = {}
for sub_order, sub_stage in enumerate(pipeline):
sub_operator = self.get_aggregation_stage_operator(sub_stage)
if sub_operator == "$match":
sub_sequence = [self.get_aggregation_stage_operator(sub_stage) for sub_stage in pipeline]
sub_index = self.get_match_index(sub_operator, sub_stage, sub_order, sub_sequence, aggregation, name)
if sub_index:
index_key.update(sub_index["key"])
if index_key:
return {"name": name, "key": index_key, "collection": foreign_collection, "operator": "$lookup", "order": order}
return None
def get_aggregation_stage_operator(self, stage):
return list(stage.keys())[0]
def is_field_array(self, comparison):
if isinstance(comparison, list):
return True
elif isinstance(comparison, dict):
for key, value in comparison.items():
if key in ['$eq', '$ne']:
return self.is_field_array(value)
elif key == '$size':
return True
return False
def process_logical_operator(self, params):
logic_fields = {}
for logic_param in params:
for param, value in logic_param.items():
if not param.startswith('$'):
logic_fields[param] = 1
elif param in self.LOGIC_OPS:
logic_fields.update(self.process_logical_operator(value))
elif param in self.COMPARISON_OPS:
logic_fields.update(self.process_comp_operator(value))
elif param == '$expr':
logic_fields.update(self.process_expr_operator(value))
return logic_fields
def process_comp_operator(self, comparison_fields):
comp_index = {}
for comp_field in comparison_fields:
if isinstance(comp_field, str) and comp_field.startswith('$'):
comp_index[comp_field.replace('$', '')] = 1
break
return comp_index
def process_expr_operator(self, element_expr):
expr_index = {}
expr_op = list(element_expr.keys())[0]
if expr_op in self.COMPARISON_OPS:
comparison_fields = element_expr[expr_op]
return self.process_comp_operator(comparison_fields)
elif expr_op in self.LOGIC_OPS:
return self.process_logical_operator(element_expr[expr_op])
return expr_index
def get_added_fields(self, stages):
added_fields = {}
for stage in stages:
operator = self.get_aggregation_stage_operator(stage)
if operator in ['$addFields', '$project']:
for field, value in stage[operator].items():
if operator == '$addFields':
self.parse_alias(field, value, added_fields)
elif operator == '$project' and value not in [0, 1]:
self.parse_alias(field, value, added_fields)
return added_fields
def parse_alias(self, alias, reference, previous):
if isinstance(reference, str) and reference.startswith('$'):
field = reference[1:]
if field in previous:
previous[alias] = previous[field]
else:
previous[alias] = field
else:
previous[alias] = None
def apply_aliases(self, added_fields, index_key):
for alias, field in added_fields.items():
if not field:
index_key = {k: v for k, v in index_key.items() if not k.startswith(alias)}
else:
index_key = {k.replace(alias, field, 1): v if k.startswith(alias) else k: v for k, v in index_key.items()}
return index_key
def main():
aggregations = [
{
"name": "example_aggregation",
"collection": "orders",
"pipeline": [
{"$match": {"status": "processing", "orderDate": {"$gte": "2023-01-01"}}},
{"$sort": {"orderDate": -1}},
{"$group": {"_id": "$customerId", "total": {"$sum": "$amount"}}}
]
}
]
recommender = IndexRecommender()
index_recommendations = recommender.generate(aggregations)
for agg_name, indexes in index_recommendations.items():
print(f"Aggregation: {agg_name}")
for index in indexes:
print(f"Suggested index: {index}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment