Skip to content

Instantly share code, notes, and snippets.

@sssemil
Created November 15, 2023 11:26
Show Gist options
  • Save sssemil/74a466d622097e800d7b3d0a860d533b to your computer and use it in GitHub Desktop.
Save sssemil/74a466d622097e800d7b3d0a860d533b to your computer and use it in GitHub Desktop.
Simple depth from monocular cam demo
import numpy as np
import torch
import cv2
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
from transformers import DPTForDepthEstimation, DPTImageProcessor
def depth_estimation(model, feature_extractor, image):
inputs = feature_extractor(images=image, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
).float()
output = prediction.squeeze().cpu().numpy()
formatted = (output * 255 / np.max(output)).astype("uint8")
return cv2.cvtColor(np.array(formatted), cv2.COLOR_GRAY2BGR)
# Load models and feature extractors
glpn_model = AutoModelForDepthEstimation.from_pretrained("vinvino02/glpn-nyu").to("cuda")
glpn_feature_extractor = AutoImageProcessor.from_pretrained("vinvino02/glpn-nyu")
dpt_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas", low_cpu_mem_usage=True).to("cuda")
dpt_feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
# Start video capture
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("Error: Camera device is not accessible.")
exit()
while True:
ret, frame = cap.read()
if not ret:
print("Error: Unable to capture video.")
break
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# Depth estimation for both models
glpn_depth_map = depth_estimation(glpn_model, glpn_feature_extractor, image)
dpt_depth_map = depth_estimation(dpt_model, dpt_feature_extractor, image)
# Invert colors of the GLPN depth map
glpn_depth_map = 255 - glpn_depth_map
combined = np.hstack((frame, glpn_depth_map, dpt_depth_map))
cv2.imshow('Camera Input, GLPN Depth, DPT Depth', combined)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment