Skip to content

Instantly share code, notes, and snippets.

@manifoldhiker
Last active January 26, 2024 11:20
Show Gist options
  • Save manifoldhiker/568485cc403e7f21b29b2e3c0fd0d7ed to your computer and use it in GitHub Desktop.
Save manifoldhiker/568485cc403e7f21b29b2e3c0fd0d7ed to your computer and use it in GitHub Desktop.
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
def ms(x, y, z, radius, resolution=20):
"""Return the coordinates for plotting a sphere centered at (x,y,z)"""
u, v = np.mgrid[0:2*np.pi:resolution*2j, 0:np.pi:resolution*1j]
X = radius * np.cos(u)*np.sin(v) + x
Y = radius * np.sin(u)*np.sin(v) + y
Z = radius * np.cos(v) + z
return (X, Y, Z)
def make_vector(ray, color='red', name=None):
o, d = ray[:3], ray[3:]
return go.Scatter3d( x = [o[0],d[0]],
y = [o[1],d[1]],
z = [o[2],d[2]],
name=name,
marker = dict( size = 5,
color = color),
line = dict( color = color,
width = 6)
)
def scatter3d(xyz, color='red', size=1.):
return go.Scatter3d(x=xyz[:,0], y=xyz[:,1], z=xyz[:,2],
mode='markers+text', marker={"size": size, 'color': color})
def plane_to_3d(points_2d, plane_normal):
plane_normal = torch.tensor(plane_normal.reshape(3,1), dtype=torch.float32)
f = plane_normal
r = F.normalize(torch.cross(f, torch.tensor([0,1.,0.1]).unsqueeze(1)), dim=1)
u = torch.cross(f, r)
points_2d = torch.tensor(points_2d, dtype=torch.float32)
points_3d = points_2d[:,0] * r + points_2d[:,1] * u + f
return points_3d.T
def my_draw_pointclouds(*point_clouds, size=1., show_axes=True):
# Convert PyTorch tensors to NumPy arrays
traces = []
for point_cloud, color in zip(point_clouds, ['blue', 'red', 'green', 'orange', 'yellow']):
point_cloud = point_cloud.cpu().detach().numpy()
# Create traces for each set of points
trace = go.Scatter3d(
x=point_cloud[:, 0],
y=point_cloud[:, 1],
z=point_cloud[:, 2],
mode='markers',
marker=dict(
size=size,
color=color,
opacity=0.8
)
)
traces.append(trace)
# Create the layout
layout = go.Layout(
scene=dict(
xaxis=dict(title='X'),
yaxis=dict(title='Y'),
zaxis=dict(title='Z')
),
width=800, height=400,
)
if show_axes:
traces += [make_vector([0,0,0,1,0,0], 'red', 'x'), make_vector([0,0,0,0,1,0], 'blue', 'y'), make_vector([0,0,0,0,0,1], 'green', 'z')]
fig = go.Figure(data=traces, layout=layout)
# Show the figure
fig.show()
@manifoldhiker
Copy link
Author

manifoldhiker commented Nov 24, 2022

data = []

# Old coordinates
data += [make_vector([0,0,0,1,0,0], 'red'), make_vector([0,0,0,0,1,0], 'blue'), make_vector([0,0,0,0,0,1], 'green')]

(x_pns_surface, y_pns_surface, z_pns_suraface) = ms(0, 0, 0, 1)
data.append(go.Surface(x=x_pns_surface, y=y_pns_surface, z=z_pns_suraface, opacity=0.1))

data += [scatter3d(xyz_plot, color_plot)]

fig = go.Figure(data=data)
fig.show()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment