6.6 Out-of-sample embedding

6.6 Out-of-sample embedding#

mode = "svg"

import matplotlib

font = {'family' : 'Dejavu Sans',
        'weight' : 'normal',
        'size'   : 20}

matplotlib.rc('font', **font)

import matplotlib
from matplotlib import pyplot as plt
import numpy as np
from graspologic.simulations import sbm

# the in-sample nodes
n = 100
nk = 50
# the out-of-sample nodes
np1 = 1; np2 = 2
B = np.array([[0.6, 0.2], [0.2, 0.4]])
# sample network
np.random.seed(0)
A, zs = sbm([nk + np1, nk + np2], B, return_labels=True)
from graspologic.utils import remove_vertices

# the indices of the out-of-sample nodes
oos_idx = [nk, nk + np1 + nk, nk + np1 + nk + 1]
# get adjacency matrix and the adjacency vectors A prime
Ain, Aoos = remove_vertices(A, indices=oos_idx, return_removed=True)
from graspologic.embed import AdjacencySpectralEmbed as ase

oos_embedder = ase()
# estimate latent positions for the in-sample nodes
# using the subnetwork induced by the in-sample nodes
Xhat_in = oos_embedder.fit_transform(Ain)
Xhat_oos = oos_embedder.transform(Aoos)
print(Xhat_oos.shape)
(3, 2)
from graphbook_code import heatmap, lpm_heatmap, plot_latents
import os

zin = np.delete(zs, oos_idx) + 1
fig, axs = plt.subplots(1, 3, figsize=(18, 6), gridspec_kw={"width_ratios": [2, 1, 2]})

heatmap(Ain.astype(int), title="", xtitle="In-sample node", ax=axs[0],
        inner_hier_labels=zin, xticks=[0.5, 49.5, 99.5], xticklabels=[1, 50, 100], 
        yticks=[0.5, 49.5, 99.5], yticklabels=[1, 50, 100], cbar=False)
axs[0].set_title("(A) Adjacency matrix", pad=50, loc="left")

lpm_heatmap(Aoos.T.astype(int), title="", xtitle="Out-of-sample node", ytitle="In-sample node",
            xticks=[0.5, 1.5, 2.5], xticklabels=[1, 2, 3], yticks=[0.5, 49.5, 99.5],
            yticklabels=[1, 50, 100], ax=axs[1])
axs[1].set_title("(B) $A'^\\top$", loc="left", pad=25)

plot_latents(Xhat_in, labels=zin, ax=axs[2], s=50, alpha=0.3, title="(C) Estimated latent positions",
            xtitle="Dimension 1", ytitle="Dimension 2")
axs[2].set_title("(C) Estimated latent positions", loc="left", pad=25)
plot_latents(Xhat_oos, ax=axs[2], labels=zs[oos_idx] + 1, s=100)

handles, labels = axs[2].get_legend_handles_labels()
axs[2].legend(handles=handles[:2], labels=labels[:2], title="Community")

for i in range(Xhat_oos.shape[0]):
    axs[2].annotate(f"OOS Node {i+1:d}", xy=(Xhat_oos[i,0], Xhat_oos[i,1]))
fig.tight_layout()

os.makedirs("Figures", exist_ok=True)
fname = "oos_ex"
if mode != "png":
    os.makedirs(f"Figures/{mode:s}", exist_ok=True)
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

os.makedirs("Figures/png", exist_ok=True)
fig.savefig(f"Figures/png/{fname:s}.png")
../../_images/1564d719592d18654851ceaac1c060b8c7a022220bcabb0f4d58c4fd5691d047.png