Skip to content

Commit f830c5d

Browse files
committed
Tutorial 7 (JAX): Fixing PyTorch mentioning
1 parent 6e642af commit f830c5d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

docs/tutorial_notebooks/JAX/tutorial7/GNN_overview.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@
210210
"\n",
211211
"$W^{(l)}$ is the weight parameters with which we transform the input features into messages ($H^{(l)}W^{(l)}$). To the adjacency matrix $A$ we add the identity matrix so that each node sends its own message also to itself: $\\hat{A}=A+I$. Finally, to take the average instead of summing, we calculate the matrix $\\hat{D}$ which is a diagonal matrix with $D_{ii}$ denoting the number of neighbors node $i$ has. $\\sigma$ represents an arbitrary activation function, and not necessarily the sigmoid (usually a ReLU-based activation function is used in GNNs). \n",
212212
"\n",
213-
"When implementing the GCN layer in PyTorch, we can take advantage of the flexible operations on tensors. Instead of defining a matrix $\\hat{D}$, we can simply divide the summed messages by the number of neighbors afterward. Additionally, we replace the weight matrix with a linear layer, which additionally allows us to add a bias. Written as a PyTorch module, the GCN layer is defined as follows:"
213+
"When implementing the GCN layer in JAX/Flax, we can take advantage of the flexible operations on tensors. Instead of defining a matrix $\\hat{D}$, we can simply divide the summed messages by the number of neighbors afterward. Additionally, we replace the weight matrix with a linear layer, which additionally allows us to add a bias. Written as a Flax module, the GCN layer is defined as follows:"
214214
]
215215
},
216216
{

0 commit comments

Comments
 (0)