Skip to content

Instantly share code, notes, and snippets.

@altescy
Last active September 25, 2025 03:22
Show Gist options
  • Save altescy/15df45e5739894cc4cd9ae05cea78bb3 to your computer and use it in GitHub Desktop.
Save altescy/15df45e5739894cc4cd9ae05cea78bb3 to your computer and use it in GitHub Desktop.
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "numpy>=2.3.3",
# "pygame>=2.6.1",
# "pygame-widgets>=1.2.2",
# ]
# ///
import dataclasses
import math
from collections.abc import Sequence
from typing import NewType, Self
import numpy
import pygame
import pygame_widgets
from pygame_widgets.slider import Slider
@dataclasses.dataclass
class Vector2D:
x: float
y: float
def __add__(self, other: Self) -> Self:
return dataclasses.replace(self, x=self.x + other.x, y=self.y + other.y)
def __sub__(self, other: Self) -> Self:
return dataclasses.replace(self, x=self.x - other.x, y=self.y - other.y)
def __mul__(self, scalar: float) -> Self:
return dataclasses.replace(self, x=self.x * scalar, y=self.y * scalar)
def __rmul__(self, scalar: float) -> Self:
return dataclasses.replace(self, x=self.x * scalar, y=self.y * scalar)
def __truediv__(self, scalar: float) -> Self:
return dataclasses.replace(self, x=self.x / scalar, y=self.y / scalar)
def __rtruediv__(self, scalar: float) -> Self:
return dataclasses.replace(self, x=scalar / self.x, y=scalar / self.y)
def __floordiv__(self, scalar: float) -> Self:
return dataclasses.replace(self, x=self.x // scalar, y=self.y // scalar)
def __rfloordiv__(self, scalar: float) -> Self:
return dataclasses.replace(self, x=scalar // self.x, y=scalar // self.y)
def __neg__(self) -> Self:
return dataclasses.replace(self, x=-self.x, y=-self.y)
def __pos__(self) -> Self:
return dataclasses.replace(self, x=+self.x, y=+self.y)
def __abs__(self) -> float:
return (self.x**2 + self.y**2) ** 0.5
def __pow__(self, power: float) -> Self:
return dataclasses.replace(self, x=self.x**power, y=self.y**power)
def __rpow__(self, power: float) -> Self:
return dataclasses.replace(self, x=power**self.x, y=power**self.y)
def rotate(self, angle: float, center: "Vector2D | None" = None) -> Self:
if center is None:
center = Vector2D(0, 0)
cos_angle = math.cos(angle)
sin_angle = math.sin(angle)
translated_x = self.x - center.x
translated_y = self.y - center.y
rotated_x = translated_x * cos_angle - translated_y * sin_angle
rotated_y = translated_x * sin_angle + translated_y * cos_angle
return self.__class__(rotated_x + center.x, rotated_y + center.y)
Seconds = NewType("Seconds", float)
Radial = NewType("Radial", float)
Coordinate = NewType("Coordinate", float)
Coordinate2D = NewType("Coordinate2D", Vector2D)
AngularVelocity = NewType("AngularVelocity", float)
AngularAcceleration = NewType("AngularAcceleration", float)
Velocity = NewType("Velocity", float)
Velocity2D = NewType("Velocity2D", Vector2D)
Acceleration = NewType("Acceleration", float)
Acceleration2D = NewType("Acceleration2D", Vector2D)
class OmniSolver:
def _wheel_angular_velocity_transform_matrix(
self,
wheels: Sequence["Wheel"],
radius: Coordinate,
) -> numpy.ndarray:
return numpy.array(
[
[-math.sin(wheel.angle) / wheel.radius, math.cos(wheel.angle) / wheel.radius, radius / wheel.radius]
for wheel in wheels
]
)
def _wheel_velocity_transform_matrix(
self,
wheels: Sequence["Wheel"],
radius: Coordinate,
) -> numpy.ndarray:
return numpy.linalg.pinv(self._wheel_angular_velocity_transform_matrix(wheels, radius))
def compute_body_velocities(self, car: "Car") -> tuple[Velocity2D, AngularVelocity]:
wheel_velocities = numpy.array([wheel.velocity for wheel in car.wheels])
transform_matrix = self._wheel_velocity_transform_matrix(car.wheels, car.radius)
state = transform_matrix @ wheel_velocities
velocity = Velocity2D(Vector2D(x=state[0], y=state[1])).rotate(car.angle)
angular_velocity = AngularVelocity(state[2])
return velocity, angular_velocity
class OmniDrawer:
@staticmethod
def draw(car: "Car", *, on: pygame.Surface) -> None:
screen = on
pygame.draw.circle(
screen,
(0, 255, 0),
(int(car.position.x), int(car.position.y)),
int(car.radius),
2,
)
# draw velocity vector
velocity = car.velocity
pygame_draw_arrow(
screen,
color=(100, 100, 255),
start=Vector2D(x=car.position.x, y=car.position.y),
end=Vector2D(
x=car.position.x + velocity.x,
y=car.position.y + velocity.y,
),
)
# draw angular velocity vector
angular_velocity = car.angular_velocity
pygame_draw_rot_arrow(
screen,
color=(255, 100, 100),
center=Vector2D(x=car.position.x, y=car.position.y),
radius=8,
angle=Radial(angular_velocity),
arrowhead_length=4,
)
# draw wheel velocity vectors
for wheel in car.wheels:
wheel_position = Vector2D(
x=car.position.x + car.radius * math.cos(wheel.angle),
y=car.position.y + car.radius * math.sin(wheel.angle),
).rotate(car.angle, center=Vector2D(x=car.position.x, y=car.position.y))
wheel_velocity = Vector2D(
x=-wheel.velocity * wheel.radius * math.sin(wheel.angle),
y=wheel.velocity * wheel.radius * math.cos(wheel.angle),
).rotate(car.angle)
pygame_draw_arrow(
screen,
color=(255, 255, 0),
start=wheel_position,
end=wheel_position + wheel_velocity,
)
@dataclasses.dataclass
class Wheel:
angle: Radial
radius: Coordinate
velocity: AngularVelocity
@dataclasses.dataclass
class Car:
wheels: Sequence[Wheel]
radius: Coordinate
position: Coordinate2D
angle: Radial = Radial(0.0)
_solver: "OmniSolver" = dataclasses.field(default_factory=OmniSolver, init=False, repr=False)
_drawer: "OmniDrawer" = dataclasses.field(default_factory=OmniDrawer, init=False, repr=False)
def step(self, dt: Seconds) -> None:
self.position = self.position + Coordinate2D(dt * self.velocity)
self.angle = Radial(self.angle + self.angular_velocity * dt)
@property
def velocity(self) -> Velocity2D:
velocity, _ = self._solver.compute_body_velocities(self)
return velocity
@property
def angular_velocity(self) -> AngularVelocity:
_, angular_velocity = self._solver.compute_body_velocities(self)
return angular_velocity
def draw(self, screen: pygame.Surface):
self._drawer.draw(self, on=screen)
def pygame_draw_arrow(
screen: pygame.Surface,
color: tuple[int, int, int],
start: Vector2D,
end: Vector2D,
arrowhead_length: float = 8.0,
arrowhead_angle: Radial = Radial(math.pi / 6),
):
pygame.draw.aaline(
screen,
color,
(int(start.x), int(start.y)),
(int(end.x), int(end.y)),
2,
)
direction = end - start
length = abs(direction)
if length == 0:
return
direction = direction / length
left = Vector2D(
x=direction.x * math.cos(arrowhead_angle) - direction.y * math.sin(arrowhead_angle),
y=direction.x * math.sin(arrowhead_angle) + direction.y * math.cos(arrowhead_angle),
)
right = Vector2D(
x=direction.x * math.cos(-arrowhead_angle) - direction.y * math.sin(-arrowhead_angle),
y=direction.x * math.sin(-arrowhead_angle) + direction.y * math.cos(-arrowhead_angle),
)
left_end = end - left * arrowhead_length
right_end = end - right * arrowhead_length
pygame.draw.aaline(
screen,
color,
(int(end.x), int(end.y)),
(int(left_end.x), int(left_end.y)),
2,
)
pygame.draw.aaline(
screen,
color,
(int(end.x), int(end.y)),
(int(right_end.x), int(right_end.y)),
2,
)
def pygame_draw_rot_arrow(
screen: pygame.Surface,
color: tuple[int, int, int],
center: Vector2D,
radius: float,
angle: Radial,
arrowhead_length: float = 8.0,
arrowhead_angle: Radial = Radial(math.pi / 6),
) -> None:
arc_rect = (center.x - radius, center.y - radius, radius * 2, radius * 2)
end = radius * Vector2D(x=math.cos(angle), y=math.sin(angle)) + center
direction = Vector2D(x=-math.sin(angle), y=math.cos(angle))
left = Vector2D(
x=direction.x * math.cos(arrowhead_angle) - direction.y * math.sin(arrowhead_angle),
y=direction.x * math.sin(arrowhead_angle) + direction.y * math.cos(arrowhead_angle),
)
right = Vector2D(
x=direction.x * math.cos(-arrowhead_angle) - direction.y * math.sin(-arrowhead_angle),
y=direction.x * math.sin(-arrowhead_angle) + direction.y * math.cos(-arrowhead_angle),
)
left_end = end - left * arrowhead_length
right_end = end - right * arrowhead_length
pygame.draw.arc(
screen,
color,
start_angle=2 * math.pi - angle,
stop_angle=2 * math.pi,
rect=arc_rect,
)
pygame.draw.aaline(
screen,
color,
(int(end.x), int(end.y)),
(int(left_end.x), int(left_end.y)),
1,
)
pygame.draw.aaline(
screen,
color,
(int(end.x), int(end.y)),
(int(right_end.x), int(right_end.y)),
1,
)
def main():
pygame.init()
screen_width, screen_height = 800, 600
screen = pygame.display.set_mode((screen_width, screen_height))
pygame.display.set_caption("Omniwheel Robot Simulation")
clock = pygame.time.Clock()
car = Car(
wheels=[
Wheel(angle=Radial(0), radius=Coordinate(10), velocity=AngularVelocity(10)),
Wheel(angle=Radial(2 * math.pi / 3), radius=Coordinate(10), velocity=AngularVelocity(10)),
Wheel(angle=Radial(4 * math.pi / 3), radius=Coordinate(10), velocity=AngularVelocity(10)),
],
radius=Coordinate(30),
position=Coordinate2D(Vector2D(x=400, y=300)),
)
font = pygame.font.Font(None, 18)
dt_slider = Slider(
screen,
x=660,
y=500,
width=100,
height=8,
min=0.001,
max=0.03,
step=0.001,
initial=0.01,
)
wheel1_velocity_slider = Slider(
screen,
x=660,
y=520,
width=100,
height=8,
min=0,
max=40,
step=0.1,
initial=car.wheels[0].velocity + 20,
)
wheel2_velocity_slider = Slider(
screen,
x=660,
y=540,
width=100,
height=8,
min=0,
max=40,
step=0.1,
initial=car.wheels[1].velocity + 20,
)
wheel3_velocity_slider = Slider(
screen,
x=660,
y=560,
width=100,
height=8,
min=0,
max=40,
step=0.1,
initial=car.wheels[2].velocity + 20,
)
while True:
events = pygame.event.get()
for event in events:
if event.type == pygame.QUIT:
pygame.quit()
return
screen.fill((0, 0, 0))
dt = Seconds(dt_slider.getValue())
dt_label = font.render(f"dt: {dt:.3f}", True, (255, 255, 255), None)
screen.blit(dt_label, (590, 500))
wheel1_velocity = AngularVelocity(wheel1_velocity_slider.getValue() - 20)
wheel1_velocity_label = font.render(f"Wheel 1 Velocity: {wheel1_velocity:.1f}", True, (255, 255, 255), None)
screen.blit(wheel1_velocity_label, (510, 520))
car.wheels[0].velocity = wheel1_velocity
wheel2_velocity = AngularVelocity(wheel2_velocity_slider.getValue() - 20)
wheel2_velocity_label = font.render(f"Wheel 2 Velocity: {wheel2_velocity:.1f}", True, (255, 255, 255), None)
screen.blit(wheel2_velocity_label, (510, 540))
car.wheels[1].velocity = wheel2_velocity
wheel3_velocity = AngularVelocity(wheel3_velocity_slider.getValue() - 20)
wheel3_velocity_label = font.render(f"Wheel 3 Velocity: {wheel3_velocity:.1f}", True, (255, 255, 255), None)
screen.blit(wheel3_velocity_label, (510, 560))
car.wheels[2].velocity = wheel3_velocity
car.step(dt)
car.draw(screen)
pygame_widgets.update(events)
pygame.display.update()
clock.tick(60)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment