from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import matplotlib.pyplot as plt
import os
fig = plt.figure(figsize=(14, 6))
# 2D plot
ax = fig.add_subplot(1, 2, 1)
x = [1, 0]
y = [0, 1]
ax.plot(x, y, color='black')
ax.set_xticks([0, .5, 1])
ax.set_yticks([0, .5, 1])
ax.set_xlabel("$\pi_1$")
ax.set_ylabel("$\pi_2$")
ax.set_title("(A) 2-probability simplex")
# Ensure the axes limits are [0, 1]
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_aspect("equal")
# 3D plot
ax = fig.add_subplot(1, 2, 2, projection='3d')
x = [1, 0, 0]
y = [0, 1, 0]
z = [0, 0, 1]
verts = [list(zip(x, y, z))]
ax.add_collection3d(Poly3DCollection(verts, alpha=.6, facecolors='black', edgecolors='black'))
ax.view_init(elev=20, azim=10)
# Make axes clearer
ax.set_xticks([0, .5, 1])
ax.set_yticks([0, .5, 1])
ax.set_zticks([0, .5, 1])
ax.set_xlabel("$\pi_1$", labelpad=10)
ax.set_ylabel("$\pi_2$", labelpad=10)
ax.set_zlabel("$\pi_3$", labelpad=10, rotation=0)
# Add axis lines
ax.plot([0, 1], [0, 0], [0, 0], color='black') # x-axis
ax.plot([0, 0], [0, 1], [0, 0], color='black') # y-axis
ax.plot([0, 0], [0, 0], [0, 1], color='black') # z-axis
ax.set_title("(B) 3-probability simplex")
# Remove panes and grid for cleaner look
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor('w')
ax.yaxis.pane.set_edgecolor('w')
ax.zaxis.pane.set_edgecolor('w')
ax.grid(False)
plt.tight_layout()
plt.show()
os.makedirs("Figures", exist_ok=True)
fname = "simplex"
if mode != "png":
os.makedirs(f"Figures/{mode:s}", exist_ok=True)
fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")
os.makedirs("Figures/png", exist_ok=True)
fig.savefig(f"Figures/png/{fname:s}.png")