Created
December 2, 2025 02:55
-
-
Save Owen-Liuyuxuan/55da9fa9061522fc35841b345795fd3d to your computer and use it in GitHub Desktop.
Shape estimation informing the logics in https://github.com/autowarefoundation/autoware_universe/pull/11634
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import pygame | |
| import numpy as np | |
| from dataclasses import dataclass | |
| from typing import Tuple, List, Optional | |
| # ========================================== | |
| # PART 1: ALGORITHM & LOGIC (Math Core) | |
| # ========================================== | |
| @dataclass | |
| class CorrectionBBParameters: | |
| min_width: float | |
| max_width: float | |
| default_width: float | |
| min_length: float | |
| max_length: float | |
| default_length: float | |
| @dataclass | |
| class Pose: | |
| x: float | |
| y: float | |
| yaw: float | |
| @dataclass | |
| class Shape: | |
| x: float | |
| y: float | |
| def get_rotation_matrix(yaw: float) -> np.ndarray: | |
| c, s = np.cos(yaw), np.sin(yaw) | |
| return np.array([[c, -s], [s, c]]) | |
| def apply_correction_vector_logic(correction_vec: np.ndarray, default_size: float) -> np.ndarray: | |
| res = correction_vec.copy() | |
| if np.isclose(res[0], 0.0): # Correct Y | |
| current_y = res[1] | |
| target_y = max(abs(current_y), default_size / 2.0) | |
| sign = -1.0 if current_y < 0.0 else 1.0 | |
| res[1] = (target_y * sign) - current_y | |
| elif np.isclose(res[1], 0.0): # Correct X | |
| current_x = res[0] | |
| target_x = max(abs(current_x), default_size / 2.0) | |
| sign = -1.0 if current_x < 0.0 else 1.0 | |
| res[0] = (target_x * sign) - current_x | |
| return res | |
| def correct_with_default_value(param: CorrectionBBParameters, shape: Shape, pose: Pose): | |
| """ | |
| Returns: (success, new_shape, new_pose, case_id, sorted_indices) | |
| """ | |
| new_shape = Shape(shape.x, shape.y) | |
| new_pose = Pose(pose.x, pose.y, pose.yaw) | |
| case_id = 0 | |
| rot_mat = get_rotation_matrix(pose.yaw) | |
| trans_vec = np.array([pose.x, pose.y]) | |
| # 1. Define Edge Centers (Box Frame): 0:Front, 1:Left, 2:Back, 3:Right | |
| v_point = [ | |
| np.array([new_shape.x / 2.0, 0.0]), | |
| np.array([0.0, new_shape.y / 2.0]), | |
| np.array([-new_shape.x / 2.0, 0.0]), | |
| np.array([0.0, -new_shape.y / 2.0]) | |
| ] | |
| # 2. Calculate Distances and Sort | |
| point_distances = [] | |
| for i, p in enumerate(v_point): | |
| global_p = (rot_mat @ p) + trans_vec | |
| dist = np.linalg.norm(global_p) | |
| point_distances.append((dist, i)) | |
| # Sort descending (Furthest first) | |
| point_distances.sort(key=lambda x: x[0], reverse=True) | |
| sorted_indices = [x[1] for x in point_distances] | |
| most_distant_idx = point_distances[0][1] | |
| second_distant_idx = point_distances[1][1] | |
| third_distant_idx = point_distances[2][1] | |
| # Distances representing the full dimension (x2) | |
| first_dist = np.linalg.norm(v_point[most_distant_idx] * 2.0) | |
| second_dist = np.linalg.norm(v_point[second_distant_idx] * 2.0) | |
| third_dist = np.linalg.norm(v_point[third_distant_idx] * 2.0) | |
| # 3. Boolean Checks | |
| f_in_width = param.min_width < first_dist < param.max_width | |
| f_in_len = param.min_length < first_dist < param.max_length | |
| f_below_max_w = first_dist < param.max_width | |
| s_in_width = param.min_width < second_dist < param.max_width | |
| s_in_len = param.min_length < second_dist < param.max_length | |
| s_below_max_w = second_dist < param.max_width | |
| t_below_max_w = third_dist < param.max_width | |
| t_below_max_l = third_dist < param.max_length | |
| # 4. Logic Tree | |
| correction_vector_local = np.zeros(2) | |
| success = False | |
| # Check if 1st and 2nd furthest are opposite edges (0 vs 2 OR 1 vs 3) | |
| are_opposite = (abs(int(most_distant_idx) - int(second_distant_idx)) % 2 == 0) | |
| if are_opposite: | |
| if f_in_width and t_below_max_l: | |
| case_id = 1 | |
| correction_vector_local = v_point[third_distant_idx].copy() | |
| correction_vector_local = apply_correction_vector_logic(correction_vector_local, param.default_length) | |
| success = True | |
| elif f_in_len and t_below_max_w: | |
| case_id = 2 | |
| correction_vector_local = v_point[third_distant_idx].copy() | |
| correction_vector_local = apply_correction_vector_logic(correction_vector_local, param.default_width) | |
| success = True | |
| else: | |
| # Adjacent edges | |
| if f_in_width and s_in_width: | |
| case_id = 3 | |
| correction_vector_local = v_point[most_distant_idx].copy() | |
| correction_vector_local = apply_correction_vector_logic(correction_vector_local, param.default_length) | |
| success = True | |
| elif f_in_width: | |
| case_id = 4 | |
| correction_vector_local = v_point[second_distant_idx].copy() | |
| correction_vector_local = apply_correction_vector_logic(correction_vector_local, param.default_length) | |
| success = True | |
| elif s_in_width: | |
| case_id = 5 | |
| correction_vector_local = v_point[most_distant_idx].copy() | |
| correction_vector_local = apply_correction_vector_logic(correction_vector_local, param.default_length) | |
| success = True | |
| elif f_in_len and s_below_max_w: | |
| case_id = 6 | |
| correction_vector_local = v_point[second_distant_idx].copy() | |
| correction_vector_local = apply_correction_vector_logic(correction_vector_local, param.default_width) | |
| success = True | |
| elif s_in_len and f_below_max_w: | |
| case_id = 7 | |
| correction_vector_local = v_point[most_distant_idx].copy() | |
| correction_vector_local = apply_correction_vector_logic(correction_vector_local, param.default_width) | |
| success = True | |
| if success: | |
| new_shape.x += abs(correction_vector_local[0]) * 2.0 | |
| new_shape.y += abs(correction_vector_local[1]) * 2.0 | |
| correction_global = rot_mat @ correction_vector_local | |
| new_pose.x += correction_global[0] | |
| new_pose.y += correction_global[1] | |
| # Orientation Normalization (Visual only usually, but good for logic consistency) | |
| if new_shape.x < new_shape.y: | |
| new_pose.yaw += np.pi / 2.0 | |
| new_pose.yaw = (new_pose.yaw + np.pi) % (2 * np.pi) - np.pi | |
| new_shape.x, new_shape.y = new_shape.y, new_shape.x | |
| return success, new_shape, new_pose, case_id, sorted_indices | |
| # ========================================== | |
| # PART 2: PYGAME ENGINE & VISUALIZATION | |
| # ========================================== | |
| # Constants | |
| SCREEN_WIDTH = 1200 | |
| SCREEN_HEIGHT = 800 | |
| BG_COLOR = (30, 30, 30) | |
| GRID_COLOR = (50, 50, 50) | |
| TEXT_COLOR = (220, 220, 220) | |
| # Colors | |
| COL_EGO = (0, 255, 0) # Green | |
| COL_ORIG = (100, 100, 255) # Blue | |
| COL_CORR = (255, 100, 100) # Red | |
| COL_RANK1 = (255, 50, 50) # Red Dot | |
| COL_RANK2 = (255, 165, 0) # Orange Dot | |
| COL_RANK3 = (255, 255, 0) # Yellow Dot | |
| COL_RANK4 = (200, 200, 200)# Gray Dot | |
| # Scale | |
| PIXELS_PER_METER = 50.0 | |
| CENTER_X = SCREEN_WIDTH // 2 | |
| CENTER_Y = SCREEN_HEIGHT // 2 | |
| def world_to_screen(wx, wy): | |
| """ Converts World Coords (Meters, +Y up) to Screen Coords (Pixels, +Y down) """ | |
| sx = CENTER_X + (wx * PIXELS_PER_METER) | |
| sy = CENTER_Y - (wy * PIXELS_PER_METER) | |
| return int(sx), int(sy) | |
| def draw_grid(surface): | |
| # Draw vertical lines | |
| for x in range(-20, 20, 5): # Every 5 meters | |
| sx, _ = world_to_screen(x, 0) | |
| pygame.draw.line(surface, GRID_COLOR, (sx, 0), (sx, SCREEN_HEIGHT)) | |
| # Draw horizontal lines | |
| for y in range(-20, 20, 5): | |
| _, sy = world_to_screen(0, y) | |
| pygame.draw.line(surface, GRID_COLOR, (0, sy), (SCREEN_WIDTH, sy)) | |
| # Draw Axis | |
| cx, cy = world_to_screen(0,0) | |
| pygame.draw.line(surface, (100, 100, 100), (cx, 0), (cx, SCREEN_HEIGHT), 2) | |
| pygame.draw.line(surface, (100, 100, 100), (0, cy), (SCREEN_WIDTH, cy), 2) | |
| def draw_rotated_box(surface, pose: Pose, shape: Shape, color, width=2, is_dashed=False): | |
| """ Calculates corners and draws a polygon """ | |
| cx, cy = world_to_screen(pose.x, pose.y) | |
| # Half dimensions | |
| hx = shape.x / 2.0 | |
| hy = shape.y / 2.0 | |
| # Corners in Box Frame | |
| corners_local = [ | |
| np.array([hx, hy]), | |
| np.array([-hx, hy]), | |
| np.array([-hx, -hy]), | |
| np.array([hx, -hy]) | |
| ] | |
| R = get_rotation_matrix(pose.yaw) | |
| poly_points = [] | |
| for p in corners_local: | |
| # Rotate and Translate to World | |
| p_world = (R @ p) + np.array([pose.x, pose.y]) | |
| # To Screen | |
| poly_points.append(world_to_screen(p_world[0], p_world[1])) | |
| if is_dashed: | |
| # Pygame doesn't support dashed lines natively, draw circles at corners and midpoints | |
| for pt in poly_points: | |
| pygame.draw.circle(surface, color, pt, 4) | |
| pygame.draw.lines(surface, color, True, poly_points, width=1) | |
| else: | |
| pygame.draw.lines(surface, color, True, poly_points, width) | |
| # Heading Arrow | |
| head_world = (R @ np.array([hx, 0])) + np.array([pose.x, pose.y]) | |
| hsx, hsy = world_to_screen(head_world[0], head_world[1]) | |
| pygame.draw.line(surface, color, (cx, cy), (hsx, hsy), 2) | |
| def draw_edge_ranks(surface, pose, shape, sorted_indices, font): | |
| """ Draws colored dots on edges based on distance rank """ | |
| R = get_rotation_matrix(pose.yaw) | |
| # Edge centers in Box Frame (0:Front, 1:Left, 2:Back, 3:Right) | |
| v_point = [ | |
| np.array([shape.x / 2.0, 0.0]), | |
| np.array([0.0, shape.y / 2.0]), | |
| np.array([-shape.x / 2.0, 0.0]), | |
| np.array([0.0, -shape.y / 2.0]) | |
| ] | |
| # Map: OriginalIndex -> Rank (0 to 3) | |
| rank_map = {orig_idx: rank for rank, orig_idx in enumerate(sorted_indices)} | |
| colors = [COL_RANK1, COL_RANK2, COL_RANK3, COL_RANK4] | |
| for i, p in enumerate(v_point): | |
| # Convert to World then Screen | |
| p_world = (R @ p) + np.array([pose.x, pose.y]) | |
| sx, sy = world_to_screen(p_world[0], p_world[1]) | |
| rank = rank_map[i] | |
| col = colors[rank] | |
| # Draw Dot | |
| pygame.draw.circle(surface, col, (sx, sy), 8) | |
| pygame.draw.circle(surface, (0,0,0), (sx, sy), 9, 1) # Outline | |
| # Draw Rank Number | |
| txt = font.render(str(rank+1), True, (0,0,0)) | |
| surface.blit(txt, (sx-4, sy-8)) | |
| def main(): | |
| pygame.init() | |
| screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT)) | |
| pygame.display.set_caption("Shape Correction Simulator") | |
| clock = pygame.time.Clock() | |
| font = pygame.font.SysFont("Arial", 16, bold=True) | |
| large_font = pygame.font.SysFont("Arial", 24, bold=True) | |
| # --- System State --- | |
| params = CorrectionBBParameters( | |
| min_width=1.2, max_width=2.5, default_width=1.7, | |
| min_length=3.0, max_length=5.8, default_length=4.4 | |
| ) | |
| # Initial Pose (Meters) | |
| curr_shape = Shape(2.0, 2.0) | |
| curr_pose = Pose(5.0, 5.0, np.radians(45)) | |
| mouse_dragging = False | |
| running = True | |
| while running: | |
| # 1. Event Handling | |
| for event in pygame.event.get(): | |
| if event.type == pygame.QUIT: | |
| running = False | |
| elif event.type == pygame.MOUSEBUTTONDOWN: | |
| if event.button == 1: mouse_dragging = True | |
| elif event.type == pygame.MOUSEBUTTONUP: | |
| if event.button == 1: mouse_dragging = False | |
| # 2. Input Logic (Continuous) | |
| keys = pygame.key.get_pressed() | |
| # Shape Control | |
| if keys[pygame.K_UP]: curr_shape.x += 0.05 | |
| if keys[pygame.K_DOWN]: curr_shape.x = max(0.3, curr_shape.x - 0.05) | |
| if keys[pygame.K_RIGHT]: curr_shape.y += 0.05 | |
| if keys[pygame.K_LEFT]: curr_shape.y = max(0.3, curr_shape.y - 0.05) | |
| # Rotation | |
| if keys[pygame.K_q]: curr_pose.yaw += 0.05 | |
| if keys[pygame.K_e]: curr_pose.yaw -= 0.05 | |
| # Position (Mouse or WASD) | |
| if mouse_dragging: | |
| mx, my = pygame.mouse.get_pos() | |
| # Convert Screen to World | |
| wx = (mx - CENTER_X) / PIXELS_PER_METER | |
| wy = -(my - CENTER_Y) / PIXELS_PER_METER | |
| curr_pose.x, curr_pose.y = wx, wy | |
| else: | |
| if keys[pygame.K_w]: curr_pose.y += 0.1 | |
| if keys[pygame.K_s]: curr_pose.y -= 0.1 | |
| if keys[pygame.K_a]: curr_pose.x -= 0.1 | |
| if keys[pygame.K_d]: curr_pose.x += 0.1 | |
| # 3. RUN ALGORITHM | |
| success, corr_shape, corr_pose, case_id, ranks = correct_with_default_value(params, curr_shape, curr_pose) | |
| # Case Descriptions | |
| case_descriptions = { | |
| 0: "No Correction Needed / Failure", | |
| 1: "Opposite: Fix Length (Third pt)", | |
| 2: "Opposite: Fix Width (Third pt)", | |
| 3: "Adj: Fix Length (Both fit Width)", | |
| 4: "Adj: Fix Length (1st fits Width)", | |
| 5: "Adj: Fix Length (2nd fits Width)", | |
| 6: "Adj: Fix Width (1st fits Length)", | |
| 7: "Adj: Fix Width (2nd fits Length)" | |
| } | |
| # 4. Rendering | |
| screen.fill(BG_COLOR) | |
| draw_grid(screen) | |
| # Draw Ego | |
| cx, cy = world_to_screen(0,0) | |
| pygame.draw.circle(screen, COL_EGO, (cx, cy), 15) | |
| pygame.draw.line(screen, COL_EGO, (cx-20, cy), (cx+20, cy), 2) | |
| pygame.draw.line(screen, COL_EGO, (cx, cy-20), (cx, cy+20), 2) | |
| # Draw Original | |
| draw_rotated_box(screen, curr_pose, curr_shape, COL_ORIG, width=3) | |
| draw_edge_ranks(screen, curr_pose, curr_shape, ranks, font) | |
| # Draw Corrected (if valid) | |
| if success: | |
| draw_rotated_box(screen, corr_pose, corr_shape, COL_CORR, width=2, is_dashed=True) | |
| # 5. UI Overlays | |
| # Info Panel | |
| ui_y = 10 | |
| lines = [ | |
| f"CASE {case_id}: {case_descriptions.get(case_id, 'Unknown')}", | |
| f"Correction Applied: {'YES' if success else 'NO'}", | |
| "", | |
| "--- DETECTED (Blue) ---", | |
| f"Dim X (Len): {curr_shape.x:.2f}m", | |
| f"Dim Y (Wid): {curr_shape.y:.2f}m", | |
| f"Dist from Ego: {np.linalg.norm([curr_pose.x, curr_pose.y]):.2f}m", | |
| "", | |
| "--- CORRECTED (Red) ---", | |
| f"Dim X: {corr_shape.x:.2f}m" if success else "Dim X: -", | |
| f"Dim Y: {corr_shape.y:.2f}m" if success else "Dim Y: -", | |
| "", | |
| "--- CONTROLS ---", | |
| "Mouse Drag: Move Position", | |
| "WASD: Fine Position", | |
| "Q / E: Rotate", | |
| "Arrows: Resize Box", | |
| ] | |
| for line in lines: | |
| col = COL_CORR if "CASE" in line and success else TEXT_COLOR | |
| s_txt = font.render(line, True, col) | |
| screen.blit(s_txt, (10, ui_y)) | |
| ui_y += 20 | |
| # Rank Legend | |
| lx, ly = SCREEN_WIDTH - 150, SCREEN_HEIGHT - 120 | |
| legend_items = [("Rank 1 (Furthest)", COL_RANK1), ("Rank 2", COL_RANK2), ("Rank 3", COL_RANK3)] | |
| for text, col in legend_items: | |
| pygame.draw.circle(screen, col, (lx, ly), 6) | |
| s_leg = font.render(text, True, TEXT_COLOR) | |
| screen.blit(s_leg, (lx + 15, ly - 8)) | |
| ly += 25 | |
| pygame.display.flip() | |
| clock.tick(60) | |
| pygame.quit() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment