Skip to content

Commit 66aba16

Browse files
committed
Harry Potter Book 8 - RNN Climax!
Test zone to try and create a new Harry Potter book with RNN/GPU Tesla80 on Google Cloud
1 parent 62150c3 commit 66aba16

11 files changed

+211405
-9
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {
7+
"collapsed": true
8+
},
9+
"outputs": [],
10+
"source": [
11+
"import numpy as np\n",
12+
"import tensorflow as tf\n",
13+
"import codecs"
14+
]
15+
},
16+
{
17+
"cell_type": "markdown",
18+
"metadata": {},
19+
"source": [
20+
"## Loading the stuff "
21+
]
22+
},
23+
{
24+
"cell_type": "markdown",
25+
"metadata": {},
26+
"source": [
27+
"#### check if the books exist "
28+
]
29+
},
30+
{
31+
"cell_type": "code",
32+
"execution_count": null,
33+
"metadata": {},
34+
"outputs": [],
35+
"source": [
36+
"import glob\n",
37+
"\n",
38+
"book_filenames = sorted(glob.glob(\"data/*txt\"))\n",
39+
"\n",
40+
"print(\"Found {} books\".format(len(book_filenames)))"
41+
]
42+
},
43+
{
44+
"cell_type": "markdown",
45+
"metadata": {},
46+
"source": [
47+
"#### Joining the books into a string "
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"potter_raw = u\"\"\n",
57+
"for filename in book_filenames:\n",
58+
" with codecs.open(filename, 'r', 'utf-8') as book_file:\n",
59+
" potter_raw += book_file.read()\n",
60+
"print(\"Potter is \", len(potter_raw), \" characters long\")"
61+
]
62+
},
63+
{
64+
"cell_type": "markdown",
65+
"metadata": {},
66+
"source": [
67+
"## Process Potter "
68+
]
69+
},
70+
{
71+
"cell_type": "markdown",
72+
"metadata": {},
73+
"source": [
74+
"#### create lookup tables "
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": null,
80+
"metadata": {
81+
"collapsed": true
82+
},
83+
"outputs": [],
84+
"source": [
85+
"def lookup_tables(text):\n",
86+
" vocab = set(text)\n",
87+
" int_to_vocab = {key: word for key, word in enumerate(vocab)}\n",
88+
" vocab_to_int = {word: key for key, word in enumerate(vocab)}\n",
89+
" return vocab_to_int, int_to_vocab"
90+
]
91+
},
92+
{
93+
"cell_type": "markdown",
94+
"metadata": {},
95+
"source": [
96+
"#### Tokenize punctuation "
97+
]
98+
},
99+
{
100+
"cell_type": "code",
101+
"execution_count": null,
102+
"metadata": {
103+
"collapsed": true
104+
},
105+
"outputs": [],
106+
"source": [
107+
"def token_lookup():\n",
108+
" \"\"\"\n",
109+
" Generate a dict to map punctuation into a token\n",
110+
" :return: dictionary mapping puncuation to token\n",
111+
" \"\"\"\n",
112+
" return {\n",
113+
" '.': '||period||',\n",
114+
" ',': '||comma||',\n",
115+
" '\"': '||quotes||',\n",
116+
" ';': '||semicolon||',\n",
117+
" '!': '||exclamation-mark||',\n",
118+
" '?': '||question-mark||',\n",
119+
" '(': '||left-parentheses||',\n",
120+
" ')': '||right-parentheses||',\n",
121+
" '--': '||emm-dash||',\n",
122+
" '\\n': '||return||'\n",
123+
" \n",
124+
" }"
125+
]
126+
},
127+
{
128+
"cell_type": "markdown",
129+
"metadata": {},
130+
"source": [
131+
"#### Process and save data "
132+
]
133+
},
134+
{
135+
"cell_type": "code",
136+
"execution_count": null,
137+
"metadata": {},
138+
"outputs": [],
139+
"source": [
140+
"import pickle\n",
141+
"\n",
142+
"token_dict = token_lookup()\n",
143+
"for token, replacement in token_dict.items():\n",
144+
" potter_raw = potter_raw.replace(token, ' {} '.format(replacement))\n",
145+
"corpus_raw = potter_raw.lower()\n",
146+
"corpus_raw = potter_raw.split()\n",
147+
"\n",
148+
"vocab_to_int, int_to_vocab = lookup_tables(potter_raw)\n",
149+
"potter_int = [vocab_to_int[word] for word in potter_raw]\n",
150+
"pickle.dump((potter_int, vocab_to_int, int_to_vocab, token_dict), open('preprocess.p', 'wb'))"
151+
]
152+
},
153+
{
154+
"cell_type": "markdown",
155+
"metadata": {},
156+
"source": [
157+
"## Building the network"
158+
]
159+
},
160+
{
161+
"cell_type": "markdown",
162+
"metadata": {},
163+
"source": [
164+
"### Batching the data "
165+
]
166+
},
167+
{
168+
"cell_type": "code",
169+
"execution_count": null,
170+
"metadata": {
171+
"collapsed": true
172+
},
173+
"outputs": [],
174+
"source": [
175+
"def get_batches(int_text, batch_size, seq_length):\n",
176+
" words_per_batch = batch_size*seq_length\n",
177+
" num_batches = len(int_text)//words_per_batch\n",
178+
" int_text = int_text[:num_batches*words_per_batch]\n",
179+
" y = np.array(int_text[1:] + [int_text[0]])\n",
180+
" x = np.array(int_text)\n",
181+
" \n",
182+
" x_batches = np.split(x.reshape(batch_size, -1), num_batches, axis=1)\n",
183+
" y_batches = np.split(y.reshape(batch_size, -1), num_batches, axis=1)\n",
184+
" \n",
185+
" batch_data = list(zip(x_batches, y_batches))\n",
186+
" \n",
187+
" return np.array(batch_data)"
188+
]
189+
},
190+
{
191+
"cell_type": "markdown",
192+
"metadata": {},
193+
"source": [
194+
"#### Set the hyperparameters "
195+
]
196+
},
197+
{
198+
"cell_type": "code",
199+
"execution_count": null,
200+
"metadata": {
201+
"collapsed": true
202+
},
203+
"outputs": [],
204+
"source": [
205+
"num_epochs = 10000\n",
206+
"batch_size = 512\n",
207+
"rnn_size = 512\n",
208+
"num_layers = 3\n",
209+
"keep_prob = 0.7\n",
210+
"embed_dim = 512\n",
211+
"seq_length = 30\n",
212+
"learning_rate = 0.001\n",
213+
"save_dir = './save'"
214+
]
215+
},
216+
{
217+
"cell_type": "markdown",
218+
"metadata": {},
219+
"source": [
220+
"#### Building the graph "
221+
]
222+
},
223+
{
224+
"cell_type": "code",
225+
"execution_count": null,
226+
"metadata": {
227+
"collapsed": true
228+
},
229+
"outputs": [],
230+
"source": [
231+
"train_graph = tf.Graph()\n",
232+
"with train_graph.as_default(): \n",
233+
" \n",
234+
" # Initialize input placeholders\n",
235+
" input_text = tf.placeholder(tf.int32, [None, None], name='input')\n",
236+
" targets = tf.placeholder(tf.int32, [None, None], name='targets')\n",
237+
" lr = tf.placeholder(tf.float32, name='learning_rate')\n",
238+
" \n",
239+
" # Calculate text attributes\n",
240+
" vocab_size = len(int_to_vocab)\n",
241+
" input_text_shape = tf.shape(input_text)\n",
242+
" \n",
243+
" # Build the RNN cell\n",
244+
" lstm = tf.contrib.rnn.BasicLSTMCell(num_units=rnn_size)\n",
245+
" drop_cell = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob)\n",
246+
" cell = tf.contrib.rnn.MultiRNNCell([drop_cell] * num_layers)\n",
247+
" \n",
248+
" # Set the initial state\n",
249+
" initial_state = cell.zero_state(input_text_shape[0], tf.float32)\n",
250+
" initial_state = tf.identity(initial_state, name='initial_state')\n",
251+
" \n",
252+
" # Create word embedding as input to RNN\n",
253+
" embed = tf.contrib.layers.embed_sequence(input_text, vocab_size, embed_dim)\n",
254+
" \n",
255+
" # Build RNN\n",
256+
" outputs, final_state = tf.nn.dynamic_rnn(cell, embed, dtype=tf.float32)\n",
257+
" final_state = tf.identity(final_state, name='final_state')\n",
258+
" \n",
259+
" # Take RNN output and make logits\n",
260+
" logits = tf.contrib.layers.fully_connected(outputs, vocab_size, activation_fn=None)\n",
261+
" \n",
262+
" # Calculate the probability of generating each word\n",
263+
" probs = tf.nn.softmax(logits, name='probs')\n",
264+
" \n",
265+
" # Define loss function\n",
266+
" cost = tf.contrib.seq2seq.sequence_loss(\n",
267+
" logits,\n",
268+
" targets,\n",
269+
" tf.ones([input_text_shape[0], input_text_shape[1]])\n",
270+
" )\n",
271+
" \n",
272+
" # Learning rate optimizer\n",
273+
" optimizer = tf.train.AdamOptimizer(learning_rate)\n",
274+
" \n",
275+
" # Gradient clipping to avoid exploding gradients\n",
276+
" gradients = optimizer.compute_gradients(cost)\n",
277+
" capped_gradients = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gradients if grad is not None]\n",
278+
" train_op = optimizer.apply_gradients(capped_gradients)"
279+
]
280+
},
281+
{
282+
"cell_type": "markdown",
283+
"metadata": {},
284+
"source": [
285+
"#### Train the network "
286+
]
287+
},
288+
{
289+
"cell_type": "code",
290+
"execution_count": null,
291+
"metadata": {},
292+
"outputs": [],
293+
"source": [
294+
"import time\n",
295+
"\n",
296+
"pickle.dump((seq_length, save_dir), open('params.p', 'wb'))\n",
297+
"batches = get_batches(potter_int, batch_size, seq_length)\n",
298+
"num_batches = len(batches)\n",
299+
"start_time = time.time()\n",
300+
"\n",
301+
"with tf.Session(graph=train_graph) as sess:\n",
302+
" sess.run(tf.global_variables_initializer())\n",
303+
" \n",
304+
" for epoch in range(num_epochs):\n",
305+
" state = sess.run(initial_state, {input_text: batches[0][0]})\n",
306+
" \n",
307+
" for batch_index, (x, y) in enumerate(batches):\n",
308+
" feed_dict = {\n",
309+
" input_text: x,\n",
310+
" targets: y,\n",
311+
" initial_state: state,\n",
312+
" lr: learning_rate\n",
313+
" }\n",
314+
" train_loss, state, _ = sess.run([cost, final_state, train_op], feed_dict)\n",
315+
" \n",
316+
" time_elapsed = time.time() - start_time\n",
317+
" print('Epoch {:>3} Batch {:>4}/{} train_loss = {:.3f} time_elapsed = {:.3f} time_remaining = {:.0f}'.format(\n",
318+
" epoch + 1,\n",
319+
" batch_index + 1,\n",
320+
" len(batches),\n",
321+
" train_loss,\n",
322+
" time_elapsed,\n",
323+
" ((num_batches * num_epochs)/((epoch + 1) * (batch_index + 1))) * time_elapsed - time_elapsed))\n",
324+
"\n",
325+
" # save model every 10 epochs\n",
326+
" if epoch % 10 == 0:\n",
327+
" saver = tf.train.Saver()\n",
328+
" saver.save(sess, save_dir)\n",
329+
" print('Model Trained and Saved')"
330+
]
331+
},
332+
{
333+
"cell_type": "code",
334+
"execution_count": null,
335+
"metadata": {
336+
"collapsed": true
337+
},
338+
"outputs": [],
339+
"source": []
340+
}
341+
],
342+
"metadata": {
343+
"kernelspec": {
344+
"display_name": "Python 3",
345+
"language": "python",
346+
"name": "python3"
347+
},
348+
"language_info": {
349+
"codemirror_mode": {
350+
"name": "ipython",
351+
"version": 3
352+
},
353+
"file_extension": ".py",
354+
"mimetype": "text/x-python",
355+
"name": "python",
356+
"nbconvert_exporter": "python",
357+
"pygments_lexer": "ipython3",
358+
"version": "3.6.1"
359+
}
360+
},
361+
"nbformat": 4,
362+
"nbformat_minor": 2
363+
}

0 commit comments

Comments
 (0)