Skip to content

Commit

Permalink
Fix memory network
Browse files Browse the repository at this point in the history
  • Loading branch information
ejls committed Jul 24, 2015
1 parent 60e6bc6 commit 7dab7e4
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 50 deletions.
54 changes: 54 additions & 0 deletions config/memory_network_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from blocks.initialization import IsotropicGaussian, Constant

from blocks.bricks import Tanh

import data
from model.memory_network_mlp import Model, Stream

n_begin_end_pts = 5

dim_embeddings = [
('origin_call', data.origin_call_train_size, 10),
('origin_stand', data.stands_size, 10),
('week_of_year', 52, 10),
('day_of_week', 7, 10),
('qhour_of_day', 24 * 4, 10),
('day_type', 3, 10),
]

embed_weights_init = IsotropicGaussian(0.001)

class MLPConfig(object):
__slots__ = ('dim_input', 'dim_hidden', 'dim_output', 'weights_init', 'biases_init', 'embed_weights_init', 'dim_embeddings')

prefix_encoder = MLPConfig()
prefix_encoder.dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings)
prefix_encoder.dim_hidden = [100, 100]
prefix_encoder.weights_init = IsotropicGaussian(0.01)
prefix_encoder.biases_init = Constant(0.001)
prefix_encoder.embed_weights_init = embed_weights_init
prefix_encoder.dim_embeddings = dim_embeddings

candidate_encoder = MLPConfig()
candidate_encoder.dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings)
candidate_encoder.dim_hidden = [100, 100]
candidate_encoder.weights_init = IsotropicGaussian(0.01)
candidate_encoder.biases_init = Constant(0.001)
candidate_encoder.embed_weights_init = embed_weights_init
candidate_encoder.dim_embeddings = dim_embeddings

representation_size = 100
representation_activation = Tanh

normalize_representation = True


batch_size = 32
batch_sort_size = 20

max_splits = 100
num_cuts = 1000

train_candidate_size = 100
valid_candidate_size = 100
test_candidate_size = 100
17 changes: 2 additions & 15 deletions data/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import fuel

from fuel.schemes import ConstantScheme
from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack
from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack, FilterSources

import data

Expand All @@ -22,20 +22,7 @@ def at_least_k(k, v, pad_at_begin, is_longitude):
v = numpy.concatenate((v, numpy.full((k - len(v),), v[-1])))
return v


class Select(Transformer):
produces_examples = True

def __init__(self, data_stream, sources):
super(Select, self).__init__(data_stream)
self.ids = [data_stream.sources.index(source) for source in sources]
self.sources=sources

def get_data(self, request=None):
if request is not None:
raise ValueError
data=next(self.child_epoch_iterator)
return [data[id] for id in self.ids]
Select = FilterSources

class TaxiExcludeTrips(Transformer):
produces_examples = True
Expand Down
58 changes: 44 additions & 14 deletions model/memory_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ def __init__(self, config, prefix_encoder, candidate_encoder, **kwargs):
self.children = [ self.softmax, prefix_encoder, candidate_encoder ]

self.inputs = self.prefix_encoder.apply.inputs \
+ ['candidate_%s'%x for x in self.candidate_encoder.apply.inputs]
+ ['candidate_%s'%x for x in self.candidate_encoder.apply.inputs] \
+ ['candidate_destination_latitude', 'candidate_destination_longitude']

def candidate_destination(**kwargs):
def candidate_destination(self, **kwargs):
return tensor.concatenate(
(tensor.shape_padright(kwargs['candidate_last_k_latitude'][:,-1]),
tensor.shape_padright(kwargs['candidate_last_k_longitude'][:,-1])),
(tensor.shape_padright(kwargs['candidate_destination_latitude']),
tensor.shape_padright(kwargs['candidate_destination_longitude'])),
axis=1)

@application(outputs=['cost'])
Expand All @@ -43,10 +44,8 @@ def cost(self, **kwargs):

@application(outputs=['destination'])
def predict(self, **kwargs):
prefix_representation = self.prefix_encoder.apply(
{ x: kwargs[x] for x in self.prefix_encoder.apply.inputs })
candidate_representatin = self.candidate_encoder.apply(
{ x: kwargs['candidate_'+x] for x in self.candidate_encoder.apply.inputs })
prefix_representation = self.prefix_encoder.apply(**{ x: kwargs[x] for x in self.prefix_encoder.apply.inputs })
candidate_representation = self.candidate_encoder.apply(**{ x: kwargs['candidate_'+x] for x in self.candidate_encoder.apply.inputs })

if self.config.normalize_representation:
prefix_representation = prefix_representation \
Expand Down Expand Up @@ -130,12 +129,16 @@ def __init__(self, config):

def candidate_stream(self, n_candidates):
candidate_stream = DataStream(self.train_dataset,
iteration_scheme=ShuffledExampleScheme(dataset.num_examples))
candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids)
iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
if not data.tvt:
candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids)
candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
candidate_stream = transformers.taxi_add_datetime(candidate_stream)
candidate_stream = transformers.taxi_add_first_last_len(candidate_stream,
self.config.n_begin_end_pts)
if not data.tvt:
candidate_stream = transformers.add_destination(candidate_stream)

return Batch(candidate_stream,
iteration_scheme=ConstantScheme(n_candidates))

Expand Down Expand Up @@ -180,6 +183,27 @@ def valid(self, req_vars):
stream = MultiProcessing(stream)
return stream

def test(self, req_vars):
prefix_stream = DataStream(
self.test_dataset,
iteration_scheme=SequentialExampleScheme(self.test_dataset.num_examples))
prefix_stream = transformers.taxi_add_datetime(prefix_stream)
prefix_stream = transformers.taxi_add_first_last_len(prefix_stream,
self.config.n_begin_end_pts)

