From 034822359d9a4c3b2fc63de0676bad512b426112 Mon Sep 17 00:00:00 2001 From: Anjishnu Kumar Date: Sat, 5 Sep 2015 13:36:52 -0700 Subject: [PATCH] Fixed import errors with six.moves.cPickle and model.train typo in the skipgram embeddings example --- examples/skipgram_word_embeddings.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/skipgram_word_embeddings.py b/examples/skipgram_word_embeddings.py index 5da77159828..bfdf86d8b35 100644 --- a/examples/skipgram_word_embeddings.py +++ b/examples/skipgram_word_embeddings.py @@ -32,7 +32,7 @@ import numpy as np import theano -import six.moves.cPickle +from six.moves import cPickle import os, re, json from keras.preprocessing import sequence, text @@ -90,7 +90,7 @@ def text_generator(path=data_path): # model management if load_tokenizer: print('Load tokenizer...') - tokenizer = six.moves.cPickle.load(open(os.path.join(save_dir, tokenizer_fname), 'rb')) + tokenizer = cPickle.load(open(os.path.join(save_dir, tokenizer_fname), 'rb')) else: print("Fit tokenizer...") tokenizer = text.Tokenizer(nb_words=max_features) @@ -99,13 +99,13 @@ def text_generator(path=data_path): print("Save tokenizer...") if not os.path.exists(save_dir): os.makedirs(save_dir) - six.moves.cPickle.dump(tokenizer, open(os.path.join(save_dir, tokenizer_fname), "wb")) + cPickle.dump(tokenizer, open(os.path.join(save_dir, tokenizer_fname), "wb")) # training process if train_model: if load_model: print('Load model...') - model = six.moves.cPickle.load(open(os.path.join(save_dir, model_load_fname), 'rb')) + model = cPickle.load(open(os.path.join(save_dir, model_load_fname), 'rb')) else: print('Build model...') model = Sequential() @@ -129,7 +129,7 @@ def text_generator(path=data_path): if couples: # one gradient update per sentence (one sentence = a few 1000s of word couples) X = np.array(couples, dtype="int32") - loss = model.train(X, labels) + loss = model.fit(X, labels) losses.append(loss) if len(losses) % 100 == 0: progbar.update(i, values=[("loss", np.mean(losses))]) @@ -142,7 +142,7 @@ def text_generator(path=data_path): print("Saving model...") if not os.path.exists(save_dir): os.makedirs(save_dir) - six.moves.cPickle.dump(model, open(os.path.join(save_dir, model_save_fname), "wb")) + cPickle.dump(model, open(os.path.join(save_dir, model_save_fname), "wb")) print("It's test time!")