Last active
February 13, 2024 00:33
-
-
Save marl0ny/5059782ce8e9bace8971e8fab2a54279 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script numerically solves the linear and nonlinear Schrodinger equation | |
using the split operator method, and then shows a matplotlib animation of the | |
results. | |
References: | |
Split operator method: | |
James Schloss. The Split Operator Method - Arcane Algorithm Archive. | |
https://www.algorithm-archive.org/contents/split-operator_method/ | |
split-operator_method.html | |
Nonlinear Schrodinger equation: | |
Xavier Antoine, Weizhu Bao, Christophe Besse. | |
Computational methods for the dynamics of the nonlinear | |
Schrodinger/Gross-Pitaevskii equations. | |
https://arxiv.org/abs/1305.1093 | |
""" | |
import numpy as np | |
# from scipy.fft import dstn, idstn | |
from visualization import animate | |
NX, NY = 1024, 1024 # Grid dimensions | |
LX, LY = 1024.0, 1024.0 # Spatial dimensions | |
X, Y = np.meshgrid(np.linspace(0.0, LX*(1.0 - 1.0/NX), NX), | |
np.linspace(0.0, LY*(1.0 - 1.0/NY), NY)) | |
DX, DY = X[0, 1] - X[0, 0], Y[1, 0] - Y[0, 0] | |
# Energies for a free particle with periodic boundary conditions | |
E = sum(np.meshgrid(2.0*(np.pi*NX*np.fft.fftfreq(NX)/LX)**2, | |
2.0*(np.pi*NY*np.fft.fftfreq(NY)/LY)**2)) | |
# NL_TIME = 100000.0 | |
NL_TIME = 0.0 | |
def nonlinear(psi, t): | |
# | |
if np.abs(t) > NL_TIME: | |
return 4000.0*psi*np.conj(psi) | |
return 0.0*psi*np.conj(psi) | |
def u(psi, t): | |
"""Free space with periodic boundary condition propagator for psi""" | |
return np.fft.ifftn(np.exp(-1.0j*E*t)*np.fft.fftn(psi)) | |
def normalize(psi): | |
return psi/np.sqrt(DX*DY*np.sum(np.abs(psi)**2)) | |
def step(psi, t, dt, phi, should_normalize=False): | |
"""Advance psi by a single time step dt""" | |
psi1 = np.exp(-1.0j*(phi + nonlinear(psi, t))*dt/2.0)*psi | |
psi2 = u(psi1, dt) | |
psi3 = np.exp(-1.0j*(phi + nonlinear(psi2, t))*dt/2.0)*psi2 | |
return normalize(psi3) if should_normalize else psi3 | |
def make_heart_potential(height, size, edge_sharpness, x_off, y_off): | |
r = 5.0*np.sqrt((X/LX - x_off)**2 + (Y/LY - y_off)**2) | |
angle = np.angle(1.0j*(X/LX - x_off) + (Y/LY - y_off)) | |
angle = np.where(angle < 0.0, angle + 2.0*np.pi, angle) | |
s = size*(np.abs(np.sin(angle)) + 2.0*np.exp(-1.2*np.abs(angle-np.pi))) | |
return height*(np.tanh(edge_sharpness*(r - s))/2.0 + 0.5) | |
nx, ny = 20.0, 20.0 | |
sigma_x, sigma_y = 0.07, 0.07 | |
r0x, r0y = 0.35, 0.5 | |
# psi0 = normalize(np.exp(-0.5*((X/LX - r0x)/sigma_x)**2 | |
# -0.5*((Y/LY - r0y)/sigma_y)**2 | |
# )*np.exp(2.0j*np.pi*(nx*X/LX + ny*Y/LY))) | |
psi0 = normalize(np.ones([NX, NY])*np.exp(2.0j*np.pi*np.random.rand(NX, NY))) | |
animate(wave_function=psi0, x=X, y=Y, | |
potential=make_heart_potential(0.5, 1.8, 8.0, 0.5, 0.75), | |
steps_per_frame=1, | |
normalize_after_each_step=True, | |
step_function=step, t=0.0, | |
# dt=1.0, | |
# dt=10.0*(1.0 - 0.2j) | |
dt_func=lambda t: 10.0*(1.0 - 0.2j), | |
# dt_func=lambda t: 5.0 if np.abs(t) < NL_TIME | |
# else 5.0*np.exp(-1.0j*np.pi*0.1), | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.animation as animation | |
data = {} | |
def init_func(): | |
x, y = data['x'], data['y'] | |
psi = np.flip(data['wave_function'], axis=0) | |
fig = data['fig'] | |
ax = fig.add_subplot(1, 1, 1) | |
data['ax'] = ax | |
im = ax.imshow(np.angle(psi), | |
alpha=np.abs(psi)**2/np.amax(np.abs(psi)**2), | |
extent=(x[0, 1], x[0, -1], y[0, 0], y[-1, 0]), | |
interpolation='nearest', | |
cmap='hsv', | |
) | |
im2 = ax.imshow(np.flip(np.zeros([512, 512]), axis=0), | |
extent=(x[0, 1], x[0, -1], y[0, 0], y[-1, 0]), | |
interpolation='nearest', cmap='gray') | |
ax.set_xlabel('x') | |
ax.set_ylabel('y') | |
# ax.set_title('Wavefunction') | |
data['plots'] = im, im2 | |
# return im, im2 | |
def func(*args): | |
steps_per_frame = data['steps_per_frame'] | |
step = data['step_function'] | |
t = data['t'] | |
# dt = data['dt'] | |
dt_func=data['dt_func'] | |
potential = data['potential'] | |
im, im2 = data['plots'] | |
should_normalize = False | |
if 'normalize_after_each_step' in data: | |
should_normalize = data['normalize_after_each_step'] | |
for _ in range(steps_per_frame): | |
psi = data['wave_function'] | |
data['wave_function'] = step(psi, t, dt_func(t), potential, | |
should_normalize=should_normalize) | |
data['t'] += dt_func(t) | |
psi_view = np.flip(data['wave_function'], axis=0) | |
im.set_data(np.angle(psi_view)) | |
abs_wavefunc2 = np.abs(psi_view)**2 | |
alpha_map = 5.0*abs_wavefunc2/np.amax(abs_wavefunc2) | |
im.set_alpha(np.where(alpha_map > 1.0, 1.0, alpha_map)) | |
data['frames'] += 1 | |
if data['frames'] % 60 == 0: | |
print('frames: ', data['frames'], '\n', 'time: ', data['t']) | |
return im2, im, | |
def animate(**kw): | |
data['frames'] = 0 | |
fig = plt.figure() | |
data['fig'] = fig | |
for k in kw: | |
data[k] = kw[k] | |
init_func() | |
data['animation']= animation.FuncAnimation(fig, func, | |
frames=60, | |
blit=True, | |
interval=1000.0/60.0, | |
) | |
# plt.show() | |
data['animation'].save('animation.mp4', writer='ffmpeg', fps=30, bitrate=1800) | |
# plt.close() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment