Skip to content

Commit ac42535

Browse files
committed
Tutorial 6: Supporting flexible mask shape
1 parent a2fe2d3 commit ac42535

File tree

2 files changed

+49
-5
lines changed

2 files changed

+49
-5
lines changed

docs/tutorial_notebooks/JAX/tutorial6/Transformers_and_MHAttention.ipynb

+25-3
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,26 @@
344344
"With this in mind, we can implement the Multi-Head Attention module below."
345345
]
346346
},
347+
{
348+
"cell_type": "code",
349+
"execution_count": null,
350+
"metadata": {},
351+
"outputs": [],
352+
"source": [
353+
"# Helper function to support different mask shapes.\n",
354+
"# Output shape supports (batch_size, number of heads, seq length, seq length)\n",
355+
"# If 2D: broadcasted over batch size and number of heads\n",
356+
"# If 3D: broadcasted over number of heads\n",
357+
"# If 4D: leave as is\n",
358+
"def expand_mask(mask):\n",
359+
" assert mask.ndim > 2, \"Mask must be at least 2-dimensional with seq_length x seq_length\"\n",
360+
" if mask.ndim == 3:\n",
361+
" mask = mask.unsqueeze(1)\n",
362+
" while mask.ndim < 4:\n",
363+
" mask = mask.unsqueeze(0)\n",
364+
" return mask"
365+
]
366+
},
347367
{
348368
"cell_type": "code",
349369
"execution_count": 5,
@@ -367,6 +387,8 @@
367387
"\n",
368388
" def __call__(self, x, mask=None):\n",
369389
" batch_size, seq_length, embed_dim = x.shape\n",
390+
" if mask is not None:\n",
391+
" mask = expand_mask(mask)\n",
370392
" qkv = self.qkv_proj(x)\n",
371393
" \n",
372394
" # Separate Q, K, V from linear output\n",
@@ -526,7 +548,7 @@
526548
"encblock = EncoderBlock(input_dim=128, num_heads=4, dim_feedforward=512, dropout_prob=0.1)\n",
527549
"# Initialize parameters of encoder block with random key and inputs\n",
528550
"main_rng, init_rng, dropout_init_rng = random.split(main_rng, 3)\n",
529-
"params = encblock.init({'params': init_rng, 'dropout': dropout_init_rng}, x, True)['params']\n",
551+
"params = encblock.init({'params': init_rng, 'dropout': dropout_init_rng}, x, train=True)['params']\n",
530552
"# Apply encoder block with parameters on the inputs\n",
531553
"# Since dropout is stochastic, we need to pass a rng to the forward\n",
532554
"main_rng, dropout_apply_rng = random.split(main_rng)\n",
@@ -20341,7 +20363,7 @@
2034120363
],
2034220364
"metadata": {
2034320365
"kernelspec": {
20344-
"display_name": "Python 3",
20366+
"display_name": "Python 3 (ipykernel)",
2034520367
"language": "python",
2034620368
"name": "python3"
2034720369
},
@@ -20355,7 +20377,7 @@
2035520377
"name": "python",
2035620378
"nbconvert_exporter": "python",
2035720379
"pygments_lexer": "ipython3",
20358-
"version": "3.7.9"
20380+
"version": "3.10.4"
2035920381
}
2036020382
},
2036120383
"nbformat": 4,

docs/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.ipynb

+24-2
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,26 @@
326326
"How are we applying a Multi-Head Attention layer in a neural network, where we don't have an arbitrary query, key, and value vector as input? Looking at the computation graph above, a simple but effective implementation is to set the current feature map in a NN, $X\\in\\mathbb{R}^{B\\times T\\times d_{\\text{model}}}$, as $Q$, $K$ and $V$ ($B$ being the batch size, $T$ the sequence length, $d_{\\text{model}}$ the hidden dimensionality of $X$). The consecutive weight matrices $W^{Q}$, $W^{K}$, and $W^{V}$ can transform $X$ to the corresponding feature vectors that represent the queries, keys, and values of the input. Using this approach, we can implement the Multi-Head Attention module below."
327327
]
328328
},
329+
{
330+
"cell_type": "code",
331+
"execution_count": null,
332+
"metadata": {},
333+
"outputs": [],
334+
"source": [
335+
"# Helper function to support different mask shapes.\n",
336+
"# Output shape supports (batch_size, number of heads, seq length, seq length)\n",
337+
"# If 2D: broadcasted over batch size and number of heads\n",
338+
"# If 3D: broadcasted over number of heads\n",
339+
"# If 4D: leave as is\n",
340+
"def expand_mask(mask):\n",
341+
" assert mask.ndim > 2, \"Mask must be at least 2-dimensional with seq_length x seq_length\"\n",
342+
" if mask.ndim == 3:\n",
343+
" mask = mask.unsqueeze(1)\n",
344+
" while mask.ndim < 4:\n",
345+
" mask = mask.unsqueeze(0)\n",
346+
" return mask"
347+
]
348+
},
329349
{
330350
"cell_type": "code",
331351
"execution_count": 5,
@@ -358,6 +378,8 @@
358378
"\n",
359379
" def forward(self, x, mask=None, return_attention=False):\n",
360380
" batch_size, seq_length, _ = x.size()\n",
381+
" if mask is not None:\n",
382+
" mask = expand_mask(mask)\n",
361383
" qkv = self.qkv_proj(x)\n",
362384
" \n",
363385
" # Separate Q, K, V from linear output\n",
@@ -20264,7 +20286,7 @@
2026420286
],
2026520287
"metadata": {
2026620288
"kernelspec": {
20267-
"display_name": "Python 3",
20289+
"display_name": "Python 3 (ipykernel)",
2026820290
"language": "python",
2026920291
"name": "python3"
2027020292
},
@@ -20278,7 +20300,7 @@
2027820300
"name": "python",
2027920301
"nbconvert_exporter": "python",
2028020302
"pygments_lexer": "ipython3",
20281-
"version": "3.7.9"
20303+
"version": "3.10.4"
2028220304
}
2028320305
},
2028420306
"nbformat": 4,

0 commit comments

Comments
 (0)