Last active
August 6, 2024 19:26
-
-
Save PageotD/0f677fe99ebd382927de918762968fa2 to your computer and use it in GitHub Desktop.
Eikonal
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import heapq as hq | |
import numpy as np | |
import numpy.ma as ma | |
class Eikonal: | |
""" | |
Eikonal solver. | |
References | |
---------- | |
[1] R. Kimmel and J.A. Sethian (1996). Fast Marching Methods for Computing Distance Maps and | |
Shortest Paths. | |
""" | |
def __init__(self, model, dh, src, rec): | |
self.model, self.nx, self.nz = self.validate_model(model) | |
self.dh = self.validate_stepsize(dh) | |
self.src = self.validate_sources(src) | |
self.rec = self.validate_receivers(rec) | |
def validate_model(self, model): | |
# Check that the model is 2D | |
try: | |
assert model.ndim == 2 | |
nx, nz = model.shape | |
except AssertionError: | |
raise AssertionError("model must be a 2D array") from None | |
# Check that the model does not have negative values | |
try: | |
assert np.amin(model) >= 0. | |
except AssertionError: | |
raise AssertionError("model must have positive values") from None | |
return model, nx, nz | |
def validate_stepsize(self, dh): | |
# Check that dh is a positive number | |
try: | |
assert dh > 0 | |
except AssertionError: | |
raise AssertionError("dh must be a positive values") from None | |
return dh | |
def validate_sources(self, src): | |
# Check that src is a 2D numpy array | |
try: | |
assert isinstance(src, np.ndarray) | |
assert src.ndim == 2 | |
#self.src = src | |
except AssertionError: | |
raise AssertionError("src must be a 2D numpy array") from None | |
# Check that source positions are inside the model | |
try: | |
xmax = float((self.nx-1)*self.dh) | |
zmax = float((self.nz-1)*self.dh) | |
for isrc in range(len(src)): | |
assert (0 <= src[isrc][0] <= xmax and 0 <= src[isrc][1] <= zmax) | |
except AssertionError: | |
raise AssertionError("source coordinates ({}, {}) are outside the model.".format(src[isrc][0], src[isrc][1])) | |
return src | |
def validate_receivers(self, rec): | |
# Check that rec is a 2D numpy array | |
try: | |
assert isinstance(rec, np.ndarray) | |
assert rec.ndim == 2 | |
except AssertionError: | |
raise AssertionError("rec must be a 2D numpy array") from None | |
# Check that source positions are inside the model | |
try: | |
xmax = float((self.nx-1)*self.dh) | |
zmax = float((self.nz-1)*self.dh) | |
for irec in range(len(rec)): | |
assert (0 <= rec[irec][0] <= xmax and 0 <= rec[irec][1] <= zmax) | |
except AssertionError: | |
raise AssertionError("receiver coordinates ({}, {}) are outside the model.".format(rec[irec][0], rec[irec][1])) | |
return rec | |
def initTimeMap(self): | |
""" | |
Initialize the calculated time map. | |
All points are considered as Far Away points (value = infinity) | |
""" | |
return np.ones_like(self.model)*np.inf | |
def initTimeMask(self, timeMap): | |
""" | |
Initialize the calculated time map mask | |
""" | |
return ma.masked_array(timeMap, mask=True) | |
def updateTime(self, ttx, ttz, vel): | |
""" | |
Calculate the traveltime at a given point. | |
""" | |
if (abs(ttx-ttz) < self.dh/vel): | |
S = 2.*(self.dh/vel)**2 - (ttx-ttz)**2 | |
return (ttx + ttz + np.sqrt(S) )/ 2. | |
else: | |
return np.min((ttz, ttx)) + self.dh/vel | |
def initFront(self): | |
""" | |
Initialize front points list (narrow band) | |
""" | |
narrowBand = [] | |
hq.heapify(narrowBand) | |
return narrowBand | |
def updateFront(self, narrowBand, value, coords): | |
""" | |
Update the front points list | |
""" | |
# Get indices on the grid of the front points | |
indices = [frontPoint[1] for frontPoint in narrowBand] | |
# If neighbor is in the list of Front points | |
if coords in indices: | |
narrowBand[indices.index(coords)] = (value, coords) | |
# If neighbor is NOT in the list of Front points | |
else: | |
hq.heappush(narrowBand, (value, coords)) | |
return narrowBand | |
def initSource(self, timeMap, timeMask, narrowBand): | |
""" | |
Initialize maps and front points lsit with the positions | |
""" | |
# Initialize arrays for the source positions | |
# Grid points at source positions are considered Alive and will be components | |
# of the initial Narrow Band | |
for isrc in range(len(self.src)): | |
isrc = (int(self.src[isrc][0]/self.dh), int(self.src[isrc][1]/self.dh)) | |
timeMap[isrc[0], isrc[1]] = 0. | |
timeMask[isrc[0], isrc[1]] = False | |
# Add source position to the Narrow Band (front points) | |
hq.heappush(narrowBand, (0., isrc)) | |
return timeMap, timeMask, narrowBand | |
def run(self): | |
""" | |
The Eikonal solver | |
""" | |
# Initialize calculated traveltime map | |
ttCalc = self.initTimeMap() | |
# Apply a mask on ttcalc | |
ttMask = self.initTimeMask(ttCalc) | |
# Initialize front points list (narrowband) | |
narrowBand = self.initFront() | |
ttCalc, ttMask, narrowBand = self.initSource(ttCalc, ttMask, narrowBand) | |
# Loop over points in the narrow band | |
# True if narrowBand list not empty | |
while narrowBand: | |
# Get the coordinates of the point in the Narrow Band with the smallest value for traveltime ttime | |
nbtime, (inbx, inbz) = hq.heappop(narrowBand) | |
# Get velocity at the selected narrow band point position | |
vel = self.model[inbx, inbz] | |
# Tag the neighbor points | |
ngb = [ | |
(inbx-1, inbz), (inbx+1, inbz), | |
(inbx, inbz-1), (inbx, inbz+1) | |
] | |
# Loop over neighbors | |
for (ixn, izn) in ngb: | |
# Check if neighbors are in model | |
if( not(0 <= ixn <= self.nx-1 and 0 <= izn <= self.nz-1) or ttMask.mask[ixn, izn] == False): | |
continue | |
if ixn == 0: | |
ttx = ttCalc[ixn+1, izn] | |
elif ixn == self.nx-1: | |
ttx = ttCalc[ixn-1, izn] | |
else: | |
ttx = np.min((ttCalc[ixn-1, izn], ttCalc[ixn+1, izn])) | |
if izn == 0: | |
ttz = ttCalc[ixn, izn+1] | |
elif izn == self.nz-1: | |
ttz = ttCalc[ixn, izn-1] | |
else: | |
ttz = np.min((ttCalc[ixn, izn-1], ttCalc[ixn, izn+1])) | |
# Calculate traveltime | |
ttCalc[ixn, izn] = self.updateTime(ttx, ttz, vel) | |
# Update narrow band | |
narrowBand = self.updateFront(narrowBand, ttCalc[ixn, izn], (ixn, izn)) | |
# Move the point from Far Away status to Alive status | |
ttMask.mask[inbx, inbz] = False | |
# Sort the Narrow Band points by increasing value of T | |
narrowBand.sort() | |
return ttCalc | |
def compute_gradient(ttCalc, dh): | |
""" | |
Compute the gradient of the traveltime map. | |
""" | |
gradient_x, gradient_z = np.gradient(ttCalc, dh, dh) | |
return gradient_x, gradient_z | |
def trace_raypath(ttCalc, start, gradient_x, gradient_z, dh, max_steps=1000): | |
""" | |
Trace the raypath from a given starting point using the gradients. | |
""" | |
path = [start] | |
current_pos = np.array(start) | |
for _ in range(max_steps): | |
ix, iz = int(current_pos[0]), int(current_pos[1]) | |
if ix < 0 or ix >= ttCalc.shape[0] or iz < 0 or iz >= ttCalc.shape[1]: | |
break # Exit if we go out of bounds | |
grad_x = gradient_x[ix, iz] | |
grad_z = gradient_z[ix, iz] | |
# Move in the direction of the gradient | |
if grad_x != 0 or grad_z != 0: | |
step_length = dh | |
grad_norm = np.sqrt(grad_x**2 + grad_z**2) | |
grad_x /= grad_norm | |
grad_z /= grad_norm | |
next_pos = current_pos - step_length * np.array([grad_x, grad_z]) | |
path.append(next_pos.tolist()) | |
current_pos = next_pos | |
else: | |
break # Stop if gradient is zero (i.e., at a local extremum) | |
return np.array(path) | |
if __name__ == "__main__": | |
import matplotlib.pyplot as plt | |
# Create a model | |
nx = 201 | |
nz = 201 | |
dh = 1.0 | |
vmodel = np.ones((nx, nz), dtype=np.float32)*100. | |
vmodel[:, :20] = 200. | |
for ix in range(nx): | |
for iz in range(nz): | |
Amp = 10. | |
H = np.sin(2.*float(ix)/float(nx)*2.*np.pi) | |
V = np.sin(2.*float(iz)/float(nz)*2.*np.pi) | |
vmodel[ix, iz] = H * V * Amp | |
if np.abs(vmodel[ix, iz]) > 1.: | |
vmodel[ix, iz] = np.sign(vmodel[ix, iz]) * 1. | |
#if vmodel[ix, iz] < -1.: | |
# vmodel[ix, iz] = -1. | |
vmodel[ix, iz] = vmodel[ix, iz]*200.+500. | |
# Source positions (REAL COORDINATES) | |
src = np.array([ | |
[50., 50.], | |
[150., 150.] | |
]) | |
# Receiver positions | |
rec = np.array([ | |
[10., 1.], | |
[20., 1.], | |
[30., 1.], | |
[40., 1.], | |
[50., 1.], | |
[60., 1.], | |
[70., 1.], | |
[80., 1.], | |
[90., 1.] | |
]) | |
# Initialize FM Eikonal Solver | |
eikorun = Eikonal(vmodel, dh, src, rec) | |
# Run FM Eikonal solver | |
ttCalc = eikorun.run() | |
# Extract first arrival time at receiver positions | |
arrival = [] | |
for irec in range(len(rec)): | |
ix = int(rec[irec][0]/dh) | |
iz = int(rec[irec][1]/dh) | |
arrival.append(ttCalc[ix, iz]) | |
# Plot time map with traveltime contours | |
plt.subplot(121) | |
plt.xlabel("Distance (m)") | |
plt.ylabel("Depth (m)") | |
plt.imshow(ttCalc.T, interpolation='bilinear', cmap = 'viridis_r', aspect="auto", extent=[0, float(nx-1)*dh, float(nz-1)*dh, 0]) | |
cbar = plt.colorbar() | |
cbar.set_label("Traveltime (s)") | |
contour = plt.contour(ttCalc.T, 10, colors = 'k', linewidths = 0.5) | |
plt.clabel(contour, inline=1, fontsize=6) | |
for isrc in range (len(src)): | |
plt.scatter(src[isrc][0], src[isrc][1], c="red", marker='*') | |
# Plot velocity model with traveltime contours | |
plt.subplot(122) | |
plt.xlabel("Distance (m)") | |
plt.ylabel("Depth (m)") | |
plt.imshow(vmodel.T, interpolation='bilinear', cmap = 'gray', aspect="auto", extent=[0, float(nx-1)*dh, float(nz-1)*dh, 0]) | |
cbar = plt.colorbar() | |
cbar.set_label("P-wave velocity (m/s)") | |
contour = plt.contour(ttCalc.T, 10, colors = 'red', linewidths = 0.5) | |
plt.clabel(contour, inline=1, fontsize=6) | |
for isrc in range (len(src)): | |
plt.scatter(src[isrc][0], src[isrc][1], c="red", marker='*') | |
plt.show() | |
# Compute gradients | |
gradient_x, gradient_z = compute_gradient(ttCalc, dh) | |
# Plot the traveltime map with raypaths | |
plt.subplot(121) | |
plt.xlabel("Distance (m)") | |
plt.ylabel("Depth (m)") | |
plt.imshow(ttCalc.T, interpolation='bilinear', cmap='viridis_r', aspect="auto", extent=[0, float(nx-1)*dh, float(nz-1)*dh, 0]) | |
cbar = plt.colorbar() | |
cbar.set_label("Traveltime (s)") | |
contour = plt.contour(ttCalc.T, 10, colors='k', linewidths=0.5) | |
plt.clabel(contour, inline=1, fontsize=6) | |
for isrc in range(len(src)): | |
plt.scatter(src[isrc][0], src[isrc][1], c="red", marker='*') | |
# Trace and plot raypaths for each receiver | |
for rec_pos in rec: | |
ix = int(rec_pos[0] / dh) | |
iz = int(rec_pos[1] / dh) | |
receiver_path = trace_raypath(ttCalc, [ix, iz], gradient_x, gradient_z, dh) | |
# Convert to real coordinates and plot | |
plt.plot(receiver_path[:, 0] * dh, receiver_path[:, 1] * dh, 'r-', lw=1) | |
# Plot the velocity model with traveltime contours | |
plt.subplot(122) | |
plt.xlabel("Distance (m)") | |
plt.ylabel("Depth (m)") | |
plt.imshow(vmodel.T, interpolation='bilinear', cmap='gray', aspect="auto", extent=[0, float(nx-1)*dh, float(nz-1)*dh, 0]) | |
cbar = plt.colorbar() | |
cbar.set_label("P-wave velocity (m/s)") | |
contour = plt.contour(ttCalc.T, 10, colors='red', linewidths=0.5) | |
plt.clabel(contour, inline=1, fontsize=6) | |
for isrc in range(len(src)): | |
plt.scatter(src[isrc][0], src[isrc][1], c="red", marker='*') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment