Last active
June 25, 2022 11:16
-
-
Save dongkwan-kim/f9cba350c7df138a0f0a7848baff31d5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple, List | |
try: | |
import matplotlib.pyplot as plt | |
from matplotlib import cm | |
from matplotlib.ticker import LinearLocator | |
from mpl_toolkits.mplot3d import Axes3D | |
except ImportError: | |
pass | |
import numpy as np | |
import pandas as pd | |
def create_fake_data(crossed_z=True): | |
X = [0.1, 0.3, 0.5, 0.7, 0.9] | |
Y = [1.0, 2.0, 4.0, 8.0, 16.0] | |
X, Y = np.meshgrid(X, Y) | |
Z1 = (X - 0.9) ** 2 + Y / 10 | |
if crossed_z: | |
Z2 = X ** 2 + Y / 10 | |
else: | |
Z2 = Z1 - 0.4 | |
return X, Y, Z1, Z2 | |
def format_custom_data(X, Y, Z_list: List) -> Tuple: | |
""" | |
:param X: e.g., [0.1, 0.3, 0.5, 0.7, 0.9] | |
:param Y: e.g., [1.0, 2.0, 4.0, 8.0, 16.0] | |
:param Z_list: | |
e.g., Z might be a matrix of shape (N_Y, N_X) | |
where meshed X is, | |
[[0.1 0.3 0.5 0.7 0.9] | |
[0.1 0.3 0.5 0.7 0.9] | |
[0.1 0.3 0.5 0.7 0.9] | |
[0.1 0.3 0.5 0.7 0.9] | |
[0.1 0.3 0.5 0.7 0.9]] | |
and meshed Y is, | |
[[ 1. 1. 1. 1. 1.] | |
[ 2. 2. 2. 2. 2.] | |
[ 4. 4. 4. 4. 4.] | |
[ 8. 8. 8. 8. 8.] | |
[16. 16. 16. 16. 16.]] | |
""" | |
X, Y = np.meshgrid(X, Y) | |
Z_list = [np.asarray(z) for z in Z_list] | |
assert X.shape == Y.shape == Z_list[0].shape | |
return tuple([X, Y, *Z_list]) | |
def table_to_custom_xyz1z2_data(table: List[List[float]]) -> Tuple: | |
df = pd.DataFrame(table, columns=["X", "Y", "Z1", "Z2"]) | |
X, Y = sorted(pd.unique(df.X)), sorted(pd.unique(df.Y)) | |
idx_x = {v: i for i, v in enumerate(X)} | |
idx_y = {v: i for i, v in enumerate(Y)} | |
Z1 = np.zeros((len(Y), len(X))) - 1 | |
Z2 = np.zeros((len(Y), len(X))) - 1 | |
for _, row in df.iterrows(): | |
ix, iy = idx_x[row.X], idx_y[row.Y] | |
Z1[(iy, ix)] = row.Z1 | |
Z2[(iy, ix)] = row.Z2 | |
return format_custom_data(X, Y, Z_list=[Z1, Z2]) | |
def plot_two_surfaces_3d(X, Y, Z1, Z2, path=None): | |
fig, ax = plt.subplots(subplot_kw={"projection": "3d"}) | |
def set_wireframe_and_surface(x, y, z, wireframe_cm, surface_rgb): | |
colors = wireframe_cm(plt.Normalize(z.min(), z.max())(z)) | |
surf = ax.plot_surface(x, y, z, | |
facecolors=colors, | |
rstride=1, cstride=1, shade=True, linewidth=1) | |
surf.set_facecolor(surface_rgb) | |
set_wireframe_and_surface(X, Y, Z1, cm.spring, (1, 0.9, 0.9, 0.3)) | |
set_wireframe_and_surface(X, Y, Z2, cm.winter, (0.9, 0.9, 1, 0.3)) | |
# Customize the z axis. | |
# ax.set_zlim(-1.01, 1.01) | |
ax.zaxis.set_major_locator(LinearLocator(10)) | |
# A StrMethodFormatter is used automatically | |
ax.zaxis.set_major_formatter('{x:.02f}') | |
plt.tight_layout() | |
if path is not None: | |
plt.savefig(path) | |
plt.show() | |
if __name__ == '__main__': | |
FROM = "TABLE" | |
if FROM == "MESH": | |
z = np.asarray([[1.1, 1.3, 1.5, 1.7, 1.9], | |
[2.1, 2.3, 2.5, 2.7, 2.9], | |
[4.1, 4.3, 4.5, 4.7, 4.9], | |
[8.1, 8.3, 8.5, 8.7, 8.9], | |
[16.1, 16.3, 16.5, 16.7, 16.9]]) | |
plot_two_surfaces_3d( | |
*format_custom_data( | |
X=[0.1, 0.3, 0.5, 0.7, 0.9], | |
Y=[1.0, 2.0, 4.0, 8.0, 16.0], | |
Z_list=[z, z / 2], | |
), | |
path="./3d_mesh.pdf", | |
) | |
elif FROM == "TABLE": | |
plot_two_surfaces_3d( | |
*table_to_custom_xyz1z2_data( | |
# This is the table of | |
# X Y Z1 Z2 | |
[[0.1, 1.0, 0.0, 0.1], | |
[0.1, 2.0, 0.3, 0.4], | |
[0.5, 1.0, 0.6, 0.7], | |
[0.5, 2.0, 0.9, 1.0], | |
[0.7, 1.0, 0.6, 0.7], | |
[0.7, 2.0, 0.9, 1.0]] | |
), | |
path="./3d_table.pdf", | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment