Skip to content

Instantly share code, notes, and snippets.

Last active April 15, 2019 03:12
Show Gist options
  • Save randvoorhies/cf8dcc5e7c8180f7fbe71e9cf07162b8 to your computer and use it in GitHub Desktop.
Save randvoorhies/cf8dcc5e7c8180f7fbe71e9cf07162b8 to your computer and use it in GitHub Desktop.
Simple 3D lattice generator
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
np.set_printoptions(precision=3, suppress=True)
alpha, beta, gamma = np.deg2rad([60, 70, 80])
a, b, c = 1, 1, 1
n_x, n_y, n_z = 5, 5, 5
points = np.zeros((n_x * n_y * n_z, 3))
def get_index(x, y, z):
return x + y * n_x + z * n_x * n_y
def main():
for z in range(0, n_z):
# Compute the location of the root node for this z-layer
if z > 0:
parent_index = get_index(0, 0, z - 1)
child_index = get_index(0, 0, z)
# Compute the length of the vector projection from the parent to the child onto
# the child's z-layer
l_p = np.cos(gamma) * c
# Note that this transform probably isn't exactly what you're looking for.
# In your notebook, you drew beta and gamma as the angles between A/C and B/C respectively.
# In this formulation, beta is the angle between A and the "shadow" of C on the A/B plane.
# Gamma is the angle between the A/B plane and C. I did this because it's way easier for me
# to think about, and the rest is left as an exercise for the reader ;)
points[child_index] = (points[parent_index] +
[l_p * np.cos(beta), l_p * np.sin(beta), c * np.sin(gamma)])
# Compute the location of the (x, 0) nodes for this z-layer
for x in range(1, n_x):
parent_index = get_index(x - 1, 0, z)
child_index = get_index(x, 0, z)
points[child_index] = points[parent_index] + [a, 0, 0]
# Compute the rest of the nodes for this z-layer
for x in range(0, n_x):
for y in range(1, n_y):
parent_index = get_index(x, y - 1, z)
child_index = get_index(x, y, z)
points[child_index] = points[parent_index] + [b * np.cos(alpha), b * np.sin(alpha), 0]
# Create a new 3d plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# Plot the nodes
ax.scatter(points[:, 0], points[:, 1], points[:, 2], c=points[:, 2])
# Plot the edges
for z in range(n_z):
for y in range(n_x):
for x in range(n_y):
r = points[get_index(x, y, z)]
if x + 1 < n_x:
c1 = points[get_index(x + 1, y, z)]
ax.plot([r[0], c1[0]], [r[1], c1[1]], [r[2], c1[2]], 'k', lineWidth=.25)
if y + 1 < n_y:
c2 = points[get_index(x, y + 1, z)]
ax.plot([r[0], c2[0]], [r[1], c2[1]], [r[2], c2[2]], 'k', lineWidth=.25)
if z + 1 < n_z:
c3 = points[get_index(x, y, z + 1)]
ax.plot([r[0], c3[0]], [r[1], c3[1]], [r[2], c3[2]], 'k', lineWidth=.25)
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment