Last active
September 25, 2025 03:22
-
-
Save altescy/15df45e5739894cc4cd9ae05cea78bb3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# requires-python = ">=3.12" | |
# dependencies = [ | |
# "numpy>=2.3.3", | |
# "pygame>=2.6.1", | |
# "pygame-widgets>=1.2.2", | |
# ] | |
# /// | |
import dataclasses | |
import math | |
from collections.abc import Sequence | |
from typing import NewType, Self | |
import numpy | |
import pygame | |
import pygame_widgets | |
from pygame_widgets.slider import Slider | |
@dataclasses.dataclass | |
class Vector2D: | |
x: float | |
y: float | |
def __add__(self, other: Self) -> Self: | |
return dataclasses.replace(self, x=self.x + other.x, y=self.y + other.y) | |
def __sub__(self, other: Self) -> Self: | |
return dataclasses.replace(self, x=self.x - other.x, y=self.y - other.y) | |
def __mul__(self, scalar: float) -> Self: | |
return dataclasses.replace(self, x=self.x * scalar, y=self.y * scalar) | |
def __rmul__(self, scalar: float) -> Self: | |
return dataclasses.replace(self, x=self.x * scalar, y=self.y * scalar) | |
def __truediv__(self, scalar: float) -> Self: | |
return dataclasses.replace(self, x=self.x / scalar, y=self.y / scalar) | |
def __rtruediv__(self, scalar: float) -> Self: | |
return dataclasses.replace(self, x=scalar / self.x, y=scalar / self.y) | |
def __floordiv__(self, scalar: float) -> Self: | |
return dataclasses.replace(self, x=self.x // scalar, y=self.y // scalar) | |
def __rfloordiv__(self, scalar: float) -> Self: | |
return dataclasses.replace(self, x=scalar // self.x, y=scalar // self.y) | |
def __neg__(self) -> Self: | |
return dataclasses.replace(self, x=-self.x, y=-self.y) | |
def __pos__(self) -> Self: | |
return dataclasses.replace(self, x=+self.x, y=+self.y) | |
def __abs__(self) -> float: | |
return (self.x**2 + self.y**2) ** 0.5 | |
def __pow__(self, power: float) -> Self: | |
return dataclasses.replace(self, x=self.x**power, y=self.y**power) | |
def __rpow__(self, power: float) -> Self: | |
return dataclasses.replace(self, x=power**self.x, y=power**self.y) | |
def rotate(self, angle: float, center: "Vector2D | None" = None) -> Self: | |
if center is None: | |
center = Vector2D(0, 0) | |
cos_angle = math.cos(angle) | |
sin_angle = math.sin(angle) | |
translated_x = self.x - center.x | |
translated_y = self.y - center.y | |
rotated_x = translated_x * cos_angle - translated_y * sin_angle | |
rotated_y = translated_x * sin_angle + translated_y * cos_angle | |
return self.__class__(rotated_x + center.x, rotated_y + center.y) | |
Seconds = NewType("Seconds", float) | |
Radial = NewType("Radial", float) | |
Coordinate = NewType("Coordinate", float) | |
Coordinate2D = NewType("Coordinate2D", Vector2D) | |
AngularVelocity = NewType("AngularVelocity", float) | |
AngularAcceleration = NewType("AngularAcceleration", float) | |
Velocity = NewType("Velocity", float) | |
Velocity2D = NewType("Velocity2D", Vector2D) | |
Acceleration = NewType("Acceleration", float) | |
Acceleration2D = NewType("Acceleration2D", Vector2D) | |
class OmniSolver: | |
def _wheel_angular_velocity_transform_matrix( | |
self, | |
wheels: Sequence["Wheel"], | |
radius: Coordinate, | |
) -> numpy.ndarray: | |
return numpy.array( | |
[ | |
[-math.sin(wheel.angle) / wheel.radius, math.cos(wheel.angle) / wheel.radius, radius / wheel.radius] | |
for wheel in wheels | |
] | |
) | |
def _wheel_velocity_transform_matrix( | |
self, | |
wheels: Sequence["Wheel"], | |
radius: Coordinate, | |
) -> numpy.ndarray: | |
return numpy.linalg.pinv(self._wheel_angular_velocity_transform_matrix(wheels, radius)) | |
def compute_body_velocities(self, car: "Car") -> tuple[Velocity2D, AngularVelocity]: | |
wheel_velocities = numpy.array([wheel.velocity for wheel in car.wheels]) | |
transform_matrix = self._wheel_velocity_transform_matrix(car.wheels, car.radius) | |
state = transform_matrix @ wheel_velocities | |
velocity = Velocity2D(Vector2D(x=state[0], y=state[1])).rotate(car.angle) | |
angular_velocity = AngularVelocity(state[2]) | |
return velocity, angular_velocity | |
class OmniDrawer: | |
@staticmethod | |
def draw(car: "Car", *, on: pygame.Surface) -> None: | |
screen = on | |
pygame.draw.circle( | |
screen, | |
(0, 255, 0), | |
(int(car.position.x), int(car.position.y)), | |
int(car.radius), | |
2, | |
) | |
# draw velocity vector | |
velocity = car.velocity | |
pygame_draw_arrow( | |
screen, | |
color=(100, 100, 255), | |
start=Vector2D(x=car.position.x, y=car.position.y), | |
end=Vector2D( | |
x=car.position.x + velocity.x, | |
y=car.position.y + velocity.y, | |
), | |
) | |
# draw angular velocity vector | |
angular_velocity = car.angular_velocity | |
pygame_draw_rot_arrow( | |
screen, | |
color=(255, 100, 100), | |
center=Vector2D(x=car.position.x, y=car.position.y), | |
radius=8, | |
angle=Radial(angular_velocity), | |
arrowhead_length=4, | |
) | |
# draw wheel velocity vectors | |
for wheel in car.wheels: | |
wheel_position = Vector2D( | |
x=car.position.x + car.radius * math.cos(wheel.angle), | |
y=car.position.y + car.radius * math.sin(wheel.angle), | |
).rotate(car.angle, center=Vector2D(x=car.position.x, y=car.position.y)) | |
wheel_velocity = Vector2D( | |
x=-wheel.velocity * wheel.radius * math.sin(wheel.angle), | |
y=wheel.velocity * wheel.radius * math.cos(wheel.angle), | |
).rotate(car.angle) | |
pygame_draw_arrow( | |
screen, | |
color=(255, 255, 0), | |
start=wheel_position, | |
end=wheel_position + wheel_velocity, | |
) | |
@dataclasses.dataclass | |
class Wheel: | |
angle: Radial | |
radius: Coordinate | |
velocity: AngularVelocity | |
@dataclasses.dataclass | |
class Car: | |
wheels: Sequence[Wheel] | |
radius: Coordinate | |
position: Coordinate2D | |
angle: Radial = Radial(0.0) | |
_solver: "OmniSolver" = dataclasses.field(default_factory=OmniSolver, init=False, repr=False) | |
_drawer: "OmniDrawer" = dataclasses.field(default_factory=OmniDrawer, init=False, repr=False) | |
def step(self, dt: Seconds) -> None: | |
self.position = self.position + Coordinate2D(dt * self.velocity) | |
self.angle = Radial(self.angle + self.angular_velocity * dt) | |
@property | |
def velocity(self) -> Velocity2D: | |
velocity, _ = self._solver.compute_body_velocities(self) | |
return velocity | |
@property | |
def angular_velocity(self) -> AngularVelocity: | |
_, angular_velocity = self._solver.compute_body_velocities(self) | |
return angular_velocity | |
def draw(self, screen: pygame.Surface): | |
self._drawer.draw(self, on=screen) | |
def pygame_draw_arrow( | |
screen: pygame.Surface, | |
color: tuple[int, int, int], | |
start: Vector2D, | |
end: Vector2D, | |
arrowhead_length: float = 8.0, | |
arrowhead_angle: Radial = Radial(math.pi / 6), | |
): | |
pygame.draw.aaline( | |
screen, | |
color, | |
(int(start.x), int(start.y)), | |
(int(end.x), int(end.y)), | |
2, | |
) | |
direction = end - start | |
length = abs(direction) | |
if length == 0: | |
return | |
direction = direction / length | |
left = Vector2D( | |
x=direction.x * math.cos(arrowhead_angle) - direction.y * math.sin(arrowhead_angle), | |
y=direction.x * math.sin(arrowhead_angle) + direction.y * math.cos(arrowhead_angle), | |
) | |
right = Vector2D( | |
x=direction.x * math.cos(-arrowhead_angle) - direction.y * math.sin(-arrowhead_angle), | |
y=direction.x * math.sin(-arrowhead_angle) + direction.y * math.cos(-arrowhead_angle), | |
) | |
left_end = end - left * arrowhead_length | |
right_end = end - right * arrowhead_length | |
pygame.draw.aaline( | |
screen, | |
color, | |
(int(end.x), int(end.y)), | |
(int(left_end.x), int(left_end.y)), | |
2, | |
) | |
pygame.draw.aaline( | |
screen, | |
color, | |
(int(end.x), int(end.y)), | |
(int(right_end.x), int(right_end.y)), | |
2, | |
) | |
def pygame_draw_rot_arrow( | |
screen: pygame.Surface, | |
color: tuple[int, int, int], | |
center: Vector2D, | |
radius: float, | |
angle: Radial, | |
arrowhead_length: float = 8.0, | |
arrowhead_angle: Radial = Radial(math.pi / 6), | |
) -> None: | |
arc_rect = (center.x - radius, center.y - radius, radius * 2, radius * 2) | |
end = radius * Vector2D(x=math.cos(angle), y=math.sin(angle)) + center | |
direction = Vector2D(x=-math.sin(angle), y=math.cos(angle)) | |
left = Vector2D( | |
x=direction.x * math.cos(arrowhead_angle) - direction.y * math.sin(arrowhead_angle), | |
y=direction.x * math.sin(arrowhead_angle) + direction.y * math.cos(arrowhead_angle), | |
) | |
right = Vector2D( | |
x=direction.x * math.cos(-arrowhead_angle) - direction.y * math.sin(-arrowhead_angle), | |
y=direction.x * math.sin(-arrowhead_angle) + direction.y * math.cos(-arrowhead_angle), | |
) | |
left_end = end - left * arrowhead_length | |
right_end = end - right * arrowhead_length | |
pygame.draw.arc( | |
screen, | |
color, | |
start_angle=2 * math.pi - angle, | |
stop_angle=2 * math.pi, | |
rect=arc_rect, | |
) | |
pygame.draw.aaline( | |
screen, | |
color, | |
(int(end.x), int(end.y)), | |
(int(left_end.x), int(left_end.y)), | |
1, | |
) | |
pygame.draw.aaline( | |
screen, | |
color, | |
(int(end.x), int(end.y)), | |
(int(right_end.x), int(right_end.y)), | |
1, | |
) | |
def main(): | |
pygame.init() | |
screen_width, screen_height = 800, 600 | |
screen = pygame.display.set_mode((screen_width, screen_height)) | |
pygame.display.set_caption("Omniwheel Robot Simulation") | |
clock = pygame.time.Clock() | |
car = Car( | |
wheels=[ | |
Wheel(angle=Radial(0), radius=Coordinate(10), velocity=AngularVelocity(10)), | |
Wheel(angle=Radial(2 * math.pi / 3), radius=Coordinate(10), velocity=AngularVelocity(10)), | |
Wheel(angle=Radial(4 * math.pi / 3), radius=Coordinate(10), velocity=AngularVelocity(10)), | |
], | |
radius=Coordinate(30), | |
position=Coordinate2D(Vector2D(x=400, y=300)), | |
) | |
font = pygame.font.Font(None, 18) | |
dt_slider = Slider( | |
screen, | |
x=660, | |
y=500, | |
width=100, | |
height=8, | |
min=0.001, | |
max=0.03, | |
step=0.001, | |
initial=0.01, | |
) | |
wheel1_velocity_slider = Slider( | |
screen, | |
x=660, | |
y=520, | |
width=100, | |
height=8, | |
min=0, | |
max=40, | |
step=0.1, | |
initial=car.wheels[0].velocity + 20, | |
) | |
wheel2_velocity_slider = Slider( | |
screen, | |
x=660, | |
y=540, | |
width=100, | |
height=8, | |
min=0, | |
max=40, | |
step=0.1, | |
initial=car.wheels[1].velocity + 20, | |
) | |
wheel3_velocity_slider = Slider( | |
screen, | |
x=660, | |
y=560, | |
width=100, | |
height=8, | |
min=0, | |
max=40, | |
step=0.1, | |
initial=car.wheels[2].velocity + 20, | |
) | |
while True: | |
events = pygame.event.get() | |
for event in events: | |
if event.type == pygame.QUIT: | |
pygame.quit() | |
return | |
screen.fill((0, 0, 0)) | |
dt = Seconds(dt_slider.getValue()) | |
dt_label = font.render(f"dt: {dt:.3f}", True, (255, 255, 255), None) | |
screen.blit(dt_label, (590, 500)) | |
wheel1_velocity = AngularVelocity(wheel1_velocity_slider.getValue() - 20) | |
wheel1_velocity_label = font.render(f"Wheel 1 Velocity: {wheel1_velocity:.1f}", True, (255, 255, 255), None) | |
screen.blit(wheel1_velocity_label, (510, 520)) | |
car.wheels[0].velocity = wheel1_velocity | |
wheel2_velocity = AngularVelocity(wheel2_velocity_slider.getValue() - 20) | |
wheel2_velocity_label = font.render(f"Wheel 2 Velocity: {wheel2_velocity:.1f}", True, (255, 255, 255), None) | |
screen.blit(wheel2_velocity_label, (510, 540)) | |
car.wheels[1].velocity = wheel2_velocity | |
wheel3_velocity = AngularVelocity(wheel3_velocity_slider.getValue() - 20) | |
wheel3_velocity_label = font.render(f"Wheel 3 Velocity: {wheel3_velocity:.1f}", True, (255, 255, 255), None) | |
screen.blit(wheel3_velocity_label, (510, 560)) | |
car.wheels[2].velocity = wheel3_velocity | |
car.step(dt) | |
car.draw(screen) | |
pygame_widgets.update(events) | |
pygame.display.update() | |
clock.tick(60) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment