|
344 | 344 | "With this in mind, we can implement the Multi-Head Attention module below."
|
345 | 345 | ]
|
346 | 346 | },
|
| 347 | + { |
| 348 | + "cell_type": "code", |
| 349 | + "execution_count": null, |
| 350 | + "metadata": {}, |
| 351 | + "outputs": [], |
| 352 | + "source": [ |
| 353 | + "# Helper function to support different mask shapes.\n", |
| 354 | + "# Output shape supports (batch_size, number of heads, seq length, seq length)\n", |
| 355 | + "# If 2D: broadcasted over batch size and number of heads\n", |
| 356 | + "# If 3D: broadcasted over number of heads\n", |
| 357 | + "# If 4D: leave as is\n", |
| 358 | + "def expand_mask(mask):\n", |
| 359 | + " assert mask.ndim > 2, \"Mask must be at least 2-dimensional with seq_length x seq_length\"\n", |
| 360 | + " if mask.ndim == 3:\n", |
| 361 | + " mask = mask.unsqueeze(1)\n", |
| 362 | + " while mask.ndim < 4:\n", |
| 363 | + " mask = mask.unsqueeze(0)\n", |
| 364 | + " return mask" |
| 365 | + ] |
| 366 | + }, |
347 | 367 | {
|
348 | 368 | "cell_type": "code",
|
349 | 369 | "execution_count": 5,
|
|
367 | 387 | "\n",
|
368 | 388 | " def __call__(self, x, mask=None):\n",
|
369 | 389 | " batch_size, seq_length, embed_dim = x.shape\n",
|
| 390 | + " if mask is not None:\n", |
| 391 | + " mask = expand_mask(mask)\n", |
370 | 392 | " qkv = self.qkv_proj(x)\n",
|
371 | 393 | " \n",
|
372 | 394 | " # Separate Q, K, V from linear output\n",
|
|
526 | 548 | "encblock = EncoderBlock(input_dim=128, num_heads=4, dim_feedforward=512, dropout_prob=0.1)\n",
|
527 | 549 | "# Initialize parameters of encoder block with random key and inputs\n",
|
528 | 550 | "main_rng, init_rng, dropout_init_rng = random.split(main_rng, 3)\n",
|
529 |
| - "params = encblock.init({'params': init_rng, 'dropout': dropout_init_rng}, x, True)['params']\n", |
| 551 | + "params = encblock.init({'params': init_rng, 'dropout': dropout_init_rng}, x, train=True)['params']\n", |
530 | 552 | "# Apply encoder block with parameters on the inputs\n",
|
531 | 553 | "# Since dropout is stochastic, we need to pass a rng to the forward\n",
|
532 | 554 | "main_rng, dropout_apply_rng = random.split(main_rng)\n",
|
|
20341 | 20363 | ],
|
20342 | 20364 | "metadata": {
|
20343 | 20365 | "kernelspec": {
|
20344 |
| - "display_name": "Python 3", |
| 20366 | + "display_name": "Python 3 (ipykernel)", |
20345 | 20367 | "language": "python",
|
20346 | 20368 | "name": "python3"
|
20347 | 20369 | },
|
|
20355 | 20377 | "name": "python",
|
20356 | 20378 | "nbconvert_exporter": "python",
|
20357 | 20379 | "pygments_lexer": "ipython3",
|
20358 |
| - "version": "3.7.9" |
| 20380 | + "version": "3.10.4" |
20359 | 20381 | }
|
20360 | 20382 | },
|
20361 | 20383 | "nbformat": 4,
|
|
0 commit comments