Last active
October 21, 2016 20:37
-
-
Save ehermes/748395f53e036e3363adc6369e0a38a8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import numpy as np | |
class Cell(np.ndarray): | |
def __new__(cls, *args): | |
# User passed a single value, e.g. a scalar or array | |
if len(args) == 1: | |
arg = args[0] | |
# User passed a scalar: make a cubic cell | |
if isinstance(arg, (int, float)): | |
array = arg * np.eye(3, dtype=float) | |
elif isinstance(arg, (list, tuple, np.ndarray)): | |
arg = np.array(arg) | |
# If the user passed 1 or 3 scalars in a list, tuple, or | |
# array, make a cubic (1) or orthorhombic (3) cell | |
if arg.shape == (1,) or arg.shape == (3,): | |
array = arg * np.eye(3, dtype=float) | |
# If the user passed a 3x3 structure, convert it to an array | |
elif arg.shape == (3, 3): | |
array = arg.copy() | |
else: | |
raise ValueError("Could not parse {} as a Cell!".format(arg)) | |
elif len(args) in [1, 3, 9]: | |
for arg in args: | |
if not isinstance(arg, (int, float)): | |
raise ValueError("Could not parse {} as a Cell!".format(args)) | |
if len(args) == 9: | |
array = np.array(args, dtype=float).reshape((3,3)) | |
else: | |
array = np.array(args) * np.eye(3, dtype=float) | |
else: | |
raise ValueError("Could not parse {} as a Cell!".format(args)) | |
return array.view(cls) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment