Last active
November 22, 2020 12:40
-
-
Save dermesser/1ca570d500092e3ccf45a68219ea2029 to your computer and use it in GitHub Desktop.
This event-based simulation simulates a 1-dimensional gas: https://borgac.net/~lbo/8particles_1000.svg
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Sat Nov 21 20:04:54 2020 | |
@author: lbo | |
""" | |
from recordtype import recordtype | |
from queue import PriorityQueue | |
import time | |
import numpy as np | |
import matplotlib.pyplot as plt | |
Config = recordtype('Config', ('len', 'positions', 'num', 'masses', 'v0', 'rad')) | |
State = recordtype('State', ('pos', 'spd', 'ncoll', 'time')) | |
class Calculate: | |
def __init__(self, cfg): | |
pos = cfg.positions | |
v0 = cfg.v0 | |
rad = cfg.rad | |
masses = cfg.masses | |
if cfg.positions is None and cfg.num is not None: | |
pos = [(2*i+1)*float(cfg.len)/(cfg.num*2) for i in range(0, cfg.num)] | |
else: | |
assert len(cfg.positions) == cfg.num | |
if type(cfg.v0) in [float, int] and cfg.num is not None: | |
v0 = [cfg.v0] * cfg.num | |
else: | |
assert type(cfg.v0) is list | |
if type(cfg.masses) in [float, int] and cfg.num is not None: | |
masses = [cfg.masses] * cfg.num | |
else: | |
assert len(cfg.masses) == cfg.num | |
if type(cfg.rad) in [float, int] and cfg.num is not None: | |
rad = [cfg.rad] * cfg.num | |
else: | |
assert len(cfg.rad) == cfg.num | |
# State and event queue. Event queue contains (timetohit, (ball1, ball2)) | |
self.state = State(pos=pos, spd=v0, time=0., ncoll=[0]*cfg.num) | |
self.events = PriorityQueue(0) | |
cfg.rad = rad | |
cfg.masses = masses | |
self.cfg = cfg | |
assert len(self.state.pos) == len(self.state.spd) | |
def timetohit(self, ballix1, ballix2): | |
print('calculating', ballix1, ballix2) | |
if ballix1 == ballix2: | |
return None | |
if abs(ballix1-ballix2) > 1: | |
return None | |
if ballix1 > ballix2: | |
ballix1, ballix2 = ballix2, ballix1 | |
# Handle borders | |
pos1, pos2 = 0, self.cfg.len | |
spd1, spd2 = 0, 0 | |
rad1, rad2 = 0, 0 | |
if ballix1 >= 0: | |
pos1 = self.state.pos[ballix1] | |
spd1 = self.state.spd[ballix1] | |
rad1 = self.cfg.rad[ballix1] | |
if ballix2 < len(self.state.pos): | |
pos2 = self.state.pos[ballix2] | |
spd2 = self.state.spd[ballix2] | |
rad2 = self.cfg.rad[ballix2] | |
reldist = pos2-pos1-rad1-rad2 | |
relspeed = spd2-spd1 | |
if relspeed >= 0: | |
return None | |
tth = -reldist/relspeed | |
if tth > 0: | |
print('TIMETOHIT: {} vs {} at {}, rad {} vs {}'.format(ballix1, ballix2, self.state.time+tth, rad1, rad2)) | |
return tth | |
def bounce(self, ballix1, ballix2): | |
if ballix1 > ballix2: | |
ballix1, ballix2 = ballix2, ballix1 | |
# Special case: wall. | |
if ballix1 < 0: | |
self.state.spd[ballix2] *= -1 | |
self.state.ncoll[ballix2] += 1 | |
print('BOUNCE: {} vs left wall'.format(ballix2)) | |
return | |
if ballix2 >= len(self.state.pos): | |
self.state.spd[ballix1] *= -1 | |
self.state.ncoll[ballix1] += 1 | |
print('BOUNCE: {} vs right wall'.format(ballix1)) | |
return | |
v1, v2 = self.state.spd[ballix1], self.state.spd[ballix2] | |
m1, m2 = self.cfg.masses[ballix1], self.cfg.masses[ballix2] | |
ctr = (m1*v1+m2*v2)/(m1+m2) | |
newv1 = 2 * ctr - v1 | |
newv2 = 2 * ctr - v2 | |
print('BOUNCE: old {} vs {}; new {} vs {}'.format(v1, v2, newv1, newv2)) | |
self.state.spd[ballix1] = newv1 | |
self.state.spd[ballix2] = newv2 | |
self.state.ncoll[ballix1] += 1 | |
self.state.ncoll[ballix2] += 1 | |
def predict(self, ballix): | |
if ballix < 0 or ballix >= len(self.state.pos): | |
return | |
tth = self.timetohit(ballix, ballix-1) | |
if tth is not None: | |
ncoll = self.state.ncoll[ballix] | |
if ballix > 0: | |
ncoll += self.state.ncoll[ballix-1] | |
self.events.put_nowait((self.state.time+tth, ballix-1, ballix, ncoll)) | |
if ballix == len(self.state.pos)-1: | |
tth = self.timetohit(ballix, len(self.state.pos)) | |
if tth is not None: | |
self.events.put_nowait((self.state.time+tth, ballix, len(self.state.pos), self.state.ncoll[ballix])) | |
def predict_all(self): | |
for i in range(0, len(self.state.pos)): | |
self.predict(i) | |
def update(self, ballix, dt): | |
self.state.pos[ballix] += dt*self.state.spd[ballix] | |
if not (self.state.pos[ballix] > -0.02*self.cfg.len and self.state.pos[ballix] < 1.02*self.cfg.len): | |
raise Exception('ERROR: ball {} out of bounds, t = {} at {} speed {}'.format(ballix, self.state.time, self.state.pos[ballix], self.state.spd[ballix])) | |
def update_all(self, dt): | |
for i in range(0, len(self.state.pos)): | |
self.update(i, dt) | |
def one_step(self): | |
self.predict_all() | |
if self.events.empty(): | |
return False | |
# Until valid bounce found | |
while True: | |
when, ballix1, ballix2, ncoll = self.events.get_nowait() | |
current_comb = (ballix1, ballix2) | |
dt = when - self.state.time | |
# Check validity | |
coll1 = self.state.ncoll[ballix1] if ballix1 >= 0 else 0 | |
coll2 = self.state.ncoll[ballix2] if ballix2 < len(self.state.pos) else 0 | |
#print(ballix1, ballix2, coll1, coll2, coll1+coll2, ncoll) | |
if coll1+coll2 != ncoll: | |
continue | |
print('EVENT: [{}] this event at {} (now+{}): {} vs {}'.format(self.state.time, when, when-self.state.time, ballix1, ballix2)) | |
self.update_all(dt) | |
self.bounce(ballix1, ballix2) | |
self.state.time = when | |
while not self.events.empty(): | |
tup = self.events.get_nowait() | |
when, ix1, ix2, nc = tup | |
if (ix1, ix2) != current_comb: | |
self.events.put_nowait(tup) | |
break | |
break | |
return True | |
def run(self, steps=1000): | |
times = np.zeros(steps) | |
speeds = np.zeros((steps, self.cfg.num)) | |
positions = np.zeros((steps, self.cfg.num)) | |
i = 0 | |
for i in range(steps): | |
speeds[i, :] = self.state.spd | |
positions[i, :] = self.state.pos | |
times[i] = self.state.time | |
i += 1 | |
if not self.one_step(): | |
break | |
return times[0:i], positions[0:i], speeds[0:i] | |
def plot_result(t, x, v, name=None): | |
fig = plt.figure(figsize=(10,20)) | |
plt.tight_layout() | |
xs = fig.add_subplot(111) | |
xs.plot(x, t) | |
name = name or '{}.svg'.format(time.time()) | |
fig.savefig(name, bbox_inches='tight') | |
def plot_phasespace(t, x, v, name=None): | |
fig = plt.figure(figsize=(10,10)) | |
plt.tight_layout() | |
xs = fig.add_subplot(111) | |
for i in range(x.shape[1]): | |
xs.plot(x[:, i], v[:, i], label=str(i)) | |
xs.legend() | |
name = name or '{}_phasespace.svg'.format(time.time()) | |
fig.savefig(name, bbox_inches='tight') | |
manydifferentballs = Config(len=10, num=8, positions=[1,1.5,2.1,3.2,3.9,4.5,6,9], rad=0.2, masses=1, v0=[0.5,-2/3,1/4,0.5,1/7,1.1,4/5, 1/4]) | |
threeballs = Config(len=10, rad=0.2, num=3, positions=[2,6,9], masses=1, v0=[-1,1,-1]) | |
tenballs = Config(len=10, positions=None, rad=0.1, num=10, masses=1, v0=[-1, 1] * 5) | |
sim = Calculate(tenballs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment