8.3 Building coherent signal subnetworks

8.3 Building coherent signal subnetworks#

mode = "svg"

import matplotlib

font = {'family' : 'Dejavu Sans',
        'weight' : 'normal',
        'size'   : 20}

matplotlib.rc('font', **font)

import matplotlib
from matplotlib import pyplot as plt
import numpy as np
from graspologic.simulations import sample_edges

nodenames = [
    "SI", "L", "H/E", 
    "T/M", "BS"
]

# generate probability matrices
n = 5  # the number of nodes
P_earthling = 0.3*np.ones((n, n))
signal_subnetwork = np.zeros((n, n), dtype=bool)
signal_subnetwork[1:n, 0] = True
signal_subnetwork[0, 1:n] = True
P_astronaut = np.copy(P_earthling)
P_astronaut[signal_subnetwork] = np.tile(np.linspace(0.4, 0.9, num=4), 2)

# sample the classes of each sample
M = 200  # the number of training and testing samples
pi_astronaut = 0.45
pi_earthling = 0.55
np.random.seed(0)
ytrain = np.random.choice([1,2], p=[pi_earthling, pi_astronaut], size=M)
ytest = np.random.choice([1,2], p=[pi_earthling, pi_astronaut], size=M)

# sample network realizations given the class of each sample
Ps = [P_earthling, P_astronaut]
np.random.seed(0)
Atrain = np.stack([sample_edges(Ps[y-1]) for y in ytrain], axis=2)
Atest = np.stack([sample_edges(Ps[y-1]) for y in ytest], axis=2)
from graspologic.subgraph import SignalSubgraph
K = 8  # the number of signal edges
V = 1  # the number of signal nodes

# the incoherent signal subnetwork estimator
ssn_est_inco = SignalSubgraph()
ssn_est_inco.fit_transform(Atrain, labels=ytrain-1, constraints=K)

# the coherent signal subnetwork estimator
ssn_est_coherent = SignalSubgraph()
ssn_est_coherent.fit_transform(Atrain, labels=ytrain-1, constraints=[K, V])
(array([0, 4, 0, 2, 0, 3, 0, 1]), array([4, 0, 2, 0, 3, 0, 1, 0]))
ssn_coherent = np.zeros((n, n))
ssn_incoherent = np.zeros((n, n))

ssn_incoherent[ssn_est_inco.sigsub_] = 1
ssn_coherent[ssn_est_coherent.sigsub_] = 1
import os
from graphbook_code import heatmap

fig, axs = plt.subplots(1, 3, figsize=(18, 6), gridspec_kw={"width_ratios": [1.25, 1, 1.25]})

heatmap(signal_subnetwork.astype(int), ax = axs[0], title="(A) Signal subnetwork",
        xtitle="Brain area", ytitle="Brain area", xticklabels=nodenames, yticklabels=nodenames,
        legend_title="Part of $\\mathcal{S}$?", shrink=0.5)
heatmap(ssn_incoherent.astype(int), ax = axs[1], title="(B) Incoherent signal subnetwork",
        xtitle="Brain area", ytitle="Brain area", xticklabels=nodenames, yticklabels=nodenames,
        legend_title="Edge importance", cbar=False)
heatmap(ssn_coherent.astype(int), ax = axs[2], title="(C) Coherent signal subnetwork",
        xtitle="Brain area", ytitle="Brain area", shrink=0.5, xticklabels=nodenames, yticklabels=nodenames,
        legend_title="Part of $\\hat{\\mathcal{S}}$?")
fig.tight_layout()

os.makedirs("Figures", exist_ok=True)
fname = "ssn_co"
if mode != "png":
    os.makedirs(f"Figures/{mode:s}", exist_ok=True)
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

os.makedirs("Figures/png", exist_ok=True)
fig.savefig(f"Figures/png/{fname:s}.png")
../../_images/baf3c7b75c6ce76fd2c45f5fd3a7665d5efa1f0e38b41a80430a28a8b875f60a.png
from sklearn.naive_bayes import BernoulliNB

def train_and_eval_coherent_ssn(Atrain, ytrain, Atest, ytest, K, V):
    """
    A function which trains and tests an incoherent signal subnetwork
    classifier with K signal edges and V signal nodes.
    """
    ssn_mod = SignalSubgraph()
    ssn_mod.fit_transform(Atrain, labels=ytrain-1, constraints=[int(K), int(V)]);

    Dtrain = Atrain[ssn_mod.sigsub_[0], ssn_mod.sigsub_[1],:].T
    classifier = BernoulliNB()
    # fit the classifier using the vector of classes for each sample
    classifier.fit(Dtrain, ytrain)

    # compute testing data on the estimated signal subnetwork
    Dtest = Atest[ssn_mod.sigsub_[0], ssn_mod.sigsub_[1],:].T
    yhat_test = classifier.predict(Dtest)
    
    # classifier accuracy is the fraction of predictions that are correct
    return (np.mean(yhat_test == ytest), ssn_mod, classifier)
from sklearn.model_selection import KFold
import pandas as pd
from tqdm import tqdm

kf = KFold(n_splits=20, shuffle=True, random_state=0)
xv_res = []
for l, (train_index, test_index) in tqdm(enumerate(kf.split(range(0, M)))):
    A_train, A_test = Atrain[:,:,train_index], Atrain[:,:,test_index]
    y_train, y_test = ytrain[train_index], ytrain[test_index]
    nl = len(test_index)
    
    for k in np.arange(2, n*(n-1), step=2):
        for v in range(1, n+1):
            try:
                acc_kl, _, _ = train_and_eval_coherent_ssn(A_train, y_train, A_test, y_test, k, v)
                xv_res.append({"Fold": l, "k": k, "nl": nl, "v": v, "Accuracy": acc_kl})
            except:
                xv_res.append({"Fold": l, "k": k, "nl": nl, "v": v, "Accuracy": np.nan})
xv_data = pd.DataFrame(xv_res)

def weighted_avg(group):
    acc = group['Accuracy']
    nl = group['nl']
    return (acc * nl).sum() / nl.sum()

xv_acc = xv_data.groupby(["k", "v"]).apply(weighted_avg).reset_index(name='Accuracy')
# convert the pandas dataframe (long format) to a data matrix (wide format)
df_hm = xv_acc.pivot(index="k", columns="v", values="Accuracy")
0it [00:00, ?it/s]
1it [00:00,  1.17it/s]
2it [00:01,  1.07it/s]
3it [00:02,  1.04it/s]
4it [00:03,  1.07it/s]
5it [00:04,  1.02it/s]
6it [00:05,  1.00s/it]
7it [00:06,  1.01it/s]
8it [00:07,  1.09it/s]
9it [00:08,  1.09it/s]
10it [00:09,  1.09it/s]
11it [00:10,  1.09it/s]
12it [00:11,  1.10it/s]
13it [00:12,  1.13it/s]
14it [00:12,  1.15it/s]
15it [00:13,  1.11it/s]
16it [00:14,  1.11it/s]
17it [00:15,  1.11it/s]
18it [00:16,  1.09it/s]
19it [00:17,  1.06it/s]
20it [00:18,  1.11it/s]
20it [00:18,  1.09it/s]
/tmp/ipykernel_2786/2829969769.py:26: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.
  xv_acc = xv_data.groupby(["k", "v"]).apply(weighted_avg).reset_index(name='Accuracy')
# the coherent signal subnetwork estimator, using the parameters from xv
ssn_est_coherent_xv = SignalSubgraph()
ssn_est_coherent_xv.fit_transform(Atrain, labels=ytrain-1, constraints=[6, 1])

ssn_coherent_xv = np.zeros((n, n))
ssn_coherent_xv[ssn_est_coherent_xv.sigsub_] = 1
from graphbook_code import lpm_heatmap

fig, axs = plt.subplots(1, 2, figsize=(15, 6))

lpm_heatmap(df_hm, legend_title="Accuracy", ax=axs[0], vmin=.75,
            xtitle="Number of signal nodes", ytitle="Number of signal edges",
            xticklabels=list(xv_acc["v"].unique()), yticklabels=list(xv_acc["k"].unique()),
            title="(A) Cross validation results", shrink=0.8)
heatmap(ssn_coherent_xv.astype(int), ax = axs[1], title="(B) Coherent estimate, with XV",
        xtitle="Brain area", ytitle="Brain area", shrink=0.5, xticklabels=nodenames, yticklabels=nodenames,
        legend_title="Part of $\\hat{\\mathcal{S}}$?")

fig.tight_layout()

fname = "ssn_co_acc"
if mode != "png":
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

fig.savefig(f"Figures/png/{fname:s}.png")
../../_images/411edf2c88e7b46d9da24406676b181a465f10ca877d1af3c25c9bfa25d79753.png