2.4 Select and Train

2.4 Select and Train#

mode = "svg"  # output format for figs

import matplotlib

font = {'family' : 'Dejavu Sans',
        'weight' : 'normal',
        'size'   : 20}

matplotlib.rc('font', **font)

import matplotlib
from matplotlib import pyplot as plt
import os
import urllib
import boto3
from botocore import UNSIGNED
from botocore.client import Config
from graspologic.utils import import_edgelist
import numpy as np
import glob
from tqdm import tqdm

# the AWS bucket the data is stored in
BUCKET_ROOT = "open-neurodata"
parcellation = "Schaefer400"
FMRI_PREFIX = "m2g/Functional/BNU1-11-12-20-m2g-func/Connectomes/" + parcellation + "_space-MNI152NLin6_res-2x2x2.nii.gz/"
FMRI_PATH = os.path.join("datasets", "fmri")  # the output folder
DS_KEY = "abs_edgelist"  # correlation matrices for the networks to exclude

def fetch_fmri_data(bucket=BUCKET_ROOT, fmri_prefix=FMRI_PREFIX,
                    output=FMRI_PATH, name=DS_KEY):
    """
    A function to fetch fMRI connectomes from AWS S3.
    """
    # check that output directory exists
    if not os.path.isdir(FMRI_PATH):
        os.makedirs(FMRI_PATH)
    # start boto3 session anonymously
    s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
    # obtain the filenames
    bucket_conts = s3.list_objects(Bucket=bucket, 
                    Prefix=fmri_prefix)["Contents"]
    for s3_key in tqdm(bucket_conts):
        # get the filename
        s3_object = s3_key['Key']
        # verify that we are grabbing the right file
        if name not in s3_object:
            op_fname = os.path.join(FMRI_PATH, str(s3_object.split('/')[-1]))
            if not os.path.exists(op_fname):
                s3.download_file(bucket, s3_object, op_fname)

def read_fmri_data(path=FMRI_PATH):
    """
    A function which loads the connectomes as adjacency matrices.
    """
    fnames = glob.glob(os.path.join(path, "*.csv"))
    fnames.sort()
    # import edgelists with graspologic
    # edgelists will be all of the files that end in a csv
    networks = [import_edgelist(fname) for fname in tqdm(fnames)]
    return np.stack(networks, axis=0)
fetch_fmri_data()
As = read_fmri_data()
A = As[0]
  0%|          | 0/212 [00:00<?, ?it/s]
100%|██████████| 212/212 [00:00<00:00, 198170.82it/s]

  0%|          | 0/106 [00:00<?, ?it/s]
  1%|          | 1/106 [00:00<00:30,  3.39it/s]
  2%|▏         | 2/106 [00:00<00:30,  3.40it/s]
  3%|▎         | 3/106 [00:00<00:30,  3.41it/s]
  4%|▍         | 4/106 [00:01<00:29,  3.41it/s]
  5%|▍         | 5/106 [00:01<00:29,  3.42it/s]
  6%|▌         | 6/106 [00:01<00:29,  3.41it/s]
  7%|▋         | 7/106 [00:02<00:29,  3.41it/s]
  8%|▊         | 8/106 [00:02<00:28,  3.43it/s]
  8%|▊         | 9/106 [00:02<00:28,  3.44it/s]
  9%|▉         | 10/106 [00:02<00:28,  3.40it/s]
 10%|█         | 11/106 [00:03<00:28,  3.35it/s]
 11%|█▏        | 12/106 [00:03<00:28,  3.33it/s]
 12%|█▏        | 13/106 [00:03<00:27,  3.35it/s]
 13%|█▎        | 14/106 [00:04<00:27,  3.38it/s]
 14%|█▍        | 15/106 [00:04<00:26,  3.40it/s]
 15%|█▌        | 16/106 [00:04<00:26,  3.43it/s]
 16%|█▌        | 17/106 [00:05<00:26,  3.41it/s]
 17%|█▋        | 18/106 [00:05<00:25,  3.45it/s]
 18%|█▊        | 19/106 [00:05<00:25,  3.46it/s]
 19%|█▉        | 20/106 [00:05<00:25,  3.44it/s]
 20%|█▉        | 21/106 [00:06<00:24,  3.44it/s]
 21%|██        | 22/106 [00:06<00:24,  3.45it/s]
 22%|██▏       | 23/106 [00:06<00:24,  3.44it/s]
 23%|██▎       | 24/106 [00:07<00:23,  3.42it/s]
 24%|██▎       | 25/106 [00:07<00:23,  3.43it/s]
 25%|██▍       | 26/106 [00:07<00:23,  3.45it/s]
 25%|██▌       | 27/106 [00:07<00:23,  3.42it/s]
 26%|██▋       | 28/106 [00:08<00:22,  3.45it/s]
 27%|██▋       | 29/106 [00:08<00:22,  3.47it/s]
 28%|██▊       | 30/106 [00:08<00:21,  3.47it/s]
 29%|██▉       | 31/106 [00:09<00:21,  3.44it/s]
 30%|███       | 32/106 [00:09<00:21,  3.44it/s]
 31%|███       | 33/106 [00:09<00:21,  3.45it/s]
 32%|███▏      | 34/106 [00:09<00:20,  3.43it/s]
 33%|███▎      | 35/106 [00:10<00:20,  3.44it/s]
 34%|███▍      | 36/106 [00:10<00:20,  3.45it/s]
 35%|███▍      | 37/106 [00:10<00:20,  3.45it/s]
 36%|███▌      | 38/106 [00:11<00:19,  3.43it/s]
 37%|███▋      | 39/106 [00:11<00:19,  3.43it/s]
 38%|███▊      | 40/106 [00:11<00:19,  3.44it/s]
 39%|███▊      | 41/106 [00:11<00:19,  3.41it/s]
 40%|███▉      | 42/106 [00:12<00:18,  3.44it/s]
 41%|████      | 43/106 [00:12<00:18,  3.44it/s]
 42%|████▏     | 44/106 [00:12<00:18,  3.44it/s]
 42%|████▏     | 45/106 [00:13<00:17,  3.43it/s]
 43%|████▎     | 46/106 [00:13<00:17,  3.43it/s]
 44%|████▍     | 47/106 [00:13<00:17,  3.42it/s]
 45%|████▌     | 48/106 [00:14<00:17,  3.40it/s]
 46%|████▌     | 49/106 [00:14<00:16,  3.41it/s]
 47%|████▋     | 50/106 [00:14<00:16,  3.42it/s]
 48%|████▊     | 51/106 [00:14<00:16,  3.43it/s]
 49%|████▉     | 52/106 [00:15<00:15,  3.41it/s]
 50%|█████     | 53/106 [00:15<00:15,  3.43it/s]
 51%|█████     | 54/106 [00:15<00:15,  3.42it/s]
 52%|█████▏    | 55/106 [00:16<00:15,  3.39it/s]
 53%|█████▎    | 56/106 [00:16<00:14,  3.41it/s]
 54%|█████▍    | 57/106 [00:16<00:14,  3.42it/s]
 55%|█████▍    | 58/106 [00:16<00:14,  3.43it/s]
 56%|█████▌    | 59/106 [00:17<00:13,  3.39it/s]
 57%|█████▋    | 60/106 [00:18<00:30,  1.49it/s]
 58%|█████▊    | 61/106 [00:19<00:24,  1.80it/s]
 58%|█████▊    | 62/106 [00:19<00:20,  2.12it/s]
 59%|█████▉    | 63/106 [00:19<00:17,  2.41it/s]
 60%|██████    | 64/106 [00:19<00:15,  2.66it/s]
 61%|██████▏   | 65/106 [00:20<00:14,  2.85it/s]
 62%|██████▏   | 66/106 [00:20<00:13,  3.00it/s]
 63%|██████▎   | 67/106 [00:20<00:12,  3.12it/s]
 64%|██████▍   | 68/106 [00:21<00:11,  3.21it/s]
 65%|██████▌   | 69/106 [00:21<00:11,  3.29it/s]
 66%|██████▌   | 70/106 [00:21<00:10,  3.33it/s]
 67%|██████▋   | 71/106 [00:21<00:10,  3.32it/s]
 68%|██████▊   | 72/106 [00:22<00:10,  3.35it/s]
 69%|██████▉   | 73/106 [00:22<00:09,  3.39it/s]
 70%|██████▉   | 74/106 [00:22<00:09,  3.42it/s]
 71%|███████   | 75/106 [00:23<00:08,  3.45it/s]
 72%|███████▏  | 76/106 [00:23<00:08,  3.47it/s]
 73%|███████▎  | 77/106 [00:23<00:08,  3.49it/s]
 74%|███████▎  | 78/106 [00:23<00:07,  3.50it/s]
 75%|███████▍  | 79/106 [00:24<00:07,  3.46it/s]
 75%|███████▌  | 80/106 [00:24<00:07,  3.46it/s]
 76%|███████▋  | 81/106 [00:24<00:07,  3.46it/s]
 77%|███████▋  | 82/106 [00:25<00:06,  3.47it/s]
 78%|███████▊  | 83/106 [00:25<00:06,  3.45it/s]
 79%|███████▉  | 84/106 [00:25<00:06,  3.45it/s]
 80%|████████  | 85/106 [00:26<00:06,  3.46it/s]
 81%|████████  | 86/106 [00:26<00:05,  3.42it/s]
 82%|████████▏ | 87/106 [00:26<00:05,  3.44it/s]
 83%|████████▎ | 88/106 [00:26<00:05,  3.46it/s]
 84%|████████▍ | 89/106 [00:27<00:04,  3.48it/s]
 85%|████████▍ | 90/106 [00:27<00:04,  3.47it/s]
 86%|████████▌ | 91/106 [00:27<00:04,  3.44it/s]
 87%|████████▋ | 92/106 [00:28<00:04,  3.46it/s]
 88%|████████▊ | 93/106 [00:28<00:03,  3.43it/s]
 89%|████████▊ | 94/106 [00:28<00:03,  3.46it/s]
 90%|████████▉ | 95/106 [00:28<00:03,  3.48it/s]
 91%|█████████ | 96/106 [00:29<00:02,  3.50it/s]
 92%|█████████▏| 97/106 [00:29<00:02,  3.47it/s]
 92%|█████████▏| 98/106 [00:29<00:02,  3.48it/s]
 93%|█████████▎| 99/106 [00:30<00:02,  3.49it/s]
 94%|█████████▍| 100/106 [00:30<00:01,  3.47it/s]
 95%|█████████▌| 101/106 [00:30<00:01,  3.50it/s]
 96%|█████████▌| 102/106 [00:30<00:01,  3.51it/s]
 97%|█████████▋| 103/106 [00:31<00:00,  3.51it/s]
 98%|█████████▊| 104/106 [00:31<00:00,  3.51it/s]
 99%|█████████▉| 105/106 [00:31<00:00,  3.48it/s]
100%|██████████| 106/106 [00:32<00:00,  3.49it/s]
100%|██████████| 106/106 [00:32<00:00,  3.31it/s]

from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.pipeline import Pipeline
from graspologic.utils import pass_to_ranks

def remove_isolates(A):
    """
    A function which removes isolated nodes from the 
    adjacency matrix A.
    """
    degree = A.sum(axis=0)  # sum along the rows to obtain the node degree
    out_degree = A.sum(axis=1)
    A_purged = A[~(degree == 0),:]
    A_purged = A_purged[:,~(degree == 0)]
    print("Purging {:d} nodes...".format((degree == 0).sum()))
    return A_purged

class CleanData(BaseEstimator, TransformerMixin):

    def fit(self, X):
        return self

    def transform(self, X):
        print("Cleaning data...")
        Acleaned = remove_isolates(X)
        A_abs_cl = np.abs(Acleaned)
        self.A_ = A_abs_cl
        return self.A_

class FeatureScaler(BaseEstimator, TransformerMixin):
    
    def fit(self, X):
        return self
    
    def transform(self, X):
        print("Scaling edge-weights...")
        A_scaled = pass_to_ranks(X)
        return (A_scaled)

num_pipeline = Pipeline([
    ('cleaner', CleanData()),
    ('scaler', FeatureScaler()),
])

A_xfm = num_pipeline.fit_transform(A)
Cleaning data...
Purging 0 nodes...
Scaling edge-weights...
from graspologic.embed import AdjacencySpectralEmbed

embedding = AdjacencySpectralEmbed(n_components=3, svd_seed=0).fit_transform(A_xfm)
from graspologic.plot import pairplot
from matplotlib import pyplot as plt

fig = pairplot(embedding, title="(A) Spectral Embedding for connectome")

fig.tight_layout()

fname = "pairplots0"
if mode != "png":
    os.makedirs(f"Figures/{mode:s}", exist_ok=True)
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

os.makedirs("Figures/png", exist_ok=True)
fig.savefig(f"Figures/png/{fname:s}.png")
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/seaborn/axisgrid.py:1513: UserWarning: Ignoring `palette` because no `hue` variable has been assigned.
  func(x=vector, **plot_kwargs)
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/seaborn/axisgrid.py:1513: UserWarning: Ignoring `palette` because no `hue` variable has been assigned.
  func(x=vector, **plot_kwargs)
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/seaborn/axisgrid.py:1513: UserWarning: Ignoring `palette` because no `hue` variable has been assigned.
  func(x=vector, **plot_kwargs)
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/seaborn/axisgrid.py:1615: UserWarning: Ignoring `palette` because no `hue` variable has been assigned.
  func(x=x, y=y, **kwargs)
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/seaborn/axisgrid.py:1615: UserWarning: Ignoring `palette` because no `hue` variable has been assigned.
  func(x=x, y=y, **kwargs)
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/seaborn/axisgrid.py:1615: UserWarning: Ignoring `palette` because no `hue` variable has been assigned.
  func(x=x, y=y, **kwargs)
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/seaborn/axisgrid.py:1615: UserWarning: Ignoring `palette` because no `hue` variable has been assigned.
  func(x=x, y=y, **kwargs)
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/seaborn/axisgrid.py:1615: UserWarning: Ignoring `palette` because no `hue` variable has been assigned.
  func(x=x, y=y, **kwargs)
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/seaborn/axisgrid.py:1615: UserWarning: Ignoring `palette` because no `hue` variable has been assigned.
  func(x=x, y=y, **kwargs)
../../_images/20e83e8ba59d594ca6eb00441f03395de966957ba907326b09238c88f0921662.png
from sklearn.cluster import KMeans

labels = KMeans(n_clusters=2, random_state=0).fit_predict(embedding)
fig = pairplot(embedding, labels=labels, legend_name="Predicter Clusters", 
                 title="(B) KMeans clustering")

fig.tight_layout()

fname = "pairplots1"
if mode != "png":
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

fig.savefig(f"Figures/png/{fname:s}.png")
../../_images/3a74c211737f8be993a10bb04c0dd15cb8668f9c4b55144a62cd891a6c0a7b7d.png
import matplotlib.image as mpimg

fig, axs = plt.subplots(1, 2, figsize=(20, 10), gridspec_kw={"width_ratios": [1, 1.39]})

axs[0].imshow(mpimg.imread('Figures/png/pairplots0.png'))
axs[1].imshow(mpimg.imread('Figures/png/pairplots1.png'))
# turn off x and y axis
[ax.set_axis_off() for ax in axs.ravel()];
fig.tight_layout()

fname = "pairplots"
if mode != "png":
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

fig.savefig(f"Figures/png/{fname:s}.png")
../../_images/af804ed872dbe9cff17b99c3efb602e9170ddb1b3286711b1849b6ec5858599f.png
from graspologic.cluster import KMeansCluster

labels = KMeansCluster(max_clusters=10, random_state=0).fit_predict(embedding)
fig = pairplot(embedding, labels=labels, title="(A) KMeans clustering, automatic selection", 
                 legend_name="Predicted Clusters")

fig.tight_layout()

fname = "pairplot_impute0"
if mode != "png":
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

fig.savefig(f"Figures/png/{fname:s}.png")
../../_images/42b01bc3175a1916686d965633e0a37386aebd4e208f01930ca9270e174ef488.png
from graspologic.cluster import AutoGMMCluster

labels = AutoGMMCluster(max_components=10, random_state=0).fit_predict(embedding)
fig = pairplot(embedding, labels=labels, title="(B) AutoGMM Clustering, automatic selection", 
                  legend_name="Predicted Clusters")

fig.tight_layout()

fname = "pairplot_impute1"
if mode != "png":
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

fig.savefig(f"Figures/png/{fname:s}.png")
../../_images/41ff84cc52160404826eb4ef6f1f8a4ff0a5d5552d1afaf622496b57a4512dd9.png
fig, axs = plt.subplots(1, 2, figsize=(20, 10), gridspec_kw={"width_ratios": [1, 1]})

axs[0].imshow(mpimg.imread('Figures/png/pairplot_impute0.png'))
axs[1].imshow(mpimg.imread('Figures/png/pairplot_impute1.png'))
# turn off x and y axis
[ax.set_axis_off() for ax in axs.ravel()]
fig.tight_layout()

fname = "pairplot_impute"
if mode != "png":
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

fig.savefig(f"Figures/png/{fname:s}.png")
../../_images/ebfc340dccf0221af4da1eeaebc2b69b33983d636f166bd7f4ac85d3a8aba119.png