Created
September 11, 2025 21:05
-
-
Save carsongee/4c68c32f93ece2e5325820f9faa0c236 to your computer and use it in GitHub Desktop.
Finds all <ENV>_DATAROBOT_ENDPOINT variables from a .env file and hits the catalog to find all models and identify ones that match
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env uv run | |
| # /// script | |
| # dependencies = [ | |
| # "httpx", | |
| # "python-dotenv", | |
| # ] | |
| # /// | |
| """ | |
| Script to find the intersection of models available across all DataRobot endpoints. | |
| Usage: ./catalog_intersection.py | |
| """ | |
| import os | |
| import json | |
| import httpx | |
| from dotenv import load_dotenv | |
| def parse_env_file(): | |
| """Parse .env file to extract endpoint/token pairs.""" | |
| load_dotenv() | |
| endpoints = {} | |
| # Get all environment variables | |
| env_vars = dict(os.environ) | |
| # Find all endpoint/token pairs | |
| for key, value in env_vars.items(): | |
| if key.endswith('_DATAROBOT_ENDPOINT'): | |
| # Extract the prefix (e.g., 'STAGING', 'USPROD', etc.) | |
| prefix = key.replace('_DATAROBOT_ENDPOINT', '') | |
| token_key = f"{prefix}_DATAROBOT_API_TOKEN" | |
| if token_key in env_vars: | |
| endpoints[prefix] = { | |
| 'endpoint': value.strip("'\""), | |
| 'token': env_vars[token_key].strip("'\"") | |
| } | |
| return endpoints | |
| async def fetch_catalog(client, endpoint_name, endpoint_url, token): | |
| """Fetch catalog from a DataRobot endpoint.""" | |
| url = f"{endpoint_url}/genai/llmgw/catalog/" | |
| headers = {"Authorization": f"Bearer {token}"} | |
| try: | |
| print(f"Fetching catalog from {endpoint_name}: {url}") | |
| response = await client.get(url, headers=headers) | |
| response.raise_for_status() | |
| data = response.json() | |
| # Extract models from the response (looking for 'model' field) | |
| models = set() | |
| def extract_models(obj): | |
| """Recursively extract model names from JSON object.""" | |
| if isinstance(obj, dict): | |
| if 'model' in obj: | |
| models.add(obj['model']) | |
| for value in obj.values(): | |
| extract_models(value) | |
| elif isinstance(obj, list): | |
| for item in obj: | |
| extract_models(item) | |
| extract_models(data) | |
| print(f"Found {len(models)} models in {endpoint_name}") | |
| return models | |
| except httpx.HTTPError as e: | |
| print(f"Error fetching from {endpoint_name}: {e}") | |
| return set() | |
| except json.JSONDecodeError as e: | |
| print(f"Error parsing JSON from {endpoint_name}: {e}") | |
| return set() | |
| async def main(): | |
| """Main function to find intersection of models across all endpoints.""" | |
| print("DataRobot Catalog Intersection Tool") | |
| print("=" * 40) | |
| # Parse environment file | |
| endpoints = parse_env_file() | |
| if not endpoints: | |
| print("No DataRobot endpoints found in .env file!") | |
| return | |
| print(f"Found {len(endpoints)} endpoints:") | |
| for name in endpoints: | |
| print(f" - {name}") | |
| print() | |
| # Fetch catalogs from all endpoints | |
| async with httpx.AsyncClient(timeout=30.0) as client: | |
| endpoint_models = {} | |
| for name, config in endpoints.items(): | |
| models = await fetch_catalog(client, name, config['endpoint'], config['token']) | |
| endpoint_models[name] = models | |
| print() | |
| print("Results:") | |
| print("-" * 20) | |
| # Display models from each endpoint | |
| for name, models in endpoint_models.items(): | |
| print(f"{name}: {len(models)} models") | |
| if models: | |
| for model in sorted(models): | |
| print(f" - {model}") | |
| print() | |
| # Find intersection | |
| if endpoint_models: | |
| all_model_sets = list(endpoint_models.values()) | |
| # Filter out empty sets to avoid empty intersection | |
| non_empty_sets = [s for s in all_model_sets if s] | |
| if non_empty_sets: | |
| intersection = set.intersection(*non_empty_sets) | |
| print(f"Models available in ALL {len(non_empty_sets)} endpoints:") | |
| print("=" * 50) | |
| if intersection: | |
| for model in sorted(intersection): | |
| print(f"✓ {model}") | |
| print(f"\nTotal: {len(intersection)} models") | |
| else: | |
| print("No models are available in all endpoints.") | |
| else: | |
| print("No models found in any endpoint.") | |
| print("\nSummary:") | |
| for name, models in endpoint_models.items(): | |
| status = "✓" if models else "✗" | |
| print(f"{status} {name}: {len(models)} models") | |
| if __name__ == "__main__": | |
| import asyncio | |
| asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment