Skip to content

Instantly share code, notes, and snippets.

@ehermes
Last active October 21, 2016 20:37
Show Gist options
  • Save ehermes/748395f53e036e3363adc6369e0a38a8 to your computer and use it in GitHub Desktop.
Save ehermes/748395f53e036e3363adc6369e0a38a8 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import numpy as np
class Cell(np.ndarray):
def __new__(cls, *args):
# User passed a single value, e.g. a scalar or array
if len(args) == 1:
arg = args[0]
# User passed a scalar: make a cubic cell
if isinstance(arg, (int, float)):
array = arg * np.eye(3, dtype=float)
elif isinstance(arg, (list, tuple, np.ndarray)):
arg = np.array(arg)
# If the user passed 1 or 3 scalars in a list, tuple, or
# array, make a cubic (1) or orthorhombic (3) cell
if arg.shape == (1,) or arg.shape == (3,):
array = arg * np.eye(3, dtype=float)
# If the user passed a 3x3 structure, convert it to an array
elif arg.shape == (3, 3):
array = arg.copy()
else:
raise ValueError("Could not parse {} as a Cell!".format(arg))
elif len(args) in [1, 3, 9]:
for arg in args:
if not isinstance(arg, (int, float)):
raise ValueError("Could not parse {} as a Cell!".format(args))
if len(args) == 9:
array = np.array(args, dtype=float).reshape((3,3))
else:
array = np.array(args) * np.eye(3, dtype=float)
else:
raise ValueError("Could not parse {} as a Cell!".format(args))
return array.view(cls)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment