Skip to content

Commit

Permalink
Peek across (allenai#327)
Browse files Browse the repository at this point in the history
* added peek across after first noisy or

* added peek-across and subtract minimum layer

* pylint fixes

* Fixes recommended in code review

* Added get_config

* Fixing a few bugs

* Fixed some dtype issues with masks
  • Loading branch information
BeckySharp authored and matt-gardner committed Apr 30, 2017
1 parent 7f008e4 commit 0b6a735
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 1 deletion.
2 changes: 1 addition & 1 deletion deep_qa/layers/backend/add_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, mask_value: float=0.0, **kwargs):

@overrides
def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument
return K.cast(K.not_equal(inputs, self.mask_value), 'uint8')
return K.cast(K.not_equal(inputs, self.mask_value), 'bool')

@overrides
def compute_output_shape(self, input_shape):
Expand Down
57 changes: 57 additions & 0 deletions deep_qa/layers/subtract_minimum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from keras import backend as K
from overrides import overrides

from deep_qa.layers.masked_layer import MaskedLayer
from deep_qa.tensors.backend import VERY_LARGE_NUMBER

class SubtractMinimum(MaskedLayer):
'''
This layer is used to normalize across a tensor axis. Normalization is done by finding the
minimum value across the specified axis, and then subtracting that value from all values
(again, across the spcified axis). Note that this also works just fine if you want to find the
minimum across more than one axis.
Inputs:
- A tensor with arbitrary dimension, and a mask of the same shape (currently doesn't
support masks with other shapes).
Output:
- The same tensor, with the minimum across one (or more) of the dimensions subtracted.
Parameters
----------
axis: int
The axis (or axes) across which to find the minimum. Can be a single int, a list of ints,
or None. We just call `K.min` with this parameter, so anything that's valid there works
here too.
'''
def __init__(self, axis: int, **kwargs):
self.axis = axis
super(SubtractMinimum, self).__init__(**kwargs)

@overrides
def compute_output_shape(self, input_shape): # pylint: disable=no-self-use
return input_shape

@overrides
def compute_mask(self, inputs, mask=None):
return mask

@overrides
def call(self, inputs, mask=None):
if mask is not None:
mask_value = False if K.dtype(mask) == 'bool' else 0
# Make sure masked values don't affect the input, by adding a very large number.
mask_flipped_and_scaled = K.cast(K.equal(mask, mask_value), "float32") * VERY_LARGE_NUMBER
minimums = K.min(inputs + mask_flipped_and_scaled, axis=self.axis, keepdims=True)
else:
minimums = K.min(inputs, axis=self.axis, keepdims=True)
normalized = inputs - minimums
return normalized

@overrides
def get_config(self):
base_config = super(SubtractMinimum, self).get_config()
config = {'axis': self.axis}
config.update(base_config)
return config
15 changes: 15 additions & 0 deletions deep_qa/models/multiple_choice_qa/tuple_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ...layers import NoisyOr
from ...layers.attention import MaskedSoftmax
from ...layers.backend import Repeat
from ...layers.subtract_minimum import SubtractMinimum
from ...layers.tuple_matchers import tuple_matchers, WordOverlapTupleMatcher
from ...layers.wrappers import TimeDistributedWithMask
from ...training import TextTrainer
Expand Down Expand Up @@ -49,6 +50,12 @@ class TupleInferenceModel(TextTrainer):
num_options: int, default=4
The number of answer options/candidates.
normalize_tuples_across_answers: bool, default=False
Whether or not to normalize each question tuple's score across the answer options. This
assumes that the tuples are in the same order for all answer options. Normalization is
currently done by subtracting the minimum score for a given tuple "position" from all the
tuples in that position.
display_text_wrap: int, default=150
This is used by the debug output methods to wrap long tuple strings.
Expand All @@ -64,6 +71,7 @@ def __init__(self, params: Params):
self.num_tuple_slots = params.pop('num_tuple_slots', 4)
self.num_slot_words = params.pop('num_sentence_words', 5)
self.num_options = params.pop('num_answer_options', 4)
self.normalize_tuples_across_answers = params.pop('normalize_tuples_across_answers', False)
self.display_text_wrap = params.pop('display_text_wrap', 150)
self.display_num_tuples = params.pop('display_num_tuples', 5)
tuple_matcher_params = params.pop('tuple_matcher', {})
Expand Down Expand Up @@ -100,6 +108,7 @@ def _get_custom_objects(cls):
custom_objects['NoisyOr'] = NoisyOr
custom_objects['Repeat'] = Repeat
custom_objects['TimeDistributedWithMask'] = TimeDistributedWithMask
custom_objects['SubtractMinimum'] = SubtractMinimum
return custom_objects

@overrides
Expand Down Expand Up @@ -186,6 +195,12 @@ def _build_model(self):
combine_background_evidence.name = "noisy_or_1"
qi_probabilities = combine_background_evidence(matches)

# If desired, peek across the options, and normalize the amount that a given answer tuple template "counts"
# towards a correct answer.
if self.normalize_tuples_across_answers:
normalize_across_options = SubtractMinimum(axis=1)
qi_probabilities = normalize_across_options(qi_probabilities)

# Find the probability that any given option is correct, given the entailement scores of each of its
# question tuples given the set of background tuples.
# shape: (batch size, num_options)
Expand Down
8 changes: 8 additions & 0 deletions doc/layers/core_layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ RecurrenceModes
:undoc-members:
:show-inheritance:

SubtractMinimum
---------------

.. automodule:: deep_qa.layers.subtract_minimum
:members:
:undoc-members:
:show-inheritance:

TimeDistributedEmbedding
------------------------

Expand Down
44 changes: 44 additions & 0 deletions tests/layers/test_subtract_minimum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# pylint: disable=no-self-use
import numpy as np
from numpy.testing import assert_array_almost_equal

from keras.layers import Input
from keras.models import Model
from deep_qa.layers.subtract_minimum import SubtractMinimum
from deep_qa.layers.backend.add_mask import AddMask
from ..common.test_case import DeepQaTestCase


class TestSubtractMinimum(DeepQaTestCase):
def test_general_case(self):

input_layer = Input(shape=(4, 3,), dtype='float32', name="input")
subtract_minimum_layer = SubtractMinimum(axis=1)
normalized_input = subtract_minimum_layer(input_layer)

model = Model([input_layer], normalized_input)
# Testing general unmasked 1D case.
unnormalized_tensor = np.array([[[0.1, 0.1, 0.1],
[0.2, 0.3, 0.4],
[0.5, 0.4, 0.6],
[0.5, 0.4, 0.6]]])
result = model.predict([unnormalized_tensor])

assert_array_almost_equal(result, np.array([[[0.0, 0.0, 0.0],
[0.1, 0.2, 0.3],
[0.4, 0.3, 0.5],
[0.4, 0.3, 0.5]]]))

# Testing masked batched case.
# By setting the mast value to 0.1. should ignore this value when deciding the minimum
mask_layer = AddMask(mask_value=0.1)
masked_input = mask_layer(input_layer)
normalized_masked_input = subtract_minimum_layer(masked_input)
masking_model = Model([input_layer], normalized_masked_input)

masked_result = masking_model.predict([unnormalized_tensor])

assert_array_almost_equal(masked_result, np.array([[[-0.1, -0.2, -0.3],
[0.0, 0.0, 0.0],
[0.3, 0.1, 0.2],
[0.3, 0.1, 0.2]]]))
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ def test_model_trains_and_loads_correctly(self):
"num_tuple_slots": 4,
"num_sentence_words": 10,
"num_answer_options": 4,
"normalize_tuples_across_answers": True,
"save_models": True})
self.ensure_model_trains_and_loads(TupleInferenceModel, args)

0 comments on commit 0b6a735

Please sign in to comment.