(ch2:select)=
# 2.4 Select and Train

In [None]:
mode = "svg"  # output format for figs

import matplotlib

font = {'family' : 'Dejavu Sans',
        'weight' : 'normal',
        'size'   : 20}

matplotlib.rc('font', **font)

import matplotlib
from matplotlib import pyplot as plt

In [None]:
import os
import urllib
import boto3
from botocore import UNSIGNED
from botocore.client import Config
from graspologic.utils import import_edgelist
import numpy as np
import glob
from tqdm import tqdm

# the AWS bucket the data is stored in
BUCKET_ROOT = "open-neurodata"
parcellation = "Schaefer400"
FMRI_PREFIX = "m2g/Functional/BNU1-11-12-20-m2g-func/Connectomes/" + parcellation + "_space-MNI152NLin6_res-2x2x2.nii.gz/"
FMRI_PATH = os.path.join("datasets", "fmri")  # the output folder
DS_KEY = "abs_edgelist"  # correlation matrices for the networks to exclude

def fetch_fmri_data(bucket=BUCKET_ROOT, fmri_prefix=FMRI_PREFIX,
                    output=FMRI_PATH, name=DS_KEY):
    """
    A function to fetch fMRI connectomes from AWS S3.
    """
    # check that output directory exists
    if not os.path.isdir(FMRI_PATH):
        os.makedirs(FMRI_PATH)
    # start boto3 session anonymously
    s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
    # obtain the filenames
    bucket_conts = s3.list_objects(Bucket=bucket, 
                    Prefix=fmri_prefix)["Contents"]
    for s3_key in tqdm(bucket_conts):
        # get the filename
        s3_object = s3_key['Key']
        # verify that we are grabbing the right file
        if name not in s3_object:
            op_fname = os.path.join(FMRI_PATH, str(s3_object.split('/')[-1]))
            if not os.path.exists(op_fname):
                s3.download_file(bucket, s3_object, op_fname)

def read_fmri_data(path=FMRI_PATH):
    """
    A function which loads the connectomes as adjacency matrices.
    """
    fnames = glob.glob(os.path.join(path, "*.csv"))
    fnames.sort()
    # import edgelists with graspologic
    # edgelists will be all of the files that end in a csv
    networks = [import_edgelist(fname) for fname in tqdm(fnames)]
    return np.stack(networks, axis=0)

In [None]:
fetch_fmri_data()
As = read_fmri_data()
A = As[0]

In [None]:
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.pipeline import Pipeline
from graspologic.utils import pass_to_ranks

def remove_isolates(A):
    """
    A function which removes isolated nodes from the 
    adjacency matrix A.
    """
    degree = A.sum(axis=0)  # sum along the rows to obtain the node degree
    out_degree = A.sum(axis=1)
    A_purged = A[~(degree == 0),:]
    A_purged = A_purged[:,~(degree == 0)]
    print("Purging {:d} nodes...".format((degree == 0).sum()))
    return A_purged

class CleanData(BaseEstimator, TransformerMixin):

    def fit(self, X):
        return self

    def transform(self, X):
        print("Cleaning data...")
        Acleaned = remove_isolates(X)
        A_abs_cl = np.abs(Acleaned)
        self.A_ = A_abs_cl
        return self.A_

class FeatureScaler(BaseEstimator, TransformerMixin):
    
    def fit(self, X):
        return self
    
    def transform(self, X):
        print("Scaling edge-weights...")
        A_scaled = pass_to_ranks(X)
        return (A_scaled)

num_pipeline = Pipeline([
    ('cleaner', CleanData()),
    ('scaler', FeatureScaler()),
])

A_xfm = num_pipeline.fit_transform(A)

In [None]:
from graspologic.embed import AdjacencySpectralEmbed

embedding = AdjacencySpectralEmbed(n_components=3, svd_seed=0).fit_transform(A_xfm)

In [None]:
from graspologic.plot import pairplot
from matplotlib import pyplot as plt

fig = pairplot(embedding, title="(A) Spectral Embedding for connectome")

fig.tight_layout()

fname = "pairplots0"
if mode != "png":
    os.makedirs(f"Figures/{mode:s}", exist_ok=True)
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

os.makedirs("Figures/png", exist_ok=True)
fig.savefig(f"Figures/png/{fname:s}.png")

In [None]:
from sklearn.cluster import KMeans

labels = KMeans(n_clusters=2, random_state=0).fit_predict(embedding)
fig = pairplot(embedding, labels=labels, legend_name="Predicter Clusters", 
                 title="(B) KMeans clustering")

fig.tight_layout()

fname = "pairplots1"
if mode != "png":
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

fig.savefig(f"Figures/png/{fname:s}.png")

In [None]:
import matplotlib.image as mpimg

fig, axs = plt.subplots(1, 2, figsize=(20, 10), gridspec_kw={"width_ratios": [1, 1.39]})

axs[0].imshow(mpimg.imread('Figures/png/pairplots0.png'))
axs[1].imshow(mpimg.imread('Figures/png/pairplots1.png'))
# turn off x and y axis
[ax.set_axis_off() for ax in axs.ravel()];
fig.tight_layout()

fname = "pairplots"
if mode != "png":
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

fig.savefig(f"Figures/png/{fname:s}.png")

In [None]:
from graspologic.cluster import KMeansCluster

labels = KMeansCluster(max_clusters=10, random_state=0).fit_predict(embedding)

fig = pairplot(embedding, labels=labels, title="(A) KMeans clustering, automatic selection", 
                 legend_name="Predicted Clusters")

fig.tight_layout()

fname = "pairplot_impute0"
if mode != "png":
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

fig.savefig(f"Figures/png/{fname:s}.png")

In [None]:
from graspologic.cluster import AutoGMMCluster

labels = AutoGMMCluster(max_components=10, random_state=0).fit_predict(embedding)
fig = pairplot(embedding, labels=labels, title="(B) AutoGMM Clustering, automatic selection", 
                  legend_name="Predicted Clusters", palette=sns.color_palette("colorblind", 10))

fig.tight_layout()

fname = "pairplot_impute1"
if mode != "png":
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

fig.savefig(f"Figures/png/{fname:s}.png")

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(20, 10), gridspec_kw={"width_ratios": [1, 1]})

axs[0].imshow(mpimg.imread('Figures/png/pairplot_impute0.png'))
axs[1].imshow(mpimg.imread('Figures/png/pairplot_impute1.png'))
# turn off x and y axis
[ax.set_axis_off() for ax in axs.ravel()]
fig.tight_layout()

fname = "pairplot_impute"
if mode != "png":
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

fig.savefig(f"Figures/png/{fname:s}.png")