Skip to content

Commit

Permalink
Fixed import errors with six.moves.cPickle and model.train typo in th…
Browse files Browse the repository at this point in the history
…e skipgram embeddings example
  • Loading branch information
anjishnu committed Sep 5, 2015
1 parent 2e60c99 commit 0348223
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions examples/skipgram_word_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

import numpy as np
import theano
import six.moves.cPickle
from six.moves import cPickle
import os, re, json

from keras.preprocessing import sequence, text
Expand Down Expand Up @@ -90,7 +90,7 @@ def text_generator(path=data_path):
# model management
if load_tokenizer:
print('Load tokenizer...')
tokenizer = six.moves.cPickle.load(open(os.path.join(save_dir, tokenizer_fname), 'rb'))
tokenizer = cPickle.load(open(os.path.join(save_dir, tokenizer_fname), 'rb'))
else:
print("Fit tokenizer...")
tokenizer = text.Tokenizer(nb_words=max_features)
Expand All @@ -99,13 +99,13 @@ def text_generator(path=data_path):
print("Save tokenizer...")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
six.moves.cPickle.dump(tokenizer, open(os.path.join(save_dir, tokenizer_fname), "wb"))
cPickle.dump(tokenizer, open(os.path.join(save_dir, tokenizer_fname), "wb"))

# training process
if train_model:
if load_model:
print('Load model...')
model = six.moves.cPickle.load(open(os.path.join(save_dir, model_load_fname), 'rb'))
model = cPickle.load(open(os.path.join(save_dir, model_load_fname), 'rb'))
else:
print('Build model...')
model = Sequential()
Expand All @@ -129,7 +129,7 @@ def text_generator(path=data_path):
if couples:
# one gradient update per sentence (one sentence = a few 1000s of word couples)
X = np.array(couples, dtype="int32")
loss = model.train(X, labels)
loss = model.fit(X, labels)
losses.append(loss)
if len(losses) % 100 == 0:
progbar.update(i, values=[("loss", np.mean(losses))])
Expand All @@ -142,7 +142,7 @@ def text_generator(path=data_path):
print("Saving model...")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
six.moves.cPickle.dump(model, open(os.path.join(save_dir, model_save_fname), "wb"))
cPickle.dump(model, open(os.path.join(save_dir, model_save_fname), "wb"))


print("It's test time!")
Expand Down

0 comments on commit 0348223

Please sign in to comment.