Skip to content

Instantly share code, notes, and snippets.

@marcocamma
Last active June 17, 2020 12:33
Show Gist options
  • Save marcocamma/76053e9c516cef1b9a7c03b76add92b7 to your computer and use it in GitHub Desktop.
Save marcocamma/76053e9c516cef1b9a7c03b76add92b7 to your computer and use it in GitHub Desktop.
Simple Diffraction by slits using XRT
import matplotlib as mpl
mpl.use('Agg')
import argparse
import numpy as np
from matplotlib import pyplot as plt
import xrt.backends.raycing as raycing
import xrt.backends.raycing.sources as rsource
import xrt.backends.raycing.screens as rscreens
import xrt.backends.raycing.oes as roes
import xrt.backends.raycing.apertures as rslits
import xrt.backends.raycing.materials as rmats
import xrt.backends.raycing.waves as rw
import xrt.plotter as xrtp
import xrt.runner as xrtr
SAVEFOLDER = "./data"
nbins = 512
xbins = zbins = nbins
xppb, zppb = int(512 / xbins), int(512 / zbins)
xppb = max(1, xppb)
zppb = max(1, zppb)
xmesh = np.linspace(-0.3, 0.3, xbins)
zmesh = np.linspace(-0.3, 0.3, zbins)
def get_lims(bl,obj_or_dist):
if isinstance(obj_or_dist, (int, float)):
dist = obj_or_dist
else:
dist = obj_or_dist.center[1]
if dist == 0:
return (0.2,0.2)
else:
sx,sz = bl.slitsize
return 40e-6*(dist-bl.slit1.center[1])+sx+0.1,40e-6*(dist-bl.slit1.center[1])+sz+0.1
def prepare_mesh(bl,obj_or_dist):
sx,sz = get_lims(bl,obj_or_dist)
xmesh = np.linspace(-sx/2, sx/2, xbins)
zmesh = np.linspace(-sz/2, sz/2, zbins)
M = np.meshgrid(xmesh, zmesh)
return xmesh, zmesh, M
def get_undulator(bl=None, nrays=100000,zero_emittance=False,zero_espread=False):
if bl is None:
bl = raycing.BeamLine()
kwargs = dict(
bl=bl,
targetE=(8000, 1),
eMin=7999,
eMax=8001,
nrays=nrays,
eE=6,
eI=0.2,
eSigmaX=30.0,
eSigmaZ=3.6,
eEpsilonX=0.132,
eEpsilonZ=0.005,
eEspread=0.00094,
period=20,
n=125,
K=2.1,
xPrimeMax=1 / 30,
zPrimeMax=1 / 30,
xPrimeMaxAutoReduce=False,
filamentBeam=True,
uniformRayDensity=False,
R0=28e3,
distE="BW",
)
if zero_emittance:
for key in ["eSigmaX", "eSigmaZ", "eEpsilonX", "eEpsilonZ"]:
kwargs[key] = 0
if zero_espread:
kwargs['eEspread'] = 0
return rsource.Undulator(**kwargs)
def build_beamline(nrays=1e3,
slitsize = [0.04,0.3],
zero_emittance=False,
zero_espread=False):
if isinstance(slitsize, (int,float) ): slitsize = [slitsize,slitsize]
bl = raycing.BeamLine()
undulator = get_undulator(bl,nrays=nrays,zero_emittance=zero_emittance,zero_espread=zero_espread)
bl.source = undulator
slit1 = rslits.RectangularAperture(
bl,
"slit1",
(0, 30000, 0),
("left", "right", "bottom", "top"),
[-slitsize[0]/2, slitsize[0]/2, -slitsize[1]/2, slitsize[1]/2],
)
bl.slit1 = slit1
bl.slitsize = slitsize
bl.fname_append_string = "_slitsize_%.0fx%.0fum_isZeroEmitt_%s_isZeroESpread_%s"%(slitsize[0]*1e3,slitsize[1]*1e3,zero_emittance,zero_emittance)
screen_distances = [33,70,100,200]
bl.screen_distances = screen_distances
bl.screen_names = ["screen_at_%.0fm"%dist for dist in screen_distances]
for dist,name in zip(screen_distances,bl.screen_names):
screen = rscreens.Screen(bl, name, (0, dist*1e3, 0))
return bl
def run_process(bl):
nrays = bl.source.nrays
wave_on_slit1 = bl.slit1.prepare_wave(bl.source, nrays)
wave_dict = dict()
for screen in bl.screens:
xm, zm, _ = prepare_mesh(bl,screen)
wave = screen.prepare_wave(bl.slit1, xm, zm)
wave_dict[screen.name] = wave
wave_source = bl.source.shine(wave=wave_on_slit1, fixedEnergy=bl.source.E1)
for wave in wave_dict.values():
rw.diffract(wave_on_slit1, wave)
return dict(bl=bl,
wave_on_slit1=wave_on_slit1,
**wave_dict)
def define_plot(name,bl,dist=0,is_for_PCA=True):
sx,sz = get_lims(bl,dist)
sx = [-sx*1e3/2, sx*1e3/2]
sz = [-sz*1e3/2, sz*1e3/2]
kw = dict(
aspect="auto",
xaxis=xrtp.XYCAxis(r"$x$", u"µm", limits=sx, bins=xbins, ppb=xppb),
yaxis=xrtp.XYCAxis(r"$z$", u"µm", limits=sz, bins=zbins, ppb=zppb),
title = name,
persistentName = SAVEFOLDER + "/" + name + bl.fname_append_string + ".npy"
)
if is_for_PCA: kw['fluxKind']="EsPCA"
plot = xrtp.XYCPlot(name,**kw)
return plot
def define_plots(bl):
plots = []
plot = define_plot("wave_on_slit1",bl,dist=bl.slit1.center[1],is_for_PCA=True)
plots.append(plot)
for i,name in enumerate(bl.screen_names):
plot = define_plot(name,bl,dist=bl.screens[i].center[1],is_for_PCA=True)
plots.append(plot)
return plots
def main(repeats=100,nrays=1e3,slitsize=[0.04,0.3],zero_espread=False,zero_emittance=False,restart=False):
bl = build_beamline(nrays=nrays,slitsize=slitsize,zero_espread=zero_espread,zero_emittance=zero_emittance)
plots = define_plots(bl)
raycing.run.run_process = run_process
xrtr.run_ray_tracing(plots, repeats=repeats, beamLine=bl)
for plot in plots:
if not hasattr(plot,"total4D"):
continue
name = plot.beam
data = plot.total4D
fname = SAVEFOLDER + "/" + name+bl.fname_append_string+"_total4D.npy"
par_fname = SAVEFOLDER + "/" + name+bl.fname_append_string+"_pars.npz"
if not restart:
try:
olddata = np.load(fname)
data = np.concatenate( (olddata,data), axis=0)
except Exception as e:
print("Failed to append data to previous one, error was",str(e))
np.save(fname,data)
np.savez(par_fname,
repeats=repeats,
nrays=nrays,
slitsize=slitsize,
zero_espread=zero_espread,
zero_emittance=zero_emittance,
restart=restart,
xbinedges=plot.xaxis.binEdges,
zbinedges=plot.yaxis.binEdges,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Program to calculate wave propagation",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nrays", type=float, default=1e6, help="number of rays to use"
)
parser.add_argument(
"--nelectrons", type=float, default=100, help="number of electrons (repeats)"
)
parser.add_argument(
"--zero_espread", action="store_true", help="Set rms electron energy spread to zero"
)
parser.add_argument(
"--zero_emittance", action = "store_true", help="Set rms electron emittance to zero (size and divergence)"
)
parser.add_argument(
"--restart", action="store_true", help="Overwrite previous total4D files instead of appending"
)
parser.add_argument("--slitsize", type=str, default="0.04,0.3", help="slit gap (mm)")
args = parser.parse_args()
slitsize = args.slitsize.split(",")
slitsize = [float(g) for g in slitsize]
if len(slitsize) == 1: slitsize = slitsize[0]
main(repeats=int(args.nelectrons),
slitsize=slitsize,
nrays=int(args.nrays),
zero_espread=args.zero_espread,
zero_emittance=args.zero_emittance,
restart= args.restart
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment