Instantly share code, notes, and snippets.
Created
May 14, 2021 16:59
-
Star
(2)
2
You must be signed in to star a gist -
Fork
(1)
1
You must be signed in to fork a gist
-
Save emilyevans-bn/5e79e7c97e3d04ef2419bde7fba56313 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BQDatasetAuthorizer(ContextDecorator): | |
def __init__(self, target, profile_file): | |
self.target = target | |
self.profile_file = profile_file | |
# this dict maps a dataset to a list of BigQuery Access entries defined by code. | |
# We expect groupByEmail and View. | |
# https://googleapis.dev/python/bigquery/latest/generated/google.cloud.bigquery.dataset.AccessEntry.html | |
self.auth_dict = defaultdict(list) | |
def __enter__(self): | |
self.client = create_bq_client_from_profile(self.profile_file, self.target) | |
self.project_id = self.client.project | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
# bq leaves the sockets open so you have to close them | |
close_bq_client(self.client) | |
def get_auth_dict(self): | |
return self.auth_dict | |
def _generate_access_entry_for_group(self, group_email): | |
return bigquery.AccessEntry( | |
role="READER", entity_type="groupByEmail", entity_id=group_email | |
) | |
def _generate_access_entry_for_view(self, view_dataset, view_tablename): | |
dataset_ref = bigquery.dataset.DatasetReference.from_string( | |
f"{self.project_id}.{view_dataset}" | |
) | |
view = bigquery.Table(dataset_ref.table(view_tablename)) | |
# bigquery requires access entries have role=None | |
return bigquery.AccessEntry( | |
role=None, entity_type="view", entity_id=view.reference.to_api_repr() | |
) | |
# parsing helper functions | |
def _get_all_group_auth(self, dbt_base_dir): | |
"""Searches dbt directory's schema.yml for any dataset access entries and returns them in a list | |
Returns: | |
List of dataset dicts with the accesses to be granted | |
""" | |
group_auth = [] | |
for file in dbt_base_dir.glob("models/**/schema.yml"): | |
with open(file) as f: | |
d = yaml.safe_load(f) | |
if d.get("datasets"): | |
group_auth.extend(d.get("datasets")) | |
return group_auth | |
def add_views_to_auth_dict(self, manifest_file): | |
"""Add BQ Access entries for views as that need authorized based on the dbt-generated manifest file to the | |
authorization dictionary. | |
""" | |
node_dict, source_dict, child_map_dict = parse_manifest(manifest_file) | |
def _get_field_from_node(node_name, field): | |
node = node_dict.get(node_name) | |
if node is not None: | |
return deep_dict_lookup(node, field) | |
source = source_dict.get(node_name) | |
if source is not None: | |
return deep_dict_lookup(source, field) | |
return None | |
def _node_is_view(node): | |
return ( | |
_get_field_from_node(node, "resource_type") == "model" | |
and _get_field_from_node(node, "config.materialized") == "view" | |
) | |
for parent_node in child_map_dict.keys(): | |
# Tests are present in the child map and don't want to include them | |
if _get_field_from_node(parent_node, "resource_type") == "test": | |
continue | |
access_entries = [] | |
src_dataset = _get_field_from_node(parent_node, "schema") | |
children_nodes = child_map_dict[parent_node] | |
for child_node in children_nodes: | |
dest_dataset = _get_field_from_node(child_node, "schema") | |
# only have to authorize views in different datasets | |
# todo handle ephemeral models | |
if _node_is_view(child_node) and src_dataset != dest_dataset: | |
auth_view = self._generate_access_entry_for_view( | |
view_dataset=dest_dataset, | |
view_tablename=_get_field_from_node(child_node, "name"), | |
) | |
access_entries.append(auth_view) | |
if len(access_entries) > 0: | |
self.auth_dict[src_dataset].extend(access_entries) | |
def add_google_groups_to_auth_dict(self, dbt_dir): | |
"""Add BQ Access entries for google groups to the authorization dictionary. | |
""" | |
group_auth = self._get_all_group_auth(dbt_dir) | |
for dataset in group_auth: | |
access_dict = dataset["access"] | |
access_entries = list( | |
map(lambda x: self._generate_access_entry_for_group(x), access_dict,) | |
) | |
self.auth_dict[dataset["name"]].extend(access_entries) | |
def authorize_in_bq(self): | |
"""Updates each BigQuery dataset that has new access entries in the authorization dict with those entries. | |
(Aka permissions the newly added groups and authorizes the new views) | |
1. Gets the current access entries for the dataset, | |
2. Identifies new access entries in the authorization dictionary | |
3. Updates Bigquery to enable these accesses | |
""" | |
for dataset, access_list in self.auth_dict.items(): | |
try: | |
cur_dataset = self.client.get_dataset(dataset) | |
except Exception as e: | |
logging.error(f"dataset {dataset} cannot be gotten, exception: {e}") | |
continue | |
access_entries = cur_dataset.access_entries | |
# bigquery will error if duplicates in access entries | |
new_access_entries = [ | |
access | |
for x, access in enumerate(access_list) | |
if access not in access_entries and access not in access_list[:x] | |
] | |
if len(new_access_entries) == 0: | |
continue | |
logging.info( | |
f"adding {len(new_access_entries)} new access_entries for dataset: {dataset}" | |
) | |
access_entries.extend(new_access_entries) | |
cur_dataset.access_entries = access_entries | |
try: | |
self.client.update_dataset(cur_dataset, ["access_entries"]) | |
except Exception as e: | |
logging.error(f"issues updating dataset {dataset}, exception: {e}") | |
continue |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment