Created
June 8, 2025 06:30
-
-
Save ShichengChen/3d0c817404d0d063cd246020f3822538 to your computer and use it in GitHub Desktop.
winml sample
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import numpy as np | |
import urllib.request | |
import os | |
from PIL import Image | |
import winrt.windows.ai.machinelearning as winml | |
import winrt.windows.storage as storage | |
import winrt.windows.media as media | |
import winrt.windows.graphics.imaging as imaging | |
async def main(): | |
# Download a pre-trained ONNX model (MobileNetV2 for image classification) | |
model_url = "https://github.com/onnx/models/raw/main/vision/classification/mobilenet/model/mobilenetv2-7.onnx" | |
model_path = "mobilenetv2.onnx" | |
if not os.path.exists(model_path): | |
print("Downloading model...") | |
urllib.request.urlretrieve(model_url, model_path) | |
print("Model downloaded!") | |
# Load the ONNX model | |
print("Loading model...") | |
model_file = await storage.StorageFile.get_file_from_path_async(os.path.abspath(model_path)) | |
model = await winml.LearningModel.load_from_storage_file_async(model_file) | |
print(f"Model loaded: {model.name}") | |
print(f"Model description: {model.description}") | |
print(f"Input features: {[f.name for f in model.input_features]}") | |
print(f"Output features: {[f.name for f in model.output_features]}") | |
# Create inference session | |
device = winml.LearningModelDevice(winml.LearningModelDeviceKind.DEFAULT) | |
session = winml.LearningModelSession(model, device) | |
# Create a sample input (224x224 RGB image) | |
# For this example, we'll create a random image | |
print("Creating sample input...") | |
sample_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) | |
# Convert numpy array to PIL Image | |
pil_image = Image.fromarray(sample_image) | |
# Create Windows ML tensor from image | |
# Note: This is a simplified version - in practice you'd want proper preprocessing | |
input_tensor = create_tensor_from_image(pil_image) | |
# Create binding and bind input | |
binding = winml.LearningModelBinding(session) | |
# Get input feature (usually the first one for image models) | |
input_feature = model.input_features[0] | |
binding.bind(input_feature.name, input_tensor) | |
# Run inference | |
print("Running inference...") | |
results = await session.evaluate_async(binding, "inference_id") | |
# Get output | |
output_feature = model.output_features[0] | |
output_tensor = results.outputs[output_feature.name] | |
print("Inference completed!") | |
print(f"Output shape: {output_tensor.shape}") | |
# For classification models, you'd typically get probabilities | |
# Convert to numpy for easier handling | |
if hasattr(output_tensor, 'get_as_vector_view'): | |
output_values = list(output_tensor.get_as_vector_view()) | |
print(f"First 10 output values: {output_values[:10]}") | |
# Find top prediction | |
max_index = output_values.index(max(output_values)) | |
print(f"Top prediction index: {max_index} with confidence: {max(output_values):.4f}") | |
def create_tensor_from_image(pil_image): | |
"""Convert PIL Image to Windows ML tensor""" | |
# Convert PIL image to numpy array | |
img_array = np.array(pil_image).astype(np.float32) | |
# Normalize to [0,1] range (typical for many models) | |
img_array = img_array / 255.0 | |
# Rearrange from HWC to CHW format (channels first) | |
img_array = np.transpose(img_array, (2, 0, 1)) | |
# Add batch dimension | |
img_array = np.expand_dims(img_array, axis=0) | |
# Create tensor from numpy array | |
tensor = winml.TensorFloat.create_from_array([1, 3, 224, 224], img_array.flatten().tolist()) | |
return tensor | |
if __name__ == "__main__": | |
# Check if running on Windows | |
import sys | |
if sys.platform != "win32": | |
print("This example requires Windows 10/11 with Windows ML support") | |
sys.exit(1) | |
try: | |
asyncio.run(main()) | |
except Exception as e: | |
print(f"Error: {e}") | |
print("Make sure you have Windows ML properly installed and you're running on Windows 10/11") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment