Skip to content

Commit

Permalink
rnns
Browse files Browse the repository at this point in the history
  • Loading branch information
sjchoi86 committed Jul 19, 2016
1 parent ccddc0b commit 67043a2
Show file tree
Hide file tree
Showing 7 changed files with 1,502 additions and 6,390 deletions.
Binary file added notebooks/Hangulpy.pyc
Binary file not shown.
Binary file added notebooks/TextLoader.pyc
Binary file not shown.
145 changes: 81 additions & 64 deletions notebooks/char_rnn_sample_hangul.ipynb
100644 → 100755

Large diffs are not rendered by default.

159 changes: 123 additions & 36 deletions notebooks/char_rnn_train_hangul.ipynb
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train Hangul RNN"
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down Expand Up @@ -30,6 +37,13 @@
"print (\"Packages Imported\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load dataset using TextLoader"
]
},
{
"cell_type": "code",
"execution_count": 2,
Expand All @@ -42,19 +56,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"reading text file\n",
"type of 'data_loader' is <type 'dict'>, length is 76\n",
"\n",
"\n",
"data_loader.vocab looks like \n",
"{u'_': 69, u'6': 59, u':': 57, u'\\n': 19, u'4': 67, u'5': 63, u'>': 75, u'!': 52, u' ': 1, u'\"': 28, u'\\u1d25': 0, u\"'\": 49, u')': 46, u'(': 45, u'-': 65, u',': 27, u'.': 24, u'\\u3131': 7, u'0': 73, u'\\u3133': 60, u'\\u3132': 29, u'\\u3135': 50, u'\\u3134': 4, u'\\u3137': 13, u'\\u3136': 44, u'\\u3139': 5, u'\\u3138': 32, u'\\u313b': 55, u'\\u313a': 48, u'\\u313c': 54, u'?': 41, u'3': 66, u'\\u3141': 12, u'\\u3140': 51, u'\\u3143': 47, u'\\u3142': 17, u'\\u3145': 10, u'\\u3144': 43, u'\\u3147': 2, u'\\u3146': 22, u'\\u3149': 40, u'\\u3148': 15, u'\\u314b': 42, u'\\u314a': 23, u'\\u314d': 31, u'\\u314c': 30, u'\\u314f': 3, u'\\u314e': 14, u'\\u3151': 34, u'\\u3150': 21, u'\\u3153': 11, u'\\u3152': 74, u'\\u3155': 18, u'\\u3154': 20, u'\\u3157': 9, u'\\u3156': 39, u'\\u3159': 53, u'\\u3158': 26, u'\\u315b': 38, u'\\u315a': 33, u'\\u315d': 36, u'\\u315c': 16, u'\\u315f': 35, u'\\u315e': 61, u'\\u3161': 8, u'\\u3160': 37, u'\\u3163': 6, u'\\u3162': 25, u'\\x1a': 72, u'9': 64, u'7': 71, u'2': 62, u'1': 58, u'\\u313f': 56, u'\\u313e': 70, u'8': 68} \n",
"\n",
"\n",
"type of 'data_loader.chars' is <type 'tuple'>, length is 76\n",
"\n",
"\n",
"data_loader.chars looks like \n",
"(u'\\u1d25', u' ', u'\\u3147', u'\\u314f', u'\\u3134', u'\\u3139', u'\\u3163', u'\\u3131', u'\\u3161', u'\\u3157', u'\\u3145', u'\\u3153', u'\\u3141', u'\\u3137', u'\\u314e', u'\\u3148', u'\\u315c', u'\\u3142', u'\\u3155', u'\\n', u'\\u3154', u'\\u3150', u'\\u3146', u'\\u314a', u'.', u'\\u3162', u'\\u3158', u',', u'\"', u'\\u3132', u'\\u314c', u'\\u314d', u'\\u3138', u'\\u315a', u'\\u3151', u'\\u315f', u'\\u315d', u'\\u3160', u'\\u315b', u'\\u3156', u'\\u3149', u'?', u'\\u314b', u'\\u3144', u'\\u3136', u'(', u')', u'\\u3143', u'\\u313a', u\"'\", u'\\u3135', u'\\u3140', u'!', u'\\u3159', u'\\u313c', u'\\u313b', u'\\u313f', u':', u'1', u'6', u'\\u3133', u'\\u315e', u'2', u'5', u'9', u'-', u'3', u'4', u'8', u'_', u'\\u313e', u'7', u'\\x1a', u'0', u'\\u3152', u'>') \n"
"loading preprocessed files\n",
"type of 'data_loader' is <type 'dict'>, length is 76\n"
]
}
],
Expand All @@ -68,18 +71,77 @@
"vocab_size = data_loader.vocab_size\n",
"vocab = data_loader.vocab\n",
"chars = data_loader.chars\n",
"print ( \"type of 'data_loader' is %s, length is %d\" % (type(data_loader.vocab), len(data_loader.vocab)) )\n",
"print ( \"\\n\" )\n",
"print (\"data_loader.vocab looks like \\n%s \" % (data_loader.vocab))\n",
"print ( \"\\n\" )\n",
"print ( \"type of 'data_loader.chars' is %s, length is %d\" % (type(data_loader.chars), len(data_loader.chars)) )\n",
"print ( \"\\n\" )\n",
"print ( \"type of 'data_loader' is %s, length is %d\" \n",
" % (type(data_loader.vocab), len(data_loader.vocab)) )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# data_loader.vocab"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data_loader.vocab looks like \n",
"{u'_': 69, u'6': 59, u':': 57, u'\\n': 19, u'4': 67, u'5': 63, u'>': 75, u'!': 52, u' ': 1, u'\"': 28, u'\\u1d25': 0, u\"'\": 49, u')': 46, u'(': 45, u'-': 65, u',': 27, u'.': 24, u'\\u3131': 7, u'0': 73, u'\\u3133': 60, u'\\u3132': 29, u'\\u3135': 50, u'\\u3134': 4, u'\\u3137': 13, u'\\u3136': 44, u'\\u3139': 5, u'\\u3138': 32, u'\\u313b': 55, u'\\u313a': 48, u'\\u313c': 54, u'?': 41, u'3': 66, u'\\u3141': 12, u'\\u3140': 51, u'\\u3143': 47, u'\\u3142': 17, u'\\u3145': 10, u'\\u3144': 43, u'\\u3147': 2, u'\\u3146': 22, u'\\u3149': 40, u'\\u3148': 15, u'\\u314b': 42, u'\\u314a': 23, u'\\u314d': 31, u'\\u314c': 30, u'\\u314f': 3, u'\\u314e': 14, u'\\u3151': 34, u'\\u3150': 21, u'\\u3153': 11, u'\\u3152': 74, u'\\u3155': 18, u'\\u3154': 20, u'\\u3157': 9, u'\\u3156': 39, u'\\u3159': 53, u'\\u3158': 26, u'\\u315b': 38, u'\\u315a': 33, u'\\u315d': 36, u'\\u315c': 16, u'\\u315f': 35, u'\\u315e': 61, u'\\u3161': 8, u'\\u3160': 37, u'\\u3163': 6, u'\\u3162': 25, u'\\x1a': 72, u'9': 64, u'7': 71, u'2': 62, u'1': 58, u'\\u313f': 56, u'\\u313e': 70, u'8': 68} \n"
]
}
],
"source": [
"print (\"data_loader.vocab looks like \\n%s \" % (data_loader.vocab))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# data_loader.chars"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"type of 'data_loader.chars' is <type 'tuple'>, length is 76\n",
"data_loader.chars looks like \n",
"(u'\\u1d25', u' ', u'\\u3147', u'\\u314f', u'\\u3134', u'\\u3139', u'\\u3163', u'\\u3131', u'\\u3161', u'\\u3157', u'\\u3145', u'\\u3153', u'\\u3141', u'\\u3137', u'\\u314e', u'\\u3148', u'\\u315c', u'\\u3142', u'\\u3155', u'\\n', u'\\u3154', u'\\u3150', u'\\u3146', u'\\u314a', u'.', u'\\u3162', u'\\u3158', u',', u'\"', u'\\u3132', u'\\u314c', u'\\u314d', u'\\u3138', u'\\u315a', u'\\u3151', u'\\u315f', u'\\u315d', u'\\u3160', u'\\u315b', u'\\u3156', u'\\u3149', u'?', u'\\u314b', u'\\u3144', u'\\u3136', u'(', u')', u'\\u3143', u'\\u313a', u\"'\", u'\\u3135', u'\\u3140', u'!', u'\\u3159', u'\\u313c', u'\\u313b', u'\\u313f', u':', u'1', u'6', u'\\u3133', u'\\u315e', u'2', u'5', u'9', u'-', u'3', u'4', u'8', u'_', u'\\u313e', u'7', u'\\x1a', u'0', u'\\u3152', u'>') \n"
]
}
],
"source": [
"print ( \"type of 'data_loader.chars' is %s, length is %d\" \n",
" % (type(data_loader.chars), len(data_loader.chars)) )\n",
"print (\"data_loader.chars looks like \\n%s \" % (data_loader.chars,))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Define network"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {
"collapsed": false
},
Expand All @@ -88,7 +150,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Network Ready\n"
"Network ready\n"
]
}
],
Expand Down Expand Up @@ -117,7 +179,33 @@
" inputs = tf.split(1, seq_length, tf.nn.embedding_lookup(embedding\n",
" , input_data))\n",
" inputs = [tf.squeeze(input_, [1]) for input_ in inputs]\n",
"# Loop function for seq2seq\n",
"print (\"Network ready\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Define functions"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Network Ready\n"
]
}
],
"source": [
"# Loop function for seq2seq (not used)\n",
"def loop(prev, _):\n",
" prev = tf.nn.xw_plus_b(prev, softmax_w, softmax_b)\n",
" prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))\n",
Expand Down Expand Up @@ -146,30 +234,29 @@
"print (\"Network Ready\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Run"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0/9950 (epoch: 0), loss: 4.893, time/batch: 8.269\n",
"model saved to /tmp/tf_logs/char_rnn_hangul/model.ckpt\n"
]
}
],
"outputs": [],
"source": [
"# Train the model!\n",
"num_epochs = 50\n",
"num_epochs = 500\n",
"save_every = 1000\n",
"learning_rate = 0.0002\n",
"decay_rate = 0.97\n",
"\n",
"save_dir = '/tmp/tf_logs/char_rnn_hangul'\n",
"save_dir = 'data/nine_dreams'\n",
"sess = tf.Session()\n",
"sess.run(tf.initialize_all_variables())\n",
"summary_writer = tf.train.SummaryWriter(save_dir\n",
Expand Down Expand Up @@ -229,7 +316,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.11"
"version": "2.7.6"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 67043a2

Please sign in to comment.