|
class IndexRecommender: |
|
LOGIC_OPS = ["$or", "$and", "$nor", "$not"] |
|
COMPARISON_OPS = ['$cmp', '$eq', '$gt', '$gte', '$lt', '$lte', '$ne'] |
|
|
|
def __init__(self): |
|
pass |
|
|
|
def generate(self, aggregations): |
|
result = {} |
|
for aggregation in aggregations: |
|
result[aggregation['name']] = self.generate_indexes(aggregation) |
|
return result |
|
|
|
def generate_indexes(self, aggregation): |
|
indexes = [] |
|
operator_seq = [self.get_aggregation_stage_operator(stage) for stage in aggregation['pipeline']] |
|
|
|
base_collection = aggregation['collection'] |
|
step_num = 0 |
|
|
|
for stage in aggregation['pipeline']: |
|
operator = self.get_aggregation_stage_operator(stage) |
|
name = f"{aggregation['name']}_{base_collection}_{operator.replace('$', '')}_{step_num}" |
|
index = self.get_stage_index(operator, stage, step_num, operator_seq, aggregation, name) |
|
if index: |
|
indexes.append(index) |
|
step_num += 1 |
|
|
|
return indexes |
|
|
|
def get_stage_index(self, operator, stage, order, sequence, aggregation, name): |
|
if operator == '$match': |
|
return self.get_match_index(operator, stage, order, sequence, aggregation, name) |
|
elif operator == '$sort': |
|
return self.get_sort_index(operator, stage, order, sequence, aggregation, name) |
|
elif operator == '$group': |
|
return self.get_group_index(operator, stage, order, sequence, aggregation, name) |
|
elif operator == '$lookup': |
|
return self.get_lookup_index(operator, stage, order, sequence, aggregation, name) |
|
return None |
|
|
|
def get_match_index(self, operator, stage, order, sequence, aggregation, name): |
|
if order > 0: |
|
valid_stages = ["$project", "$match", "$unset", "$set", "$sort", "$addFields"] |
|
for previous in sequence[:order]: |
|
if previous not in valid_stages: |
|
return None |
|
|
|
element_dict = stage[operator] |
|
index_key = {} |
|
|
|
for key, value in element_dict.items(): |
|
if not key.startswith('$'): |
|
if not self.is_field_array(value): |
|
index_key[key] = 1 |
|
elif key in self.LOGIC_OPS: |
|
logic_index = self.process_logical_operator(value) |
|
for field in logic_index: |
|
index_key[field] = 1 |
|
elif key == '$expr': |
|
expr_index = self.process_expr_operator(value) |
|
for field in expr_index: |
|
index_key[field] = 1 |
|
|
|
added_fields = self.get_added_fields(aggregation['pipeline'][:order]) |
|
index_key = self.apply_aliases(added_fields, index_key) |
|
|
|
if index_key: |
|
return {"name": name, "key": index_key, "collection": aggregation['collection'], "operator": "$match", "order": order} |
|
return None |
|
|
|
def get_sort_index(self, operator, stage, order, sequence, aggregation, name): |
|
if order > 0: |
|
invalid_stages = ["$project", "$unwind", "$group"] |
|
for previous in sequence[:order]: |
|
if previous in invalid_stages: |
|
return None |
|
|
|
element_dict = stage[operator] |
|
index_key = {} |
|
|
|
for key, value in element_dict.items(): |
|
if not key.startswith('$') and (value == 1 or value == -1): |
|
index_key[key] = value |
|
|
|
added_fields = self.get_added_fields(aggregation['pipeline'][:order]) |
|
index_key = self.apply_aliases(added_fields, index_key) |
|
|
|
if index_key: |
|
return {"name": name, "key": index_key, "collection": aggregation['collection'], "operator": "$sort", "order": order} |
|
return None |
|
|
|
def get_group_index(self, operator, stage, order, sequence, aggregation, name): |
|
previous_sort_order = 0 |
|
agg_order = 0 |
|
invalid_stages = ["$group", "$project", "$unwind"] |
|
|
|
for previous in sequence[:order]: |
|
if previous in invalid_stages: |
|
return None |
|
if previous == "$sort": |
|
previous_sort_order = agg_order |
|
agg_order += 1 |
|
|
|
element_dict = stage[operator] |
|
first_op_fields = {} |
|
|
|
for key, element in element_dict.items(): |
|
if key != "_id": |
|
for group_op, field in element.items(): |
|
if group_op != "$first": |
|
return None |
|
else: |
|
first_op_fields[field.replace('$', '')] = 1 |
|
|
|
if '$sort' not in sequence[:order]: |
|
group_id_fields = list(element_dict["_id"].values()) |
|
if len(group_id_fields) < 2: |
|
id_field = group_id_fields[0].replace('$', '') |
|
index_key = {id_field: 1, **first_op_fields} |
|
added_fields = self.get_added_fields(aggregation['pipeline'][:order]) |
|
index_key = self.apply_aliases(added_fields, index_key) |
|
return {"name": name, "key": index_key, "collection": aggregation['collection'], "operator": "$group", "order": order} |
|
else: |
|
previous_sort_dict = aggregation['pipeline'][previous_sort_order]["$sort"] |
|
valid_sort_fields = {} |
|
|
|
for i, group_field in enumerate(list(element_dict["_id"].values())): |
|
id_field = group_field.replace('$', '') |
|
if i < len(previous_sort_dict): |
|
sort_field = list(previous_sort_dict.keys())[i] |
|
if id_field == sort_field: |
|
valid_sort_fields[sort_field] = previous_sort_dict[sort_field] |
|
else: |
|
return None |
|
else: |
|
valid_sort_fields[id_field] = 1 |
|
|
|
index_key = {**valid_sort_fields, **first_op_fields} |
|
added_fields = self.get_added_fields(aggregation['pipeline'][:order]) |
|
index_key = self.apply_aliases(added_fields, index_key) |
|
return {"name": name, "key": index_key, "collection": aggregation['collection'], "operator": "$group", "order": order} |
|
|
|
return None |
|
|
|
def get_lookup_index(self, operator, stage, order, sequence, aggregation, name): |
|
lookup_pipeline = stage[operator].get("pipeline") |
|
|
|
if lookup_pipeline: |
|
return self.get_pipeline_lookup_index(operator, stage, order, sequence, aggregation, name) |
|
else: |
|
return self.get_basic_lookup_index(operator, stage, order, sequence, aggregation, name) |
|
|
|
def get_basic_lookup_index(self, operator, stage, order, sequence, aggregation, name): |
|
foreign_field = stage[operator]["foreignField"] |
|
if foreign_field == "_id": |
|
return None |
|
return {"name": name, "key": {foreign_field: 1}, "collection": stage[operator]["from"], "operator": "$lookup", "order": order} |
|
|
|
def get_pipeline_lookup_index(self, operator, stage, order, sequence, aggregation, name): |
|
foreign_collection = stage[operator]["from"] |
|
pipeline = stage[operator]["pipeline"] |
|
index_key = {} |
|
|
|
for sub_order, sub_stage in enumerate(pipeline): |
|
sub_operator = self.get_aggregation_stage_operator(sub_stage) |
|
if sub_operator == "$match": |
|
sub_sequence = [self.get_aggregation_stage_operator(sub_stage) for sub_stage in pipeline] |
|
sub_index = self.get_match_index(sub_operator, sub_stage, sub_order, sub_sequence, aggregation, name) |
|
if sub_index: |
|
index_key.update(sub_index["key"]) |
|
|
|
if index_key: |
|
return {"name": name, "key": index_key, "collection": foreign_collection, "operator": "$lookup", "order": order} |
|
return None |
|
|
|
def get_aggregation_stage_operator(self, stage): |
|
return list(stage.keys())[0] |
|
|
|
def is_field_array(self, comparison): |
|
if isinstance(comparison, list): |
|
return True |
|
elif isinstance(comparison, dict): |
|
for key, value in comparison.items(): |
|
if key in ['$eq', '$ne']: |
|
return self.is_field_array(value) |
|
elif key == '$size': |
|
return True |
|
return False |
|
|
|
def process_logical_operator(self, params): |
|
logic_fields = {} |
|
for logic_param in params: |
|
for param, value in logic_param.items(): |
|
if not param.startswith('$'): |
|
logic_fields[param] = 1 |
|
elif param in self.LOGIC_OPS: |
|
logic_fields.update(self.process_logical_operator(value)) |
|
elif param in self.COMPARISON_OPS: |
|
logic_fields.update(self.process_comp_operator(value)) |
|
elif param == '$expr': |
|
logic_fields.update(self.process_expr_operator(value)) |
|
return logic_fields |
|
|
|
def process_comp_operator(self, comparison_fields): |
|
comp_index = {} |
|
for comp_field in comparison_fields: |
|
if isinstance(comp_field, str) and comp_field.startswith('$'): |
|
comp_index[comp_field.replace('$', '')] = 1 |
|
break |
|
return comp_index |
|
|
|
def process_expr_operator(self, element_expr): |
|
expr_index = {} |
|
expr_op = list(element_expr.keys())[0] |
|
if expr_op in self.COMPARISON_OPS: |
|
comparison_fields = element_expr[expr_op] |
|
return self.process_comp_operator(comparison_fields) |
|
elif expr_op in self.LOGIC_OPS: |
|
return self.process_logical_operator(element_expr[expr_op]) |
|
return expr_index |
|
|
|
def get_added_fields(self, stages): |
|
added_fields = {} |
|
for stage in stages: |
|
operator = self.get_aggregation_stage_operator(stage) |
|
if operator in ['$addFields', '$project']: |
|
for field, value in stage[operator].items(): |
|
if operator == '$addFields': |
|
self.parse_alias(field, value, added_fields) |
|
elif operator == '$project' and value not in [0, 1]: |
|
self.parse_alias(field, value, added_fields) |
|
return added_fields |
|
|
|
def parse_alias(self, alias, reference, previous): |
|
if isinstance(reference, str) and reference.startswith('$'): |
|
field = reference[1:] |
|
if field in previous: |
|
previous[alias] = previous[field] |
|
else: |
|
previous[alias] = field |
|
else: |
|
previous[alias] = None |
|
|
|
def apply_aliases(self, added_fields, index_key): |
|
for alias, field in added_fields.items(): |
|
if not field: |
|
index_key = {k: v for k, v in index_key.items() if not k.startswith(alias)} |
|
else: |
|
index_key = {k.replace(alias, field, 1): v if k.startswith(alias) else k: v for k, v in index_key.items()} |
|
return index_key |