From a01f788aaf33bf770e418132c8017955a9f7831b Mon Sep 17 00:00:00 2001 From: Mat Leonard Date: Sun, 7 Jan 2018 09:00:15 -0800 Subject: [PATCH 1/2] Convert network and training to PyTorch --- intro-to-rnns/Anna_KaRNNa_Solution.ipynb | 500 +++++------------------ 1 file changed, 113 insertions(+), 387 deletions(-) diff --git a/intro-to-rnns/Anna_KaRNNa_Solution.ipynb b/intro-to-rnns/Anna_KaRNNa_Solution.ipynb index 9a09d2cdd9..7f5f0ff4e8 100644 --- a/intro-to-rnns/Anna_KaRNNa_Solution.ipynb +++ b/intro-to-rnns/Anna_KaRNNa_Solution.ipynb @@ -20,10 +20,8 @@ "outputs": [], "source": [ "import time\n", - "from collections import namedtuple\n", "\n", - "import numpy as np\n", - "import tensorflow as tf" + "import numpy as np" ] }, { @@ -119,6 +117,26 @@ "The way I like to do this window is use `range` to take steps of size `n_steps` from $0$ to `arr.shape[1]`, the total number of steps in each sequence. That way, the integers you get from `range` always point to the start of a batch, and each window is `n_steps` wide." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def one_hot_encode(arr, n_labels):\n", + " \n", + " # Initialize the the encoded array\n", + " one_hot = np.zeros((np.multiply(*arr.shape), n_labels), dtype=np.float32)\n", + " \n", + " # Fill the appropriate elements with ones\n", + " one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.\n", + " \n", + " # Finally reshape it to get back to the original array\n", + " one_hot = one_hot.reshape((*arr.shape, n_labels))\n", + " \n", + " return one_hot" + ] + }, { "cell_type": "code", "execution_count": null, @@ -229,12 +247,7 @@ "\n", "Below is where you'll build the network. We'll break it up into parts so it's easier to reason about each bit. Then we can connect them up into the whole network.\n", "\n", - "\n", - "\n", - "\n", - "### Inputs\n", - "\n", - "First off we'll create our input placeholders. As usual we need placeholders for the training data and the targets. We'll also create a placeholder for dropout layers called `keep_prob`." + "\n" ] }, { @@ -243,71 +256,10 @@ "metadata": {}, "outputs": [], "source": [ - "def build_inputs(batch_size, num_steps):\n", - " ''' Define placeholders for inputs, targets, and dropout \n", - " \n", - " Arguments\n", - " ---------\n", - " batch_size: Batch size, number of sequences per batch\n", - " num_steps: Number of sequence steps in a batch\n", - " \n", - " '''\n", - " # Declare placeholders we'll feed into the graph\n", - " inputs = tf.placeholder(tf.int32, [batch_size, num_steps], name='inputs')\n", - " targets = tf.placeholder(tf.int32, [batch_size, num_steps], name='targets')\n", - " \n", - " # Keep probability placeholder for drop out layers\n", - " keep_prob = tf.placeholder(tf.float32, name='keep_prob')\n", - " \n", - " return inputs, targets, keep_prob" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### LSTM Cell\n", - "\n", - "Here we will create the LSTM cell we'll use in the hidden layer. We'll use this cell as a building block for the RNN. So we aren't actually defining the RNN here, just the type of cell we'll use in the hidden layer.\n", - "\n", - "We first create a basic LSTM cell with\n", - "\n", - "```python\n", - "lstm = tf.contrib.rnn.BasicLSTMCell(num_units)\n", - "```\n", - "\n", - "where `num_units` is the number of units in the hidden layers in the cell. Then we can add dropout by wrapping it with \n", - "\n", - "```python\n", - "tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob)\n", - "```\n", - "You pass in a cell and it will automatically add dropout to the inputs or outputs. Finally, we can stack up the LSTM cells into layers with [`tf.contrib.rnn.MultiRNNCell`](https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/contrib/rnn/MultiRNNCell). With this, you pass in a list of cells and it will send the output of one cell into the next cell. Previously with TensorFlow 1.0, you could do this\n", - "\n", - "```python\n", - "tf.contrib.rnn.MultiRNNCell([cell]*num_layers)\n", - "```\n", - "\n", - "This might look a little weird if you know Python well because this will create a list of the same `cell` object. However, TensorFlow 1.0 will create different weight matrices for all `cell` objects. But, starting with TensorFlow 1.1 you actually need to create new cell objects in the list. To get it to work in TensorFlow 1.1, it should look like\n", - "\n", - "```python\n", - "def build_cell(num_units, keep_prob):\n", - " lstm = tf.contrib.rnn.BasicLSTMCell(num_units)\n", - " drop = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob)\n", - " \n", - " return drop\n", - " \n", - "tf.contrib.rnn.MultiRNNCell([build_cell(num_units, keep_prob) for _ in range(num_layers)])\n", - "```\n", - "\n", - "Even though this is actually multiple LSTM cells stacked on each other, you can treat the multiple layers as one cell.\n", - "\n", - "We also need to create an initial cell state of all zeros. This can be done like so\n", - "\n", - "```python\n", - "initial_state = cell.zero_state(batch_size, tf.float32)\n", - "```\n", - "\n", - "Below, we implement the `build_lstm` function to create these LSTM cells and the initial state." + "import torch\n", + "from torch import nn, optim\n", + "from torch.autograd import Variable\n", + "import torch.nn.functional as F" ] }, { @@ -316,167 +268,40 @@ "metadata": {}, "outputs": [], "source": [ - "def build_lstm(lstm_size, num_layers, batch_size, keep_prob):\n", - " ''' Build LSTM cell.\n", + "class CharRNN(nn.Module):\n", " \n", - " Arguments\n", - " ---------\n", - " keep_prob: Scalar tensor (tf.placeholder) for the dropout keep probability\n", - " lstm_size: Size of the hidden layers in the LSTM cells\n", - " num_layers: Number of LSTM layers\n", - " batch_size: Batch size\n", - "\n", - " '''\n", - " ### Build the LSTM Cell\n", - " \n", - " def build_cell(lstm_size, keep_prob):\n", - " # Use a basic LSTM cell\n", - " lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)\n", + " def __init__(self, n_tokens, n_steps=50, n_layers=2, \n", + " n_hidden=256, drop_prob=0.5):\n", " \n", - " # Add dropout to the cell\n", - " drop = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob)\n", - " return drop\n", - " \n", - " \n", - " # Stack up multiple LSTM layers, for deep learning\n", - " cell = tf.contrib.rnn.MultiRNNCell([build_cell(lstm_size, keep_prob) for _ in range(num_layers)])\n", - " initial_state = cell.zero_state(batch_size, tf.float32)\n", - " \n", - " return cell, initial_state" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### RNN Output\n", - "\n", - "Here we'll create the output layer. We need to connect the output of the RNN cells to a full connected layer with a softmax output. The softmax output gives us a probability distribution we can use to predict the next character.\n", - "\n", - "If our input has batch size $N$, number of steps $M$, and the hidden layer has $L$ hidden units, then the output is a 3D tensor with size $N \\times M \\times L$. The output of each LSTM cell has size $L$, we have $M$ of them, one for each sequence step, and we have $N$ sequences. So the total size is $N \\times M \\times L$.\n", - "\n", - "We are using the same fully connected layer, the same weights, for each of the outputs. Then, to make things easier, we should reshape the outputs into a 2D tensor with shape $(M * N) \\times L$. That is, one row for each sequence and step, where the values of each row are the output from the LSTM cells.\n", - "\n", - "One we have the outputs reshaped, we can do the matrix multiplication with the weights. We need to wrap the weight and bias variables in a variable scope with `tf.variable_scope(scope_name)` because there are weights being created in the LSTM cells. TensorFlow will throw an error if the weights created here have the same names as the weights created in the LSTM cells, which they will be default. To avoid this, we wrap the variables in a variable scope so we can give them unique names." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def build_output(lstm_output, in_size, out_size):\n", - " ''' Build a softmax layer, return the softmax output and logits.\n", - " \n", - " Arguments\n", - " ---------\n", + " super().__init__()\n", " \n", - " x: Input tensor\n", - " in_size: Size of the input tensor, for example, size of the LSTM cells\n", - " out_size: Size of this softmax layer\n", - " \n", - " '''\n", - "\n", - " # Reshape output so it's a bunch of rows, one row for each step for each sequence.\n", - " # That is, the shape should be batch_size*num_steps rows by lstm_size columns\n", - " seq_output = tf.concat(lstm_output, axis=1)\n", - " x = tf.reshape(seq_output, [-1, in_size])\n", - " \n", - " # Connect the RNN outputs to a softmax layer\n", - " with tf.variable_scope('softmax'):\n", - " softmax_w = tf.Variable(tf.truncated_normal((in_size, out_size), stddev=0.1))\n", - " softmax_b = tf.Variable(tf.zeros(out_size))\n", - " \n", - " # Since output is a bunch of rows of RNN cell outputs, logits will be a bunch\n", - " # of rows of logit outputs, one for each step and sequence\n", - " logits = tf.matmul(x, softmax_w) + softmax_b\n", - " \n", - " # Use softmax to get the probabilities for predicted characters\n", - " out = tf.nn.softmax(logits, name='predictions')\n", - " \n", - " return out, logits" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Training loss\n", - "\n", - "Next up is the training loss. We get the logits and targets and calculate the softmax cross-entropy loss. First we need to one-hot encode the targets, we're getting them as encoded characters. Then, reshape the one-hot targets so it's a 2D tensor with size $(M*N) \\times C$ where $C$ is the number of classes/characters we have. Remember that we reshaped the LSTM outputs and ran them through a fully connected layer with $C$ units. So our logits will also have size $(M*N) \\times C$.\n", - "\n", - "Then we run the logits and targets through `tf.nn.softmax_cross_entropy_with_logits` and find the mean to get the loss." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def build_loss(logits, targets, lstm_size, num_classes):\n", - " ''' Calculate the loss from the logits and the targets.\n", - " \n", - " Arguments\n", - " ---------\n", - " logits: Logits from final fully connected layer\n", - " targets: Targets for supervised learning\n", - " lstm_size: Number of LSTM hidden units\n", - " num_classes: Number of classes in targets\n", + " # Store parameters\n", + " self.chars = n_tokens\n", + " self.drop_prob = drop_prob\n", + " self.n_layers = n_layers\n", + " self.n_hidden = n_hidden\n", " \n", - " '''\n", - " \n", - " # One-hot encode targets and reshape to match logits, one row per batch_size per step\n", - " y_one_hot = tf.one_hot(targets, num_classes)\n", - " y_reshaped = tf.reshape(y_one_hot, logits.get_shape())\n", - " \n", - " # Softmax cross entropy loss\n", - " loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_reshaped)\n", - " loss = tf.reduce_mean(loss)\n", - " return loss" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Optimizer\n", - "\n", - "Here we build the optimizer. Normal RNNs have have issues gradients exploding and disappearing. LSTMs fix the disappearance problem, but the gradients can still grow without bound. To fix this, we can clip the gradients above some threshold. That is, if a gradient is larger than that threshold, we set it to the threshold. This will ensure the gradients never grow overly large. Then we use an AdamOptimizer for the learning step." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def build_optimizer(loss, learning_rate, grad_clip):\n", - " ''' Build optmizer for training, using gradient clipping.\n", - " \n", - " Arguments:\n", - " loss: Network loss\n", - " learning_rate: Learning rate for optimizer\n", - " \n", - " '''\n", - " \n", - " # Optimizer for training, using gradient clipping to control exploding gradients\n", - " tvars = tf.trainable_variables()\n", - " grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), grad_clip)\n", - " train_op = tf.train.AdamOptimizer(learning_rate)\n", - " optimizer = train_op.apply_gradients(zip(grads, tvars))\n", - " \n", - " return optimizer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Build the network\n", - "\n", - "Now we can put all the pieces together and build a class for the network. To actually run data through the LSTM cells, we will use [`tf.nn.dynamic_rnn`](https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/nn/dynamic_rnn). This function will pass the hidden and cell states across LSTM cells appropriately for us. It returns the outputs for each LSTM cell at each step for each sequence in the mini-batch. It also gives us the final LSTM state. We want to save this state as `final_state` so we can pass it to the first LSTM cell in the the next mini-batch run. For `tf.nn.dynamic_rnn`, we pass in the cell and initial state we get from `build_lstm`, as well as our input sequences. Also, we need to one-hot encode the inputs before going into the RNN. " + " # Define layers\n", + " self.dropout = nn.Dropout(drop_prob)\n", + " self.lstm = nn.LSTM(self.chars, n_hidden, n_layers, \n", + " dropout=drop_prob, batch_first=True)\n", + " self.fc = nn.Linear(n_hidden, self.chars)\n", + " \n", + " def forward(self, x, hc):\n", + " \n", + " # x = input, h = hidden state, c = cell state\n", + " x, (h, c) = self.lstm(x, hc)\n", + " \n", + " x = self.dropout(x)\n", + " \n", + " # Stack up LSTM outputs \n", + " batch_size = x.size()[0]\n", + " n_steps = x.size()[1]\n", + " x = x.view(batch_size * n_steps, self.n_hidden)\n", + " \n", + " x = self.fc(x)\n", + " \n", + " return x, (h, c)" ] }, { @@ -485,41 +310,13 @@ "metadata": {}, "outputs": [], "source": [ - "class CharRNN:\n", - " \n", - " def __init__(self, num_classes, batch_size=64, num_steps=50, \n", - " lstm_size=128, num_layers=2, learning_rate=0.001, \n", - " grad_clip=5, sampling=False):\n", - " \n", - " # When we're using this network for sampling later, we'll be passing in\n", - " # one character at a time, so providing an option for that\n", - " if sampling == True:\n", - " batch_size, num_steps = 1, 1\n", - " else:\n", - " batch_size, num_steps = batch_size, num_steps\n", - "\n", - " tf.reset_default_graph()\n", - " \n", - " # Build the input placeholder tensors\n", - " self.inputs, self.targets, self.keep_prob = build_inputs(batch_size, num_steps)\n", - "\n", - " # Build the LSTM cell\n", - " cell, self.initial_state = build_lstm(lstm_size, num_layers, batch_size, self.keep_prob)\n", - "\n", - " ### Run the data through the RNN layers\n", - " # First, one-hot encode the input tokens\n", - " x_one_hot = tf.one_hot(self.inputs, num_classes)\n", - " \n", - " # Run each sequence step through the RNN and collect the outputs\n", - " outputs, state = tf.nn.dynamic_rnn(cell, x_one_hot, initial_state=self.initial_state)\n", - " self.final_state = state\n", - " \n", - " # Get softmax predictions and logits\n", - " self.prediction, self.logits = build_output(outputs, lstm_size, num_classes)\n", - " \n", - " # Loss and optimizer (with gradient clipping)\n", - " self.loss = build_loss(self.logits, self.targets, lstm_size, num_classes)\n", - " self.optimizer = build_optimizer(self.loss, learning_rate, grad_clip)" + "def init_hidden(net, batch_size):\n", + " ''' Initializes hidden state '''\n", + " # Create two new tensors with sizes n_layers x batch_size x n_hidden,\n", + " # initialized to zero, for hidden state and cell state of LSTM\n", + " weight = next(net.parameters()).data\n", + " return (Variable(weight.new(net.n_layers, batch_size, net.n_hidden).zero_()),\n", + " Variable(weight.new(net.n_layers, batch_size, net.n_hidden).zero_()))" ] }, { @@ -575,24 +372,19 @@ "outputs": [], "source": [ "batch_size = 100 # Sequences per batch\n", - "num_steps = 100 # Number of sequence steps per batch\n", - "lstm_size = 512 # Size of hidden layers in LSTMs\n", - "num_layers = 2 # Number of LSTM layers\n", + "n_steps = 100 # Number of sequence steps per batch\n", + "lstm_size = 256 # Size of hidden layers in LSTMs\n", + "n_layers = 2 # Number of LSTM layers\n", "learning_rate = 0.001 # Learning rate\n", - "keep_prob = 0.5 # Dropout keep probability" + "drop_prob = 0.5 # Dropout drop probability\n", + "clip = 5 # Gradien clipping" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Time for training\n", - "\n", - "This is typical training code, passing inputs and targets into the network, then running the optimizer. Here we also get back the final LSTM state for the mini-batch. Then, we pass that state back into the network so the next batch can continue the state from the previous batch. And every so often (set by `save_every_n`) I save a checkpoint.\n", - "\n", - "Here I'm saving checkpoints with the format\n", - "\n", - "`i{iteration number}_l{# hidden layer units}.ckpt`" + "## Time for training" ] }, { @@ -603,68 +395,55 @@ }, "outputs": [], "source": [ - "epochs = 20\n", - "# Print losses every N interations\n", - "print_every_n = 50\n", + "epochs = 1\n", + "print_every = 10\n", + "cuda = False\n", "\n", - "# Save every N iterations\n", - "save_every_n = 200\n", + "net = CharRNN(len(vocab), n_steps=n_steps, n_layers=n_layers, \n", + " n_hidden=lstm_size, drop_prob=drop_prob)\n", "\n", - "model = CharRNN(len(vocab), batch_size=batch_size, num_steps=num_steps,\n", - " lstm_size=lstm_size, num_layers=num_layers, \n", - " learning_rate=learning_rate)\n", + "opt = optim.Adam(net.parameters(), lr=learning_rate)\n", + "criterion = nn.CrossEntropyLoss()\n", "\n", - "saver = tf.train.Saver(max_to_keep=100)\n", - "with tf.Session() as sess:\n", - " sess.run(tf.global_variables_initializer())\n", - " \n", - " # Use the line below to load a checkpoint and resume training\n", - " #saver.restore(sess, 'checkpoints/______.ckpt')\n", - " counter = 0\n", - " for e in range(epochs):\n", - " # Train network\n", - " new_state = sess.run(model.initial_state)\n", - " loss = 0\n", - " for x, y in get_batches(encoded, batch_size, num_steps):\n", - " counter += 1\n", - " start = time.time()\n", - " feed = {model.inputs: x,\n", - " model.targets: y,\n", - " model.keep_prob: keep_prob,\n", - " model.initial_state: new_state}\n", - " batch_loss, new_state, _ = sess.run([model.loss, \n", - " model.final_state, \n", - " model.optimizer], \n", - " feed_dict=feed)\n", - " if (counter % print_every_n == 0):\n", - " end = time.time()\n", - " print('Epoch: {}/{}... '.format(e+1, epochs),\n", - " 'Training Step: {}... '.format(counter),\n", - " 'Training loss: {:.4f}... '.format(batch_loss),\n", - " '{:.4f} sec/batch'.format((end-start)))\n", + "if cuda:\n", + " net.cuda()\n", + "\n", + "counter = 0\n", + "n_chars = len(vocab)\n", + "for e in range(epochs):\n", + " # init hc a tuple of (hidden, cell) states\n", + " hc = init_hidden(net, batch_size)\n", + " for x, y in get_batches(encoded, batch_size, n_steps):\n", + " counter += 1\n", + "\n", + " # One-hot encode our data and make them Torch tensors\n", + " x = one_hot_encode(x, n_chars)\n", + " x, y = torch.from_numpy(x), torch.from_numpy(y)\n", + "\n", + " inputs, targets = Variable(x), Variable(y.long())\n", + " if cuda:\n", + " inputs, targets = inputs.cuda(), targets.cuda()\n", + "\n", + " # Creating new variables for the hidden/cell state, otherwise\n", + " # we'd backprop through the entire training history\n", + " hc = tuple([Variable(each.data) for each in hc])\n", + "\n", + " net.zero_grad()\n", + "\n", + " output, hc = net.forward(inputs, hc)\n", + " loss = criterion(output, targets.view(batch_size * n_steps))\n", " \n", - " if (counter % save_every_n == 0):\n", - " saver.save(sess, \"checkpoints/i{}_l{}.ckpt\".format(counter, lstm_size))\n", - " \n", - " saver.save(sess, \"checkpoints/i{}_l{}.ckpt\".format(counter, lstm_size))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Saved checkpoints\n", + " loss.backward()\n", "\n", - "Read up on saving and loading checkpoints here: https://www.tensorflow.org/programmers_guide/variables" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tf.train.get_checkpoint_state('checkpoints')" + " # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.\n", + " nn.utils.clip_grad_norm(net.parameters(), clip)\n", + "\n", + " opt.step()\n", + "\n", + " if counter % print_every == 0:\n", + " print(\"Epoch: {}/{}...\".format(e+1, epochs),\n", + " \"Step: {}...\".format(counter),\n", + " \"Loss: {:.4f}...\".format(loss.data[0]))" ] }, { @@ -738,59 +517,6 @@ "source": [ "Here, pass in the path to a checkpoint and sample from the network." ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tf.train.latest_checkpoint('checkpoints')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "checkpoint = tf.train.latest_checkpoint('checkpoints')\n", - "samp = sample(checkpoint, 2000, lstm_size, len(vocab), prime=\"Far\")\n", - "print(samp)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "checkpoint = 'checkpoints/i200_l512.ckpt'\n", - "samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime=\"Far\")\n", - "print(samp)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "checkpoint = 'checkpoints/i600_l512.ckpt'\n", - "samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime=\"Far\")\n", - "print(samp)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "checkpoint = 'checkpoints/i1200_l512.ckpt'\n", - "samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime=\"Far\")\n", - "print(samp)" - ] } ], "metadata": { From eb22abd2945cf9759167851b989020d3a7ed5baa Mon Sep 17 00:00:00 2001 From: Mat Leonard Date: Sun, 7 Jan 2018 22:18:42 -0800 Subject: [PATCH 2/2] Convert code for sampling text --- intro-to-rnns/Anna_KaRNNa_Solution.ipynb | 708 ++++++++++++++++++++--- 1 file changed, 635 insertions(+), 73 deletions(-) diff --git a/intro-to-rnns/Anna_KaRNNa_Solution.ipynb b/intro-to-rnns/Anna_KaRNNa_Solution.ipynb index 7f5f0ff4e8..3a7b52dd36 100644 --- a/intro-to-rnns/Anna_KaRNNa_Solution.ipynb +++ b/intro-to-rnns/Anna_KaRNNa_Solution.ipynb @@ -15,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -33,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -54,9 +54,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'Chapter 1\\n\\n\\nHappy families are all alike; every unhappy family is unhappy in its own\\nway.\\n\\nEverythin'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "text[:100]" ] @@ -70,9 +81,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "array([31, 64, 57, 72, 76, 61, 74, 1, 16, 0, 0, 0, 36, 57, 72, 72, 81,\n", + " 1, 62, 57, 69, 65, 68, 65, 61, 75, 1, 57, 74, 61, 1, 57, 68, 68,\n", + " 1, 57, 68, 65, 67, 61, 26, 1, 61, 78, 61, 74, 81, 1, 77, 70, 64,\n", + " 57, 72, 72, 81, 1, 62, 57, 69, 65, 68, 81, 1, 65, 75, 1, 77, 70,\n", + " 64, 57, 72, 72, 81, 1, 65, 70, 1, 65, 76, 75, 1, 71, 79, 70, 0,\n", + " 79, 57, 81, 13, 0, 0, 33, 78, 61, 74, 81, 76, 64, 65, 70], dtype=int32)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "encoded[:100]" ] @@ -86,9 +113,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "83" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "len(vocab)" ] @@ -119,7 +157,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -139,7 +177,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -188,7 +226,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -198,9 +236,39 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x\n", + " [[31 64 57 72 76 61 74 1 16 0]\n", + " [ 1 57 69 1 70 71 76 1 63 71]\n", + " [78 65 70 13 0 0 3 53 61 75]\n", + " [70 1 60 77 74 65 70 63 1 64]\n", + " [ 1 65 76 1 65 75 11 1 75 65]\n", + " [ 1 37 76 1 79 57 75 0 71 70]\n", + " [64 61 70 1 59 71 69 61 1 62]\n", + " [26 1 58 77 76 1 70 71 79 1]\n", + " [76 1 65 75 70 7 76 13 1 48]\n", + " [ 1 75 57 65 60 1 76 71 1 64]]\n", + "\n", + "y\n", + " [[64 57 72 76 61 74 1 16 0 0]\n", + " [57 69 1 70 71 76 1 63 71 65]\n", + " [65 70 13 0 0 3 53 61 75 11]\n", + " [ 1 60 77 74 65 70 63 1 64 65]\n", + " [65 76 1 65 75 11 1 75 65 74]\n", + " [37 76 1 79 57 75 0 71 70 68]\n", + " [61 70 1 59 71 69 61 1 62 71]\n", + " [ 1 58 77 76 1 70 71 79 1 75]\n", + " [ 1 65 75 70 7 76 13 1 48 64]\n", + " [75 57 65 60 1 76 71 1 64 61]]\n" + ] + } + ], "source": [ "print('x\\n', x[:10, :10])\n", "print('\\ny\\n', y[:10, :10])" @@ -252,7 +320,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -264,7 +332,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -276,16 +344,16 @@ " super().__init__()\n", " \n", " # Store parameters\n", - " self.chars = n_tokens\n", + " self.n_tokens = n_tokens\n", " self.drop_prob = drop_prob\n", " self.n_layers = n_layers\n", " self.n_hidden = n_hidden\n", " \n", " # Define layers\n", " self.dropout = nn.Dropout(drop_prob)\n", - " self.lstm = nn.LSTM(self.chars, n_hidden, n_layers, \n", + " self.lstm = nn.LSTM(self.n_tokens, n_hidden, n_layers, \n", " dropout=drop_prob, batch_first=True)\n", - " self.fc = nn.Linear(n_hidden, self.chars)\n", + " self.fc = nn.Linear(n_hidden, self.n_tokens)\n", " \n", " def forward(self, x, hc):\n", " \n", @@ -299,14 +367,14 @@ " n_steps = x.size()[1]\n", " x = x.view(batch_size * n_steps, self.n_hidden)\n", " \n", - " x = self.fc(x)\n", + " x = F.log_softmax(self.fc(x), dim=1)\n", " \n", " return x, (h, c)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -367,17 +435,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 167, "metadata": {}, "outputs": [], "source": [ "batch_size = 100 # Sequences per batch\n", "n_steps = 100 # Number of sequence steps per batch\n", - "lstm_size = 256 # Size of hidden layers in LSTMs\n", + "lstm_size = 512 # Size of hidden layers in LSTMs\n", "n_layers = 2 # Number of LSTM layers\n", - "learning_rate = 0.001 # Learning rate\n", - "drop_prob = 0.5 # Dropout drop probability\n", - "clip = 5 # Gradien clipping" + "learning_rate = 0.005 # Learning rate\n", + "drop_prob = 0.2 # Dropout drop probability\n", + "clip = 5 # Gradient clipping" ] }, { @@ -389,21 +457,434 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], + "execution_count": 168, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 1/20... Step: 10... Loss: 3.1692...\n", + "Epoch: 1/20... Step: 20... Loss: 3.1058...\n", + "Epoch: 1/20... Step: 30... Loss: 3.1249...\n", + "Epoch: 1/20... Step: 40... Loss: 3.0916...\n", + "Epoch: 1/20... Step: 50... Loss: 3.0507...\n", + "Epoch: 1/20... Step: 60... Loss: 2.9645...\n", + "Epoch: 1/20... Step: 70... Loss: 2.8561...\n", + "Epoch: 1/20... Step: 80... Loss: 2.7512...\n", + "Epoch: 1/20... Step: 90... Loss: 2.6269...\n", + "Epoch: 1/20... Step: 100... Loss: 2.5104...\n", + "Epoch: 1/20... Step: 110... Loss: 2.4128...\n", + "Epoch: 1/20... Step: 120... Loss: 2.3846...\n", + "Epoch: 1/20... Step: 130... Loss: 2.3226...\n", + "Epoch: 1/20... Step: 140... Loss: 2.2583...\n", + "Epoch: 1/20... Step: 150... Loss: 2.2081...\n", + "Epoch: 1/20... Step: 160... Loss: 2.1868...\n", + "Epoch: 1/20... Step: 170... Loss: 2.1093...\n", + "Epoch: 1/20... Step: 180... Loss: 2.0707...\n", + "Epoch: 1/20... Step: 190... Loss: 2.0576...\n", + "Epoch: 2/20... Step: 200... Loss: 2.2204...\n", + "Epoch: 2/20... Step: 210... Loss: 2.0383...\n", + "Epoch: 2/20... Step: 220... Loss: 2.0298...\n", + "Epoch: 2/20... Step: 230... Loss: 1.9313...\n", + "Epoch: 2/20... Step: 240... Loss: 1.9093...\n", + "Epoch: 2/20... Step: 250... Loss: 1.9194...\n", + "Epoch: 2/20... Step: 260... Loss: 1.8927...\n", + "Epoch: 2/20... Step: 270... Loss: 1.8406...\n", + "Epoch: 2/20... Step: 280... Loss: 1.8210...\n", + "Epoch: 2/20... Step: 290... Loss: 1.7735...\n", + "Epoch: 2/20... Step: 300... Loss: 1.7411...\n", + "Epoch: 2/20... Step: 310... Loss: 1.7121...\n", + "Epoch: 2/20... Step: 320... Loss: 1.6770...\n", + "Epoch: 2/20... Step: 330... Loss: 1.6775...\n", + "Epoch: 2/20... Step: 340... Loss: 1.6657...\n", + "Epoch: 2/20... Step: 350... Loss: 1.6849...\n", + "Epoch: 2/20... Step: 360... Loss: 1.6440...\n", + "Epoch: 2/20... Step: 370... Loss: 1.6214...\n", + "Epoch: 2/20... Step: 380... Loss: 1.6167...\n", + "Epoch: 2/20... Step: 390... Loss: 1.5718...\n", + "Epoch: 3/20... Step: 400... Loss: 1.5755...\n", + "Epoch: 3/20... Step: 410... Loss: 1.5963...\n", + "Epoch: 3/20... Step: 420... Loss: 1.5416...\n", + "Epoch: 3/20... Step: 430... Loss: 1.5538...\n", + "Epoch: 3/20... Step: 440... Loss: 1.4846...\n", + "Epoch: 3/20... Step: 450... Loss: 1.5144...\n", + "Epoch: 3/20... Step: 460... Loss: 1.5081...\n", + "Epoch: 3/20... Step: 470... Loss: 1.4719...\n", + "Epoch: 3/20... Step: 480... Loss: 1.4791...\n", + "Epoch: 3/20... Step: 490... Loss: 1.4432...\n", + "Epoch: 3/20... Step: 500... Loss: 1.4348...\n", + "Epoch: 3/20... Step: 510... Loss: 1.4380...\n", + "Epoch: 3/20... Step: 520... Loss: 1.4433...\n", + "Epoch: 3/20... Step: 530... Loss: 1.4281...\n", + "Epoch: 3/20... Step: 540... Loss: 1.4504...\n", + "Epoch: 3/20... Step: 550... Loss: 1.4127...\n", + "Epoch: 3/20... Step: 560... Loss: 1.4049...\n", + "Epoch: 3/20... Step: 570... Loss: 1.4241...\n", + "Epoch: 3/20... Step: 580... Loss: 1.3958...\n", + "Epoch: 3/20... Step: 590... Loss: 1.3745...\n", + "Epoch: 4/20... Step: 600... Loss: 1.3359...\n", + "Epoch: 4/20... Step: 610... Loss: 1.3501...\n", + "Epoch: 4/20... Step: 620... Loss: 1.3370...\n", + "Epoch: 4/20... Step: 630... Loss: 1.3559...\n", + "Epoch: 4/20... Step: 640... Loss: 1.3218...\n", + "Epoch: 4/20... Step: 650... Loss: 1.3531...\n", + "Epoch: 4/20... Step: 660... Loss: 1.3546...\n", + "Epoch: 4/20... Step: 670... Loss: 1.3557...\n", + "Epoch: 4/20... Step: 680... Loss: 1.3250...\n", + "Epoch: 4/20... Step: 690... Loss: 1.3389...\n", + "Epoch: 4/20... Step: 700... Loss: 1.3143...\n", + "Epoch: 4/20... Step: 710... Loss: 1.2854...\n", + "Epoch: 4/20... Step: 720... Loss: 1.2696...\n", + "Epoch: 4/20... Step: 730... Loss: 1.3260...\n", + "Epoch: 4/20... Step: 740... Loss: 1.3268...\n", + "Epoch: 4/20... Step: 750... Loss: 1.2980...\n", + "Epoch: 4/20... Step: 760... Loss: 1.2930...\n", + "Epoch: 4/20... Step: 770... Loss: 1.2784...\n", + "Epoch: 4/20... Step: 780... Loss: 1.2795...\n", + "Epoch: 4/20... Step: 790... Loss: 1.2977...\n", + "Epoch: 5/20... Step: 800... Loss: 1.2804...\n", + "Epoch: 5/20... Step: 810... Loss: 1.2957...\n", + "Epoch: 5/20... Step: 820... Loss: 1.3009...\n", + "Epoch: 5/20... Step: 830... Loss: 1.2251...\n", + "Epoch: 5/20... Step: 840... Loss: 1.2477...\n", + "Epoch: 5/20... Step: 850... Loss: 1.2555...\n", + "Epoch: 5/20... Step: 860... Loss: 1.2471...\n", + "Epoch: 5/20... Step: 870... Loss: 1.2591...\n", + "Epoch: 5/20... Step: 880... Loss: 1.2427...\n", + "Epoch: 5/20... Step: 890... Loss: 1.2130...\n", + "Epoch: 5/20... Step: 900... Loss: 1.2422...\n", + "Epoch: 5/20... Step: 910... Loss: 1.2516...\n", + "Epoch: 5/20... Step: 920... Loss: 1.2352...\n", + "Epoch: 5/20... Step: 930... Loss: 1.2548...\n", + "Epoch: 5/20... Step: 940... Loss: 1.2842...\n", + "Epoch: 5/20... Step: 950... Loss: 1.2185...\n", + "Epoch: 5/20... Step: 960... Loss: 1.3003...\n", + "Epoch: 5/20... Step: 970... Loss: 1.2644...\n", + "Epoch: 5/20... Step: 980... Loss: 1.2129...\n", + "Epoch: 5/20... Step: 990... Loss: 1.2945...\n", + "Epoch: 6/20... Step: 1000... Loss: 1.2140...\n", + "Epoch: 6/20... Step: 1010... Loss: 1.2499...\n", + "Epoch: 6/20... Step: 1020... Loss: 1.2308...\n", + "Epoch: 6/20... Step: 1030... Loss: 1.2026...\n", + "Epoch: 6/20... Step: 1040... Loss: 1.2146...\n", + "Epoch: 6/20... Step: 1050... Loss: 1.2441...\n", + "Epoch: 6/20... Step: 1060... Loss: 1.2010...\n", + "Epoch: 6/20... Step: 1070... Loss: 1.2013...\n", + "Epoch: 6/20... Step: 1080... Loss: 1.2019...\n", + "Epoch: 6/20... Step: 1090... Loss: 1.1829...\n", + "Epoch: 6/20... Step: 1100... Loss: 1.2022...\n", + "Epoch: 6/20... Step: 1110... Loss: 1.1880...\n", + "Epoch: 6/20... Step: 1120... Loss: 1.1462...\n", + "Epoch: 6/20... Step: 1130... Loss: 1.2097...\n", + "Epoch: 6/20... Step: 1140... Loss: 1.1853...\n", + "Epoch: 6/20... Step: 1150... Loss: 1.2118...\n", + "Epoch: 6/20... Step: 1160... Loss: 1.1953...\n", + "Epoch: 6/20... Step: 1170... Loss: 1.1623...\n", + "Epoch: 6/20... Step: 1180... Loss: 1.1799...\n", + "Epoch: 7/20... Step: 1190... Loss: 1.1995...\n", + "Epoch: 7/20... Step: 1200... Loss: 1.1771...\n", + "Epoch: 7/20... Step: 1210... Loss: 1.1963...\n", + "Epoch: 7/20... Step: 1220... Loss: 1.1507...\n", + "Epoch: 7/20... Step: 1230... Loss: 1.2124...\n", + "Epoch: 7/20... Step: 1240... Loss: 1.1978...\n", + "Epoch: 7/20... Step: 1250... Loss: 1.2087...\n", + "Epoch: 7/20... Step: 1260... Loss: 1.1959...\n", + "Epoch: 7/20... Step: 1270... Loss: 1.1614...\n", + "Epoch: 7/20... Step: 1280... Loss: 1.1602...\n", + "Epoch: 7/20... Step: 1290... Loss: 1.1511...\n", + "Epoch: 7/20... Step: 1300... Loss: 1.1543...\n", + "Epoch: 7/20... Step: 1310... Loss: 1.1284...\n", + "Epoch: 7/20... Step: 1320... Loss: 1.1423...\n", + "Epoch: 7/20... Step: 1330... Loss: 1.1699...\n", + "Epoch: 7/20... Step: 1340... Loss: 1.1290...\n", + "Epoch: 7/20... Step: 1350... Loss: 1.1633...\n", + "Epoch: 7/20... Step: 1360... Loss: 1.1406...\n", + "Epoch: 7/20... Step: 1370... Loss: 1.1640...\n", + "Epoch: 7/20... Step: 1380... Loss: 1.1205...\n", + "Epoch: 8/20... Step: 1390... Loss: 1.1685...\n", + "Epoch: 8/20... Step: 1400... Loss: 1.1541...\n", + "Epoch: 8/20... Step: 1410... Loss: 1.1510...\n", + "Epoch: 8/20... Step: 1420... Loss: 1.1437...\n", + "Epoch: 8/20... Step: 1430... Loss: 1.1136...\n", + "Epoch: 8/20... Step: 1440... Loss: 1.1488...\n", + "Epoch: 8/20... Step: 1450... Loss: 1.1541...\n", + "Epoch: 8/20... Step: 1460... Loss: 1.1170...\n", + "Epoch: 8/20... Step: 1470... Loss: 1.1352...\n", + "Epoch: 8/20... Step: 1480... Loss: 1.1041...\n", + "Epoch: 8/20... Step: 1490... Loss: 1.1107...\n", + "Epoch: 8/20... Step: 1500... Loss: 1.1200...\n", + "Epoch: 8/20... Step: 1510... Loss: 1.1180...\n", + "Epoch: 8/20... Step: 1520... Loss: 1.1236...\n", + "Epoch: 8/20... Step: 1530... Loss: 1.1509...\n", + "Epoch: 8/20... Step: 1540... Loss: 1.1254...\n", + "Epoch: 8/20... Step: 1550... Loss: 1.1242...\n", + "Epoch: 8/20... Step: 1560... Loss: 1.1287...\n", + "Epoch: 8/20... Step: 1570... Loss: 1.1241...\n", + "Epoch: 8/20... Step: 1580... Loss: 1.1030...\n", + "Epoch: 9/20... Step: 1590... Loss: 1.0936...\n", + "Epoch: 9/20... Step: 1600... Loss: 1.0958...\n", + "Epoch: 9/20... Step: 1610... Loss: 1.0834...\n", + "Epoch: 9/20... Step: 1620... Loss: 1.1247...\n", + "Epoch: 9/20... Step: 1630... Loss: 1.0970...\n", + "Epoch: 9/20... Step: 1640... Loss: 1.1177...\n", + "Epoch: 9/20... Step: 1650... Loss: 1.1147...\n", + "Epoch: 9/20... Step: 1660... Loss: 1.1276...\n", + "Epoch: 9/20... Step: 1670... Loss: 1.1178...\n", + "Epoch: 9/20... Step: 1680... Loss: 1.1198...\n", + "Epoch: 9/20... Step: 1690... Loss: 1.0939...\n", + "Epoch: 9/20... Step: 1700... Loss: 1.0703...\n", + "Epoch: 9/20... Step: 1710... Loss: 1.0671...\n", + "Epoch: 9/20... Step: 1720... Loss: 1.1006...\n", + "Epoch: 9/20... Step: 1730... Loss: 1.1253...\n", + "Epoch: 9/20... Step: 1740... Loss: 1.0978...\n", + "Epoch: 9/20... Step: 1750... Loss: 1.0984...\n", + "Epoch: 9/20... Step: 1760... Loss: 1.0889...\n", + "Epoch: 9/20... Step: 1770... Loss: 1.0783...\n", + "Epoch: 9/20... Step: 1780... Loss: 1.1042...\n", + "Epoch: 10/20... Step: 1790... Loss: 1.0870...\n", + "Epoch: 10/20... Step: 1800... Loss: 1.1159...\n", + "Epoch: 10/20... Step: 1810... Loss: 1.1249...\n", + "Epoch: 10/20... Step: 1820... Loss: 1.0647...\n", + "Epoch: 10/20... Step: 1830... Loss: 1.0700...\n", + "Epoch: 10/20... Step: 1840... Loss: 1.0790...\n", + "Epoch: 10/20... Step: 1850... Loss: 1.0845...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 10/20... Step: 1860... Loss: 1.0794...\n", + "Epoch: 10/20... Step: 1870... Loss: 1.0722...\n", + "Epoch: 10/20... Step: 1880... Loss: 1.0575...\n", + "Epoch: 10/20... Step: 1890... Loss: 1.0870...\n", + "Epoch: 10/20... Step: 1900... Loss: 1.0957...\n", + "Epoch: 10/20... Step: 1910... Loss: 1.0845...\n", + "Epoch: 10/20... Step: 1920... Loss: 1.1070...\n", + "Epoch: 10/20... Step: 1930... Loss: 1.1184...\n", + "Epoch: 10/20... Step: 1940... Loss: 1.0567...\n", + "Epoch: 10/20... Step: 1950... Loss: 1.1360...\n", + "Epoch: 10/20... Step: 1960... Loss: 1.1225...\n", + "Epoch: 10/20... Step: 1970... Loss: 1.0764...\n", + "Epoch: 10/20... Step: 1980... Loss: 1.1680...\n", + "Epoch: 11/20... Step: 1990... Loss: 1.0801...\n", + "Epoch: 11/20... Step: 2000... Loss: 1.1032...\n", + "Epoch: 11/20... Step: 2010... Loss: 1.1046...\n", + "Epoch: 11/20... Step: 2020... Loss: 1.0832...\n", + "Epoch: 11/20... Step: 2030... Loss: 1.0842...\n", + "Epoch: 11/20... Step: 2040... Loss: 1.1037...\n", + "Epoch: 11/20... Step: 2050... Loss: 1.0748...\n", + "Epoch: 11/20... Step: 2060... Loss: 1.0640...\n", + "Epoch: 11/20... Step: 2070... Loss: 1.0818...\n", + "Epoch: 11/20... Step: 2080... Loss: 1.0496...\n", + "Epoch: 11/20... Step: 2090... Loss: 1.0796...\n", + "Epoch: 11/20... Step: 2100... Loss: 1.0579...\n", + "Epoch: 11/20... Step: 2110... Loss: 1.0459...\n", + "Epoch: 11/20... Step: 2120... Loss: 1.0746...\n", + "Epoch: 11/20... Step: 2130... Loss: 1.0608...\n", + "Epoch: 11/20... Step: 2140... Loss: 1.0809...\n", + "Epoch: 11/20... Step: 2150... Loss: 1.0836...\n", + "Epoch: 11/20... Step: 2160... Loss: 1.0566...\n", + "Epoch: 11/20... Step: 2170... Loss: 1.0722...\n", + "Epoch: 12/20... Step: 2180... Loss: 1.0836...\n", + "Epoch: 12/20... Step: 2190... Loss: 1.0646...\n", + "Epoch: 12/20... Step: 2200... Loss: 1.0784...\n", + "Epoch: 12/20... Step: 2210... Loss: 1.0442...\n", + "Epoch: 12/20... Step: 2220... Loss: 1.0989...\n", + "Epoch: 12/20... Step: 2230... Loss: 1.0796...\n", + "Epoch: 12/20... Step: 2240... Loss: 1.0965...\n", + "Epoch: 12/20... Step: 2250... Loss: 1.0681...\n", + "Epoch: 12/20... Step: 2260... Loss: 1.0713...\n", + "Epoch: 12/20... Step: 2270... Loss: 1.0511...\n", + "Epoch: 12/20... Step: 2280... Loss: 1.0368...\n", + "Epoch: 12/20... Step: 2290... Loss: 1.0529...\n", + "Epoch: 12/20... Step: 2300... Loss: 1.0194...\n", + "Epoch: 12/20... Step: 2310... Loss: 1.0438...\n", + "Epoch: 12/20... Step: 2320... Loss: 1.0654...\n", + "Epoch: 12/20... Step: 2330... Loss: 1.0309...\n", + "Epoch: 12/20... Step: 2340... Loss: 1.0638...\n", + "Epoch: 12/20... Step: 2350... Loss: 1.0321...\n", + "Epoch: 12/20... Step: 2360... Loss: 1.0654...\n", + "Epoch: 12/20... Step: 2370... Loss: 1.0242...\n", + "Epoch: 13/20... Step: 2380... Loss: 1.0763...\n", + "Epoch: 13/20... Step: 2390... Loss: 1.0489...\n", + "Epoch: 13/20... Step: 2400... Loss: 1.0598...\n", + "Epoch: 13/20... Step: 2410... Loss: 1.0549...\n", + "Epoch: 13/20... Step: 2420... Loss: 1.0193...\n", + "Epoch: 13/20... Step: 2430... Loss: 1.0558...\n", + "Epoch: 13/20... Step: 2440... Loss: 1.0538...\n", + "Epoch: 13/20... Step: 2450... Loss: 1.0264...\n", + "Epoch: 13/20... Step: 2460... Loss: 1.0411...\n", + "Epoch: 13/20... Step: 2470... Loss: 1.0098...\n", + "Epoch: 13/20... Step: 2480... Loss: 1.0260...\n", + "Epoch: 13/20... Step: 2490... Loss: 1.0278...\n", + "Epoch: 13/20... Step: 2500... Loss: 1.0343...\n", + "Epoch: 13/20... Step: 2510... Loss: 1.0421...\n", + "Epoch: 13/20... Step: 2520... Loss: 1.0796...\n", + "Epoch: 13/20... Step: 2530... Loss: 1.0343...\n", + "Epoch: 13/20... Step: 2540... Loss: 1.0376...\n", + "Epoch: 13/20... Step: 2550... Loss: 1.0472...\n", + "Epoch: 13/20... Step: 2560... Loss: 1.0306...\n", + "Epoch: 13/20... Step: 2570... Loss: 1.0272...\n", + "Epoch: 14/20... Step: 2580... Loss: 1.0038...\n", + "Epoch: 14/20... Step: 2590... Loss: 1.0151...\n", + "Epoch: 14/20... Step: 2600... Loss: 0.9968...\n", + "Epoch: 14/20... Step: 2610... Loss: 1.0411...\n", + "Epoch: 14/20... Step: 2620... Loss: 1.0187...\n", + "Epoch: 14/20... Step: 2630... Loss: 1.0459...\n", + "Epoch: 14/20... Step: 2640... Loss: 1.0391...\n", + "Epoch: 14/20... Step: 2650... Loss: 1.0387...\n", + "Epoch: 14/20... Step: 2660... Loss: 1.0258...\n", + "Epoch: 14/20... Step: 2670... Loss: 1.0293...\n", + "Epoch: 14/20... Step: 2680... Loss: 1.0243...\n", + "Epoch: 14/20... Step: 2690... Loss: 0.9998...\n", + "Epoch: 14/20... Step: 2700... Loss: 0.9905...\n", + "Epoch: 14/20... Step: 2710... Loss: 1.0253...\n", + "Epoch: 14/20... Step: 2720... Loss: 1.0492...\n", + "Epoch: 14/20... Step: 2730... Loss: 1.0228...\n", + "Epoch: 14/20... Step: 2740... Loss: 1.0302...\n", + "Epoch: 14/20... Step: 2750... Loss: 1.0038...\n", + "Epoch: 14/20... Step: 2760... Loss: 1.0061...\n", + "Epoch: 14/20... Step: 2770... Loss: 1.0302...\n", + "Epoch: 15/20... Step: 2780... Loss: 1.0255...\n", + "Epoch: 15/20... Step: 2790... Loss: 1.0541...\n", + "Epoch: 15/20... Step: 2800... Loss: 1.0511...\n", + "Epoch: 15/20... Step: 2810... Loss: 0.9877...\n", + "Epoch: 15/20... Step: 2820... Loss: 1.0131...\n", + "Epoch: 15/20... Step: 2830... Loss: 1.0119...\n", + "Epoch: 15/20... Step: 2840... Loss: 1.0179...\n", + "Epoch: 15/20... Step: 2850... Loss: 1.0230...\n", + "Epoch: 15/20... Step: 2860... Loss: 1.0077...\n", + "Epoch: 15/20... Step: 2870... Loss: 0.9867...\n", + "Epoch: 15/20... Step: 2880... Loss: 1.0120...\n", + "Epoch: 15/20... Step: 2890... Loss: 1.0217...\n", + "Epoch: 15/20... Step: 2900... Loss: 1.0175...\n", + "Epoch: 15/20... Step: 2910... Loss: 1.0351...\n", + "Epoch: 15/20... Step: 2920... Loss: 1.0332...\n", + "Epoch: 15/20... Step: 2930... Loss: 0.9929...\n", + "Epoch: 15/20... Step: 2940... Loss: 1.0611...\n", + "Epoch: 15/20... Step: 2950... Loss: 1.0488...\n", + "Epoch: 15/20... Step: 2960... Loss: 1.0075...\n", + "Epoch: 15/20... Step: 2970... Loss: 1.1087...\n", + "Epoch: 16/20... Step: 2980... Loss: 1.0061...\n", + "Epoch: 16/20... Step: 2990... Loss: 1.0419...\n", + "Epoch: 16/20... Step: 3000... Loss: 1.0333...\n", + "Epoch: 16/20... Step: 3010... Loss: 1.0156...\n", + "Epoch: 16/20... Step: 3020... Loss: 1.0261...\n", + "Epoch: 16/20... Step: 3030... Loss: 1.0368...\n", + "Epoch: 16/20... Step: 3040... Loss: 1.0194...\n", + "Epoch: 16/20... Step: 3050... Loss: 1.0033...\n", + "Epoch: 16/20... Step: 3060... Loss: 1.0089...\n", + "Epoch: 16/20... Step: 3070... Loss: 0.9822...\n", + "Epoch: 16/20... Step: 3080... Loss: 1.0177...\n", + "Epoch: 16/20... Step: 3090... Loss: 1.0126...\n", + "Epoch: 16/20... Step: 3100... Loss: 0.9754...\n", + "Epoch: 16/20... Step: 3110... Loss: 1.0221...\n", + "Epoch: 16/20... Step: 3120... Loss: 1.0034...\n", + "Epoch: 16/20... Step: 3130... Loss: 1.0244...\n", + "Epoch: 16/20... Step: 3140... Loss: 1.0278...\n", + "Epoch: 16/20... Step: 3150... Loss: 0.9978...\n", + "Epoch: 16/20... Step: 3160... Loss: 1.0082...\n", + "Epoch: 17/20... Step: 3170... Loss: 1.0218...\n", + "Epoch: 17/20... Step: 3180... Loss: 1.0085...\n", + "Epoch: 17/20... Step: 3190... Loss: 1.0321...\n", + "Epoch: 17/20... Step: 3200... Loss: 1.0041...\n", + "Epoch: 17/20... Step: 3210... Loss: 1.0501...\n", + "Epoch: 17/20... Step: 3220... Loss: 1.0353...\n", + "Epoch: 17/20... Step: 3230... Loss: 1.0436...\n", + "Epoch: 17/20... Step: 3240... Loss: 1.0216...\n", + "Epoch: 17/20... Step: 3250... Loss: 1.0059...\n", + "Epoch: 17/20... Step: 3260... Loss: 0.9994...\n", + "Epoch: 17/20... Step: 3270... Loss: 0.9930...\n", + "Epoch: 17/20... Step: 3280... Loss: 0.9944...\n", + "Epoch: 17/20... Step: 3290... Loss: 0.9750...\n", + "Epoch: 17/20... Step: 3300... Loss: 0.9910...\n", + "Epoch: 17/20... Step: 3310... Loss: 1.0102...\n", + "Epoch: 17/20... Step: 3320... Loss: 0.9833...\n", + "Epoch: 17/20... Step: 3330... Loss: 1.0043...\n", + "Epoch: 17/20... Step: 3340... Loss: 0.9951...\n", + "Epoch: 17/20... Step: 3350... Loss: 1.0025...\n", + "Epoch: 17/20... Step: 3360... Loss: 0.9809...\n", + "Epoch: 18/20... Step: 3370... Loss: 1.0241...\n", + "Epoch: 18/20... Step: 3380... Loss: 0.9989...\n", + "Epoch: 18/20... Step: 3390... Loss: 1.0106...\n", + "Epoch: 18/20... Step: 3400... Loss: 1.0071...\n", + "Epoch: 18/20... Step: 3410... Loss: 0.9717...\n", + "Epoch: 18/20... Step: 3420... Loss: 1.0024...\n", + "Epoch: 18/20... Step: 3430... Loss: 1.0024...\n", + "Epoch: 18/20... Step: 3440... Loss: 0.9846...\n", + "Epoch: 18/20... Step: 3450... Loss: 0.9949...\n", + "Epoch: 18/20... Step: 3460... Loss: 0.9706...\n", + "Epoch: 18/20... Step: 3470... Loss: 0.9760...\n", + "Epoch: 18/20... Step: 3480... Loss: 0.9871...\n", + "Epoch: 18/20... Step: 3490... Loss: 0.9915...\n", + "Epoch: 18/20... Step: 3500... Loss: 0.9913...\n", + "Epoch: 18/20... Step: 3510... Loss: 1.0283...\n", + "Epoch: 18/20... Step: 3520... Loss: 0.9822...\n", + "Epoch: 18/20... Step: 3530... Loss: 0.9881...\n", + "Epoch: 18/20... Step: 3540... Loss: 0.9994...\n", + "Epoch: 18/20... Step: 3550... Loss: 0.9906...\n", + "Epoch: 18/20... Step: 3560... Loss: 0.9822...\n", + "Epoch: 19/20... Step: 3570... Loss: 0.9615...\n", + "Epoch: 19/20... Step: 3580... Loss: 0.9698...\n", + "Epoch: 19/20... Step: 3590... Loss: 0.9628...\n", + "Epoch: 19/20... Step: 3600... Loss: 0.9946...\n", + "Epoch: 19/20... Step: 3610... Loss: 0.9680...\n", + "Epoch: 19/20... Step: 3620... Loss: 0.9973...\n", + "Epoch: 19/20... Step: 3630... Loss: 0.9957...\n", + "Epoch: 19/20... Step: 3640... Loss: 0.9965...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 19/20... Step: 3650... Loss: 0.9875...\n", + "Epoch: 19/20... Step: 3660... Loss: 0.9950...\n", + "Epoch: 19/20... Step: 3670... Loss: 0.9830...\n", + "Epoch: 19/20... Step: 3680... Loss: 0.9618...\n", + "Epoch: 19/20... Step: 3690... Loss: 0.9593...\n", + "Epoch: 19/20... Step: 3700... Loss: 0.9912...\n", + "Epoch: 19/20... Step: 3710... Loss: 1.0064...\n", + "Epoch: 19/20... Step: 3720... Loss: 0.9769...\n", + "Epoch: 19/20... Step: 3730... Loss: 0.9775...\n", + "Epoch: 19/20... Step: 3740... Loss: 0.9659...\n", + "Epoch: 19/20... Step: 3750... Loss: 0.9666...\n", + "Epoch: 19/20... Step: 3760... Loss: 0.9797...\n", + "Epoch: 20/20... Step: 3770... Loss: 0.9856...\n", + "Epoch: 20/20... Step: 3780... Loss: 1.0034...\n", + "Epoch: 20/20... Step: 3790... Loss: 1.0050...\n", + "Epoch: 20/20... Step: 3800... Loss: 0.9559...\n", + "Epoch: 20/20... Step: 3810... Loss: 0.9629...\n", + "Epoch: 20/20... Step: 3820... Loss: 0.9684...\n", + "Epoch: 20/20... Step: 3830... Loss: 0.9766...\n", + "Epoch: 20/20... Step: 3840... Loss: 0.9917...\n", + "Epoch: 20/20... Step: 3850... Loss: 0.9802...\n", + "Epoch: 20/20... Step: 3860... Loss: 0.9610...\n", + "Epoch: 20/20... Step: 3870... Loss: 0.9842...\n", + "Epoch: 20/20... Step: 3880... Loss: 0.9794...\n", + "Epoch: 20/20... Step: 3890... Loss: 0.9848...\n", + "Epoch: 20/20... Step: 3900... Loss: 1.0022...\n", + "Epoch: 20/20... Step: 3910... Loss: 0.9965...\n", + "Epoch: 20/20... Step: 3920... Loss: 0.9527...\n", + "Epoch: 20/20... Step: 3930... Loss: 1.0170...\n", + "Epoch: 20/20... Step: 3940... Loss: 1.0181...\n", + "Epoch: 20/20... Step: 3950... Loss: 0.9724...\n", + "Epoch: 20/20... Step: 3960... Loss: 1.0671...\n" + ] + } + ], "source": [ - "epochs = 1\n", + "epochs = 20\n", "print_every = 10\n", - "cuda = False\n", + "cuda = True\n", "\n", "net = CharRNN(len(vocab), n_steps=n_steps, n_layers=n_layers, \n", " n_hidden=lstm_size, drop_prob=drop_prob)\n", "\n", "opt = optim.Adam(net.parameters(), lr=learning_rate)\n", - "criterion = nn.CrossEntropyLoss()\n", + "criterion = nn.NLLLoss()\n", "\n", "if cuda:\n", " net.cuda()\n", @@ -446,6 +927,15 @@ " \"Loss: {:.4f}...\".format(loss.data[0]))" ] }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "torch.save(net.state_dict(), 'anna_rnn.net')" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -454,61 +944,133 @@ "\n", "Now that the network is trained, we'll can use it to generate new text. The idea is that we pass in a character, then the network will predict the next character. We can use the new one, to predict the next one. And we keep doing this to generate all new text. I also included some functionality to prime the network with some text by passing in a string and building up a state from that.\n", "\n", - "The network gives us predictions for each character. To reduce noise and make things a little less random, I'm going to only choose a new character from the top N most likely characters.\n", - "\n" + "The network gives us predictions for each character. To reduce noise and make things a little less random, I'm going to only choose a new character from the top N most likely characters." + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "metadata": {}, + "outputs": [], + "source": [ + "def predict(net, char, hc=None, cuda=False):\n", + " \n", + " if hc is None:\n", + " hc = init_hidden(net, 1)\n", + " \n", + " x = one_hot_encode(np.array([[char]]), net.n_tokens)\n", + " \n", + " # Make sure our variables are volatile so we don't save the history\n", + " # since we're in inference mode here\n", + " inputs = Variable(torch.from_numpy(x), volatile=True)\n", + " hc = tuple([Variable(each.data, volatile=True) for each in hc])\n", + " \n", + " if cuda:\n", + " inputs = inputs.cuda()\n", + "\n", + " x, hc = net.forward(inputs, hc)\n", + " \n", + " return x, hc" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 157, "metadata": {}, "outputs": [], "source": [ - "def pick_top_n(preds, vocab_size, top_n=5):\n", - " p = np.squeeze(preds)\n", - " p[np.argsort(p)[:-top_n]] = 0\n", - " p = p / np.sum(p)\n", - " c = np.random.choice(vocab_size, 1, p=p)[0]\n", - " return c" + "def choose_char(x, top_k=None):\n", + " if top_k is None:\n", + " ps, out = torch.exp(x).max(dim=0)\n", + " out = out[0].data.numpy()[0]\n", + " else:\n", + " probs, idx = torch.exp(x).topk(top_k)\n", + " probs, idx = probs.data.numpy().squeeze(), idx.data.numpy().squeeze()\n", + " out = np.random.choice(idx, p=probs/probs.sum())\n", + " return out" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 158, "metadata": {}, "outputs": [], "source": [ - "def sample(checkpoint, n_samples, lstm_size, vocab_size, prime=\"The \"):\n", - " samples = [c for c in prime]\n", - " model = CharRNN(len(vocab), lstm_size=lstm_size, sampling=True)\n", - " saver = tf.train.Saver()\n", - " with tf.Session() as sess:\n", - " saver.restore(sess, checkpoint)\n", - " new_state = sess.run(model.initial_state)\n", - " for c in prime:\n", - " x = np.zeros((1, 1))\n", - " x[0,0] = vocab_to_int[c]\n", - " feed = {model.inputs: x,\n", - " model.keep_prob: 1.,\n", - " model.initial_state: new_state}\n", - " preds, new_state = sess.run([model.prediction, model.final_state], \n", - " feed_dict=feed)\n", - "\n", - " c = pick_top_n(preds, len(vocab))\n", - " samples.append(int_to_vocab[c])\n", - "\n", - " for i in range(n_samples):\n", - " x[0,0] = c\n", - " feed = {model.inputs: x,\n", - " model.keep_prob: 1.,\n", - " model.initial_state: new_state}\n", - " preds, new_state = sess.run([model.prediction, model.final_state], \n", - " feed_dict=feed)\n", - "\n", - " c = pick_top_n(preds, len(vocab))\n", - " samples.append(int_to_vocab[c])\n", + "def sample(net, n_samples, prime=\"The\", cuda=False, top_k=None):\n", + " ''' Sample from a trained network.\n", + " '''\n", + " # First make sure the network is in inference mode\n", + " net.eval()\n", + " if cuda:\n", + " net.cuda()\n", + " else:\n", + " net.cpu()\n", + " \n", + " # Initialize hidden state\n", + " hc = init_hidden(net, 1)\n", " \n", - " return ''.join(samples)" + " # Build up the hidden state from the priming text\n", + " sample = list(prime)\n", + " for char in sample:\n", + " x, hc = predict(net, vocab_to_int[char], hc=hc, cuda=cuda)\n", + " \n", + " # Get the first new character\n", + " if cuda:\n", + " x = x.cpu()\n", + " char_int = choose_char(x)\n", + " sample.append(int_to_vocab[char_int])\n", + " \n", + " for ii in range(n_samples):\n", + " x, hc = predict(net, char_int, hc=hc, cuda=cuda)\n", + " if cuda:\n", + " x = x.cpu()\n", + " char_int = choose_char(x, top_k=top_k)\n", + " sample.append(int_to_vocab[char_int])\n", + " \n", + " return sample" + ] + }, + { + "cell_type": "code", + "execution_count": 166, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The\n", + "woman struck and drew a strung of all the carriage\n", + "what he was not friendly of him, and would not criming to see him and so much\n", + "frightening what had\n", + "been seem, when\n", + "they had so straight out.\n", + "\n", + "\"Yes, I don't know about her,\" he addressed home, sat\n", + "down and say that the porter sat\n", + "down to him. \"Though I shall not stand at the same\n", + "force over the\n", + "sense of the praceity of the same to you to tell me to drive or. It seems\n", + "or to be servants\n", + "with me.\"\n", + "\n", + "\"And what does he had not say?\" he asked was\n", + "that her\n", + "shoulders, so that they could not be anything, and so as a minute\n", + "that she was simply talking at her there.\"\n", + "\n", + "\"Yes, there's none that things should have sumplecely attended? And I have told her husband what I works alone in, that had been\n", + "bowed to my dear, and though I won't consider the porter's side.\n", + "\n", + "\"Why didn't you to give it.\"\n", + "\n", + "He set off till the stream of late and carelessness of this. Having had terrible on his face. She did anything to her a long while. He saw that she could\n", + "not love\n" + ] + } + ], + "source": [ + "print(''.join(sample(net, 1000, top_k=5)))" ] }, {