if not data.tvt:
prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream)

prefix_stream = Batch(prefix_stream,
iteration_scheme=ConstantScheme(self.config.batch_size))

candidate_stream = self.candidate_stream(self.config.test_candidate_size)

sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources)
stream = Merge((prefix_stream, candidate_stream), sources)
stream = transformers.Select(stream, tuple(req_vars))
stream = MultiProcessing(stream)
return stream

class StreamRecurrent(StreamBase):
def __init__(self, config):
Expand All @@ -194,10 +218,14 @@ def __init__(self, config):
def candidate_stream(self, n_candidates):
candidate_stream = DataStream(self.train_dataset,
iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids)
if not data.tvt:
candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids)
candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
candidate_stream = transformers.taxi_add_datetime(candidate_stream)

if not data.tvt:
candidate_stream = transformers.add_destination(candidate_stream)

candidate_stream = Batch(candidate_stream,
iteration_scheme=ConstantScheme(n_candidates))

Expand All @@ -210,7 +238,8 @@ def train(self, req_vars):
prefix_stream = DataStream(self.train_dataset,
iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))

prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids)
if not data.tvt:
prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids)
prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
prefix_stream = transformers.TaxiGenerateSplits(prefix_stream,
max_splits=self.config.max_splits)
Expand Down Expand Up @@ -238,7 +267,7 @@ def valid(self, req_vars):
self.valid_dataset,
iteration_scheme=SequentialExampleScheme(self.valid_dataset.num_examples))

prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
#prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)

prefix_stream = transformers.taxi_add_datetime(prefix_stream)

Expand All @@ -262,7 +291,8 @@ def test(self, req_vars):
iteration_scheme=SequentialExampleScheme(self.test_dataset.num_examples))

prefix_stream = transformers.taxi_add_datetime(prefix_stream)
prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream)
if not data.tvt:
prefix_stream = transformers.taxi_remove_test_only_clients(prefix_stream)

prefix_stream = Batch(prefix_stream,
iteration_scheme=ConstantScheme(self.config.batch_size))
Expand Down
17 changes: 10 additions & 7 deletions model/memory_network_bidir.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,25 @@ def apply(self, latitude, longitude, latitude_mask, **kwargs):

return outputs

@apply.property('inputs')
def apply_inputs(self):
return self.inputs


class Model(MemoryNetworkBase):
def __init__(self, config, **kwargs):

# Build prefix encoder : recurrent then MLP
prefix_encoder = RecurrentEncoder(self.config.prefix_encoder,
self.config.representation_size,
self.config.representation_activation(),
prefix_encoder = RecurrentEncoder(config.prefix_encoder,
config.representation_size,
config.representation_activation(),
name='prefix_encoder')

# Build candidate encoder
candidate_encoder = RecurrentEncoder(self.config.candidate_encoder,
self.config.representation_size,
self.config.representation_activation(),
candidate_encoder = RecurrentEncoder(config.candidate_encoder,
config.representation_size,
config.representation_activation(),
name='candidate_encoder')

# And... that's it!
super(Model, self).__init__(config, prefix_encoder, candidate_encoder, **kwargs)

28 changes: 14 additions & 14 deletions model/memory_network_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@

class MLPEncoder(Initializable):
def __init__(self, config, output_dim, activation, **kwargs):
super(RecurrentEncoder, self).__init__(**kwargs)
super(MLPEncoder, self).__init__(**kwargs)

self.config = config
self.context_embedder = ContextEmbedder(self.config)

self.encoder_mlp = MLP(activations=[Rectifier() for _ in config.prefix_encoder.dim_hidden]
+ [config.representation_activation()],
dims=[config.prefix_encoder.dim_input]
+ config.prefix_encoder.dim_hidden
+ [config.representation_size],
name='prefix_encoder')
self.encoder_mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden]
+ [activation()],
dims=[config.dim_input]
+ config.dim_hidden
+ [output_dim],
name='encoder')

self.extremities = {'%s_k_%s' % (side, ['latitude', 'longitude'][axis]): axis
for side in ['first', 'last'] for axis in [0, 1]}
Expand All @@ -37,7 +37,7 @@ def __init__(self, config, output_dim, activation, **kwargs):
self.encoder_mlp ]

def _push_initialization_config(self):
for brick in [self.contex_encoder, self.encoder_mlp]:
for brick in [self.context_embedder, self.encoder_mlp]:
brick.weights_init = self.config.weights_init
brick.biases_init = self.config.biases_init

Expand All @@ -46,7 +46,7 @@ def apply(self, **kwargs):
embeddings = tuple(self.context_embedder.apply(
**{k: kwargs[k] for k in self.context_embedder.inputs }))
extremities = tuple((kwargs[k] - data.train_gps_mean[v]) / data.train_gps_std[v]
for k, v in self.prefix_extremities.items())
for k, v in self.extremities.items())
inputs = tensor.concatenate(extremities + embeddings, axis=1)

return self.encoder_mlp.apply(inputs)
Expand All @@ -60,12 +60,12 @@ class Model(MemoryNetworkBase):
def __init__(self, config, **kwargs):
prefix_encoder = MLPEncoder(config.prefix_encoder,
config.representation_size,
config.representation_activation())
config.representation_activation,
name='prefix_encoder')

candidate_encoer = MLPEncoder(config.candidate_encoder,
candidate_encoder = MLPEncoder(config.candidate_encoder,
config.representation_size,
config.representation_activation())
config.representation_activation,
name='candidate_encoder')

super(Model, self).__init__(config, prefix_encoder, candidate_encoder, **kwargs)


0 comments on commit 7dab7e4

Please sign in to comment.