Skip to content

Instantly share code, notes, and snippets.

@cat-state
Created October 18, 2024 03:15
Show Gist options
  • Save cat-state/e663337d46e3a015e62209d7edf0ef01 to your computer and use it in GitHub Desktop.
Save cat-state/e663337d46e3a015e62209d7edf0ef01 to your computer and use it in GitHub Desktop.
import torch
E = 5e6 # Young's modulus
nu = 0.4 # Poisson's ratio
mu = E / (2 * (1 + nu)) # Shear modulus
lambda_ = (E * nu) / ((1 + nu) * (1 - 2 * nu)) # Lame's first parameter
rest_x = torch.randn(N, 3).cuda()
rest_x = rest_x / rest_x.norm(dim=-1, keepdim=True)
rest_x = (rest_x - rest_x.mean(dim=0))
x = torch.randn(N, 3).cuda() * 0.1 + rest_x
x.requires_grad_()
optim = torch.optim.Adam([x], lr=1e-2)
w = torch.ones(N).cuda()
bonds = torch.stack([torch.randperm(N) for _ in range(8)], dim=1).cuda()
bonds[bonds == torch.arange(N).unsqueeze(1).cuda()] = (bonds[bonds == torch.arange(N).unsqueeze(1).cuda()] + 1) % N
ref_bonds = rest_x[bonds] - rest_x.unsqueeze(1)
inv_rest = torch.linalg.inv(torch.einsum("n,nbx,nby->nxy", w, ref_bonds, ref_bonds))
import time
@torch.compile
def get_energy(x, bonds, ref_bonds, inv_rest, w):
def_bonds = x[bonds] - x.unsqueeze(1)
def_grads = (torch.einsum("n,nbx,nby->nxy", w, def_bonds, ref_bonds)
@ inv_rest)
target_bonds = torch.einsum("nxy,nby->nbx", def_grads, ref_bonds)
deviatoric_energy = ((target_bonds.norm(dim=-1)/ref_bonds.norm(dim=-1) - 1)**2)
isotropic_energy = ((def_bonds.norm(dim=-1)/ref_bonds.norm(dim=-1) - 1)**2)
bond_energy = mu * deviatoric_energy + (lambda_/2) * isotropic_energy
total_energy = torch.einsum("n,nb->n", w, bond_energy)
return total_energy
import matplotlib.pyplot as plt
for i in range(10000):
t0 = time.time()
total_energy = get_energy(x, bonds, ref_bonds, inv_rest, w)
total_energy.sum().backward()
optim.step()
optim.zero_grad()
t1 = time.time()
# print(f"Time: {t1-t0}, Energy: {total_energy.sum()}")
print(f"\rTime: {t1-t0}, Energy: {total_energy.sum()}", end="", flush=True)
if i % 100 == 0: # Plot every 100 iterations
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
# Convert x to CPU and detach from computation graph
x_plot = x.detach().cpu().numpy()
# Plot the points
ax.scatter(x_plot[:, 0], x_plot[:, 1], x_plot[:, 2])
# Set labels and title
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title(f'3D Plot of x at iteration {i}')
# Show the plot
plt.show()
plt.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment