Skip to content

Commit

Permalink
Cakechat refactoring (lukalabs#23)
Browse files Browse the repository at this point in the history
CakeChat refactoring
  • Loading branch information
Nicolas authored Jul 16, 2018
1 parent ec68708 commit 1efee48
Show file tree
Hide file tree
Showing 34 changed files with 530 additions and 448 deletions.
28 changes: 19 additions & 9 deletions cakechat/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os

from cakechat.utils.env import is_dev_env
from cakechat.utils.data_structures import create_namedtuple_instance
from cakechat.utils.env import is_dev_env

RANDOM_SEED = 42 # Fix the random seed to a certain value to make everything reproducable
RANDOM_SEED = 42 # Fix the random seed to a certain value to make everything reproducible

# AWS S3 params
S3_MODELS_BUCKET_NAME = 'cake-chat-data' # S3 bucket with all the data
Expand All @@ -15,6 +15,7 @@
# data params
DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data') # Directory to store all the data
# e.g. datasets, models, indices
NN_MODELS_DIR = os.path.join(DATA_DIR, 'nn_models') # Path to a directory for saving and restoring dialog models
PROCESSED_CORPUS_DIR = os.path.join(DATA_DIR, 'corpora_processed') # Path to a processed corpora datasets
TOKEN_INDEX_DIR = os.path.join(DATA_DIR, 'tokens_index') # Path to a prepared tokens index file
CONDITION_IDS_INDEX_DIR = os.path.join(DATA_DIR, 'conditions_index') # Path to a prepared conditions index file
Expand All @@ -24,12 +25,13 @@
TRAIN_CORPUS_NAME = 'train_' + BASE_CORPUS_NAME # Corpus name prefix for the training dataset
CONTEXT_SENSITIVE_VAL_CORPUS_NAME = 'val_' + BASE_CORPUS_NAME # Corpus name prefix for the validation dataset

VAL_SUBSET_SIZE = 250 # Subset from the validation dataset to be used in validation metrics calculation
MAX_VAL_LINES_NUM = 10000 # Max lines number from validation set to be used for metrics calculation
VAL_SUBSET_SIZE = 250 # Subset from the validation dataset to be used to calculated some validation metrics
TRAIN_SUBSET_SIZE = int(os.environ['SLICE_TRAINSET']) if 'SLICE_TRAINSET' in os.environ else None # Subset from the
# training dataset to be used during the training. In case of None use all lines in the train dataset (default behavior)

# test data paths
TEST_DATA_DIR = os.path.join(DATA_DIR, 'quality') # Path to datasets for quality metrics calculation
TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'quality')
CONTEXT_FREE_VAL_CORPUS_NAME = 'context_free_validation_set' # Context-free validation set path
TEST_CORPUS_NAME = 'context_free_test_set' # Context-free test set path
QUESTIONS_CORPUS_NAME = 'context_free_questions' # Context-free questions only path
Expand Down Expand Up @@ -61,9 +63,12 @@
OUTPUT_SEQUENCE_LENGTH = 32 # Output sequence length. Better to keep as INPUT_SEQUENCE_LENGTH+2 for start/end tokens
BATCH_SIZE = 192 # Default batch size which fits into 8GB of GPU memory
SHUFFLE_TRAINING_BATCHES = True # Shuffle training batches in the dataset each epoch
EPOCHES_NUM = 100 # Total epochs num
EPOCHS_NUM = 100 # Total epochs num
GRAD_CLIP = 5.0 # Gradient clipping passed into theano.gradient.grad_clip()
ADADELTA_LEARNING_RATE = 1.0 # Initial AdaDelta learning rate
LEARNING_RATE = 1.0 # Learning rate for the chosen optimizer (currently using Adadelta, see model.py)

# model params
NN_MODEL_PREFIX = 'cakechat' # Specify prefix to be prepended to model's name

# predictions params
MAX_PREDICTIONS_LENGTH = 40 # Max. number of tokens which can be generated on the prediction step
Expand Down Expand Up @@ -94,8 +99,10 @@
LOG_CANDIDATES_NUM = 10 # Number of candidates to be printed to output during the logging
SCREEN_LOG_NUM_TEST_LINES = 10 # Number of first test lines to use when logging outputs on screen
SCREEN_LOG_FREQUENCY_PER_BATCHES = 500 # How many batches to train until next logging of output on screen
LOG_FREQUENCY_PER_BATCHES = 2500 # How many batches to train until next logging of all the output into file
LOG_LOSS_DECAY = 0.99 # Decay for the averaging the loss which is printed in logs
LOG_TO_TB_FREQUENCY_PER_BATCHES = 500 # How many batches to train until next metrics computed for TensorBoard
LOG_TO_FILE_FREQUENCY_PER_BATCHES = 2500 # How many batches to train until next logging of all the output into file
SAVE_MODEL_FREQUENCY_PER_BATCHES = 2500 # How many batches to train until next logging of all the output into file
AVG_LOSS_DECAY = 0.99 # Decay for the averaging the loss

# Use reduced sizes for input/output sequences, hidden layers and datasets sizes for the 'Developer Mode'
if is_dev_env():
Expand All @@ -105,10 +112,13 @@
BATCH_SIZE = 128
HIDDEN_LAYER_DIMENSION = 7
SCREEN_LOG_FREQUENCY_PER_BATCHES = 2
LOG_FREQUENCY_PER_BATCHES = 3
LOG_TO_TB_FREQUENCY_PER_BATCHES = 3
LOG_TO_FILE_FREQUENCY_PER_BATCHES = 4
SAVE_MODEL_FREQUENCY_PER_BATCHES = 4
WORD_EMBEDDING_DIMENSION = 15
SAMPLES_NUM_FOR_RERANKING = BEAM_SIZE = 5
LOG_CANDIDATES_NUM = 3
USE_PRETRAINED_W2V_EMBEDDINGS_LAYER = False
VAL_SUBSET_SIZE = 100
MAX_VAL_LINES_NUM = 100
TRAIN_SUBSET_SIZE = 10000
27 changes: 14 additions & 13 deletions cakechat/dialog_model/factory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os

from cachetools import cached

from cakechat.config import BASE_CORPUS_NAME, S3_MODELS_BUCKET_NAME, S3_TOKENS_IDX_REMOTE_DIR, \
S3_NN_MODEL_REMOTE_DIR, S3_CONDITIONS_IDX_REMOTE_DIR
from cakechat.dialog_model.model import get_nn_model
from cakechat.utils.s3 import S3FileResolver
from cakechat.utils.files_utils import FileNotFoundException
from cakechat.utils.text_processing import get_index_to_token_path, load_index_to_item, get_index_to_condition_path


Expand All @@ -12,10 +15,10 @@ def _get_index_to_token(fetch_from_s3):
if fetch_from_s3:
tokens_idx_resolver = S3FileResolver(index_to_token_path, S3_MODELS_BUCKET_NAME, S3_TOKENS_IDX_REMOTE_DIR)
if not tokens_idx_resolver.resolve():
raise Exception('Can\'t get index_to_token because file does not exist at S3')
raise FileNotFoundException('Can\'t get index_to_token because file does not exist at S3')
else:
if not os.path.exists(index_to_token_path):
raise Exception('Can\'t get index_to_token because file does not exist. '
raise FileNotFoundException('Can\'t get index_to_token because file does not exist. '
'Run tools/download_model.py first to get all required files or construct it by yourself.')

return load_index_to_item(index_to_token_path)
Expand All @@ -27,30 +30,28 @@ def _get_index_to_condition(fetch_from_s3):
index_to_condition_resolver = S3FileResolver(index_to_condition_path, S3_MODELS_BUCKET_NAME,
S3_CONDITIONS_IDX_REMOTE_DIR)
if not index_to_condition_resolver.resolve():
raise Exception('Can\'t get index_to_condition because file does not exist at S3')
raise FileNotFoundException('Can\'t get index_to_condition because file does not exist at S3')
else:
if not os.path.exists(index_to_condition_path):
raise Exception('Can\'t get index_to_condition because file does not exist. '
raise FileNotFoundException('Can\'t get index_to_condition because file does not exist. '
'Run tools/download_model.py first to get all required files or construct it by yourself.')

return load_index_to_item(index_to_condition_path)


@cached(cache={})
def get_trained_model(reverse=False, fetch_from_s3=True):
if fetch_from_s3:
resolver_factory = S3FileResolver.init_resolver(
bucket_name=S3_MODELS_BUCKET_NAME, remote_dir=S3_NN_MODEL_REMOTE_DIR)
else:
resolver_factory = None

nn_model, model_exists = get_nn_model(
_get_index_to_token(fetch_from_s3),
_get_index_to_condition(fetch_from_s3),
resolver_factory=resolver_factory,
is_reverse_model=reverse)

nn_model, model_exists = get_nn_model(index_to_token=_get_index_to_token(fetch_from_s3),
index_to_condition=_get_index_to_condition(fetch_from_s3),
resolver_factory=resolver_factory,
is_reverse_model=reverse)
if not model_exists:
raise Exception('Can\'t get the model. '
'Run tools/download_model.py first to get all required files or train it by yourself.')

raise FileNotFoundException('Can\'t get the pre-trained model. Run tools/download_model.py first '
'to get all required files or train it by yourself.')
return nn_model
8 changes: 3 additions & 5 deletions cakechat/dialog_model/inference/candidates/beamsearch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from six.moves import zip_longest

import numpy as np
from six.moves import xrange
from six.moves import xrange, zip_longest
import theano

from cakechat.dialog_model.inference.candidates.abstract_generator import AbstractCandidatesGenerator
Expand Down Expand Up @@ -87,7 +85,7 @@ def _update_next_candidates_and_hidden_states(self, token_idx, best_non_finished
# We need to get which original candidate this token in the expanded beam corresponds to.
# (to fill in all the previous tokens from self._cur_candidates)
# Because all the candidates in the expanded beam were filled sequentially, we just use this formula:
original_candidate_idx = candidate_idx / self._beam_size
original_candidate_idx = candidate_idx // self._beam_size

# Construct the candidates for the next step using self._cur_candidates and the last token:

Expand Down Expand Up @@ -123,7 +121,7 @@ def _update_finished_candidates(self, token_idx, best_finished_candidates_indice
# to get all the other tokens we need to get which original candidate this token in the expanded beam
# corresponds to. Because all the candidates in the expanded beam were filled sequentially, we can just
# use this formula:
original_candidate_idx = candidate_idx / self._beam_size
original_candidate_idx = candidate_idx // self._beam_size

# Construct the candidates for the next step using self._cur_candidates and the last token:

Expand Down
13 changes: 1 addition & 12 deletions cakechat/dialog_model/inference/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,6 @@
from cakechat.dialog_model.inference.reranking import DummyReranker, MMIReranker


def _get_reverse_model():
if not hasattr(_get_reverse_model, 'reverse_model'):
try:
reverse_model = get_trained_model(reverse=True)
except:
raise ValueError('Can\'t get reverse nn model for prediction. '
'Try to run \'python tools/train.py --reverse\' or switch prediction mode to sampling.')
_get_reverse_model.reverse_model = reverse_model
return _get_reverse_model.reverse_model


def predictor_factory(nn_model, mode, config):
"""
Expand All @@ -39,7 +28,7 @@ def predictor_factory(nn_model, mode, config):
if config['mmi_reverse_model_score_weight'] <= 0:
raise ValueError('mmi_reverse_model_score_weight should be > 0 for reranking mode')

reverse_model = _get_reverse_model()
reverse_model = get_trained_model(reverse=True)
reranker = MMIReranker(nn_model, reverse_model, config['mmi_reverse_model_score_weight'],
config['repetition_penalization_coefficient'])
else:
Expand Down
40 changes: 22 additions & 18 deletions cakechat/dialog_model/inference/predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import numpy as np
from six.moves import xrange

from cakechat.config import MAX_PREDICTIONS_LENGTH, BEAM_SIZE, MMI_REVERSE_MODEL_SCORE_WEIGHT, DEFAULT_TEMPERATURE, \
SAMPLES_NUM_FOR_RERANKING, PREDICTION_MODES, REPETITION_PENALIZE_COEFFICIENT
from cakechat.dialog_model.inference.factory import predictor_factory
Expand Down Expand Up @@ -54,7 +51,7 @@ def get_nn_response_ids(context_token_ids,
"""
Predicts several responses for every context.
:param context_token_ids: np.array; shape=(batch_size x context_size x context_len); dtype=int
:param context_token_ids: np.array; shape (batch_size, context_size, context_len); dtype=int
Represents all tokens ids to use for predicting
:param nn_model: CakeChatModel
:param mode: one of PREDICTION_MODES mode
Expand All @@ -65,7 +62,7 @@ def get_nn_response_ids(context_token_ids,
:param output_seq_len: Number of tokens to generate.
:param kwargs: Other prediction parameters, passed into predictor constructor.
Might be different depending on mode. See PredictionConfig for the details.
:return: np.array; shape=(responses_num x output_candidates_num x output_seq_len); dtype=int
:return: np.array; shape (batch_size, output_candidates_num, output_seq_len); dtype=int
Generated predictions.
"""
if mode == PREDICTION_MODES.sampling:
Expand All @@ -75,8 +72,9 @@ def get_nn_response_ids(context_token_ids,
_logger.debug('Generating predicted response for the following params: %s' % prediction_config)

predictor = predictor_factory(nn_model, mode, prediction_config.get_options_dict())
return np.array(
predictor.predict_responses(context_token_ids, output_seq_len, condition_ids, output_candidates_num))
responses = predictor.predict_responses(context_token_ids, output_seq_len, condition_ids, output_candidates_num)

return responses


def get_nn_responses(context_token_ids,
Expand All @@ -87,19 +85,25 @@ def get_nn_responses(context_token_ids,
condition_ids=None,
**kwargs):
"""
Predicts several responses for every context and returns them as proccessed strings.
Predicts output_candidates_num responses for every context and returns them in form of strings.
See get_nn_response_ids for the details.
:return: list of lists of strings
Generated predictions.
:param context_token_ids: numpy array of integers, shape (contexts_num, INPUT_CONTEXT_SIZE, INPUT_SEQUENCE_LENGTH)
:param nn_model: trained model
:param mode: prediction mode, see const PREDICTION_MODES
:param output_candidates_num: number of responses to be generated for each context
:param output_seq_len: max length of generated responses
:param condition_ids: extra info to be taken into account while generating response (emotion, for example)
:return: list of lists of strings, shape (contexts_num, output_candidates_num)
"""
response_tokens_ids = get_nn_response_ids(context_token_ids, nn_model, mode, output_candidates_num, output_seq_len,
condition_ids, **kwargs)
# Reshape to get list of lines to supply into transform_token_ids_to_sentences
response_tokens_ids = np.reshape(response_tokens_ids, (-1, output_seq_len))
response_tokens = transform_token_ids_to_sentences(response_tokens_ids, nn_model.index_to_token)

lines_num = len(response_tokens) // output_candidates_num
responses = [response_tokens[i * output_candidates_num:(i + 1) * output_candidates_num] for i in xrange(lines_num)]

response_tokens_ids = get_nn_response_ids(context_token_ids, nn_model, mode, output_candidates_num,
output_seq_len, condition_ids, **kwargs)
# shape (contexts_num, output_candidates_num, output_seq_len), numpy array of integers

responses = [transform_token_ids_to_sentences(response_candidates_tokens_ids, nn_model.index_to_token)
for response_candidates_tokens_ids in response_tokens_ids]
# responses shape (contexts_num, output_candidates_num), list of lists of strings

return responses
11 changes: 6 additions & 5 deletions cakechat/dialog_model/inference/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def _select_best_candidates(reranked_candidates, candidates_num):
If for some context we generated less then candidates_num candidates, we fill this responses with pads.
"""
batch_size = len(reranked_candidates)
# reranked_candidates is list of lists (we need too keep it this way because we can have different number
# of candidates for each context), so we can't just write rerankied_candidates.shape[2]
# reranked_candidates is a list of lists (we need too keep it this way because we can have different number
# of candidates for each context), so we can't just write reranked_candidates.shape[2]
output_seq_len = reranked_candidates[0][0].size
result = np.zeros((batch_size, candidates_num, output_seq_len))
# Loop here instead of slices because number of candidates for each context can vary here
result = np.zeros((batch_size, candidates_num, output_seq_len), dtype=np.int32)
# Loop here instead of slices because number of candidates for each context may vary here
for i in xrange(batch_size):
for j, candidate in enumerate(reranked_candidates[i]):
if j >= candidates_num:
Expand All @@ -30,4 +30,5 @@ def _select_best_candidates(reranked_candidates, candidates_num):
def predict_responses(self, context_token_ids, output_seq_len, condition_ids=None, candidates_num=1):
all_candidates = self._generator.generate_candidates(context_token_ids, condition_ids, output_seq_len)
reranked_candidates = self._reranker.rerank_candidates(context_token_ids, all_candidates, condition_ids)
return self._select_best_candidates(reranked_candidates, candidates_num)
selected_responses = self._select_best_candidates(reranked_candidates, candidates_num)
return selected_responses
10 changes: 2 additions & 8 deletions cakechat/dialog_model/inference/reranking.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from abc import ABCMeta, abstractmethod
from six.moves import zip_longest

import numpy as np
from six.moves import xrange
from six.moves import xrange, zip_longest

from cakechat.dialog_model.inference.service_tokens import ServiceTokensIDs
from cakechat.dialog_model.inference.utils import get_sequence_score_by_thought_vector, get_sequence_score, \
get_thought_vectors
from cakechat.dialog_model.model_utils import reverse_nn_input
from cakechat.utils.dataset_loader import Dataset
from cakechat.utils.data_types import Dataset
from cakechat.utils.logger import get_logger
from cakechat.utils.profile import timer

Expand Down Expand Up @@ -51,21 +50,18 @@ def __init__(self, nn_model, reverse_model, mmi_reverse_model_score_weight, repe
self._service_tokens_ids = ServiceTokensIDs(nn_model.token_to_index)
self._log_repetition_penalization_coefficient = np.log(repetition_penalization_coefficient)

@timer
def _compute_likelihood_of_output_given_input(self, thought_vector, candidates, condition_id):
# Repeat to get same thought vector for each candidate
thoughts_batch = np.repeat(thought_vector, candidates.shape[0], axis=0)
return get_sequence_score_by_thought_vector(self._nn_model, thoughts_batch, candidates, condition_id)

@timer
def _compute_likelihood_of_input_given_output(self, context, candidates, condition_id):
# Repeat to get same context for each candidate
repeated_context = np.repeat(context, candidates.shape[0], axis=0)
reversed_dataset = reverse_nn_input(
Dataset(x=repeated_context, y=candidates, condition_ids=None), self._service_tokens_ids)
return get_sequence_score(self._reverse_model, reversed_dataset.x, reversed_dataset.y, condition_id)

@timer
def _compute_num_repetitions(self, candidates):
skip_tokens_ids = \
self._service_tokens_ids.special_tokens_ids + self._service_tokens_ids.non_penalizable_tokens_ids
Expand All @@ -76,9 +72,7 @@ def _compute_num_repetitions(self, candidates):
result.append(num_repetitions)
return np.array(result)

@timer
def _compute_candidates_scores(self, context, candidates, condition_id):
_logger.info('Reranking {} candidates...'.format(candidates.shape[0]))
context = context[np.newaxis, :] # from (seq_len,) to (1 x seq_len)
thought_vector = get_thought_vectors(self._nn_model, context)

Expand Down
1 change: 1 addition & 0 deletions cakechat/dialog_model/inference/tests/predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import unittest

import numpy as np
from six.moves import xrange

Expand Down
Loading

0 comments on commit 1efee48

Please sign in to comment.