Skip to content

Instantly share code, notes, and snippets.

@yamini
Last active January 10, 2025 22:54
Show Gist options
  • Save yamini/2d5bd1a708a085b7847ea16a4fdf7480 to your computer and use it in GitHub Desktop.
Save yamini/2d5bd1a708a085b7847ea16a4fdf7480 to your computer and use it in GitHub Desktop.
navigator-fine-tuning-intro-s3.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "7a3ALTpLPe5S"
},
"source": [
"\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/yamini/2d5bd1a708a085b7847ea16a4fdf7480/navigator-fine-tuning-intro-s3.ipynb) <br>\n",
"\n",
"<br>\n",
"\n",
"<center><img src=\"https://gretel-public-website.s3.us-west-2.amazonaws.com/assets/brand/gretel_brand_wordmark.svg\" alt=\"Gretel\" width=\"350\"/></center>\n",
"\n",
"<br>\n",
"\n",
"## 👋 Welcome to the **Navigator Fine Tuning** Intro Notebook!\n",
"\n",
"In this Notebook, we will demonstrate how to use Gretel's SDK to train [**Navigator Fine Tuning**](https://docs.gretel.ai/create-synthetic-data/models/synthetics/gretel-navigator-fine-tuning) to generate high-quality synthetic data. We will keep it simple in this tutorial and limit our focus to basic usage of the model for generating tabular data with _independent_ records.\n",
"\n",
"<br>\n",
"\n",
"## ✅ Set up your Gretel account\n",
"\n",
"To get started, you will need a [free Gretel account](https://console.gretel.ai/).\n",
"\n",
"If this is your first time using the Gretel SDK, we recommend starting with our [Gretel SDK Blueprints](https://docs.gretel.ai/gretel-basics/getting-started/blueprints).\n",
"\n",
"\n",
"<br>\n",
"\n",
"#### Ready? Let's go 🚀"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RDiYHg6WQ_kt"
},
"source": [
"## 💾 Install `gretel-client` and its dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Nz1NRizhPOKX"
},
"outputs": [],
"source": [
"%%capture\n",
"!pip install gretel-client"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tVHl9ax1RAfO"
},
"source": [
"## 🛜 Configure your Gretel session\n",
"\n",
"- [The Gretel object](https://docs.gretel.ai/create-synthetic-data/gretel-sdk/the-gretel-object) provides a high-level interface for streamlining interactions with Gretel's APIs.\n",
"\n",
"- Retrieve your Gretel API key [here](https://console.gretel.ai/users/me/key)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SQtmS1WWaE1M"
},
"outputs": [],
"source": [
"from gretel_client import Gretel\n",
"\n",
"gretel = Gretel(\n",
" project_name=\"tab-ft-intro\",\n",
" api_key=\"prompt\",\n",
" endpoint=\"https://api.gretel.cloud\",\n",
" validate=True,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qQNQXmptH_VS"
},
"source": [
"## 📊 Tabular Data\n",
"\n",
"Generating tabular data is the most straightforward application of Navigator Fine Tuning. In this case, the models [default configuration](https://github.com/gretelai/gretel-blueprints/tree/main/config_templates/gretel/synthetics/navigator-ft.yml) parameters are an excellent place to start."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xvbvZCsdaaEm"
},
"outputs": [],
"source": [
"# @title Pick a tabular dataset 👇 { run: \"auto\" }\n",
"dataset_path_dict = {\n",
" \"adult income in the USA (14000 records, 15 fields)\": \"https://raw.githubusercontent.com/gretelai/gretel-blueprints/main/sample_data/us-adult-income.csv\",\n",
" \"hospital length of stay (9999 records, 18 fields)\": \"https://raw.githubusercontent.com/gretelai/gretel-blueprints/main/sample_data/sample-synthetic-healthcare.csv\",\n",
" \"customer churn (7032 records, 21 fields)\": \"https://raw.githubusercontent.com/gretelai/gretel-blueprints/main/sample_data/monthly-customer-payments.csv\"\n",
"}\n",
"\n",
"data_source = \"hospital length of stay (9999 records, 18 fields)\" # @param [\"adult income in the USA (14000 records, 15 fields)\", \"hospital length of stay (9999 records, 18 fields)\", \"customer churn (7032 records, 21 fields)\"]\n",
"data_source = dataset_path_dict[data_source]\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X55i5QUbiHam"
},
"source": [
"## 🏋️‍♂️ Train a generative model\n",
"\n",
"- The `navigator-ft` base config tells Gretel we want to train with **Navigator Fine Tuning** using its default parameters.\n",
"\n",
"- **Navigator Fine Tuning** is an LLM under the hood. Before training begins, information about how the input data was tokenized and assembled into examples will be logged in the cell output (as well as in Gretel's Console).\n",
"\n",
"- Generation of a dataset for evaluation will begin immediately after the model completes training. The rate at which the model produces valid records will be logged to help assess how well the model is performing."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tBf8FoaJGurT"
},
"outputs": [],
"source": [
"trained = gretel.submit_train(\"navigator-ft\", data_source=data_source)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3-trgRWqz9Gi"
},
"outputs": [],
"source": [
"# view the quality scores\n",
"trained.report"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Wj-XGm_8ky91"
},
"outputs": [],
"source": [
"# display the full report within this notebook\n",
"trained.report.display_in_notebook()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sULbdBh4k71u"
},
"outputs": [],
"source": [
"# inspect the synthetic data used to create the report\n",
"df_synth_report = trained.fetch_report_synthetic_data()\n",
"df_synth_report.head()"
]
},
{
"cell_type": "markdown",
"source": [
"## 🪣 [Optional] Upload to S3 Bucket"
],
"metadata": {
"id": "K56sOsUy9vqH"
}
},
{
"cell_type": "code",
"source": [
"%%capture\n",
"!pip install boto3"
],
"metadata": {
"id": "G8cidhuw9vRp"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import boto3\n",
"import uuid\n",
"from datetime import datetime\n",
"\n",
"# Get AWS credentials\n",
"aws_access_key = input(\"Enter your AWS access key: \")\n",
"aws_secret_key = input(\"Enter your AWS secret key: \")\n",
"bucket_name = input(\"Enter your S3 bucket name: \")\n",
"\n",
"# Generate unique filename using model name and timestamp\n",
"timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
"short_hash = str(uuid.uuid4())[:8]\n",
"filename = f\"navigator_ft_data_{timestamp}_{short_hash}.csv\"\n",
"\n",
"# Initialize S3 client\n",
"s3_client = boto3.client(\n",
" 's3',\n",
" aws_access_key_id=aws_access_key,\n",
" aws_secret_access_key=aws_secret_key\n",
")\n",
"\n",
"# Save DataFrame to CSV in memory and upload\n",
"try:\n",
" csv_buffer = df_synth_report.to_csv(index=False).encode()\n",
" s3_client.put_object(Bucket=bucket_name, Key=filename, Body=csv_buffer)\n",
" print(f\"Successfully uploaded data to s3://{bucket_name}/{filename}\")\n",
"except Exception as e:\n",
" print(f\"Error uploading to S3: {str(e)}\")"
],
"metadata": {
"id": "O5BRgDnj9qUv"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment