diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000..15034b079b --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,11 @@ +{ + "configurations": [ + { + "name": "Python: File", + "type": "python", + "request": "launch", + "program": "${file}", + "justMyCode": true + } + ] +} \ No newline at end of file diff --git a/improved_diffusion.egg-info/PKG-INFO b/improved_diffusion.egg-info/PKG-INFO new file mode 100644 index 0000000000..3a81a27e04 --- /dev/null +++ b/improved_diffusion.egg-info/PKG-INFO @@ -0,0 +1,10 @@ +Metadata-Version: 1.0 +Name: improved-diffusion +Version: 0.0.0 +Summary: UNKNOWN +Home-page: UNKNOWN +Author: UNKNOWN +Author-email: UNKNOWN +License: UNKNOWN +Description: UNKNOWN +Platform: UNKNOWN diff --git a/improved_diffusion.egg-info/SOURCES.txt b/improved_diffusion.egg-info/SOURCES.txt new file mode 100644 index 0000000000..805ac15adb --- /dev/null +++ b/improved_diffusion.egg-info/SOURCES.txt @@ -0,0 +1,7 @@ +README.md +setup.py +improved_diffusion.egg-info/PKG-INFO +improved_diffusion.egg-info/SOURCES.txt +improved_diffusion.egg-info/dependency_links.txt +improved_diffusion.egg-info/requires.txt +improved_diffusion.egg-info/top_level.txt \ No newline at end of file diff --git a/improved_diffusion.egg-info/dependency_links.txt b/improved_diffusion.egg-info/dependency_links.txt new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/improved_diffusion.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/improved_diffusion.egg-info/requires.txt b/improved_diffusion.egg-info/requires.txt new file mode 100644 index 0000000000..e251f6c85a --- /dev/null +++ b/improved_diffusion.egg-info/requires.txt @@ -0,0 +1,3 @@ +blobfile>=1.0.5 +torch +tqdm diff --git a/improved_diffusion.egg-info/top_level.txt b/improved_diffusion.egg-info/top_level.txt new file mode 100644 index 0000000000..880277ccbd --- /dev/null +++ b/improved_diffusion.egg-info/top_level.txt @@ -0,0 +1 @@ +improved_diffusion diff --git a/improved_diffusion/test_nn sun/nn.ipynb b/improved_diffusion/test_nn sun/nn.ipynb new file mode 100644 index 0000000000..b62944f145 --- /dev/null +++ b/improved_diffusion/test_nn sun/nn.ipynb @@ -0,0 +1,1097 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.functional as F\n", + "import torchvision.datasets as datasets\n", + "from torch.utils.data import DataLoader, Dataset\n", + "import torchvision.transforms as transforms\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([64])\n", + "torch.Size([32])\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[9], line 55\u001b[0m\n\u001b[1;32m 53\u001b[0m optimizer\u001b[39m.\u001b[39mzero_grad()\n\u001b[1;32m 54\u001b[0m loss\u001b[39m.\u001b[39mbackward()\n\u001b[0;32m---> 55\u001b[0m optimizer\u001b[39m.\u001b[39;49mstep()\n\u001b[1;32m 57\u001b[0m \u001b[39m#Check Accuracy\u001b[39;00m\n\u001b[1;32m 59\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcheck_accuracy\u001b[39m(loader,model):\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/optim/optimizer.py:269\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 267\u001b[0m \u001b[39mself\u001b[39m, \u001b[39m*\u001b[39m_ \u001b[39m=\u001b[39m args\n\u001b[1;32m 268\u001b[0m profile_name \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mOptimizer.step#\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m.step\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m)\n\u001b[0;32m--> 269\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mautograd\u001b[39m.\u001b[39mprofiler\u001b[39m.\u001b[39mrecord_function(profile_name):\n\u001b[1;32m 270\u001b[0m \u001b[39m# call optimizer step pre hooks\u001b[39;00m\n\u001b[1;32m 271\u001b[0m \u001b[39mfor\u001b[39;00m pre_hook \u001b[39min\u001b[39;00m chain(_global_optimizer_pre_hooks\u001b[39m.\u001b[39mvalues(), \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_optimizer_step_pre_hooks\u001b[39m.\u001b[39mvalues()):\n\u001b[1;32m 272\u001b[0m result \u001b[39m=\u001b[39m pre_hook(\u001b[39mself\u001b[39m, args, kwargs)\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/autograd/profiler.py:492\u001b[0m, in \u001b[0;36mrecord_function.__enter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 491\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__enter__\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[0;32m--> 492\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mrecord \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mops\u001b[39m.\u001b[39;49mprofiler\u001b[39m.\u001b[39;49m_record_function_enter_new(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mname, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49margs)\n\u001b[1;32m 493\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/_ops.py:502\u001b[0m, in \u001b[0;36mOpOverloadPacket.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 497\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 498\u001b[0m \u001b[39m# overloading __call__ to ensure torch.ops.foo.bar()\u001b[39;00m\n\u001b[1;32m 499\u001b[0m \u001b[39m# is still callable from JIT\u001b[39;00m\n\u001b[1;32m 500\u001b[0m \u001b[39m# We save the function ptr as the `op` attribute on\u001b[39;00m\n\u001b[1;32m 501\u001b[0m \u001b[39m# OpOverloadPacket to access it here.\u001b[39;00m\n\u001b[0;32m--> 502\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_op(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs \u001b[39mor\u001b[39;49;00m {})\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "# Create Fully Connected NN\n", + "class NN(nn.Module):\n", + " def __init__(self, input_size, num_classes):\n", + " super(NN,self).__init__()\n", + " self.fc1 = nn.Linear(input_size,100)\n", + " self.fc2 = nn.Linear(100,num_classes)\n", + "\n", + " def forward(self,x):\n", + " x = F.relu(self.fc1(x))\n", + " x = self.fc2(x)\n", + " return x\n", + " \n", + "\n", + "# Check Device\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "#Hyperparameter\n", + "input_size = 784\n", + "num_classes = 10\n", + "learning_rate = 1e-3\n", + "batch_sizes = 64\n", + "num_epochs = 100\n", + "\n", + "#Load Data\n", + "train_dataset = datasets.MNIST(root = 'nndataset/',train = True, transform = transforms.ToTensor(),download = True)\n", + "train_loader = DataLoader(dataset = train_dataset,batch_size = batch_sizes, shuffle = True)\n", + "\n", + "test_dataset = datasets.MNIST(root = 'nndataset/',train = False, transform = transforms.ToTensor(),download = True)\n", + "test_loader = DataLoader(dataset = test_dataset,batch_size = batch_sizes, shuffle = True)\n", + "\n", + "\n", + "#Init Network\n", + "\n", + "model = NN(input_size=input_size, num_classes=num_classes).to(device)\n", + "\n", + "#Loss and Optimizer\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(model.parameters(),lr = learning_rate)\n", + "\n", + "#Training Network\n", + "\n", + "for epochs in range(num_epochs):\n", + " for batch_idx, (data,targets) in enumerate(train_loader):\n", + " data = data.to(device = device)\n", + " targets = targets.to(device = device)\n", + "\n", + " data = data.reshape(data.shape[0],-1)\n", + "\n", + " scores = model(data)\n", + " loss = criterion(scores, targets)\n", + "\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + "#Check Accuracy\n", + "\n", + "def check_accuracy(loader,model):\n", + " if loader.dataset.train:\n", + " print(\"checking accuracy on training data\")\n", + " else:\n", + " print(\"checking accuracy on test data\")\n", + "\n", + " num_correct = 0\n", + " num_sample = 0\n", + " model.eval()\n", + "\n", + " with torch.no_grad():\n", + " for x , y in loader:\n", + " x = x.to(device = device)\n", + " y = y.to(device = device)\n", + " x = x.reshape(x.shape[0],-1)\n", + "\n", + " score = model(x)\n", + " _, prediction = score.max(1)\n", + "\n", + " num_correct += (prediction == y).sum()\n", + " num_sample +=prediction.size(0)\n", + "\n", + " print(\n", + " f\"Got {num_correct} / {num_sample} with accuracy\"\n", + " f\" {float(num_correct) / float(num_sample) * 100:.2f}\"\n", + " )\n", + "\n", + " model.train()\n", + "\n", + "\n", + "check_accuracy(train_loader,model)\n", + "check_accuracy(test_loader,model)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/improved_diffusion/test_nn sun/nndataset/MNIST/raw/t10k-images-idx3-ubyte b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/t10k-images-idx3-ubyte new file mode 100644 index 0000000000..1170b2cae9 Binary files /dev/null and b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/t10k-images-idx3-ubyte differ diff --git a/improved_diffusion/test_nn sun/nndataset/MNIST/raw/t10k-images-idx3-ubyte.gz b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/t10k-images-idx3-ubyte.gz new file mode 100644 index 0000000000..5ace8ea93f Binary files /dev/null and b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/t10k-images-idx3-ubyte.gz differ diff --git a/improved_diffusion/test_nn sun/nndataset/MNIST/raw/t10k-labels-idx1-ubyte b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/t10k-labels-idx1-ubyte new file mode 100644 index 0000000000..d1c3a97061 Binary files /dev/null and b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/t10k-labels-idx1-ubyte differ diff --git a/improved_diffusion/test_nn sun/nndataset/MNIST/raw/t10k-labels-idx1-ubyte.gz b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/t10k-labels-idx1-ubyte.gz new file mode 100644 index 0000000000..a7e141541c Binary files /dev/null and b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/t10k-labels-idx1-ubyte.gz differ diff --git a/improved_diffusion/test_nn sun/nndataset/MNIST/raw/train-images-idx3-ubyte b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/train-images-idx3-ubyte new file mode 100644 index 0000000000..bbce27659e Binary files /dev/null and b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/train-images-idx3-ubyte differ diff --git a/improved_diffusion/test_nn sun/nndataset/MNIST/raw/train-images-idx3-ubyte.gz b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/train-images-idx3-ubyte.gz new file mode 100644 index 0000000000..b50e4b6bcc Binary files /dev/null and b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/train-images-idx3-ubyte.gz differ diff --git a/improved_diffusion/test_nn sun/nndataset/MNIST/raw/train-labels-idx1-ubyte b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/train-labels-idx1-ubyte new file mode 100644 index 0000000000..d6b4c5db3b Binary files /dev/null and b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/train-labels-idx1-ubyte differ diff --git a/improved_diffusion/test_nn sun/nndataset/MNIST/raw/train-labels-idx1-ubyte.gz b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/train-labels-idx1-ubyte.gz new file mode 100644 index 0000000000..707a576bb5 Binary files /dev/null and b/improved_diffusion/test_nn sun/nndataset/MNIST/raw/train-labels-idx1-ubyte.gz differ