Skip to content

Instantly share code, notes, and snippets.

@bosko
Last active August 16, 2024 12:08
Show Gist options
  • Save bosko/dd9232c561b64557bbe90becb720884a to your computer and use it in GitHub Desktop.
Save bosko/dd9232c561b64557bbe90becb720884a to your computer and use it in GitHub Desktop.
Swin example

Swin experiments

Mix.install([
  {:bumblebee, path: "/Users/bosko/Code/elixir/bumblebee"},
  {:nx, "~> 0.7.2"},
  {:exla, "~> 0.7.2"},
  {:kino, "~> 0.12.3"},
  {:stb_image, "~> 0.6.8"},
  {:req, "~> 0.5.0"}
])
Nx.global_default_backend(EXLA.Backend)

Section

{:ok, swin} = Bumblebee.load_model({:hf, "microsoft/swin-base-patch4-window12-384"})
{:ok, swin_featurizer} = Bumblebee.load_featurizer({:hf, "microsoft/swin-base-patch4-window12-384"})
url_input = Kino.Input.url("Image url")
url = Kino.Input.read(url_input)
raw_image =
  Req.get!(url).body
  |> StbImage.read_binary!()
  |> StbImage.resize(384, 384)
image =
  raw_image
  |> StbImage.to_nx()
  |> Nx.reshape({384, 384, 3})
serving = Bumblebee.Vision.image_classification(swin, swin_featurizer)
Nx.Serving.run(serving, image)
from transformers import AutoFeatureExtractor, SwinForImageClassification
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/swin-base-patch4-window12-384")
model = SwinForImageClassification.from_pretrained("microsoft/swin-base-patch4-window12-384")
inputs = feature_extractor(images=image, return_tensors="pt")
# print("Input shape: ", inputs['pixel_values'].size())
# print("Input: ", inputs)
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment