|
801 | 801 | "source": [
|
802 | 802 | "## PyTorch Lightning\n",
|
803 | 803 | "\n",
|
804 |
| - "In this notebook and in many following ones, we will make use of the library [PyTorch Lightning](https://www.pytorchlightning.ai/). PyTorch Lightning is a framework that simplifies your code needed to train, evaluate, and test a model in PyTorch. It also handles logging into [TensorBoard](https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html), a visualization toolkit for ML experiments, and saving model checkpoints automatically with minimal code overhead from our side. This is extremely helpful for us as we want to focus on implementing different model architectures and spend little time on other code overhead. Note that at the time of writing/teaching, the framework has been released in version 1.6. Future versions might have a slightly changed interface and thus might not work perfectly with the code (we will try to keep it up-to-date as much as possible). \n", |
| 804 | + "In this notebook and in many following ones, we will make use of the library [PyTorch Lightning](https://www.pytorchlightning.ai/). PyTorch Lightning is a framework that simplifies your code needed to train, evaluate, and test a model in PyTorch. It also handles logging into [TensorBoard](https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html), a visualization toolkit for ML experiments, and saving model checkpoints automatically with minimal code overhead from our side. This is extremely helpful for us as we want to focus on implementing different model architectures and spend little time on other code overhead. Note that at the time of writing/teaching, the framework has been released in version 1.8. Future versions might have a slightly changed interface and thus might not work perfectly with the code (we will try to keep it up-to-date as much as possible). \n", |
805 | 805 | "\n",
|
806 | 806 | "Now, we will take the first step in PyTorch Lightning, and continue to explore the framework in our other tutorials. First, we import the library:"
|
807 | 807 | ]
|
|
862 | 862 | "4. Validation loop (`validation_step`) where similarly to the training, we only have to define what should happen per step\n",
|
863 | 863 | "5. Test loop (`test_step`) which is the same as validation, only on a test set.\n",
|
864 | 864 | "\n",
|
865 |
| - "Therefore, we don't abstract the PyTorch code, but rather organize it and define some default operations that are commonly used. If you need to change something else in your training/validation/test loop, there are many possible functions you can overwrite (see the [docs](https://pytorch-lightning.readthedocs.io/en/stable/lightning_module.html) for details).\n", |
| 865 | + "Therefore, we don't abstract the PyTorch code, but rather organize it and define some default operations that are commonly used. If you need to change something else in your training/validation/test loop, there are many possible functions you can overwrite (see the [docs](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html) for details).\n", |
866 | 866 | "\n",
|
867 | 867 | "Now we can look at an example of how a Lightning Module for training a CNN looks like:"
|
868 | 868 | ]
|
|
1008 | 1008 | "source": [
|
1009 | 1009 | "If we pass the classes or objects directly as an argument to the Lightning module, we couldn't take advantage of PyTorch Lightning's automatically hyperparameter saving and loading.\n",
|
1010 | 1010 | "\n",
|
1011 |
| - "Besides the Lightning module, the second most important module in PyTorch Lightning is the `Trainer`. The trainer is responsible to execute the training steps defined in the Lightning module and completes the framework. Similar to the Lightning module, you can override any key part that you don't want to be automated, but the default settings are often the best practice to do. For a full overview, see the [documentation](https://pytorch-lightning.readthedocs.io/en/stable/trainer.html). The most important functions we use below are:\n", |
| 1011 | + "Besides the Lightning module, the second most important module in PyTorch Lightning is the `Trainer`. The trainer is responsible to execute the training steps defined in the Lightning module and completes the framework. Similar to the Lightning module, you can override any key part that you don't want to be automated, but the default settings are often the best practice to do. For a full overview, see the [documentation](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html). The most important functions we use below are:\n", |
1012 | 1012 | "\n",
|
1013 | 1013 | "* `trainer.fit`: Takes as input a lightning module, a training dataset, and an (optional) validation dataset. This function trains the given module on the training dataset with occasional validation (default once per epoch, can be changed)\n",
|
1014 | 1014 | "* `trainer.test`: Takes as input a model and a dataset on which we want to test. It returns the test metric on the dataset.\n",
|
|
0 commit comments