|
| 1 | +# textgenrnn |
| 2 | + |
| 3 | + |
| 4 | + |
| 5 | +Generate text using a pretrained neural network with a few lines of code, or easily train your own text-generating neural network of any size and complexity on any text dataset. |
| 6 | + |
| 7 | +textgenrnn is a Python 3 module on top of [Keras](https://github.com/fchollet/keras)/[TensorFlow](https://www.tensorflow.org) for creating [char-rnn](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)s, with many cool features: |
| 8 | + |
| 9 | +* A modern neural network architecture which utilizes new techniques as attention-weighting and skip-embedding to accelerate training and improve model quality. |
| 10 | +* Able to train and predict at either the character-level or word-level. |
| 11 | +* Able to configure RNN size, the number of RNN layers, and whether to use bidirectional RNNs. |
| 12 | +* Able to train on any generic input text file. |
| 13 | +* Able to train models on a GPU and then use them with a CPU. |
| 14 | +* Able to utilize a powerful CuDNN implementation of RNNs when trained on the GPU, which massively speeds up training time as opposed to normal LSTM implementations. |
| 15 | +* Able to train the model using contextual labels, allowing it to learn faster and produce better results in some cases. |
| 16 | + |
| 17 | +You can play with textgenrnn and train any text file with a GPU *for free* in this [Colaboratory Notebook](https://drive.google.com/file/d/1mMKGnVxirJnqDViH7BDJxFqWrsXlPSoK/view?usp=sharing)! |
| 18 | + |
| 19 | +## Examples |
| 20 | + |
| 21 | +```python |
| 22 | +from textgenrnn import textgenrnn |
| 23 | + |
| 24 | +textgen = textgenrnn() |
| 25 | +textgen.generate() |
| 26 | +``` |
| 27 | + |
| 28 | +```text |
| 29 | +[Spoiler] Anyone else find this post and their person that was a little more than I really like the Star Wars in the fire or health and posting a personal house of the 2016 Letter for the game in a report of my backyard. |
| 30 | +``` |
| 31 | + |
| 32 | +The model can easily be trained on new texts, and can generate appropriate text *even after a single pass of the input data*. |
| 33 | + |
| 34 | +```python |
| 35 | +textgen.train_from_file('hacker-news-2000.txt', num_epochs=1) |
| 36 | +textgen.generate() |
| 37 | +``` |
| 38 | + |
| 39 | +```text |
| 40 | +Project State Project Firefox |
| 41 | +``` |
| 42 | + |
| 43 | +The model weights are relatively small (2 MB on disk), and they can easily be saved and loaded into a new textgenrnn instance. As a result, you can play with models which have been trained on hundreds of passes through the data. (in fact, textgenrnn learns *so well* that you have to increase the temperature significantly for creative output!) |
| 44 | + |
| 45 | +```python |
| 46 | +textgen_2 = textgenrnn('/weights/hacker_news.hdf5') |
| 47 | +textgen_2.generate(3, temperature=1.0) |
| 48 | +``` |
| 49 | + |
| 50 | +```text |
| 51 | +Why we got money “regular alter” |
| 52 | +
|
| 53 | +Urburg to Firefox acquires Nelf Multi Shamn |
| 54 | +
|
| 55 | +Kubernetes by Google’s Bern |
| 56 | +``` |
| 57 | + |
| 58 | +You can also train a new model, with support for word level embeddings and bidirectional layers. |
| 59 | + |
| 60 | +## Usage |
| 61 | + |
| 62 | +textgenrnn can be installed [from pypi](https://pypi.python.org/pypi/textgenrnn) via `pip`: |
| 63 | + |
| 64 | +```sh |
| 65 | +pip3 install textgenrnn |
| 66 | +``` |
| 67 | + |
| 68 | +You can view a demo of common features and configuration options in [this Jupyter Notebook](/docs/textgenrnn-demo.ipynb). |
| 69 | + |
| 70 | +`/datasets` contains example datasets using Hacker News/Reddit data for training textgenrnn. |
| 71 | + |
| 72 | +`/weights` contains further-pretrained models on the aforementioned datasets which can be loaded into textgenrnn. |
| 73 | + |
| 74 | +`/outputs` contains examples of text generated from the above pretrained models. |
| 75 | + |
| 76 | +## Neural Network Architecture and Implementation |
| 77 | + |
| 78 | +textgenrnn is based off of the [char-rnn](https://github.com/karpathy/char-rnn) project by [Andrej Karpathy](https://twitter.com/karpathy) with a few modern optimizations, such as the ability to work with very small text sequences. |
| 79 | + |
| 80 | + |
| 81 | + |
| 82 | +The included pretrained-model follows a [neural network architecture](https://github.com/bfelbo/DeepMoji/blob/master/deepmoji/model_def.py) inspired by [DeepMoji](https://github.com/bfelbo/DeepMoji). For the default model, textgenrnn takes in an input of up to 40 characters, converts each character to a 100-D character embedding vector, and feeds those into a 128-cell long-short-term-memory (LSTM) recurrent layer. Those outputs are then fed into *another* 128-cell LSTM. All three layers are then fed into an Attention layer to weight the most important temporal features and average them together (and since the embeddings + 1st LSTM are skip-connected into the attention layer, the model updates can backpropagate to them more easily and prevent vanishing gradients). That output is mapped to probabilities for up to [394 different characters](/textgenrnn/textgenrnn_vocab.json) that they are the next character in the sequence, including uppercase characters, lowercase, punctuation, and emoji. (if training a new model on a new dataset, all of the numeric parameters above can be configured) |
| 83 | + |
| 84 | + |
| 85 | + |
| 86 | +Alternatively, if context labels are provided with each text document, the model can be trained in a contextual mode, where the model learns the text *given the context* so the recurrent layers learn the *decontextualized* language. The text-only path can piggy-back off the decontextualized layers; in all, this results in much faster training and better quantitative and qualitative model performance than just training the model gien the text alone. |
| 87 | + |
| 88 | +The model weights included with the package are trained on hundreds of thousands of text documents from Reddit submissions ([via BigQuery](http://minimaxir.com/2015/10/reddit-bigquery/)), from a very *diverse* variety of subreddits. The network was also trained using the decontextual approach noted above in order to both improve training performance and mitigate authorial bias. |
| 89 | + |
| 90 | +When fine-tuning the model on a new dataset of texts using textgenrnn, all layers are retrained. However, since the original pretrained network has a much more robust "knowledge" initially, the new textgenrnn trains faster and more accurately in the end, and can potentially learn new relationships not present in the original dataset (e.g. the [pretrained character embeddings](http://minimaxir.com/2017/04/char-embeddings/) include the context for the character for all possible types of modern internet grammar). |
| 91 | + |
| 92 | +Additionally, the retraining is done with a momentum-based optimizer and a linearly decaying learning rate, both of which prevent exploding gradients and makes it much less likely that the model diverges after training for a long time. |
| 93 | + |
| 94 | +## Notes |
| 95 | + |
| 96 | +* **You will not get quality generated text 100% of the time**, even with a heavily-trained neural network. That's the primary reason viral [blog posts](http://aiweirdness.com/post/170685749687/candy-heart-messages-written-by-a-neural-network)/[Twitter tweets](https://twitter.com/botnikstudios/status/955870327652970496) utilizing NN text generation often generate lots of texts and curate/edit the best ones afterward. |
| 97 | + |
| 98 | +* **Results will vary greatly between datasets**. Because the pretrained neural network is relatively small, it cannot store as much data as RNNs typically flaunted in blog posts. For best results, use a dataset with atleast 2,000-5,000 documents. If a dataset is smaller, you'll need to train it for longer by setting `num_epochs` higher when calling a training method and/or training a new model from scratch. Even then, there is currently no good heuristic for determining a "good" model. |
| 99 | + |
| 100 | +* A GPU is not required to retrain textgenrnn, but it will take much longer to train on a CPU. If you do use a GPU, I recommend increasing the `batch_size` parameter for better hardware utilization. |
| 101 | + |
| 102 | +## Future Plans for textgenrnn |
| 103 | + |
| 104 | +* More formal documentation |
| 105 | + |
| 106 | +* A web-based implementation using tensorflow.js (works especially well due to the network's small size) |
| 107 | + |
| 108 | +* A way to visualize the attention-layer outputs to see how the network "learns." |
| 109 | + |
| 110 | +* Supervised text generation mode: allow the model to present the top *n* options and user select the next char/word ([reference](https://fivethirtyeight.com/features/some-like-it-bot/)) |
| 111 | + |
| 112 | +* A mode to allow the model architecture to be used for chatbot conversations (may be released as a separate project) |
| 113 | + |
| 114 | +* More depth toward context (positional context + allowing multiple context labels) |
| 115 | + |
| 116 | +* A larger pretrained network which can accommodate longer character sequences and a more indepth understanding of language, creating better generated sentences. |
| 117 | + |
| 118 | +* Hierarchical softmax activation for word-level models (once Keras has good support for it). |
| 119 | + |
| 120 | +* FP16 for superfast training on Volta/TPUs (once Keras has good support for it). |
| 121 | + |
| 122 | +## Projects using textgenrnn |
| 123 | + |
| 124 | +* [Tweet Generator](https://github.com/minimaxir/tweet-generator) — Train a neural network optimized for generating tweets based off of any number of Twitter users |
| 125 | + |
| 126 | +## Maintainer/Creator |
| 127 | + |
| 128 | +Max Woolf ([@minimaxir](http://minimaxir.com)) |
| 129 | + |
| 130 | +*Max's open-source projects are supported by his [Patreon](https://www.patreon.com/minimaxir). If you found this project helpful, any monetary contributions to the Patreon are appreciated and will be put to good creative use.* |
| 131 | + |
| 132 | +## Credits |
| 133 | + |
| 134 | +Andrej Karpathy for the original proposal of the char-rnn via the blog post [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). |
| 135 | + |
| 136 | +## License |
| 137 | + |
| 138 | +MIT |
| 139 | + |
| 140 | +Attention-layer code used from [DeepMoji](https://github.com/bfelbo/DeepMoji) (MIT Licensed) |
0 commit comments