|
848 | 848 | "**NOTE to self**: (Planned test on CPU/GPU for 25th Nov "
|
849 | 849 | ]
|
850 | 850 | },
|
| 851 | + { |
| 852 | + "cell_type": "code", |
| 853 | + "execution_count": 42, |
| 854 | + "metadata": { |
| 855 | + "collapsed": true |
| 856 | + }, |
| 857 | + "outputs": [], |
| 858 | + "source": [ |
| 859 | + "class Model(object):\n", |
| 860 | + " def __init__(self):\n", |
| 861 | + " self.layer = tf.layers.Dense(1)\n", |
| 862 | + " \n", |
| 863 | + " def predict(self, inputs):\n", |
| 864 | + " return self.layer(inputs)" |
| 865 | + ] |
| 866 | + }, |
851 | 867 | {
|
852 | 868 | "cell_type": "markdown",
|
853 | 869 | "metadata": {},
|
854 |
| - "source": [] |
| 870 | + "source": [ |
| 871 | + "#### Note: What does tf.layers API do\n", |
| 872 | + "\n", |
| 873 | + "Next, you'll see how tf.layers API makes it easy to define sophisticated models. Well, we're used to the beauty of Keras already, lets see how it plays out in TensorFlow 1.4.x" |
| 874 | + ] |
| 875 | + }, |
| 876 | + { |
| 877 | + "cell_type": "code", |
| 878 | + "execution_count": null, |
| 879 | + "metadata": {}, |
| 880 | + "outputs": [ |
| 881 | + { |
| 882 | + "name": "stdout", |
| 883 | + "output_type": "stream", |
| 884 | + "text": [ |
| 885 | + "Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.\n", |
| 886 | + "Extracting ./mnist_data/train-images-idx3-ubyte.gz\n", |
| 887 | + "Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.\n", |
| 888 | + "Extracting ./mnist_data/train-labels-idx1-ubyte.gz\n", |
| 889 | + "Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n", |
| 890 | + "Extracting ./mnist_data/t10k-images-idx3-ubyte.gz\n", |
| 891 | + "Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n", |
| 892 | + "Extracting ./mnist_data/t10k-labels-idx1-ubyte.gz\n", |
| 893 | + "WARNING:tensorflow:From <ipython-input-43-9ff01676a66a>:35: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n", |
| 894 | + "Instructions for updating:\n", |
| 895 | + "\n", |
| 896 | + "Future major versions of TensorFlow will allow gradients to flow\n", |
| 897 | + "into the labels input on backprop by default.\n", |
| 898 | + "\n", |
| 899 | + "See tf.nn.softmax_cross_entropy_with_logits_v2.\n", |
| 900 | + "\n", |
| 901 | + "Step 0: Loss on training set : 2.256627\n", |
| 902 | + "Step 100: Loss on training set : 0.399705\n", |
| 903 | + "Step 200: Loss on training set : 0.214346\n", |
| 904 | + "Step 300: Loss on training set : 0.211841\n", |
| 905 | + "Step 400: Loss on training set : 0.119034\n", |
| 906 | + "Step 500: Loss on training set : 0.085323\n", |
| 907 | + "Step 600: Loss on training set : 0.231022\n", |
| 908 | + "Step 700: Loss on training set : 0.061058\n", |
| 909 | + "Step 800: Loss on training set : 0.196906\n" |
| 910 | + ] |
| 911 | + } |
| 912 | + ], |
| 913 | + "source": [ |
| 914 | + "class MNISTModel(object):\n", |
| 915 | + " def __init__(self, data_format):\n", |
| 916 | + " # 'channels_first' is typically faster on GPUs\n", |
| 917 | + " # while 'channels_last' is typically faster on CPUs.\n", |
| 918 | + " # See: https://www.tensorflow.org/performance/performance_guide#data_formats\n", |
| 919 | + " if data_format == 'channels_first':\n", |
| 920 | + " self._input_shape = [-1, 1, 28, 28]\n", |
| 921 | + " else:\n", |
| 922 | + " self._input_shape = [-1, 28, 28, 1]\n", |
| 923 | + " self.conv1 = tf.layers.Conv2D(32, 5,\n", |
| 924 | + " padding='same',\n", |
| 925 | + " activation=tf.nn.relu,\n", |
| 926 | + " data_format=data_format)\n", |
| 927 | + " self.max_pool2d = tf.layers.MaxPooling2D(\n", |
| 928 | + " (2, 2), (2, 2), padding='same', data_format=data_format)\n", |
| 929 | + " self.conv2 = tf.layers.Conv2D(64, 5,\n", |
| 930 | + " padding='same',\n", |
| 931 | + " activation=tf.nn.relu,\n", |
| 932 | + " data_format=data_format)\n", |
| 933 | + " self.dense1 = tf.layers.Dense(1024, activation=tf.nn.relu)\n", |
| 934 | + " self.dropout = tf.layers.Dropout(0.5)\n", |
| 935 | + " self.dense2 = tf.layers.Dense(10)\n", |
| 936 | + "\n", |
| 937 | + " def predict(self, inputs):\n", |
| 938 | + " x = tf.reshape(inputs, self._input_shape)\n", |
| 939 | + " x = self.max_pool2d(self.conv1(x))\n", |
| 940 | + " x = self.max_pool2d(self.conv2(x))\n", |
| 941 | + " x = tf.layers.flatten(x)\n", |
| 942 | + " x = self.dropout(self.dense1(x))\n", |
| 943 | + " return self.dense2(x)\n", |
| 944 | + "\n", |
| 945 | + "def loss(model, inputs, targets):\n", |
| 946 | + " return tf.reduce_mean(\n", |
| 947 | + " tf.nn.softmax_cross_entropy_with_logits(\n", |
| 948 | + " logits=model.predict(inputs), labels=targets))\n", |
| 949 | + "\n", |
| 950 | + "\n", |
| 951 | + "# Load the training and validation data\n", |
| 952 | + "from tensorflow.examples.tutorials.mnist import input_data\n", |
| 953 | + "data = input_data.read_data_sets(\"./mnist_data\", one_hot=True)\n", |
| 954 | + "\n", |
| 955 | + "# Train\n", |
| 956 | + "device = \"gpu:0\" if tfe.num_gpus() else \"cpu:0\"\n", |
| 957 | + "model = MNISTModel('channels_first' if tfe.num_gpus() else 'channels_last')\n", |
| 958 | + "optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)\n", |
| 959 | + "grad = tfe.implicit_gradients(loss)\n", |
| 960 | + "for i in range(20001):\n", |
| 961 | + " with tf.device(device):\n", |
| 962 | + " (inputs, targets) = data.train.next_batch(50)\n", |
| 963 | + " optimizer.apply_gradients(grad(model, inputs, targets))\n", |
| 964 | + " if i % 100 == 0:\n", |
| 965 | + " print(\"Step %d: Loss on training set : %f\" %\n", |
| 966 | + " (i, loss(model, inputs, targets).numpy()))\n", |
| 967 | + "print(\"Loss on test set: %f\" % loss(model, data.test.images, data.test.labels).numpy())" |
| 968 | + ] |
855 | 969 | },
|
856 | 970 | {
|
857 | 971 | "cell_type": "code",
|
|
860 | 974 | "collapsed": true
|
861 | 975 | },
|
862 | 976 | "outputs": [],
|
863 |
| - "source": [] |
| 977 | + "source": [ |
| 978 | + "# Running with tf.nn.softmax_cross_entropy_with_logits_v2\n" |
| 979 | + ] |
864 | 980 | }
|
865 | 981 | ],
|
866 | 982 | "metadata": {
|
|
0 commit comments