$$ % Define your custom commands here \newcommand{\bmat}[1]{\begin{bmatrix}#1\end{bmatrix}} \newcommand{\E}{\mathbb{E}} \newcommand{\P}{\mathbb{P}} \newcommand{\S}{\mathbb{S}} \newcommand{\R}{\mathbb{R}} \newcommand{\S}{\mathbb{S}} \newcommand{\norm}[2]{\|{#1}\|_{{}_{#2}}} \newcommand{\pd}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\pdd}[2]{\frac{\partial^2 #1}{\partial #2^2}} \newcommand{\vectornorm}[1]{\left|\left|#1\right|\right|} \newcommand{\abs}[1]{\left|{#1}\right|} \newcommand{\mbf}[1]{\mathbf{#1}} \newcommand{\mc}[1]{\mathcal{#1}} \newcommand{\bm}[1]{\boldsymbol{#1}} \newcommand{\nicefrac}[2]{{}^{#1}\!/_{\!#2}} \newcommand{\argmin}{\operatorname*{arg\,min}} \newcommand{\argmax}{\operatorname*{arg\,max}} \newcommand{\dd}{\operatorname{d}\!} $$

Graph Neural Networks

Graphs present three primary challenges for deep learning:

Image credits: Understanding Deep Learning by Simon J. D. Prince, [CC BY 4.0]

Image credits: Understanding Deep Learning by Simon J. D. Prince, [CC BY 4.0]

Image credits: Understanding Deep Learning by Simon J. D. Prince, [CC BY 4.0]

Image credits: Understanding Deep Learning by Simon J. D. Prince, [CC BY 4.0]

Image credits: Understanding Deep Learning by Simon J. D. Prince, [CC BY 4.0]

What is a Graph?

A graph is a versatile mathematical structure consisting of nodes (vertices) and edges (links). While a graph theoretically permits connections between any pair of nodes, real-world graphs are typically sparse, meaning only a small fraction of possible edges are present.

Naturally Occurring Graphs

Many systems are inherently graphical in nature:

  • Road Networks: Nodes represent physical intersections or locations, while edges denote the roads connecting them (see Figure 1 a).
  • Molecules: Atoms act as nodes, and chemical bonds form the edges (see Figure 1 b).
  • Circuits: Components and junctions are represented as nodes, with electrical connections serving as edges (see Figure 1 c).

Data Represented as Graphs

Graphs are also powerful abstractions for diverse datasets:

  • Social Networks: People (nodes) connected by friendships or interactions (edges).
  • Scientific Literature: Papers (nodes) linked by citations (edges).
  • Knowledge Bases: Articles or concepts (nodes) connected by hyperlinks or semantic relationships (edges).
  • Software Systems: Syntax tokens or functions (nodes) linked by control flow or data dependencies (edges).
  • Point Clouds: 3D points (nodes) connected based on spatial proximity (edges).
  • Biological Interactions: Proteins (nodes) and their biochemical interactions (edges).

Furthermore, even simple structures can be viewed through a graph-theoretic lens. A set (unordered list) can be treated as a graph where every member connects to every other. Similarly, an image can be viewed as a graph with a regular topology, where pixels are nodes connected to their adjacent neighbors.

Types of Graphs

Graphs can be categorized in several ways based on the nature of their nodes and edges:

  • Undirected Graphs: Connections are symmetric; each pair of connected individuals or objects has a mutual relationship with no inherent directionality (see Figure 2 a).
  • Directed Graphs: Relationships are one-way; for example, in a citation network, one publication cites another, making the link asymmetric (see Figure 2 b).
  • Knowledge Graphs: These are often directed heterogeneous multigraphs (see Figure 2 c).
    • Heterogeneous: Nodes can represent different types of entities (e.g., people, countries, companies).
    • Multigraph: Multiple edges representing different relation types can exist between the same two nodes.
  • Geometric Graphs: These are formed by connecting points in space based on proximity (e.g., a 3D point cloud of an airplane where each point is connected to its \(K\)-nearest neighbors, as seen in Figure 2 d).
  • Hierarchical Graphs: These represent nested topologies where nodes themselves represent sub-graphs (e.g., representing a scene where a room, table, and light are each graphs that form nodes in a larger adjacency model, see Figure 2 e).

While deep learning can be applied to all these structures, this chapter focuses primarily on undirected graphs.

In addition to the graph structure itself, information is typically associated with each node and edge:

  • Node Embeddings: In a social network, each person (node) might be represented by a fixed-length vector representing their interests.
  • Edge Embeddings: In a road network, each connection (edge) might be characterized by its length, speed limit, and historical accident frequency.

Formally, a graph consists of a set of \(N\) nodes connected by a set of \(E\) edges. This can be encoded by three matrices: \(\mathbf{A}\) (structure/adjacency), \(\mathbf{X}\) (concatenated node embeddings), and \(\mathbf{E}\) (edge embeddings) (see Figure 3).

Graph Matrices

The graph structure is defined by the adjacency matrix \(\mathbf{A}\), an \(N \times N\) binary matrix where entry \((m,n)\) is set to one if an edge exists between nodes \(m\) and \(n\), and zero otherwise. For undirected graphs, this matrix is symmetric. For large, sparse graphs, it is often more memory-efficient to store \(\mathbf{A}\) as a list of connections \((m, n)\).

Each node has an associated \(D\)-dimensional node data matrix \(\mathbf{X}\) which stores embeddings \(\mathbf{x}^{(n)}\) of length \(D\). Similarly, edge embeddings \(\mathbf{e}^{(e)}\) of length \(D_E\) are collected into the \(D_E \times E\) matrix \(\mathbf{E}\). Initially, we focus on node embeddings and will return to edge embeddings later in sectoin 13.9.

Properties of the Adjacency Matrix

The adjacency matrix can be used to analyze graph neighbors using linear algebra. Consider encoding the \(n^{th}\) node’s position as a one-hot column vector (where all entries are zero except at position \(n\)). Pre-multiplying this vector by the adjacency matrix \(\mathbf{A}\) extracts its \(n^{th}\) column, resulting in a vector with ones at the neighbor positions—corresponding to a walk of length one from the starting node.

Repeating this operation (i.e., pre-multiplying by \(\mathbf{A}\) again) results in a vector containing the number of walks of length two from node \(n\) to every other node (see Figure 4 d-f).

In general, raising the adjacency matrix to the power of \(L\) yields the number of unique walks of length \(L\) from node \(m\) to node \(n\), stored in the \((m, n)\) entry of \(\mathbf{A}^L\) (see Figure 4 a-c). Unlike paths, walks can revisit nodes. However, \(\mathbf{A}^L\) still provides critical connectivity data: a non-zero entry at \((m, n)\) indicates that the distance from \(m\) to \(n\) is at most \(L\).

Permutation of node indices

Node indexing in graphs is arbitrary; permuting the node indices results in a permutation of the columns of the node data matrix \(\mathbf{X}\) and a permutation of both the rows and columns of the adjacency matrix \(\mathbf{A}\). However, notice that the underlying graph topology remains unchanged (see Figure 5). This is in distinct contrast to images, where permuting pixels results in a completely different image, or text, where permuting words alters the sentence’s meaning.

The operation of exchanging node indices can be expressed mathematically via a permutation matrix \(\mathbf{P}\). This is a matrix where exactly one entry in each row and column is one, and all others are zero. If \(P_{mn}=1\), node \(m\) maps to node \(n\) after permutation. We can express the transformation between different indexings as:

\[ \begin{aligned} \mathbf{X}' &= \mathbf{X}\mathbf{P} \\ \mathbf{A}' &= \mathbf{P}^T\mathbf{A}\mathbf{P} \end{aligned} \tag{1}\]

Crucially, any processing or neural architecture applied to a graph must be indifferent to these choice-of-index permutations. Otherwise, the model’s output would depend on the arbitrary ordering of nodes.

Supervised graph problems generally fall into one of three primary categories based on the scale of the prediction (see Figure 6).

Graph-level tasks

In graph-level tasks, the network assigns a label or estimates a continuous value from the entire graph, leveraging both its structure and node attributes. Examples include predicting the boiling point of a molecule (regression) or determining its toxicity to humans (classification).

For these tasks, the output node embeddings are combined—most commonly via mean pooling—where all node embeddings \(\mathbf{h}_K^{(n)}\) are averaged. This resulting vector is then mapped through a linear layer or a small neural network to a fixed-size representation. For binary classification, the probability that a graph belongs to a certain class might be given by:

\[ Pr(y = 1|\mathbf{X}, \mathbf{A}) = \text{sig}\left[\beta_K + \boldsymbol{\omega}_K \mathbf{H}_K \mathbf{1}/N\right] \tag{2}\]

Here, \(\mathbf{H}_K\) is the matrix of output embeddings, and post-multiplying by the column vector \(\mathbf{1}\) (containing ones) and dividing by \(N\) (number of nodes) computes the average.

Node-level tasks

In node-level tasks, the network assigns a label or value to each individual node. For instance, in a 3D point cloud of an airplane, the goal might be to classify each point as belonging to the “wings” or the “fuselage”. Predictions are made independently for each node \(n\) based on its final-layer embedding:

\[ Pr(y^{(n)} = 1|\mathbf{X}, \mathbf{A}) = \text{sig}\left[\beta_K + \boldsymbol{\omega}_K \mathbf{h}_K^{(n)}\right] \tag{3}\]

Edge prediction tasks

Edge prediction involves estimating the likelihood of a connection between any two nodes \(n\) and \(m\). In a social network context, this might be used to suggest new friendships. This is essentially a binary classification problem where a pair of node embeddings is mapped to a single probability. One common approach is to take the dot product of the embeddings:

\[ Pr(y^{(mn)} = 1|\mathbf{X}, \mathbf{A}) = \text{sig}\left[\mathbf{h}^{(m)T} \mathbf{h}^{(n)}\right] \tag{4}\]

While there are many GNN variants, we focus on spatial-based convolutional graph neural networks, or GCNs. These are “convolutional” in the sense that they update each node by aggregating information from its spatial neighbors, inducing a relational inductive bias that prioritizes local structure. This contrasts with spectral-based methods, which operate in the graph’s Fourier domain.

A GCN layer is modeled as a function \(\mathbf{F}[\bullet]\) with parameters \(\phi\) that maps node embeddings from one stage to the next:

\[ \begin{aligned} \mathbf{H}_1 &= \mathbf{F}[\mathbf{X}, \mathbf{A}, \phi_0] \\ \mathbf{H}_2 &= \mathbf{F}[\mathbf{H}_1, \mathbf{A}, \phi_1] \\ &\vdots \\ \mathbf{H}_K &= \mathbf{F}[\mathbf{H}_{K-1}, \mathbf{A}, \phi_{K-1}] \end{aligned} \tag{5}\]

where \(\mathbf{X}\) is the original input and \(\mathbf{H}_k\) represents the embeddings at layer \(k\).

Equivariance and invariance

Because node indexing is arbitrary, it is imperative that each GCN layer be equivariant to permutations. If a permutation matrix \(\mathbf{P}\) is applied to the nodes, the output embeddings must be permuted in the exact same way:

\[ \mathbf{H}_{k+1}\mathbf{P} = \mathbf{F}[\mathbf{H}_k \mathbf{P}, \mathbf{P}^T \mathbf{A} \mathbf{P}, \phi_k] \tag{6}\]

Individual node and edge predictions inherit this equivariance. However, for graph-level tasks, the final output must be invariant to node order, as the global property of the graph does not depend on how nodes are numbered.

Parameter sharing

To handle variable graph topologies and save memory, GCNs employ parameter sharing. Instead of learning separate weights for every node position (which would require fixed graph sizes), the same shared parameters are used across all nodes. This forces the model to treat every part of the graph identically, drawing a direct parallel to how standard CNNs share filters across image pixels.

In a GCN, neighbors send “messages” that are aggregated at each node. While images have a fixed number of neighbors (e.g., a \(3 \times 3\) grid), graph nodes have variable degrees, meaning the aggregation operation must be able to handle any number of inputs.

Example GCN layer

A simple GCN layer (see Figure 7) aggregates information from neighboring nodes by summing their embeddings:

\[ \text{agg}[n, k] = \sum_{m \in \text{ne}[n]} \mathbf{h}_k^{(m)} \tag{7}\]

where \(\text{ne}[n]\) is the set of neighbors for node \(n\). We then apply a linear transformation \(\mathbf{\Omega}_k\) to both the self-node and the aggregate, add a bias \(\beta_k\), and pass the result through a non-linearity \(\mathbf{a}[\cdot]\) (like ReLU):

\[ \mathbf{h}_{k+1}^{(n)} = \mathbf{a}\left[\beta_k + \mathbf{\Omega}_k \mathbf{h}_k^{(n)} + \mathbf{\Omega}_k \text{agg}[n, k]\right] \tag{8}\]

We can express this more succinctly in matrix form by noting that \(\mathbf{H}_k \mathbf{A}\) returns the sum of neighbor embeddings for all nodes. The complete layer update becomes:

\[ \begin{aligned} \mathbf{H}_{k+1} &= \mathbf{a}\left[\beta_k \mathbf{1}^T + \mathbf{\Omega}_k \mathbf{H}_k + \mathbf{\Omega}_k \mathbf{H}_k \mathbf{A} \right] \\ &= \mathbf{a}\left[\beta_k \mathbf{1}^T + \mathbf{\Omega}_k \mathbf{H}_k (\mathbf{A} + \mathbf{I})\right] \end{aligned} \tag{9}\]

where \(\mathbf{1}\) is a vector of ones and \(\mathbf{I}\) is the identity matrix. This simple architecture satisfies all design requirements: it is equivariant, handles variable neighbors, exploits graph structure, and shares parameters.

Example: graph classification

We now combine these ideas into a network that classifies molecules as toxic or harmless. The network inputs are the molecular adjacency matrix \(\mathbf{A}\) and a node embedding matrix \(\mathbf{X}\) consisting of one-hot vectors representing the 118 elements of the periodic table. These initial embeddings are transformed into a hidden dimension \(D\) by a learnable weight matrix \(\mathbf{\Omega}_0\).

The network architecture consists of a stack of GCN layers followed by mean pooling and a sigmoid output:

\[ \begin{aligned} \mathbf{H}_1 &= \mathbf{a}\left[\beta_0 \mathbf{1}^T + \mathbf{\Omega}_0 \mathbf{X}(\mathbf{A} + \mathbf{I})\right] \\ \mathbf{H}_2 &= \mathbf{a}\left[\beta_1 \mathbf{1}^T + \mathbf{\Omega}_1 \mathbf{H}_1(\mathbf{A} + \mathbf{I})\right] \\ &\vdots \\ \mathbf{H}_K &= \mathbf{a}\left[\beta_{K-1} \mathbf{1}^T + \mathbf{\Omega}_{K-1} \mathbf{H}_{K-1}(\mathbf{A} + \mathbf{I})\right] \\ f[\mathbf{X}, \mathbf{A}, \mathbf{\Phi}] &= \text{sig}\left[\beta_K + \mathbf{\omega}_K \mathbf{H}_K \mathbf{1}/N\right] \end{aligned} \tag{10}\]

The final output \(f[\mathbf{X}, \mathbf{A}, \mathbf{\Phi}]\) expresses the probability that the molecule is toxic.

Training with batches

Deep learning models typically leverage the parallelism of modern hardware by processing batches of training examples \(\{ \mathbf{X}_i, \mathbf{A}_i \}\) concurrently. However, since each graph may have a different number of nodes, their adjacency and node matrices differ in size, preventing standard concatenation into fixed-size 3D tensors.

A common “trick” to enable batching is to treat all \(I\) graphs in a batch as a single, large, disjoint graph. The resulting block-diagonal adjacency matrix encodes the full structure, allowing the network equations to be run as a single instance. Mean pooling is then performed independently over each graph component to generate individual representations for the loss function.

Inductive vs. transductive models

Until this point, all of the models in this book have been inductive: we exploit a training set of labeled data to learn the relation between the inputs and outputs. Then we apply this to new test data. One way to think of this is that we are learning the rule that maps inputs to outputs and then applying it elsewhere.

By contrast, a transductive model considers both the labeled and unlabeled data at the same time. It does not produce a rule but merely a labeling for the unknown outputs. This is sometimes termed semi-supervised learning. It has the advantage that it can use patterns in the unlabeled data to help make its decisions. However, it has the disadvantage that the model needs to be retrained when extra unlabeled data are added.

Both problem types are commonly encountered for graphs (see Figure 8). Sometimes, we have many labeled graphs and learn a mapping between the graph and the labels. For example, we might have many molecules, each labeled according to whether it is toxic to humans. We learn the rule that maps the graph to the toxic/non-toxic label and then apply this rule to new molecules. However, sometimes there is a single monolithic graph. In the graph of scientific paper citations, we might have labels indicating the field (physics, biology, etc.) for some nodes and wish to label the remaining nodes. Here, the training and test data are irrevocably connected.

Graph-level tasks only occur in the inductive setting where there are training and test graphs. However, node-level tasks and edge prediction tasks can occur in either setting. In the transductive case, the loss function minimizes the mismatch between the model output and the ground truth where this is known. New predictions are computed by running the forward pass and retrieving the results where the ground truth is unknown.

Example: node classification

As a second example, consider a binary node classification task in a transductive setting. We start with a commercial-sized graph with millions of nodes. Some nodes have ground truth binary labels, and the goal is to label the remaining unlabeled nodes. The body of the network will be the same as in the previous example (equation Equation 9) but with a different final layer that produces an output vector of size \(1 \times N\):

\[ \mathbf{f}[\mathbf{X}, \mathbf{A}, \mathbf{\Phi}] = \text{sig} \left[ \beta_K \mathbf{1}^T + \mathbf{\omega}_K \mathbf{H}_K \right] \tag{11}\]

where the function \(\text{sig}[\bullet]\) applies the sigmoid function independently to every element of the row vector input. As usual, we use the binary cross-entropy loss, but now only at nodes where we know the ground truth label \(y\). Note that equation Equation 11 is just a vectorized version of the node classification loss from equation 13.3.

Training this network raises two problems. First, it is logistically difficult to train a graph neural network of this size. Consider that we must store the node embeddings at every network layer in the forward pass. This will involve both storing and processing a structure several times the size of the entire graph, and this may not be practical. Second, we have only a single graph, so it’s not obvious how to perform stochastic gradient descent. How can we form a batch if there is only a single object?

Choosing batches

One way to form a batch is to choose a random subset of labeled nodes at each training step. Each node depends on its neighbors in the previous layer. These, in turn, depend on their neighbors in the layer before, so (similarly to convolutional networks) each node has a receptive field (see Figure 9). The receptive field region is termed the k-hop neighborhood. We can hence perform a gradient descent step using the graph that forms the union of the k-hop neighborhoods of the batch nodes; the remaining inputs do not contribute.

Unfortunately, if there are many layers and the graph is densely connected, every input node may be in the receptive field of every output, and this may not reduce the graph size at all. This is known as the graph expansion problem. Two approaches that tackle this problem are neighborhood sampling and graph partitioning.

Neighborhood sampling: The full graph that feeds into the batch of nodes is sampled, thereby reducing the connections at each network layer (see Figure 10). For example, we might start with the batch nodes and randomly sample a fixed number of their neighbors in the previous layer. Then, we randomly sample a fixed number of their neighbors in the layer before, and so on. The graph still increases in size with each layer but in a much more controlled way. This is done anew for each batch, so the contributing neighbors differ even if the same batch is drawn twice. This is also reminiscent of dropout (section 9.3.3) and adds some regularization.

Graph partitioning: A second approach is to cluster the original graph into disjoint subsets of nodes (i.e., smaller graphs that are not connected to one another) before processing (see Figure 11). There are standard algorithms to choose these subsets to maximize the number of internal links. These smaller graphs can each be treated as batches, or a random subset of them can be combined to form a batch (reinstating any edges between them from the original graph).

Given one of the above methods to form batches, we can now train the network parameters in the same way as for the inductive setting, dividing the labeled nodes into train, test, and validation sets as desired; we have effectively converted a transductive problem to an inductive one. To perform inference, we compute predictions for the unknown nodes based on their k-hop neighborhood. Unlike training, this does not require storing the intermediate representations, so it is much more memory efficient.

In the previous examples, we combined messages from adjacent nodes by summing them together with the transformed current node. This was accomplished by post-multiplying the node embedding matrix \(\mathbf{H}\) by the adjacency matrix plus the identity \(\mathbf{A} + \mathbf{I}\). We now consider different approaches to both (i) the combination of the current embedding with the aggregated neighbors and (ii) the aggregation process itself.

Combining current node and aggregated neighbors

In the example GCN layer above, we combined the aggregated neighbors \(\mathbf{HA}\) with the current nodes \(\mathbf{H}\) by just summing them:

\[ \mathbf{H}_{k+1} = \mathbf{a}\left[\beta_k \mathbf{1}^T + \mathbf{\Omega}_k \mathbf{H}_k (\mathbf{A} + \mathbf{I})\right] \tag{12}\]

In another variation, the current node is multiplied by a factor of \((1 + \epsilon_k)\) before contributing to the sum, where \(\epsilon_k\) is a learned scalar that is different for each layer:

\[ \mathbf{H}_{k+1} = \mathbf{a}\left[\beta_k \mathbf{1}^T + \mathbf{\Omega}_k \mathbf{H}_k (\mathbf{A} + (1 + \epsilon_k)\mathbf{I})\right] \tag{13}\]

This is known as diagonal enhancement. A related variation applies a different linear transform \(\mathbf{\Psi}_k\) to the current node:

\[ \begin{aligned} \mathbf{H}_{k+1} &= \mathbf{a}\left[\beta_k \mathbf{1}^T + \mathbf{\Omega}_k \mathbf{H}_k \mathbf{A} + \mathbf{\Psi}_k \mathbf{H}_k\right] \\ &= \mathbf{a}\left[\beta_k \mathbf{1}^T + [\mathbf{\Omega}_k \,\, \mathbf{\Psi}_k] \begin{bmatrix} \mathbf{H}_k \mathbf{A} \\ \mathbf{H}_k \end{bmatrix}\right] \\ &= \mathbf{a}\left[\beta_k \mathbf{1}^T + \mathbf{\Omega}'_k \begin{bmatrix} \mathbf{H}_k \mathbf{A} \\ \mathbf{H}_k \end{bmatrix}\right] \end{aligned} \tag{14}\]

where we have defined \(\mathbf{\Omega}'_k = [\mathbf{\Omega}_k \,\, \mathbf{\Psi}_k]\) in the third line.

Residual connections

With residual connections, the aggregated representation from the neighbors is transformed and passed through the activation function before summation or concatenation with the current node. For the latter case, the associated network equations are:

\[ \mathbf{H}_{k+1} = \begin{bmatrix} \mathbf{a}\left[\beta_k \mathbf{1}^T + \mathbf{\Omega}_k \mathbf{H}_k \mathbf{A}\right] \\ \mathbf{H}_k \end{bmatrix} \tag{15}\]

Mean aggregation

The above methods aggregate the neighbors by summing the node embeddings. However, it’s possible to combine the embeddings in different ways. Sometimes it’s better to take the average of the neighbors rather than the sum; this can be superior if the embedding information is more important and the structural information less so since the magnitude of the neighborhood contributions will not depend on the number of neighbors:

\[ \mathbf{agg}[n] = \frac{1}{|\text{ne}[n]|} \sum_{m \in \text{ne}[n]} \mathbf{h}_m \tag{16}\]

where as before, \(\text{ne}[n]\) denotes a set containing the indices of the neighbors of the \(n^{th}\) node. Equation Equation 16 can be computed neatly in matrix form by introducing the diagonal \(N \times N\) degree matrix \(\mathbf{D}\). Each non-zero element of this matrix contains the number for the associated node. It follows that each diagonal element in the inverse matrix \(\mathbf{D}^{-1}\) contains the denominator that we need to compute the average. The new GCN layer can be written as:

\[ \mathbf{H}_{k+1} = \mathbf{a}\left[\beta_k \mathbf{1}^T + \mathbf{\Omega}_k \mathbf{H}_k (\mathbf{AD}^{-1} + \mathbf{I})\right] \tag{17}\]

Kipf normalization

There are many variations of graph neural networks based on mean aggregation. Sometimes the current node is included with its neighbors in the mean computation rather than treated separately. In Kipf normalization, the sum of the node representations is normalized as:

\[ \mathbf{agg}[n] = \sum_{m \in \text{ne}[n]} \frac{\mathbf{h}_m}{\sqrt{|\text{ne}[n]| |\text{ne}[m]|}} \tag{18}\]

with the logic that information coming from nodes with a very large number of neighbors should be down-weighted since there are many connections and they provide less unique information. This can also be expressed in matrix form using the degree matrix:

\[ \mathbf{H}_{k+1} = \mathbf{a}\left[\beta_k \mathbf{1}^T + \mathbf{\Omega}_k \mathbf{H}_k (\mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2} + \mathbf{I})\right] \tag{19}\]

Max pooling aggregation

An alternative operation that is also invariant to permutation is computing the maximum of a set of objects. The max pooling aggregation operator is:

\[ \mathbf{agg}[n] = \max_{m \in \text{ne}[n]} [\mathbf{h}_m] \tag{20}\]

where the operator \(\max\) returns the element-wise maximum of the vectors \(\mathbf{h}_m\) that are neighbors to the current node \(n\).

Aggregation by attention

The aggregation methods discussed so far either weight the contribution of the neighbors equally or in a way that depends on the graph topology. Conversely, in graph attention layers, the weights depend on the data at the nodes. A linear transform is applied to the current node embeddings so that:

\[ \mathbf{H}'_k = \beta_k \mathbf{1}^T + \mathbf{\Omega}_k \mathbf{H}_k \tag{21}\]

Then the similarity \(s_{mn}\) of each transformed node embedding \(\mathbf{h}'_m\) to the transformed node embedding \(\mathbf{h}'_n\) is computed by concatenating the pairs, taking a dot product with a column vector \(\mathbf{\phi}_k\) of learned parameters, and applying an activation function:

\[ s_{mn} = \mathbf{a} \left[ \mathbf{\phi}_k^T \begin{bmatrix} \mathbf{h}'_m \\ \mathbf{h}'_n \end{bmatrix} \right] \tag{22}\]

These variables are stored in an \(N \times N\) matrix \(\mathbf{S}\), where each element represents the similarity of every node to every other. As in dot-product self-attention, the attention weights contributing to each output embedding are normalized to be positive and sum to one using the softmax operation. However, only those values corresponding to the current node and its neighbors should contribute. The attention weights are applied to the transformed embeddings:

\[ \mathbf{H}_{k+1} = \mathbf{a} \left[ \mathbf{H}'_k \cdot \text{Softmask}[\mathbf{S}, \mathbf{A} + \mathbf{I}] \right] \tag{23}\]

where \(\mathbf{a}[\bullet]\) is a second activation function. The function \(\text{Softmask}[\bullet, \bullet]\) computes the attention values by applying softmax operation separately to each column of its first argument \(\mathbf{S}\), but only after setting values where the second argument \(\mathbf{A} + \mathbf{I}\) is zero to negative infinity, so they do not contribute. This ensures that the attention to non-neighboring nodes is zero.

This is very similar to the dot-product self-attention computation in transformers (see Figure 12), except that (i) the keys, queries, and values are all the same, (ii) the measure of similarity is different, and (iii) the attentions are masked so that each node only attends to itself and its neighbors. As in transformers, this system can be extended to use multiple heads that are run in parallel and recombined.