Skip to content

Commit

Permalink
Adding bytes compatability to new_branch
Browse files Browse the repository at this point in the history
  • Loading branch information
noahnewberger committed Aug 7, 2020
1 parent 52ae2ba commit a2f0c59
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 6 deletions.
30 changes: 30 additions & 0 deletions punctuator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,36 @@ def load(file_path, minibatch_size, x, p=None):
return net, (gsums, state["learning_rate"], state["validation_ppl_history"], state["epoch"], rng)


def loads(file_bytes, minibatch_size, x, p=None):

state = cPickle.loads(file_bytes, **cpickle_options)

logging.info('Looking up %s.', state["type"])
# Model = getattr(models, state["type"])
Model = globals()[state["type"]]

rng = np.random
rng.set_state(state["random_state"])

net = Model(
rng=rng,
x=x,
minibatch_size=minibatch_size,
n_hidden=state["n_hidden"],
x_vocabulary=state["x_vocabulary"],
y_vocabulary=state["y_vocabulary"],
stage1_model_file_name=state.get("stage1_model_file_name", None),
p=p
)

for net_param, state_param in zip(net.params, state["params"]):
net_param.set_value(state_param, borrow=True)

gsums = [theano.shared(gsum) for gsum in state["gsums"]] if state["gsums"] else None

return net, (gsums, state["learning_rate"], state["validation_ppl_history"], state["epoch"], rng)


class GRULayer:

def __init__(self, rng, n_in, n_out, minibatch_size):
Expand Down
14 changes: 10 additions & 4 deletions punctuator/punc.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def restore(output_file, text, word_vocabulary, reverse_punctuation_vocabulary,
class Punctuator:

def model_exists(self, fn):
if isinstance(fn, bytes):
return fn
if os.path.isfile(fn):
return fn
_fn = os.path.join(PUNCTUATOR_DATA_DIR, fn)
Expand All @@ -186,16 +188,20 @@ def __init__(self, model_file, use_pauses=False):
p = T.matrix('p')

logging.info("Loading model parameters...")
net, _ = models.load(model_file, 1, x, p)

if isinstance(model_file, bytes):
net, _ = models.loads(model_file, 1, x, p)
else:
net, _ = models.load(model_file, 1, x, p)
logging.info("Building model...")
self.predict = theano.function(inputs=[x, p], outputs=net.y)

else:

logging.info("Loading model parameters...")
net, _ = models.load(model_file, 1, x)

if isinstance(model_file, bytes):
net, _ = models.loads(model_file, 1, x)
else:
net, _ = models.load(model_file, 1, x)
logging.info("Building model...")
self.predict = theano.function(inputs=[x], outputs=net.y)

Expand Down
10 changes: 8 additions & 2 deletions punctuator/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,18 @@ def test_punctuate(self):
if not os.path.isfile(model_file):
model_file = download_model()
print('Model file:', model_file)

# Check if file can be read in as bytes
infile = open(model_file, 'rb')
data = infile.read()
t0 = time.time()
p = Punctuator(data)
td = time.time() - t0
print('Loaded in %s seconds as bytes.' % td)
# Create punctuator.
t0 = time.time()
p = Punctuator(model_file=model_file)
td = time.time() - t0
print('Loaded in %s seconds.' % td)
print('Loaded in %s seconds from path.' % td)

# Add punctuation.
for input_text, expect_output_text in samples:
Expand Down

0 comments on commit a2f0c59

Please sign in to comment.