- https://www.youtube.com/watch?v=LLUxwHc7O4A
- when we aggregate, we can take in the features of our neighbours AND our neighbour’s neighbours
- and the feature. vector of ourself! notice how the feature vector for 0 is fed into the three lowest convolution blocks
- “A 2-layer GNN generates embedding of node 0 using 2-hop neighbourhood structure and features”
- this changed how we thought of neural nets
- before, GNNs needed the entire graph to make prediction about a single node
- older GNNs were also limited to making predictions on nodes that were there during the training phase
- but now, batches of node neighbourhoods can be put into the GPU, so we don’t need to care about the rest of the graph.
- This gave us the ability to process:
-
- much larger graphs, and
-
- structures that weren’t there during training
- Here’s how we can perform stochastic gradient descent:
-
- randomly sample M << N nodes to put into our minibatch
-
- generate the computation graph for all M nodes and put into our minibatch
-
- compute loss and get our gradients!
- but this is compute intensive cause
-
- for each node, we need to get the computation graph
- but increasing hops increases the number of nodes exponentially
-
- if we hit a hub node, we’re screwed
- Solution: use neighbourhood sampling
- we just sample at most H neighbours when doing the calculation (instead of every neighbour)
- we are capping the fan-out by H
- 3 remarks
- choosing H is a tradeoff for accuracy vs training time
- H doesn’t change the fact that fan-out is exponential
- random sampling of neighbours may not be optimal
- we can use random walk with restarts
- basically, the random walk will score each node (prob based on its features)
- we perform a few of these random walks
- then we just sample the neighbouring nodes with the highest scores
- https://youtu.be/JtDgmmQ60x8?si=7DxQdzoiINZia0xE&t=1430
- main 2 modifications of graphSAGE:
-
- rather than taking the sum of the neighbour’s hidden layers, we can make it more general and say we’re using an general AGG function (could be mean)
- note: AGG could be max pool, or an LSTM
- problem: LSTM has a different output if you input the features in a diff order.
- Solution we can perform multiple LSTM inferences on a random permutation of the input features to get a stable result
-
- rather than ADDING our own prev hidden layer, we are CONCATENATING our prev hidden layer