|
906 | 906 | "Step 500: Loss on training set : 0.085323\n",
|
907 | 907 | "Step 600: Loss on training set : 0.231022\n",
|
908 | 908 | "Step 700: Loss on training set : 0.061058\n",
|
909 |
| - "Step 800: Loss on training set : 0.196906\n" |
| 909 | + "Step 800: Loss on training set : 0.196906\n", |
| 910 | + "Step 900: Loss on training set : 0.156831\n", |
| 911 | + "Step 1000: Loss on training set : 0.091805\n", |
| 912 | + "Step 1100: Loss on training set : 0.024188\n", |
| 913 | + "Step 1200: Loss on training set : 0.092624\n", |
| 914 | + "Step 1300: Loss on training set : 0.019488\n", |
| 915 | + "Step 1400: Loss on training set : 0.064160\n", |
| 916 | + "Step 1500: Loss on training set : 0.044069\n", |
| 917 | + "Step 1600: Loss on training set : 0.088605\n", |
| 918 | + "Step 1700: Loss on training set : 0.004956\n", |
| 919 | + "Step 1800: Loss on training set : 0.044108\n", |
| 920 | + "Step 1900: Loss on training set : 0.050574\n", |
| 921 | + "Step 2000: Loss on training set : 0.013534\n", |
| 922 | + "Step 2100: Loss on training set : 0.068764\n", |
| 923 | + "Step 2200: Loss on training set : 0.061247\n", |
| 924 | + "Step 2300: Loss on training set : 0.134102\n", |
| 925 | + "Step 2400: Loss on training set : 0.002189\n", |
| 926 | + "Step 2500: Loss on training set : 0.002621\n", |
| 927 | + "Step 2600: Loss on training set : 0.084751\n", |
| 928 | + "Step 2700: Loss on training set : 0.073403\n", |
| 929 | + "Step 2800: Loss on training set : 0.034124\n", |
| 930 | + "Step 2900: Loss on training set : 0.068016\n", |
| 931 | + "Step 3000: Loss on training set : 0.026844\n", |
| 932 | + "Step 3100: Loss on training set : 0.008452\n", |
| 933 | + "Step 3200: Loss on training set : 0.052670\n", |
| 934 | + "Step 3300: Loss on training set : 0.095155\n", |
| 935 | + "Step 3400: Loss on training set : 0.019506\n", |
| 936 | + "Step 3500: Loss on training set : 0.015484\n", |
| 937 | + "Step 3600: Loss on training set : 0.007086\n", |
| 938 | + "Step 3700: Loss on training set : 0.045831\n", |
| 939 | + "Step 3800: Loss on training set : 0.058367\n" |
910 | 940 | ]
|
911 | 941 | }
|
912 | 942 | ],
|
|
975 | 1005 | },
|
976 | 1006 | "outputs": [],
|
977 | 1007 | "source": [
|
978 |
| - "# Running with tf.nn.softmax_cross_entropy_with_logits_v2\n" |
| 1008 | + "help(tf.nn.softmax_cross_entropy_with_logits_v2)" |
| 1009 | + ] |
| 1010 | + }, |
| 1011 | + { |
| 1012 | + "cell_type": "code", |
| 1013 | + "execution_count": null, |
| 1014 | + "metadata": { |
| 1015 | + "collapsed": true |
| 1016 | + }, |
| 1017 | + "outputs": [], |
| 1018 | + "source": [ |
| 1019 | + "# Running with tf.nn.softmax_cross_entropy_with_logits_v2\n", |
| 1020 | + "class MNISTModel(object):\n", |
| 1021 | + " def __init__(self, data_format):\n", |
| 1022 | + " # 'channels_first' is typically faster on GPUs\n", |
| 1023 | + " # while 'channels_last' is typically faster on CPUs.\n", |
| 1024 | + " # See: https://www.tensorflow.org/performance/performance_guide#data_formats\n", |
| 1025 | + " if data_format == 'channels_first':\n", |
| 1026 | + " self._input_shape = [-1, 1, 28, 28]\n", |
| 1027 | + " else:\n", |
| 1028 | + " self._input_shape = [-1, 28, 28, 1]\n", |
| 1029 | + " self.conv1 = tf.layers.Conv2D(32, 5,\n", |
| 1030 | + " padding='same',\n", |
| 1031 | + " activation=tf.nn.relu,\n", |
| 1032 | + " data_format=data_format)\n", |
| 1033 | + " self.max_pool2d = tf.layers.MaxPooling2D(\n", |
| 1034 | + " (2, 2), (2, 2), padding='same', data_format=data_format)\n", |
| 1035 | + " self.conv2 = tf.layers.Conv2D(64, 5,\n", |
| 1036 | + " padding='same',\n", |
| 1037 | + " activation=tf.nn.relu,\n", |
| 1038 | + " data_format=data_format)\n", |
| 1039 | + " self.dense1 = tf.layers.Dense(1024, activation=tf.nn.relu)\n", |
| 1040 | + " self.dropout = tf.layers.Dropout(0.5)\n", |
| 1041 | + " self.dense2 = tf.layers.Dense(10)\n", |
| 1042 | + "\n", |
| 1043 | + " def predict(self, inputs):\n", |
| 1044 | + " x = tf.reshape(inputs, self._input_shape)\n", |
| 1045 | + " x = self.max_pool2d(self.conv1(x))\n", |
| 1046 | + " x = self.max_pool2d(self.conv2(x))\n", |
| 1047 | + " x = tf.layers.flatten(x)\n", |
| 1048 | + " x = self.dropout(self.dense1(x))\n", |
| 1049 | + " return self.dense2(x)\n", |
| 1050 | + "\n", |
| 1051 | + "def loss(model, inputs, targets):\n", |
| 1052 | + " return tf.reduce_mean(\n", |
| 1053 | + " tf.nn.softmax_cross_entropy_with_logits(\n", |
| 1054 | + " logits=model.predict(inputs), labels=targets))\n", |
| 1055 | + "\n", |
| 1056 | + "\n", |
| 1057 | + "# Load the training and validation data\n", |
| 1058 | + "from tensorflow.examples.tutorials.mnist import input_data\n", |
| 1059 | + "data = input_data.read_data_sets(\"./mnist_data\", one_hot=True)\n", |
| 1060 | + "\n", |
| 1061 | + "# Train\n", |
| 1062 | + "device = \"gpu:0\" if tfe.num_gpus() else \"cpu:0\"\n", |
| 1063 | + "model = MNISTModel('channels_first' if tfe.num_gpus() else 'channels_last')\n", |
| 1064 | + "optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)\n", |
| 1065 | + "grad = tfe.implicit_gradients(loss)\n", |
| 1066 | + "for i in range(20001):\n", |
| 1067 | + " with tf.device(device):\n", |
| 1068 | + " (inputs, targets) = data.train.next_batch(50)\n", |
| 1069 | + " optimizer.apply_gradients(grad(model, inputs, targets))\n", |
| 1070 | + " if i % 100 == 0:\n", |
| 1071 | + " print(\"Step %d: Loss on training set : %f\" %\n", |
| 1072 | + " (i, loss(model, inputs, targets).numpy()))\n", |
| 1073 | + "print(\"Loss on test set: %f\" % loss(model, data.test.images, data.test.labels).numpy())" |
| 1074 | + ] |
| 1075 | + }, |
| 1076 | + { |
| 1077 | + "cell_type": "markdown", |
| 1078 | + "metadata": {}, |
| 1079 | + "source": [ |
| 1080 | + "## Checkpointing trained variables\n", |
| 1081 | + "\n", |
| 1082 | + "TensorFlow Eager Execution Variables `tfe.Variables` provide a way to represent shared, persistent state of the model you make. The `tfe.Saver` class -- which is a thin wrapper over the `tf.train.Saver` class provides means to save and restore variables to and from checkpoints.\n", |
| 1083 | + "\n", |
| 1084 | + "As an examples:" |
979 | 1085 | ]
|
980 | 1086 | }
|
981 | 1087 | ],
|
|
0 commit comments