Last active
April 17, 2024 06:28
-
-
Save andyfaff/5880370330da655291271d3d28cf2f32 to your computer and use it in GitHub Desktop.
Patch for basic JAX usage with Objective/ReflectModel/Structure
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/refnx/analysis/objective.py b/refnx/analysis/objective.py | |
index 9ff56057..c77139ad 100644 | |
--- a/refnx/analysis/objective.py | |
+++ b/refnx/analysis/objective.py | |
@@ -674,8 +674,8 @@ class Objective(BaseObjective): | |
logl += (y - model) ** 2 / var_y | |
# nans play havoc | |
- if np.isnan(logl).any(): | |
- raise RuntimeError("Objective.logl encountered a NaN.") | |
+ # if np.isnan(logl).any(): | |
+ # raise RuntimeError("Objective.logl encountered a NaN.") | |
# add on extra 'potential' terms from the model. | |
extra_potential = self.model.logp() | |
diff --git a/refnx/reflect/reflect_model.py b/refnx/reflect/reflect_model.py | |
index 5cebba6c..34b06949 100644 | |
--- a/refnx/reflect/reflect_model.py | |
+++ b/refnx/reflect/reflect_model.py | |
@@ -514,9 +514,10 @@ class ReflectModel: | |
# fallback to what this object was constructed with | |
x_err = float(self.dq) | |
+ slabs = self.structure.slabs()[..., :4] | |
return reflectivity( | |
x, | |
- self.structure.slabs()[..., :4], | |
+ slabs, | |
scale=self.scale.value, | |
bkg=self.bkg.value, | |
dq=x_err, | |
diff --git a/refnx/reflect/structure.py b/refnx/reflect/structure.py | |
index d3ea2edb..65f78128 100644 | |
--- a/refnx/reflect/structure.py | |
+++ b/refnx/reflect/structure.py | |
@@ -320,12 +320,12 @@ class Structure(UserList): | |
# if all the interfaces are Gaussian, then simply concatenate | |
# the default slabs property of each component. | |
sl = [c.slabs(structure=self) for c in self.components] | |
- | |
+ import jax.numpy as jnp | |
try: | |
- slabs = np.concatenate(sl) | |
+ slabs = jnp.concatenate(sl) | |
except ValueError: | |
# some of slabs may be None. np can't concatenate arr and None | |
- slabs = np.concatenate([s for s in sl if s is not None]) | |
+ slabs = jnp.concatenate([s for s in sl if s is not None]) | |
else: | |
# there is a non-default interfacial roughness, create a microslab | |
# representation | |
@@ -912,9 +912,12 @@ class SLD(Scatterer): | |
return f"SLD([{self.real!r}, {self.imag!r}], name={self.name!r})" | |
def __complex__(self): | |
- sldc = complex(self.real.value, self.imag.value) | |
+ sldc = self.real.value + self.imag.value * 1j | |
return sldc | |
+ def complex(self): | |
+ return self.real.value + self.imag.value * 1j | |
+ | |
@property | |
def parameters(self): | |
""" | |
@@ -1289,22 +1292,22 @@ class Slab(Component): | |
Slab representation of this component. See :class:`Component.slabs` | |
""" | |
# speculative shortcut to prevent a number of attribute retrievals | |
+ import jax.numpy as jnp | |
if self.sld.dispersive: | |
sldc = self.sld.complex(getattr(structure, "wavelength", None)) | |
else: | |
- sldc = complex(self.sld) | |
+ sldc = self.sld.complex() | |
- return np.array( | |
+ return jnp.array( | |
[ | |
[ | |
- self.thick.value, | |
+ self.thick._value, | |
sldc.real, | |
sldc.imag, | |
- self.rough.value, | |
- self.vfsolv.value, | |
+ self.rough._value, | |
+ self.vfsolv._value, | |
] | |
- ], | |
- dtype=float, | |
+ ] | |
) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment