9.1 Graph neural networks

9.1 Graph neural networks#

mode = "svg"

import matplotlib

font = {'family' : 'Dejavu Sans',
        'weight' : 'normal',
        'size'   : 20}

matplotlib.rc('font', **font)

import matplotlib
from matplotlib import pyplot as plt
from torch_geometric.datasets import MoleculeNet

dataset = MoleculeNet(root='data/clintox', name='ClinTox')
print(f'Dataset: {dataset}\nNumber of molecules/graphs: {len(dataset)}\nNumber of classes: {dataset.num_classes}')
Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz
Extracting data/clintox/clintox/raw/clintox.csv.gz
Processing...
/opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/torch_geometric/datasets/molecule_net.py:213: UserWarning: Skipping molecule '[NH4][Pt]([NH4])(Cl)Cl' since it resulted in zero atoms
  warnings.warn(f"Skipping molecule '{smiles}' since it "
/opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/torch_geometric/datasets/molecule_net.py:213: UserWarning: Skipping molecule 'c1ccc(cc1)n2c(=O)c(c(=O)n2c3ccccc3)CCS(=O)c4ccccc4' since it resulted in zero atoms
  warnings.warn(f"Skipping molecule '{smiles}' since it "
/opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/torch_geometric/datasets/molecule_net.py:213: UserWarning: Skipping molecule 'CCCCc1c(=O)n(n(c1=O)c2ccc(cc2)O)c3ccccc3' since it resulted in zero atoms
  warnings.warn(f"Skipping molecule '{smiles}' since it "
/opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/torch_geometric/datasets/molecule_net.py:213: UserWarning: Skipping molecule 'CCCCc1c(=O)n(n(c1=O)c2ccccc2)c3ccccc3' since it resulted in zero atoms
  warnings.warn(f"Skipping molecule '{smiles}' since it "
Dataset: ClinTox(1480)
Number of molecules/graphs: 1480
Number of classes: 2
Done!
mols = dataset[26], dataset[83]
for m in mols:
    print(m.smiles)
C([C@@H]1[C@H]([C@@H]([C@H](C(=O)O1)O)O)O)O
C1[C@@H]([C@H](O[C@H]1N2C=NC(=NC2=O)N)CO)O
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG

smiles = [Chem.MolFromSmiles(m.smiles) for m in mols]
d2d = rdMolDraw2D.MolDraw2DSVG(600,280,300,280)
d2d.drawOptions().addAtomIndices = True
d2d.DrawMolecules(smiles)
d2d.FinishDrawing()
SVG(d2d.GetDrawingText())
../../_images/4f0a350b89d53647caec3f945679c6777d2125a43852d6d1fd485afa746f2428.svg
for i,m in enumerate(mols):
    print(f'Molecule {i+1}: Number of atoms={m.x.shape[0]}, Features per atom={m.x.shape[1]}')
Molecule 1: Number of atoms=12, Features per atom=9
Molecule 2: Number of atoms=16, Features per atom=9
d2d = rdMolDraw2D.MolDraw2DSVG(600,280,300,280)
d2d.drawOptions().addBondIndices = True
d2d.DrawMolecules(smiles)
d2d.FinishDrawing()
SVG(d2d.GetDrawingText())
../../_images/4510e8f9952c81028233272af78e1dc6cb6a4ad5bbeba875cfa7e266106911fa.svg
import numpy as np

_process = lambda x: [e[0] for e in np.split(x, 2)]
def adj_from_edgelist(molecule):
    """
    A function that takes a molecule edgelist and produces an adjacency matrix.
    """
    # the number of nodes is the number of atoms (rows of .x attribute)
    n = molecule.x.shape[0]
    # the adjacency matrix is n x n
    A = np.zeros((n, n))
    edgelist = m.edge_index.numpy()
    # loop over the edges e_k, and for each edge, unpack the 
    # nodes that are incident it. for this pair of nodes, 
    # change the adjacency matrix entry to 1
    for e_k, (i, j) in enumerate(zip(*_process(edgelist))):
        A[i, j] = 1
    return A
from graphbook_code import heatmap

for m_i, m in enumerate(mols):
    A = adj_from_edgelist(m)
    heatmap(A)
../../_images/78d2dc83fcd59cc5557bb59027ce5da87b8b628930aa345dc52f14b5147c8c2e.png ../../_images/ab6562071c9b700a5a34843f7d6a7602fa70c53a07f53fdbe6536aa2a542b6cd.png
from PIL import Image
import os

fig, axs = plt.subplots(2, 2, figsize=(13, 15), gridspec_kw={"width_ratios": [1, 1.2]})

plot_titles = [[f"({x:s}.{y:s}" for y in [f"I) Molecular structure, {z:s}", f"II) Adjacency matrix, {z:s}"]] for x, z in zip(["A", "B"], ["Molecule 1", "Molecule 2"])]
title_left = ["Molecule 1", "Molecule 2"]
for i, m in enumerate(mols):
    d2d = rdMolDraw2D.MolDraw2DCairo(300,300)
    options = d2d.drawOptions()
    options.addAtomIndices = True; options.minFontSize = 14; options.annotationFontScale = 0.8
    
    d2d.DrawMolecule(Chem.MolFromSmiles(m.smiles))
    d2d.FinishDrawing()
    png_data = d2d.GetDrawingText()
    
    # save png to file
    png_fname = f'mol{i:d}.png'
    with open(png_fname, 'wb') as png_file:
        png_file.write(png_data)
    axs[i][0].imshow(Image.open(png_fname))
    axs[i][0].set_title(plot_titles[i][0], fontsize=18)
    axs[i][0].axis("off")

    A = adj_from_edgelist(m)
    tick_range = range(0, np.ceil(A.shape[0] / 2).astype(int))
    xticks = yticks = [2 * i + 0.5 for i in tick_range]
    xticklabels = yticklabels = [f"{2 * i}" for i in tick_range]
    heatmap(A.astype(int), ax=axs[i][1], xticks=xticks, xticklabels=xticklabels,
            yticks=yticks, yticklabels=yticklabels, shrink=0.6,
            title=plot_titles[i][1], xtitle="Atom number", ytitle="Atom number")

fig.tight_layout()

fname = "molecule_ex"
os.makedirs("Figures", exist_ok=True)
fname = "basic_mtxs"
if mode != "png":
    os.makedirs(f"Figures/{mode:s}", exist_ok=True)
    fig.savefig(f"Figures/{mode:s}/{fname:s}.{mode:s}")

os.makedirs("Figures/png", exist_ok=True)
fig.savefig(f"Figures/png/{fname:s}.png")
../../_images/90b6ef331271195722b7d2a223ce78bbff7d3c65dc43bb3c5c2dd6fa968a52ee.png
import torch
# for notebook reproducibility
torch.manual_seed(12345)

dataset = dataset.shuffle()

train_dataset = dataset[:1216]
test_dataset = dataset[1216:]

print(f'Number of training networks: {len(train_dataset)}')
print(f'Number of test networks: {len(test_dataset)}')
Number of training networks: 1216
Number of test networks: 264
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print(f'Number of networks in the current batch: {data.num_graphs}')
    print(data)
    break
Step 1:
Number of networks in the current batch: 64
DataBatch(x=[1414, 9], edge_index=[2, 3014], edge_attr=[3014, 3], smiles=[64], y=[64, 2], batch=[1414], ptr=[65])
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

torch.manual_seed(12345)
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes, bias=False)

    def forward(self, x, edge_index, batch):

        # 1. Obtain node embeddings via convolutional layers
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer to produce network embedding
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a prediction classifier to the network embedding
        x = self.lin(x)

        return x

model = GCN(hidden_channels=64)
print(model)
GCN(
  (conv1): GCNConv(9, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=False)
)
model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()
    for data in train_loader:  # Iterate in batches over the training dataset.
        out = model(data.x.float(), data.edge_index, data.batch)  # Perform a single forward pass.
        # Handle a pyg bug where last element in batch may be all zeros and excluded in the model output.
        # https://github.com/pyg-team/pytorch_geometric/issues/1813
        num_batch = out.shape[0]
        loss = criterion(out, data.y[:num_batch, 0].long())  # Compute the loss.
        
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def test(loader):
        model.eval()
        correct = 0
        for data in loader:  # Iterate in batches over the training/test dataset.
            out = model(data.x.float(), data.edge_index, data.batch)  
            pred = out.argmax(dim=1)  # Use the class with highest probability.
            num_batch = pred.shape[0]
            correct += int((pred == data.y[:num_batch, 0]).sum())  # Check against ground-truth labels.
        return correct / len(loader.dataset)  # Derive ratio of correct predictions.

R = 10  # number of epochs
for epoch in range(0, R):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
Epoch: 000, Train Acc: 0.9342, Test Acc: 0.9470
Epoch: 001, Train Acc: 0.9342, Test Acc: 0.9470
Epoch: 002, Train Acc: 0.9342, Test Acc: 0.9470
Epoch: 003, Train Acc: 0.9342, Test Acc: 0.9470
Epoch: 004, Train Acc: 0.9342, Test Acc: 0.9470
Epoch: 005, Train Acc: 0.9342, Test Acc: 0.9470
Epoch: 006, Train Acc: 0.9342, Test Acc: 0.9470
Epoch: 007, Train Acc: 0.9342, Test Acc: 0.9470
Epoch: 008, Train Acc: 0.9342, Test Acc: 0.9470
Epoch: 009, Train Acc: 0.9342, Test Acc: 0.9470