Created
March 16, 2025 02:57
-
-
Save yoi-hibino/77fb8a65e375d4aa06522cf6cea8ff25 to your computer and use it in GitHub Desktop.
Real-time LiDAR Localization in CAD Environment
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import open3d as o3d | |
import transforms3d as t3d | |
import trimesh | |
import scipy.spatial | |
from sklearn.neighbors import KDTree | |
import time | |
import threading | |
import queue | |
import concurrent.futures | |
from collections import deque | |
import matplotlib.pyplot as plt | |
class RealtimeSensorLocalizer: | |
def __init__(self, cad_model_path, scale_factor=1.0, precompute_slices=True, num_slices=12): | |
""" | |
Initialize the localizer with a 3D CAD model | |
Args: | |
cad_model_path: Path to the 3D CAD model file (OBJ, STL, etc.) | |
scale_factor: Scale factor to convert CAD model units to real-world meters | |
precompute_slices: Whether to precompute slices or generate them on demand | |
num_slices: Number of different orientations to precompute | |
""" | |
self.cad_model = trimesh.load(cad_model_path) | |
self.scale_factor = scale_factor | |
# Apply scale factor to the model if needed | |
if scale_factor != 1.0: | |
self.cad_model.apply_scale(scale_factor) | |
# Parameters for real-time operation | |
self.pose_history = deque(maxlen=5) # Keep last 5 poses for smoothing | |
self.running = False | |
self.processing_thread = None | |
self.scan_queue = queue.Queue(maxsize=2) # Limit queue size to prevent lag | |
self.latest_pose = None | |
self.pose_lock = threading.Lock() | |
# Performance tracking | |
self.timing_stats = { | |
'slice_generation': [], | |
'matching': [], | |
'total_processing': [] | |
} | |
# Precompute slices if requested | |
self.reference_slices = {} | |
if precompute_slices: | |
print("Precomputing CAD model slices...") | |
start_time = time.time() | |
self.reference_slices = self._extract_optimized_slices(num_slices) | |
end_time = time.time() | |
print(f"Precomputed {len(self.reference_slices)} slices in {end_time - start_time:.2f} seconds") | |
# Pre-build KD-trees for each slice | |
self.slice_kdtrees = {} | |
for orientation, points in self.reference_slices.items(): | |
self.slice_kdtrees[orientation] = KDTree(points) | |
def _extract_optimized_slices(self, num_orientations=12): | |
""" | |
Extract slices from the CAD model using an optimized approach | |
Args: | |
num_orientations: Number of orientations to extract | |
Returns: | |
Dictionary mapping orientation parameters to point arrays | |
""" | |
reference_slices = {} | |
# Center of the model for slice origin | |
center = self.cad_model.bounds.mean(axis=0) | |
# Generate orientations more efficiently using a preset distribution | |
# Focus on orientations that are likely to be useful (near horizontal) | |
# Generate roll/pitch pairs | |
roll_range = np.linspace(-np.pi/6, np.pi/6, int(np.sqrt(num_orientations))) # ±30° | |
pitch_range = np.linspace(-np.pi/6, np.pi/6, int(np.sqrt(num_orientations))) # ±30° | |
# Include the perfectly horizontal slice as it's often useful | |
orientations = [(0, 0)] | |
# Add other orientations | |
for roll in roll_range: | |
for pitch in pitch_range: | |
# Skip if too similar to existing orientations | |
if (roll, pitch) != (0, 0) and not np.isclose(roll, 0, atol=0.01) and not np.isclose(pitch, 0, atol=0.01): | |
orientations.append((roll, pitch)) | |
# Process each orientation | |
for roll, pitch in orientations: | |
# Create rotation matrix for this orientation | |
R = t3d.euler.euler2mat(roll, pitch, 0, 'sxyz') | |
# Normal vector for the slicing plane | |
normal = R @ np.array([0, 0, 1]) | |
try: | |
# Generate slice with optimized parameters | |
slice_path = self.cad_model.section( | |
plane_origin=center, | |
plane_normal=normal | |
) | |
if slice_path is not None and len(slice_path.entities) > 0: | |
# Convert to points more efficiently | |
points = [] | |
for entity in slice_path.entities: | |
for idx in entity.points: | |
points.append(slice_path.vertices[idx]) | |
if points: | |
# Downsample points for efficiency if there are too many | |
points_array = np.array(points) | |
if len(points_array) > 1000: | |
# Simple downsampling by taking every nth point | |
step = len(points_array) // 1000 | |
points_array = points_array[::step] | |
reference_slices[(roll, pitch, tuple(center))] = points_array | |
except Exception as e: | |
print(f"Skipping slice at R={np.degrees(roll):.1f}°, P={np.degrees(pitch):.1f}°: {e}") | |
return reference_slices | |
def start_realtime_localization(self, initial_pose=None): | |
""" | |
Start real-time localization thread | |
Args: | |
initial_pose: Initial [x, y, z, roll, pitch, yaw] guess | |
""" | |
if self.running: | |
print("Real-time localization already running") | |
return | |
self.running = True | |
# Initialize pose | |
if initial_pose is None: | |
initial_pose = [0, 0, 1.0, 0, 0, 0] | |
with self.pose_lock: | |
self.latest_pose = initial_pose | |
self.pose_history.clear() | |
self.pose_history.append(initial_pose) | |
# Start processing thread | |
self.processing_thread = threading.Thread( | |
target=self._localization_thread, | |
daemon=True | |
) | |
self.processing_thread.start() | |
print("Real-time localization started") | |
def stop_realtime_localization(self): | |
"""Stop the real-time localization thread""" | |
self.running = False | |
if self.processing_thread: | |
self.processing_thread.join(timeout=1.0) | |
self.processing_thread = None | |
print("Real-time localization stopped") | |
def process_scan(self, ranges, angles, imu_roll, imu_pitch, imu_yaw): | |
""" | |
Process a new LiDAR scan and add it to the processing queue | |
Args: | |
ranges: Array of range measurements | |
angles: Array of corresponding angles | |
imu_roll, imu_pitch, imu_yaw: Orientation from IMU | |
Returns: | |
True if scan was queued, False if queue was full | |
""" | |
if not self.running: | |
print("Warning: Real-time localization not running") | |
return False | |
# Convert scan to points | |
scan_points = self._scan_to_points(ranges, angles, imu_roll, imu_pitch, imu_yaw) | |
try: | |
# Try to add to queue without blocking | |
self.scan_queue.put_nowait((scan_points, (imu_roll, imu_pitch, imu_yaw))) | |
return True | |
except queue.Full: | |
# Queue is full, skip this scan | |
return False | |
def get_current_pose(self): | |
"""Get the latest estimated pose""" | |
with self.pose_lock: | |
return self.latest_pose.copy() if self.latest_pose is not None else None | |
def _scan_to_points(self, ranges, angles, imu_roll, imu_pitch, imu_yaw): | |
"""Convert LiDAR scan to 3D points using IMU orientation""" | |
# Convert polar coordinates to 3D Cartesian in LiDAR frame | |
x = ranges * np.cos(angles) | |
y = ranges * np.sin(angles) | |
z = np.zeros_like(x) | |
# Stack into 3D points | |
points = np.column_stack((x, y, z)) | |
# Apply IMU-based rotation | |
R = t3d.euler.euler2mat(imu_roll, imu_pitch, imu_yaw, 'sxyz') | |
corrected_points = np.dot(points, R.T) | |
return corrected_points | |
def _localization_thread(self): | |
"""Background thread for processing scans""" | |
print("Localization thread started") | |
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) | |
while self.running: | |
try: | |
# Get next scan from queue with timeout | |
scan_data = self.scan_queue.get(timeout=0.1) | |
# Process scan asynchronously | |
future = executor.submit(self._process_scan_data, *scan_data) | |
future.add_done_callback(self._update_pose) | |
except queue.Empty: | |
# No new scan, continue | |
pass | |
except Exception as e: | |
print(f"Error in localization thread: {e}") | |
print("Localization thread stopped") | |
executor.shutdown() | |
def _process_scan_data(self, scan_points, imu_data): | |
""" | |
Process a single scan (called from the worker thread) | |
Args: | |
scan_points: 3D points from processed LiDAR | |
imu_data: (roll, pitch, yaw) from IMU | |
Returns: | |
Estimated pose and error | |
""" | |
start_time = time.time() | |
# Get current pose estimate for initialization | |
with self.pose_lock: | |
if not self.pose_history: | |
initial_guess = [0, 0, 1.0, 0, 0, 0] | |
else: | |
initial_guess = self.pose_history[-1] | |
# Use motion prediction if we have multiple poses | |
if len(self.pose_history) >= 2: | |
predicted_pose = self._predict_next_pose() | |
initial_guess = predicted_pose | |
# Use IMU data to update initial orientation | |
imu_roll, imu_pitch, imu_yaw = imu_data | |
initial_guess[3:6] = [imu_roll, imu_pitch, initial_guess[5]] # Keep estimated yaw | |
# Skip processing if too few points | |
if len(scan_points) < 10: | |
return initial_guess, float('inf') | |
# Downsample scan for real-time performance | |
if len(scan_points) > 500: | |
step = len(scan_points) // 500 | |
scan_points = scan_points[::step] | |
# Fast matching using current pose and nearby orientations | |
pose, error = self._fast_match(scan_points, initial_guess) | |
# Record timing stats | |
processing_time = time.time() - start_time | |
self.timing_stats['total_processing'].append(processing_time) | |
if len(self.timing_stats['total_processing']) > 100: | |
self.timing_stats['total_processing'].pop(0) | |
return pose, error | |
def _fast_match(self, scan_points, initial_guess): | |
""" | |
Fast matching algorithm optimized for real-time performance | |
Args: | |
scan_points: 3D points from LiDAR | |
initial_guess: Initial [x, y, z, roll, pitch, yaw] guess | |
Returns: | |
Estimated pose and error | |
""" | |
# Extract pose components | |
x, y, z, roll, pitch, yaw = initial_guess | |
# Convert scan to Open3D format | |
scan_pcd = o3d.geometry.PointCloud() | |
scan_pcd.points = o3d.utility.Vector3dVector(scan_points) | |
best_error = float('inf') | |
best_pose = initial_guess.copy() | |
# Find the closest precomputed orientation | |
target_orientation = (roll, pitch) | |
closest_orientation = None | |
min_orientation_diff = float('inf') | |
for orientation in self.reference_slices.keys(): | |
o_roll, o_pitch, _ = orientation | |
diff = np.sqrt((o_roll - roll)**2 + (o_pitch - pitch)**2) | |
if diff < min_orientation_diff: | |
min_orientation_diff = diff | |
closest_orientation = orientation | |
if closest_orientation is None: | |
return initial_guess, float('inf') | |
# Get slice points | |
slice_points = self.reference_slices[closest_orientation] | |
# Create transformation from initial guess | |
T_init = np.eye(4) | |
R_init = t3d.euler.euler2mat(roll, pitch, yaw, 'sxyz') | |
T_init[:3, :3] = R_init | |
T_init[:3, 3] = [x, y, z] | |
# Convert to Open3D format | |
slice_pcd = o3d.geometry.PointCloud() | |
slice_pcd.points = o3d.utility.Vector3dVector(slice_points) | |
# Fast ICP with limited iterations for real-time performance | |
result = o3d.pipelines.registration.registration_icp( | |
scan_pcd, slice_pcd, | |
max_correspondence_distance=0.2, | |
init=T_init, | |
estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(), | |
criteria=o3d.pipelines.registration.ICPConvergenceCriteria( | |
max_iteration=10, # Limit iterations for speed | |
relative_fitness=1e-6, | |
relative_rmse=1e-6 | |
) | |
) | |
# Extract result | |
T_result = result.transformation | |
error = result.inlier_rmse | |
# Convert to pose | |
translation = T_result[:3, 3] | |
rotation_matrix = T_result[:3, :3] | |
euler_angles = t3d.euler.mat2euler(rotation_matrix, 'sxyz') | |
pose = [translation[0], translation[1], translation[2], | |
euler_angles[0], euler_angles[1], euler_angles[2]] | |
return pose, error | |
def _update_pose(self, future): | |
"""Callback to update pose when processing is complete""" | |
try: | |
pose, error = future.result() | |
with self.pose_lock: | |
# Apply smoothing with previous poses | |
if len(self.pose_history) > 0: | |
smoothed_pose = self._smooth_pose(pose) | |
else: | |
smoothed_pose = pose | |
self.latest_pose = smoothed_pose | |
self.pose_history.append(smoothed_pose) | |
except Exception as e: | |
print(f"Error updating pose: {e}") | |
def _smooth_pose(self, new_pose): | |
""" | |
Apply smoothing to reduce jitter in pose estimates | |
Args: | |
new_pose: Latest pose estimate | |
Returns: | |
Smoothed pose | |
""" | |
# Simple exponential smoothing with previous pose | |
if not self.pose_history: | |
return new_pose | |
last_pose = self.pose_history[-1] | |
alpha = 0.3 # Smoothing factor (0.0-1.0, lower = more smoothing) | |
smoothed = [] | |
for i in range(6): | |
# Special handling for angles (taking shortest path) | |
if i >= 3: # roll, pitch, yaw | |
# Ensure angles are within [-pi, pi] | |
last_angle = last_pose[i] % (2 * np.pi) | |
if last_angle > np.pi: | |
last_angle -= 2 * np.pi | |
new_angle = new_pose[i] % (2 * np.pi) | |
if new_angle > np.pi: | |
new_angle -= 2 * np.pi | |
# Find shortest path | |
diff = new_angle - last_angle | |
if diff > np.pi: | |
diff -= 2 * np.pi | |
elif diff < -np.pi: | |
diff += 2 * np.pi | |
smoothed_val = last_angle + alpha * diff | |
else: | |
# Regular smoothing for position | |
smoothed_val = last_pose[i] + alpha * (new_pose[i] - last_pose[i]) | |
smoothed.append(smoothed_val) | |
return smoothed | |
def _predict_next_pose(self): | |
""" | |
Predict next pose based on velocity (simple motion model) | |
Returns: | |
Predicted next pose | |
""" | |
with self.pose_lock: | |
if len(self.pose_history) < 2: | |
return self.pose_history[-1] | |
last_pose = self.pose_history[-1] | |
prev_pose = self.pose_history[-2] | |
# Calculate velocity (simple difference) | |
velocity = [last_pose[i] - prev_pose[i] for i in range(6)] | |
# Predict next position | |
predicted = [last_pose[i] + velocity[i] for i in range(6)] | |
return predicted | |
def get_performance_stats(self): | |
"""Get performance statistics""" | |
stats = {} | |
if self.timing_stats['total_processing']: | |
avg_time = np.mean(self.timing_stats['total_processing']) | |
stats['avg_processing_time'] = avg_time | |
stats['estimated_fps'] = 1.0 / avg_time if avg_time > 0 else 0 | |
return stats | |
def visualize_current_pose(self, scan_points_3d=None): | |
""" | |
Visualize the current estimated pose | |
Args: | |
scan_points_3d: Recent scan points to include (optional) | |
""" | |
current_pose = self.get_current_pose() | |
if current_pose is None: | |
print("No pose estimate available") | |
return | |
x, y, z, roll, pitch, yaw = current_pose | |
# Find the closest slice to the current orientation | |
target_orientation = (roll, pitch) | |
closest_orientation = None | |
min_orientation_diff = float('inf') | |
for orientation in self.reference_slices.keys(): | |
o_roll, o_pitch, _ = orientation | |
diff = np.sqrt((o_roll - roll)**2 + (o_pitch - pitch)**2) | |
if diff < min_orientation_diff: | |
min_orientation_diff = diff | |
closest_orientation = orientation | |
if closest_orientation is None: | |
print("No matching slice found") | |
return | |
slice_points = self.reference_slices[closest_orientation] | |
# Create 3D transformation matrix | |
R = t3d.euler.euler2mat(roll, pitch, yaw, 'sxyz') | |
T = np.eye(4) | |
T[:3, :3] = R | |
T[:3, 3] = [x, y, z] | |
# Transform scan points if available | |
transformed_points = None | |
if scan_points_3d is not None: | |
scan_hom = np.column_stack((scan_points_3d, np.ones(len(scan_points_3d)))) | |
transformed_points = (scan_hom @ T.T)[:, :3] | |
# Create plot | |
plt.figure(figsize=(10, 8)) | |
ax = plt.subplot(111, projection='3d') | |
# Plot slice points | |
ax.scatter(slice_points[:, 0], slice_points[:, 1], slice_points[:, 2], | |
color='blue', s=1, alpha=0.5, label='CAD Model Slice') | |
# Plot transformed scan if available | |
if transformed_points is not None: | |
ax.scatter(transformed_points[:, 0], transformed_points[:, 1], transformed_points[:, 2], | |
color='red', s=1, label='LiDAR Scan') | |
# Plot sensor position | |
ax.scatter([x], [y], [z], color='green', s=100, marker='*', label='Sensor Position') | |
# Draw coordinate axes | |
axis_length = 0.5 | |
colors = ['g', 'b', 'r'] # x=green, y=blue, z=red | |
for i, color in enumerate(colors): | |
axis = np.zeros(3) | |
axis[i] = axis_length | |
rotated_axis = R @ axis | |
ax.quiver(x, y, z, rotated_axis[0], rotated_axis[1], rotated_axis[2], | |
color=color, linewidth=2) | |
# Set labels and title | |
ax.set_xlabel('X (m)') | |
ax.set_ylabel('Y (m)') | |
ax.set_zlabel('Z (m)') | |
# Get performance stats | |
stats = self.get_performance_stats() | |
fps_text = f"FPS: {stats.get('estimated_fps', 0):.1f}" if stats else "FPS: N/A" | |
ax.set_title(f'Real-time Localization\n' | |
f'Position: ({x:.2f}, {y:.2f}, {z:.2f})\n' | |
f'Orientation: ({np.degrees(roll):.1f}°, {np.degrees(pitch):.1f}°, {np.degrees(yaw):.1f}°)\n' | |
f'{fps_text}') | |
ax.legend() | |
plt.tight_layout() | |
plt.show() | |
# Example usage in a real-time application | |
if __name__ == "__main__": | |
import time | |
import math | |
# Path to your CAD model | |
cad_model_path = "room_model.stl" | |
# For demonstration, create a simple model if not available | |
try: | |
with open(cad_model_path, 'r') as f: | |
pass | |
except: | |
print("Creating demo room model...") | |
import trimesh | |
room = trimesh.creation.box(extents=[10, 8, 3]) | |
room.export(cad_model_path) | |
# Initialize the real-time localizer | |
localizer = RealtimeSensorLocalizer(cad_model_path, precompute_slices=True, num_slices=12) | |
# Start real-time localization | |
localizer.start_realtime_localization() | |
try: | |
# Simulate real-time data acquisition loop | |
for i in range(100): | |
# Simulate LiDAR data (a circle with some noise) | |
angles = np.linspace(0, 2*np.pi, 360, endpoint=False) | |
base_range = 3.0 + 0.5 * np.sin(angles * 4) # Room with wavy walls | |
noise = np.random.normal(0, 0.03, angles.shape) # Measurement noise | |
ranges = base_range + noise | |
# Simulate moving IMU orientation | |
t = i * 0.1 # time variable | |
imu_roll = 0.1 * np.sin(t * 0.5) # Small roll oscillation | |
imu_pitch = 0.05 * np.cos(t * 0.7) # Small pitch oscillation | |
imu_yaw = t * 0.1 # Slowly increasing yaw | |
# Process the scan | |
localizer.process_scan(ranges, angles, imu_roll, imu_pitch, imu_yaw) | |
# Print current position estimate periodically | |
if i % 10 == 0: | |
pose = localizer.get_current_pose() | |
if pose: | |
print(f"Pose at t={t:.1f}s: ({pose[0]:.2f}, {pose[1]:.2f}, {pose[2]:.2f}), " | |
f"R={np.degrees(pose[3]):.1f}°, P={np.degrees(pose[4]):.1f}°, Y={np.degrees(pose[5]):.1f}°") | |
# Show performance stats | |
stats = localizer.get_performance_stats() | |
if stats: | |
print(f"Processing time: {stats.get('avg_processing_time', 0)*1000:.1f}ms, " | |
f"Estimated FPS: {stats.get('estimated_fps', 0):.1f}") | |
# Simulate sensor data rate | |
time.sleep(0.03) # ~30Hz | |
# Visualize the final result | |
localizer.visualize_current_pose() | |
finally: | |
# Stop the localization thread | |
localizer.stop_realtime_localization() |
Author
yoi-hibino
commented
Mar 16, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment