Skip to content

Instantly share code, notes, and snippets.

@wongcyrus
Created July 3, 2025 02:29
Show Gist options
  • Save wongcyrus/3fb6059609283b5e4149992d6ce60d84 to your computer and use it in GitHub Desktop.
Save wongcyrus/3fb6059609283b5e4149992d6ce60d84 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Virtual Try-On\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup and Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import base64\n",
"\n",
"def load_image_as_base64(image_path): \n",
" \"\"\"Helper function for preparing image data.\"\"\"\n",
" with open(image_path, \"rb\") as image_file:\n",
" return base64.b64encode(image_file.read()).decode(\"utf-8\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n",
"# SPDX-License-Identifier: Apache-2.0\n",
"\n",
"import base64\n",
"import io\n",
"import json\n",
"import logging\n",
"import os\n",
"\n",
"import boto3\n",
"from PIL import Image\n",
"from botocore.config import Config\n",
"from botocore.exceptions import ClientError\n",
"import glob\n",
"\n",
"\n",
"class ImageError(Exception):\n",
" \"\"\"Custom exception for errors returned by Amazon Nova Canvas\"\"\"\n",
"\n",
" def __init__(self, message):\n",
" self.message = message\n",
"\n",
"\n",
"logger = logging.getLogger(__name__)\n",
"logging.basicConfig(level=logging.INFO)\n",
"\n",
"# Create the Bedrock Runtime client with extended timeout\n",
"bedrock = boto3.client(\n",
" service_name='bedrock-runtime',\n",
" config=Config(read_timeout=300)\n",
")\n",
"\n",
"\n",
"def load_image_as_base64(image_path): \n",
" \"\"\"Helper function for preparing image data.\"\"\"\n",
" with open(image_path, \"rb\") as image_file:\n",
" return base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
"\n",
"\n",
"def remove_background_with_nova_canvas(image_path, output_image_path):\n",
" \"\"\"\n",
" Uses Amazon Bedrock Nova Canvas to remove the background from an image.\n",
" \n",
" Args:\n",
" image_path (str): Path to the input image\n",
" output_image_path (str): Path where the output image will be saved\n",
" \n",
" Returns:\n",
" str: Path to the saved output image\n",
" \n",
" Raises:\n",
" ImageError: If there's an error in image generation\n",
" ClientError: If there's a client error with AWS Bedrock\n",
" \"\"\"\n",
" \n",
" logger.info(\"Removing background from image: %s\", image_path)\n",
" \n",
" try:\n",
" # Ensure output directory exists\n",
" output_dir = os.path.dirname(output_image_path)\n",
" if output_dir and not os.path.exists(output_dir):\n",
" os.makedirs(output_dir, exist_ok=True)\n",
" \n",
" # Prepare the inference parameters for background removal\n",
" inference_params = {\n",
" \"taskType\": \"BACKGROUND_REMOVAL\",\n",
" \"backgroundRemovalParams\": {\n",
" \"image\": load_image_as_base64(image_path)\n",
" }\n",
" }\n",
" body_json = json.dumps(inference_params, indent=2)\n",
"\n",
" # Invoke Nova Canvas\n",
" response = bedrock.invoke_model(\n",
" body=body_json,\n",
" modelId=\"amazon.nova-canvas-v1:0\",\n",
" accept=\"application/json\",\n",
" contentType=\"application/json\"\n",
" )\n",
"\n",
" # Extract the response\n",
" response_body = json.loads(response.get(\"body\").read())\n",
" \n",
" # Check for errors in the response\n",
" error = response_body.get(\"error\")\n",
" if error is not None:\n",
" raise ImageError(f\"Background removal error: {error}\")\n",
" \n",
" # Extract the images from the response\n",
" images = response_body.get(\"images\", [])\n",
"\n",
" if not images:\n",
" raise ImageError(\"No image returned from Nova Canvas.\")\n",
"\n",
" # Decode the base64 image and save to the specified output path\n",
" base64_image = images[0]\n",
" base64_bytes = base64_image.encode('ascii')\n",
" image_bytes = base64.b64decode(base64_bytes)\n",
" \n",
" image = Image.open(io.BytesIO(image_bytes))\n",
" \n",
" # Handle image format conversion\n",
" file_extension = os.path.splitext(output_image_path)[1].lower()\n",
" \n",
" if file_extension in ['.jpg', '.jpeg']:\n",
" # Convert RGBA to RGB for JPEG format (which doesn't support transparency)\n",
" if image.mode in ('RGBA', 'LA'):\n",
" # Create a white background\n",
" background = Image.new('RGB', image.size, (255, 255, 255))\n",
" if image.mode == 'RGBA':\n",
" background.paste(image, mask=image.split()[-1]) # Use alpha channel as mask\n",
" else:\n",
" background.paste(image, mask=image.split()[-1]) # Use alpha channel as mask\n",
" image = background\n",
" elif image.mode != 'RGB':\n",
" image = image.convert('RGB')\n",
" else:\n",
" # For PNG and other formats, keep the original format to preserve transparency\n",
" if file_extension == '.png' and image.mode not in ('RGBA', 'RGB', 'P'):\n",
" image = image.convert('RGBA')\n",
" \n",
" image.save(output_image_path)\n",
" \n",
" logger.info(\"Successfully removed background and saved to: %s\", output_image_path)\n",
" return output_image_path\n",
" \n",
" except ClientError as err:\n",
" message = err.response[\"Error\"][\"Message\"]\n",
" logger.error(\"A client error occurred: %s\", message)\n",
" raise\n",
" except ImageError as err:\n",
" logger.error(err.message)\n",
" raise\n",
" except Exception as err:\n",
" logger.error(\"An unexpected error occurred: %s\", str(err))\n",
" raise\n",
"\n",
"\n",
"# Test the function - changed to PNG to preserve transparency\n",
"input_folder = \"cloths\"\n",
"output_folder = \"remove_background\"\n",
"\n",
"# Ensure output folder exists\n",
"os.makedirs(output_folder, exist_ok=True)\n",
"\n",
"image_paths = glob.glob(os.path.join(input_folder, \"*.jpg\"))\n",
"\n",
"for image_path in image_paths:\n",
" filename = os.path.basename(image_path)\n",
" output_image_path = os.path.join(output_folder, os.path.splitext(filename)[0] + \".png\")\n",
" try:\n",
" result = remove_background_with_nova_canvas(image_path, output_image_path)\n",
" print(f\"Background removal completed. Output saved to: {result}\")\n",
" except Exception as e:\n",
" print(f\"Error processing {image_path}: {e}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image\n",
"import os\n",
"import glob\n",
"\n",
"png_files = glob.glob(os.path.join(output_folder, \"*.png\"))\n",
"\n",
"for png_file in png_files:\n",
" jpg_file = os.path.splitext(png_file)[0] + \".jpg\"\n",
" with Image.open(jpg_file) as img:\n",
" rotated_img = img.rotate(-90, expand=True)\n",
" rotated_img.save(jpg_file)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import base64\n",
"import io\n",
"import json\n",
"\n",
"import boto3\n",
"from PIL import Image\n",
"import os\n",
"\n",
"# Create the Bedrock Runtime client.\n",
"bedrock = boto3.client(service_name=\"bedrock-runtime\", region_name=\"us-east-1\")\n",
"\n",
"cloths_dir = \"remove_background/\"\n",
"cloths_files = [os.path.join(cloths_dir, fname) for fname in os.listdir(cloths_dir) if fname.lower().endswith(\".jpg\") and os.path.isfile(os.path.join(cloths_dir, fname))]\n",
"print(cloths_files)\n",
"\n",
"for cloth in cloths_files:\n",
" # Prepare the inference parameters.\n",
" inference_params = {\n",
" \"taskType\": \"VIRTUAL_TRY_ON\",\n",
" \"virtualTryOnParams\": {\n",
" \"sourceImage\": load_image_as_base64(\"person.jpg\"),\n",
" \"referenceImage\": load_image_as_base64(cloth),\n",
" \"maskType\": \"GARMENT\",\n",
" \"garmentBasedMask\": {\"garmentClass\": \"UPPER_BODY\"}\n",
" }\n",
" }\n",
" # Prepare the invocation payload.\n",
" body_json = json.dumps(inference_params, indent=2)\n",
"\n",
" # Invoke Nova Canvas.\n",
" response = bedrock.invoke_model(\n",
" body=body_json,\n",
" modelId=\"amazon.nova-canvas-v1:0\",\n",
" accept=\"application/json\",\n",
" contentType=\"application/json\"\n",
" )\n",
"\n",
" # Extract the images from the response.\n",
" response_body_json = json.loads(response.get(\"body\").read())\n",
" images = response_body_json.get(\"images\", [])\n",
"\n",
" # Check for errors.\n",
" if response_body_json.get(\"error\"):\n",
" print(response_body_json.get(\"error\"))\n",
"\n",
" # Decode each image from Base64 and save as a PNG file.\n",
" for index, image_base64 in enumerate(images):\n",
" image_bytes = base64.b64decode(image_base64)\n",
" image_buffer = io.BytesIO(image_bytes)\n",
" image = Image.open(image_buffer)\n",
" cloth_basename = os.path.splitext(os.path.basename(cloth))[0]\n",
" image.save(f\"image_{cloth_basename}_{index}.png\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "virtual-try-on-project",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment