Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save 1987yama3/ac72f26931ee047f94787beff9ac14bc to your computer and use it in GitHub Desktop.
Save 1987yama3/ac72f26931ee047f94787beff9ac14bc to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
BigQuery Schema Retriever
BigQueryのテーブルスキーマ(列定義)を取得し、各列のサンプル値も含めてJSON形式で出力するプログラム。
ネストされたフィールド(STRUCT型)にも対応。
取得したJSONファイルを、生成AIのコンテストとして活用すれば、クエリ実装の精度が上がるはず(どの列にどのような値が入っているのかを、正しく理解することができため)。
"""
import json
import sys
import argparse
from typing import Dict, List, Any, Optional, Union
from google.cloud import bigquery
from google.cloud.bigquery import SchemaField
from collections import Counter
class BigQuerySchemaRetriever:
def __init__(self, project_id: str = None):
"""
BigQueryクライアントを初期化
Args:
project_id: プロジェクトID(Noneの場合は環境変数から取得)
"""
self.client = bigquery.Client(project=project_id)
def get_table_reference(self, project_id: str, dataset_id: str, table_id: str) -> bigquery.TableReference:
"""
テーブル参照を作成
Args:
project_id: プロジェクトID
dataset_id: データセットID
table_id: テーブルID
Returns:
BigQueryテーブル参照
"""
dataset_ref = self.client.dataset(dataset_id, project=project_id)
return dataset_ref.table(table_id)
def get_sample_values(self, project_id: str, dataset_id: str, table_id: str,
field_path: str, max_samples: int = 30) -> List[Any]:
"""
指定されたフィールドのサンプル値を取得
Args:
project_id: プロジェクトID
dataset_id: データセットID
table_id: テーブルID
field_path: フィールドパス(ネストされたフィールドの場合はドット記法)
max_samples: 最大サンプル数
Returns:
サンプル値のリスト
"""
# フィールド名が単一ならバッククォート、ドット区切りならそのまま
if '.' in field_path:
sql_field = field_path
else:
sql_field = f'`{field_path}`'
# サンプル値(NULL以外)
query = f"""
SELECT {sql_field} as field_value, COUNT(*) as count
FROM `{project_id}.{dataset_id}.{table_id}`
WHERE {sql_field} IS NOT NULL
GROUP BY {sql_field}
ORDER BY count DESC
LIMIT {max_samples}
"""
# NULL値の存在チェック
null_query = f"""
SELECT COUNT(*) as null_count
FROM `{project_id}.{dataset_id}.{table_id}`
WHERE {sql_field} IS NULL
"""
try:
# サンプル値取得
query_job = self.client.query(query)
results = query_job.result()
sample_values = []
for row in results:
sample_values.append(self._convert_sample_value(row.field_value))
# NULL存在チェック
null_job = self.client.query(null_query)
null_count = list(null_job.result())[0].null_count
if null_count > 0:
sample_values = [None] + sample_values
print(f"[DEBUG] {field_path} sample_values: {sample_values}", file=sys.stderr)
return sample_values[:max_samples]
except Exception as e:
print(f"Warning: Could not get sample values for {field_path}: {e}", file=sys.stderr)
return []
def _convert_sample_value(self, value: str) -> Any:
"""
文字列値を適切なデータ型に変換
Args:
value: 変換する値
Returns:
変換された値
"""
if value is None or value == "NULL":
return None
# 数値の変換を試みる
try:
# 整数の場合
if '.' not in str(value) and 'e' not in str(value).lower():
return int(value)
# 浮動小数点数の場合
else:
return float(value)
except (ValueError, TypeError):
pass
# ブール値の変換を試みる
if str(value).lower() in ('true', 'false'):
return str(value).lower() == 'true'
# その他は文字列として返す
return str(value)
def get_nested_sample_values(self, project_id: str, dataset_id: str, table_id: str,
field_path: str, nested_field: str, max_samples: int = 30, is_repeated: bool = False) -> List[Any]:
"""
ネストされたフィールドのサンプル値を取得
Args:
project_id: プロジェクトID
dataset_id: データセットID
table_id: テーブルID
field_path: 親フィールドパス
nested_field: ネストされたフィールド名
max_samples: 最大サンプル数
Returns:
サンプル値のリスト
"""
full_field_path = f"{field_path}.{nested_field}"
if is_repeated:
# REPEATED STRUCTの場合はUNNEST
query = f"""
SELECT item.{nested_field} as field_value, COUNT(*) as count
FROM `{project_id}.{dataset_id}.{table_id}`
CROSS JOIN UNNEST({field_path}) AS item
WHERE item.{nested_field} IS NOT NULL
GROUP BY item.{nested_field}
ORDER BY count DESC
LIMIT {max_samples}
"""
null_query = f"""
SELECT COUNT(*) as null_count
FROM `{project_id}.{dataset_id}.{table_id}`
CROSS JOIN UNNEST({field_path}) AS item
WHERE item.{nested_field} IS NULL
"""
else:
# 通常のネスト
query = f"""
SELECT {full_field_path} as field_value, COUNT(*) as count
FROM `{project_id}.{dataset_id}.{table_id}`
WHERE {full_field_path} IS NOT NULL
GROUP BY {full_field_path}
ORDER BY count DESC
LIMIT {max_samples}
"""
null_query = f"""
SELECT COUNT(*) as null_count
FROM `{project_id}.{dataset_id}.{table_id}`
WHERE {full_field_path} IS NULL
"""
try:
query_job = self.client.query(query)
results = query_job.result()
sample_values = []
for row in results:
sample_values.append(self._convert_sample_value(row.field_value))
null_job = self.client.query(null_query)
null_count = list(null_job.result())[0].null_count
if null_count > 0:
sample_values = [None] + sample_values
print(f"[DEBUG] {full_field_path} sample_values: {sample_values}", file=sys.stderr)
return sample_values[:max_samples]
except Exception as e:
print(f"Warning: Could not get sample values for nested field {full_field_path}: {e}", file=sys.stderr)
return []
def process_schema_field(self, field: SchemaField, project_id: str, dataset_id: str,
table_id: str, parent_path: str = "", max_samples: int = 30) -> Dict[str, Any]:
"""
スキーマフィールドを処理してJSONオブジェクトに変換
Args:
field: BigQueryスキーマフィールド
project_id: プロジェクトID
dataset_id: データセットID
table_id: テーブルID
parent_path: 親フィールドのパス
max_samples: サンプル値の最大数
Returns:
フィールド情報のDict
"""
field_path = f"{parent_path}.{field.name}" if parent_path else field.name
field_info = {
"name": field.name,
"type": field.field_type,
"mode": field.mode,
"description": field.description or ""
}
# STRUCT型の場合は子フィールドも処理
if field.field_type == "RECORD":
field_info["type"] = "STRUCT"
field_info["fields"] = []
is_repeated = (field.mode == "REPEATED")
for sub_field in field.fields:
sub_field_info = self.process_schema_field(
sub_field, project_id, dataset_id, table_id, field_path, max_samples
)
# サンプル値を取得(REPEATEDの場合のみis_repeated=True、そうでなければFalse)
if is_repeated:
sub_field_info["sample_values"] = self.get_nested_sample_values(
project_id, dataset_id, table_id, field_path, sub_field.name, max_samples, is_repeated=True
)
else:
sub_field_info["sample_values"] = self.get_nested_sample_values(
project_id, dataset_id, table_id, field_path, sub_field.name, max_samples, is_repeated=False
)
field_info["fields"].append(sub_field_info)
else:
# サンプル値を取得
if parent_path:
# ネストされたフィールドの場合
sample_values = self.get_nested_sample_values(
project_id, dataset_id, table_id, parent_path, field.name, max_samples
)
else:
# トップレベルのフィールドの場合
sample_values = self.get_sample_values(
project_id, dataset_id, table_id, field_path, max_samples
)
field_info["sample_values"] = sample_values
return field_info
def get_table_schema_with_samples(self, project_id: str, dataset_id: str, table_id: str,
max_samples: int = 30) -> Dict[str, Any]:
"""
テーブルのスキーマとサンプル値を取得
Args:
project_id: プロジェクトID
dataset_id: データセットID
table_id: テーブルID
max_samples: 各フィールドから取得するサンプル値の最大数
Returns:
スキーマ情報とサンプル値を含むDict
"""
# テーブル情報を取得
table_ref = self.get_table_reference(project_id, dataset_id, table_id)
table = self.client.get_table(table_ref)
# スキーマ情報を処理
fields = []
for field in table.schema:
field_info = self.process_schema_field(field, project_id, dataset_id, table_id,
max_samples=max_samples)
fields.append(field_info)
return {
"fields": fields
}
def main():
"""
メイン関数:コマンドライン引数から入力を受け取り、スキーマ情報を出力
"""
# コマンドライン引数のパーサーを設定
parser = argparse.ArgumentParser(
description='BigQueryのテーブルスキーマとサンプル値を取得してJSON形式で出力します。',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog='''
使用例:
%(prog)s my-project my-dataset my-table
%(prog)s --project my-project --dataset my-dataset --table my-table
%(prog)s -p my-project -d my-dataset -t my-table --max-samples 50
'''
)
# 位置引数(従来の方式)
parser.add_argument('project_id', nargs='?',
help='BigQueryプロジェクトID')
parser.add_argument('dataset_id', nargs='?',
help='BigQueryデータセットID')
parser.add_argument('table_id', nargs='?',
help='BigQueryテーブルID')
# オプション引数(より柔軟な指定方法)
parser.add_argument('-p', '--project', dest='project_option',
help='BigQueryプロジェクトID(位置引数の代替)')
parser.add_argument('-d', '--dataset', dest='dataset_option',
help='BigQueryデータセットID(位置引数の代替)')
parser.add_argument('-t', '--table', dest='table_option',
help='BigQueryテーブルID(位置引数の代替)')
# その他のオプション
parser.add_argument('--max-samples', type=int, default=30,
help='各フィールドから取得するサンプル値の最大数(デフォルト: 30)')
parser.add_argument('--output', '-o',
help='出力ファイルパス(指定しない場合は標準出力)')
parser.add_argument('--pretty', action='store_true',
help='JSON出力を整形して表示')
args = parser.parse_args()
try:
# 引数の優先順位: オプション引数 > 位置引数
project_id = args.project_option or args.project_id
dataset_id = args.dataset_option or args.dataset_id
table_id = args.table_option or args.table_id
# 必須引数のチェック
if not project_id:
parser.error("プロジェクトIDが指定されていません。位置引数または --project オプションで指定してください。")
if not dataset_id:
parser.error("データセットIDが指定されていません。位置引数または --dataset オプションで指定してください。")
if not table_id:
parser.error("テーブルIDが指定されていません。位置引数または --table オプションで指定してください。")
# バリデーション
if args.max_samples <= 0:
parser.error("--max-samples は正の整数である必要があります。")
print(f"スキーマを取得中: {project_id}.{dataset_id}.{table_id}", file=sys.stderr)
# スキーマ取得処理を実行
retriever = BigQuerySchemaRetriever()
# max_samplesを設定するためにメソッドを少し変更する必要があります
schema_info = retriever.get_table_schema_with_samples(
project_id, dataset_id, table_id, max_samples=args.max_samples
)
# JSON出力の設定
json_kwargs = {
'ensure_ascii': False,
'separators': (',', ':') if not args.pretty else None,
'indent': 2 if args.pretty else None
}
# 出力処理
json_output = json.dumps(schema_info, **json_kwargs)
if args.output:
# ファイルに出力
with open(args.output, 'w', encoding='utf-8') as f:
f.write(json_output)
print(f"結果を {args.output} に保存しました。", file=sys.stderr)
else:
# 標準出力に出力
print(json_output)
except KeyboardInterrupt:
print("\n処理が中断されました。", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"エラーが発生しました: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment