Created
June 26, 2025 07:37
-
-
Save 1987yama3/ac72f26931ee047f94787beff9ac14bc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
BigQuery Schema Retriever | |
BigQueryのテーブルスキーマ(列定義)を取得し、各列のサンプル値も含めてJSON形式で出力するプログラム。 | |
ネストされたフィールド(STRUCT型)にも対応。 | |
取得したJSONファイルを、生成AIのコンテストとして活用すれば、クエリ実装の精度が上がるはず(どの列にどのような値が入っているのかを、正しく理解することができため)。 | |
""" | |
import json | |
import sys | |
import argparse | |
from typing import Dict, List, Any, Optional, Union | |
from google.cloud import bigquery | |
from google.cloud.bigquery import SchemaField | |
from collections import Counter | |
class BigQuerySchemaRetriever: | |
def __init__(self, project_id: str = None): | |
""" | |
BigQueryクライアントを初期化 | |
Args: | |
project_id: プロジェクトID(Noneの場合は環境変数から取得) | |
""" | |
self.client = bigquery.Client(project=project_id) | |
def get_table_reference(self, project_id: str, dataset_id: str, table_id: str) -> bigquery.TableReference: | |
""" | |
テーブル参照を作成 | |
Args: | |
project_id: プロジェクトID | |
dataset_id: データセットID | |
table_id: テーブルID | |
Returns: | |
BigQueryテーブル参照 | |
""" | |
dataset_ref = self.client.dataset(dataset_id, project=project_id) | |
return dataset_ref.table(table_id) | |
def get_sample_values(self, project_id: str, dataset_id: str, table_id: str, | |
field_path: str, max_samples: int = 30) -> List[Any]: | |
""" | |
指定されたフィールドのサンプル値を取得 | |
Args: | |
project_id: プロジェクトID | |
dataset_id: データセットID | |
table_id: テーブルID | |
field_path: フィールドパス(ネストされたフィールドの場合はドット記法) | |
max_samples: 最大サンプル数 | |
Returns: | |
サンプル値のリスト | |
""" | |
# フィールド名が単一ならバッククォート、ドット区切りならそのまま | |
if '.' in field_path: | |
sql_field = field_path | |
else: | |
sql_field = f'`{field_path}`' | |
# サンプル値(NULL以外) | |
query = f""" | |
SELECT {sql_field} as field_value, COUNT(*) as count | |
FROM `{project_id}.{dataset_id}.{table_id}` | |
WHERE {sql_field} IS NOT NULL | |
GROUP BY {sql_field} | |
ORDER BY count DESC | |
LIMIT {max_samples} | |
""" | |
# NULL値の存在チェック | |
null_query = f""" | |
SELECT COUNT(*) as null_count | |
FROM `{project_id}.{dataset_id}.{table_id}` | |
WHERE {sql_field} IS NULL | |
""" | |
try: | |
# サンプル値取得 | |
query_job = self.client.query(query) | |
results = query_job.result() | |
sample_values = [] | |
for row in results: | |
sample_values.append(self._convert_sample_value(row.field_value)) | |
# NULL存在チェック | |
null_job = self.client.query(null_query) | |
null_count = list(null_job.result())[0].null_count | |
if null_count > 0: | |
sample_values = [None] + sample_values | |
print(f"[DEBUG] {field_path} sample_values: {sample_values}", file=sys.stderr) | |
return sample_values[:max_samples] | |
except Exception as e: | |
print(f"Warning: Could not get sample values for {field_path}: {e}", file=sys.stderr) | |
return [] | |
def _convert_sample_value(self, value: str) -> Any: | |
""" | |
文字列値を適切なデータ型に変換 | |
Args: | |
value: 変換する値 | |
Returns: | |
変換された値 | |
""" | |
if value is None or value == "NULL": | |
return None | |
# 数値の変換を試みる | |
try: | |
# 整数の場合 | |
if '.' not in str(value) and 'e' not in str(value).lower(): | |
return int(value) | |
# 浮動小数点数の場合 | |
else: | |
return float(value) | |
except (ValueError, TypeError): | |
pass | |
# ブール値の変換を試みる | |
if str(value).lower() in ('true', 'false'): | |
return str(value).lower() == 'true' | |
# その他は文字列として返す | |
return str(value) | |
def get_nested_sample_values(self, project_id: str, dataset_id: str, table_id: str, | |
field_path: str, nested_field: str, max_samples: int = 30, is_repeated: bool = False) -> List[Any]: | |
""" | |
ネストされたフィールドのサンプル値を取得 | |
Args: | |
project_id: プロジェクトID | |
dataset_id: データセットID | |
table_id: テーブルID | |
field_path: 親フィールドパス | |
nested_field: ネストされたフィールド名 | |
max_samples: 最大サンプル数 | |
Returns: | |
サンプル値のリスト | |
""" | |
full_field_path = f"{field_path}.{nested_field}" | |
if is_repeated: | |
# REPEATED STRUCTの場合はUNNEST | |
query = f""" | |
SELECT item.{nested_field} as field_value, COUNT(*) as count | |
FROM `{project_id}.{dataset_id}.{table_id}` | |
CROSS JOIN UNNEST({field_path}) AS item | |
WHERE item.{nested_field} IS NOT NULL | |
GROUP BY item.{nested_field} | |
ORDER BY count DESC | |
LIMIT {max_samples} | |
""" | |
null_query = f""" | |
SELECT COUNT(*) as null_count | |
FROM `{project_id}.{dataset_id}.{table_id}` | |
CROSS JOIN UNNEST({field_path}) AS item | |
WHERE item.{nested_field} IS NULL | |
""" | |
else: | |
# 通常のネスト | |
query = f""" | |
SELECT {full_field_path} as field_value, COUNT(*) as count | |
FROM `{project_id}.{dataset_id}.{table_id}` | |
WHERE {full_field_path} IS NOT NULL | |
GROUP BY {full_field_path} | |
ORDER BY count DESC | |
LIMIT {max_samples} | |
""" | |
null_query = f""" | |
SELECT COUNT(*) as null_count | |
FROM `{project_id}.{dataset_id}.{table_id}` | |
WHERE {full_field_path} IS NULL | |
""" | |
try: | |
query_job = self.client.query(query) | |
results = query_job.result() | |
sample_values = [] | |
for row in results: | |
sample_values.append(self._convert_sample_value(row.field_value)) | |
null_job = self.client.query(null_query) | |
null_count = list(null_job.result())[0].null_count | |
if null_count > 0: | |
sample_values = [None] + sample_values | |
print(f"[DEBUG] {full_field_path} sample_values: {sample_values}", file=sys.stderr) | |
return sample_values[:max_samples] | |
except Exception as e: | |
print(f"Warning: Could not get sample values for nested field {full_field_path}: {e}", file=sys.stderr) | |
return [] | |
def process_schema_field(self, field: SchemaField, project_id: str, dataset_id: str, | |
table_id: str, parent_path: str = "", max_samples: int = 30) -> Dict[str, Any]: | |
""" | |
スキーマフィールドを処理してJSONオブジェクトに変換 | |
Args: | |
field: BigQueryスキーマフィールド | |
project_id: プロジェクトID | |
dataset_id: データセットID | |
table_id: テーブルID | |
parent_path: 親フィールドのパス | |
max_samples: サンプル値の最大数 | |
Returns: | |
フィールド情報のDict | |
""" | |
field_path = f"{parent_path}.{field.name}" if parent_path else field.name | |
field_info = { | |
"name": field.name, | |
"type": field.field_type, | |
"mode": field.mode, | |
"description": field.description or "" | |
} | |
# STRUCT型の場合は子フィールドも処理 | |
if field.field_type == "RECORD": | |
field_info["type"] = "STRUCT" | |
field_info["fields"] = [] | |
is_repeated = (field.mode == "REPEATED") | |
for sub_field in field.fields: | |
sub_field_info = self.process_schema_field( | |
sub_field, project_id, dataset_id, table_id, field_path, max_samples | |
) | |
# サンプル値を取得(REPEATEDの場合のみis_repeated=True、そうでなければFalse) | |
if is_repeated: | |
sub_field_info["sample_values"] = self.get_nested_sample_values( | |
project_id, dataset_id, table_id, field_path, sub_field.name, max_samples, is_repeated=True | |
) | |
else: | |
sub_field_info["sample_values"] = self.get_nested_sample_values( | |
project_id, dataset_id, table_id, field_path, sub_field.name, max_samples, is_repeated=False | |
) | |
field_info["fields"].append(sub_field_info) | |
else: | |
# サンプル値を取得 | |
if parent_path: | |
# ネストされたフィールドの場合 | |
sample_values = self.get_nested_sample_values( | |
project_id, dataset_id, table_id, parent_path, field.name, max_samples | |
) | |
else: | |
# トップレベルのフィールドの場合 | |
sample_values = self.get_sample_values( | |
project_id, dataset_id, table_id, field_path, max_samples | |
) | |
field_info["sample_values"] = sample_values | |
return field_info | |
def get_table_schema_with_samples(self, project_id: str, dataset_id: str, table_id: str, | |
max_samples: int = 30) -> Dict[str, Any]: | |
""" | |
テーブルのスキーマとサンプル値を取得 | |
Args: | |
project_id: プロジェクトID | |
dataset_id: データセットID | |
table_id: テーブルID | |
max_samples: 各フィールドから取得するサンプル値の最大数 | |
Returns: | |
スキーマ情報とサンプル値を含むDict | |
""" | |
# テーブル情報を取得 | |
table_ref = self.get_table_reference(project_id, dataset_id, table_id) | |
table = self.client.get_table(table_ref) | |
# スキーマ情報を処理 | |
fields = [] | |
for field in table.schema: | |
field_info = self.process_schema_field(field, project_id, dataset_id, table_id, | |
max_samples=max_samples) | |
fields.append(field_info) | |
return { | |
"fields": fields | |
} | |
def main(): | |
""" | |
メイン関数:コマンドライン引数から入力を受け取り、スキーマ情報を出力 | |
""" | |
# コマンドライン引数のパーサーを設定 | |
parser = argparse.ArgumentParser( | |
description='BigQueryのテーブルスキーマとサンプル値を取得してJSON形式で出力します。', | |
formatter_class=argparse.RawDescriptionHelpFormatter, | |
epilog=''' | |
使用例: | |
%(prog)s my-project my-dataset my-table | |
%(prog)s --project my-project --dataset my-dataset --table my-table | |
%(prog)s -p my-project -d my-dataset -t my-table --max-samples 50 | |
''' | |
) | |
# 位置引数(従来の方式) | |
parser.add_argument('project_id', nargs='?', | |
help='BigQueryプロジェクトID') | |
parser.add_argument('dataset_id', nargs='?', | |
help='BigQueryデータセットID') | |
parser.add_argument('table_id', nargs='?', | |
help='BigQueryテーブルID') | |
# オプション引数(より柔軟な指定方法) | |
parser.add_argument('-p', '--project', dest='project_option', | |
help='BigQueryプロジェクトID(位置引数の代替)') | |
parser.add_argument('-d', '--dataset', dest='dataset_option', | |
help='BigQueryデータセットID(位置引数の代替)') | |
parser.add_argument('-t', '--table', dest='table_option', | |
help='BigQueryテーブルID(位置引数の代替)') | |
# その他のオプション | |
parser.add_argument('--max-samples', type=int, default=30, | |
help='各フィールドから取得するサンプル値の最大数(デフォルト: 30)') | |
parser.add_argument('--output', '-o', | |
help='出力ファイルパス(指定しない場合は標準出力)') | |
parser.add_argument('--pretty', action='store_true', | |
help='JSON出力を整形して表示') | |
args = parser.parse_args() | |
try: | |
# 引数の優先順位: オプション引数 > 位置引数 | |
project_id = args.project_option or args.project_id | |
dataset_id = args.dataset_option or args.dataset_id | |
table_id = args.table_option or args.table_id | |
# 必須引数のチェック | |
if not project_id: | |
parser.error("プロジェクトIDが指定されていません。位置引数または --project オプションで指定してください。") | |
if not dataset_id: | |
parser.error("データセットIDが指定されていません。位置引数または --dataset オプションで指定してください。") | |
if not table_id: | |
parser.error("テーブルIDが指定されていません。位置引数または --table オプションで指定してください。") | |
# バリデーション | |
if args.max_samples <= 0: | |
parser.error("--max-samples は正の整数である必要があります。") | |
print(f"スキーマを取得中: {project_id}.{dataset_id}.{table_id}", file=sys.stderr) | |
# スキーマ取得処理を実行 | |
retriever = BigQuerySchemaRetriever() | |
# max_samplesを設定するためにメソッドを少し変更する必要があります | |
schema_info = retriever.get_table_schema_with_samples( | |
project_id, dataset_id, table_id, max_samples=args.max_samples | |
) | |
# JSON出力の設定 | |
json_kwargs = { | |
'ensure_ascii': False, | |
'separators': (',', ':') if not args.pretty else None, | |
'indent': 2 if args.pretty else None | |
} | |
# 出力処理 | |
json_output = json.dumps(schema_info, **json_kwargs) | |
if args.output: | |
# ファイルに出力 | |
with open(args.output, 'w', encoding='utf-8') as f: | |
f.write(json_output) | |
print(f"結果を {args.output} に保存しました。", file=sys.stderr) | |
else: | |
# 標準出力に出力 | |
print(json_output) | |
except KeyboardInterrupt: | |
print("\n処理が中断されました。", file=sys.stderr) | |
sys.exit(1) | |
except Exception as e: | |
print(f"エラーが発生しました: {e}", file=sys.stderr) | |
sys.exit(1) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment