Skip to content

Instantly share code, notes, and snippets.

@Owen-Liuyuxuan
Created December 2, 2025 02:55
Show Gist options
  • Select an option

  • Save Owen-Liuyuxuan/55da9fa9061522fc35841b345795fd3d to your computer and use it in GitHub Desktop.

Select an option

Save Owen-Liuyuxuan/55da9fa9061522fc35841b345795fd3d to your computer and use it in GitHub Desktop.
import pygame
import numpy as np
from dataclasses import dataclass
from typing import Tuple, List, Optional
# ==========================================
# PART 1: ALGORITHM & LOGIC (Math Core)
# ==========================================
@dataclass
class CorrectionBBParameters:
min_width: float
max_width: float
default_width: float
min_length: float
max_length: float
default_length: float
@dataclass
class Pose:
x: float
y: float
yaw: float
@dataclass
class Shape:
x: float
y: float
def get_rotation_matrix(yaw: float) -> np.ndarray:
c, s = np.cos(yaw), np.sin(yaw)
return np.array([[c, -s], [s, c]])
def apply_correction_vector_logic(correction_vec: np.ndarray, default_size: float) -> np.ndarray:
res = correction_vec.copy()
if np.isclose(res[0], 0.0): # Correct Y
current_y = res[1]
target_y = max(abs(current_y), default_size / 2.0)
sign = -1.0 if current_y < 0.0 else 1.0
res[1] = (target_y * sign) - current_y
elif np.isclose(res[1], 0.0): # Correct X
current_x = res[0]
target_x = max(abs(current_x), default_size / 2.0)
sign = -1.0 if current_x < 0.0 else 1.0
res[0] = (target_x * sign) - current_x
return res
def correct_with_default_value(param: CorrectionBBParameters, shape: Shape, pose: Pose):
"""
Returns: (success, new_shape, new_pose, case_id, sorted_indices)
"""
new_shape = Shape(shape.x, shape.y)
new_pose = Pose(pose.x, pose.y, pose.yaw)
case_id = 0
rot_mat = get_rotation_matrix(pose.yaw)
trans_vec = np.array([pose.x, pose.y])
# 1. Define Edge Centers (Box Frame): 0:Front, 1:Left, 2:Back, 3:Right
v_point = [
np.array([new_shape.x / 2.0, 0.0]),
np.array([0.0, new_shape.y / 2.0]),
np.array([-new_shape.x / 2.0, 0.0]),
np.array([0.0, -new_shape.y / 2.0])
]
# 2. Calculate Distances and Sort
point_distances = []
for i, p in enumerate(v_point):
global_p = (rot_mat @ p) + trans_vec
dist = np.linalg.norm(global_p)
point_distances.append((dist, i))
# Sort descending (Furthest first)
point_distances.sort(key=lambda x: x[0], reverse=True)
sorted_indices = [x[1] for x in point_distances]
most_distant_idx = point_distances[0][1]
second_distant_idx = point_distances[1][1]
third_distant_idx = point_distances[2][1]
# Distances representing the full dimension (x2)
first_dist = np.linalg.norm(v_point[most_distant_idx] * 2.0)
second_dist = np.linalg.norm(v_point[second_distant_idx] * 2.0)
third_dist = np.linalg.norm(v_point[third_distant_idx] * 2.0)
# 3. Boolean Checks
f_in_width = param.min_width < first_dist < param.max_width
f_in_len = param.min_length < first_dist < param.max_length
f_below_max_w = first_dist < param.max_width
s_in_width = param.min_width < second_dist < param.max_width
s_in_len = param.min_length < second_dist < param.max_length
s_below_max_w = second_dist < param.max_width
t_below_max_w = third_dist < param.max_width
t_below_max_l = third_dist < param.max_length
# 4. Logic Tree
correction_vector_local = np.zeros(2)
success = False
# Check if 1st and 2nd furthest are opposite edges (0 vs 2 OR 1 vs 3)
are_opposite = (abs(int(most_distant_idx) - int(second_distant_idx)) % 2 == 0)
if are_opposite:
if f_in_width and t_below_max_l:
case_id = 1
correction_vector_local = v_point[third_distant_idx].copy()
correction_vector_local = apply_correction_vector_logic(correction_vector_local, param.default_length)
success = True
elif f_in_len and t_below_max_w:
case_id = 2
correction_vector_local = v_point[third_distant_idx].copy()
correction_vector_local = apply_correction_vector_logic(correction_vector_local, param.default_width)
success = True
else:
# Adjacent edges
if f_in_width and s_in_width:
case_id = 3
correction_vector_local = v_point[most_distant_idx].copy()
correction_vector_local = apply_correction_vector_logic(correction_vector_local, param.default_length)
success = True
elif f_in_width:
case_id = 4
correction_vector_local = v_point[second_distant_idx].copy()
correction_vector_local = apply_correction_vector_logic(correction_vector_local, param.default_length)
success = True
elif s_in_width:
case_id = 5
correction_vector_local = v_point[most_distant_idx].copy()
correction_vector_local = apply_correction_vector_logic(correction_vector_local, param.default_length)
success = True
elif f_in_len and s_below_max_w:
case_id = 6
correction_vector_local = v_point[second_distant_idx].copy()
correction_vector_local = apply_correction_vector_logic(correction_vector_local, param.default_width)
success = True
elif s_in_len and f_below_max_w:
case_id = 7
correction_vector_local = v_point[most_distant_idx].copy()
correction_vector_local = apply_correction_vector_logic(correction_vector_local, param.default_width)
success = True
if success:
new_shape.x += abs(correction_vector_local[0]) * 2.0
new_shape.y += abs(correction_vector_local[1]) * 2.0
correction_global = rot_mat @ correction_vector_local
new_pose.x += correction_global[0]
new_pose.y += correction_global[1]
# Orientation Normalization (Visual only usually, but good for logic consistency)
if new_shape.x < new_shape.y:
new_pose.yaw += np.pi / 2.0
new_pose.yaw = (new_pose.yaw + np.pi) % (2 * np.pi) - np.pi
new_shape.x, new_shape.y = new_shape.y, new_shape.x
return success, new_shape, new_pose, case_id, sorted_indices
# ==========================================
# PART 2: PYGAME ENGINE & VISUALIZATION
# ==========================================
# Constants
SCREEN_WIDTH = 1200
SCREEN_HEIGHT = 800
BG_COLOR = (30, 30, 30)
GRID_COLOR = (50, 50, 50)
TEXT_COLOR = (220, 220, 220)
# Colors
COL_EGO = (0, 255, 0) # Green
COL_ORIG = (100, 100, 255) # Blue
COL_CORR = (255, 100, 100) # Red
COL_RANK1 = (255, 50, 50) # Red Dot
COL_RANK2 = (255, 165, 0) # Orange Dot
COL_RANK3 = (255, 255, 0) # Yellow Dot
COL_RANK4 = (200, 200, 200)# Gray Dot
# Scale
PIXELS_PER_METER = 50.0
CENTER_X = SCREEN_WIDTH // 2
CENTER_Y = SCREEN_HEIGHT // 2
def world_to_screen(wx, wy):
""" Converts World Coords (Meters, +Y up) to Screen Coords (Pixels, +Y down) """
sx = CENTER_X + (wx * PIXELS_PER_METER)
sy = CENTER_Y - (wy * PIXELS_PER_METER)
return int(sx), int(sy)
def draw_grid(surface):
# Draw vertical lines
for x in range(-20, 20, 5): # Every 5 meters
sx, _ = world_to_screen(x, 0)
pygame.draw.line(surface, GRID_COLOR, (sx, 0), (sx, SCREEN_HEIGHT))
# Draw horizontal lines
for y in range(-20, 20, 5):
_, sy = world_to_screen(0, y)
pygame.draw.line(surface, GRID_COLOR, (0, sy), (SCREEN_WIDTH, sy))
# Draw Axis
cx, cy = world_to_screen(0,0)
pygame.draw.line(surface, (100, 100, 100), (cx, 0), (cx, SCREEN_HEIGHT), 2)
pygame.draw.line(surface, (100, 100, 100), (0, cy), (SCREEN_WIDTH, cy), 2)
def draw_rotated_box(surface, pose: Pose, shape: Shape, color, width=2, is_dashed=False):
""" Calculates corners and draws a polygon """
cx, cy = world_to_screen(pose.x, pose.y)
# Half dimensions
hx = shape.x / 2.0
hy = shape.y / 2.0
# Corners in Box Frame
corners_local = [
np.array([hx, hy]),
np.array([-hx, hy]),
np.array([-hx, -hy]),
np.array([hx, -hy])
]
R = get_rotation_matrix(pose.yaw)
poly_points = []
for p in corners_local:
# Rotate and Translate to World
p_world = (R @ p) + np.array([pose.x, pose.y])
# To Screen
poly_points.append(world_to_screen(p_world[0], p_world[1]))
if is_dashed:
# Pygame doesn't support dashed lines natively, draw circles at corners and midpoints
for pt in poly_points:
pygame.draw.circle(surface, color, pt, 4)
pygame.draw.lines(surface, color, True, poly_points, width=1)
else:
pygame.draw.lines(surface, color, True, poly_points, width)
# Heading Arrow
head_world = (R @ np.array([hx, 0])) + np.array([pose.x, pose.y])
hsx, hsy = world_to_screen(head_world[0], head_world[1])
pygame.draw.line(surface, color, (cx, cy), (hsx, hsy), 2)
def draw_edge_ranks(surface, pose, shape, sorted_indices, font):
""" Draws colored dots on edges based on distance rank """
R = get_rotation_matrix(pose.yaw)
# Edge centers in Box Frame (0:Front, 1:Left, 2:Back, 3:Right)
v_point = [
np.array([shape.x / 2.0, 0.0]),
np.array([0.0, shape.y / 2.0]),
np.array([-shape.x / 2.0, 0.0]),
np.array([0.0, -shape.y / 2.0])
]
# Map: OriginalIndex -> Rank (0 to 3)
rank_map = {orig_idx: rank for rank, orig_idx in enumerate(sorted_indices)}
colors = [COL_RANK1, COL_RANK2, COL_RANK3, COL_RANK4]
for i, p in enumerate(v_point):
# Convert to World then Screen
p_world = (R @ p) + np.array([pose.x, pose.y])
sx, sy = world_to_screen(p_world[0], p_world[1])
rank = rank_map[i]
col = colors[rank]
# Draw Dot
pygame.draw.circle(surface, col, (sx, sy), 8)
pygame.draw.circle(surface, (0,0,0), (sx, sy), 9, 1) # Outline
# Draw Rank Number
txt = font.render(str(rank+1), True, (0,0,0))
surface.blit(txt, (sx-4, sy-8))
def main():
pygame.init()
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption("Shape Correction Simulator")
clock = pygame.time.Clock()
font = pygame.font.SysFont("Arial", 16, bold=True)
large_font = pygame.font.SysFont("Arial", 24, bold=True)
# --- System State ---
params = CorrectionBBParameters(
min_width=1.2, max_width=2.5, default_width=1.7,
min_length=3.0, max_length=5.8, default_length=4.4
)
# Initial Pose (Meters)
curr_shape = Shape(2.0, 2.0)
curr_pose = Pose(5.0, 5.0, np.radians(45))
mouse_dragging = False
running = True
while running:
# 1. Event Handling
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
elif event.type == pygame.MOUSEBUTTONDOWN:
if event.button == 1: mouse_dragging = True
elif event.type == pygame.MOUSEBUTTONUP:
if event.button == 1: mouse_dragging = False
# 2. Input Logic (Continuous)
keys = pygame.key.get_pressed()
# Shape Control
if keys[pygame.K_UP]: curr_shape.x += 0.05
if keys[pygame.K_DOWN]: curr_shape.x = max(0.3, curr_shape.x - 0.05)
if keys[pygame.K_RIGHT]: curr_shape.y += 0.05
if keys[pygame.K_LEFT]: curr_shape.y = max(0.3, curr_shape.y - 0.05)
# Rotation
if keys[pygame.K_q]: curr_pose.yaw += 0.05
if keys[pygame.K_e]: curr_pose.yaw -= 0.05
# Position (Mouse or WASD)
if mouse_dragging:
mx, my = pygame.mouse.get_pos()
# Convert Screen to World
wx = (mx - CENTER_X) / PIXELS_PER_METER
wy = -(my - CENTER_Y) / PIXELS_PER_METER
curr_pose.x, curr_pose.y = wx, wy
else:
if keys[pygame.K_w]: curr_pose.y += 0.1
if keys[pygame.K_s]: curr_pose.y -= 0.1
if keys[pygame.K_a]: curr_pose.x -= 0.1
if keys[pygame.K_d]: curr_pose.x += 0.1
# 3. RUN ALGORITHM
success, corr_shape, corr_pose, case_id, ranks = correct_with_default_value(params, curr_shape, curr_pose)
# Case Descriptions
case_descriptions = {
0: "No Correction Needed / Failure",
1: "Opposite: Fix Length (Third pt)",
2: "Opposite: Fix Width (Third pt)",
3: "Adj: Fix Length (Both fit Width)",
4: "Adj: Fix Length (1st fits Width)",
5: "Adj: Fix Length (2nd fits Width)",
6: "Adj: Fix Width (1st fits Length)",
7: "Adj: Fix Width (2nd fits Length)"
}
# 4. Rendering
screen.fill(BG_COLOR)
draw_grid(screen)
# Draw Ego
cx, cy = world_to_screen(0,0)
pygame.draw.circle(screen, COL_EGO, (cx, cy), 15)
pygame.draw.line(screen, COL_EGO, (cx-20, cy), (cx+20, cy), 2)
pygame.draw.line(screen, COL_EGO, (cx, cy-20), (cx, cy+20), 2)
# Draw Original
draw_rotated_box(screen, curr_pose, curr_shape, COL_ORIG, width=3)
draw_edge_ranks(screen, curr_pose, curr_shape, ranks, font)
# Draw Corrected (if valid)
if success:
draw_rotated_box(screen, corr_pose, corr_shape, COL_CORR, width=2, is_dashed=True)
# 5. UI Overlays
# Info Panel
ui_y = 10
lines = [
f"CASE {case_id}: {case_descriptions.get(case_id, 'Unknown')}",
f"Correction Applied: {'YES' if success else 'NO'}",
"",
"--- DETECTED (Blue) ---",
f"Dim X (Len): {curr_shape.x:.2f}m",
f"Dim Y (Wid): {curr_shape.y:.2f}m",
f"Dist from Ego: {np.linalg.norm([curr_pose.x, curr_pose.y]):.2f}m",
"",
"--- CORRECTED (Red) ---",
f"Dim X: {corr_shape.x:.2f}m" if success else "Dim X: -",
f"Dim Y: {corr_shape.y:.2f}m" if success else "Dim Y: -",
"",
"--- CONTROLS ---",
"Mouse Drag: Move Position",
"WASD: Fine Position",
"Q / E: Rotate",
"Arrows: Resize Box",
]
for line in lines:
col = COL_CORR if "CASE" in line and success else TEXT_COLOR
s_txt = font.render(line, True, col)
screen.blit(s_txt, (10, ui_y))
ui_y += 20
# Rank Legend
lx, ly = SCREEN_WIDTH - 150, SCREEN_HEIGHT - 120
legend_items = [("Rank 1 (Furthest)", COL_RANK1), ("Rank 2", COL_RANK2), ("Rank 3", COL_RANK3)]
for text, col in legend_items:
pygame.draw.circle(screen, col, (lx, ly), 6)
s_leg = font.render(text, True, TEXT_COLOR)
screen.blit(s_leg, (lx + 15, ly - 8))
ly += 25
pygame.display.flip()
clock.tick(60)
pygame.quit()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment