Skip to content

Commit

Permalink
Merge pull request #3 from Cerenaut/replay-all_runs
Browse files Browse the repository at this point in the history
Replay all runs
  • Loading branch information
abdel authored Dec 15, 2020
2 parents 1af1082 + 6762782 commit f4e9374
Show file tree
Hide file tree
Showing 19 changed files with 1,633 additions and 314 deletions.
13 changes: 13 additions & 0 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

115 changes: 110 additions & 5 deletions aha/components/episodic_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from pagi.utils.dual import DualData
from pagi.utils.hparam_multi import HParamMulti
from pagi.utils.layer_utils import type_activation_fn
from pagi.utils.np_utils import np_uniform
from pagi.utils.tf_utils import tf_build_interpolate_distributions

from pagi.components.summarize_levels import SummarizeLevels
from pagi.components.composite_component import CompositeComponent
Expand All @@ -49,6 +51,7 @@
from aha.components.diff_plasticity_component import DifferentiablePlasticityComponent

from aha.utils.interest_filter import InterestFilter
from aha.utils.generic_utils import normalize_minmax


class PCMode(enum.Enum):
Expand Down Expand Up @@ -200,6 +203,10 @@ def is_build_ll_vc(self):
def is_build_ll_pc(self):
return self._hparams.ll_pc_type != 'none'

def is_build_ll_ensemble(self):
build_ll_ensemble = True
return self.is_build_ll_vc() and self.is_build_ll_pc() and build_ll_ensemble

def is_build_pc(self):
return self._hparams.pc_type != 'none'

Expand Down Expand Up @@ -333,6 +340,10 @@ def _build_ll_vc(self, target_output, train_input, test_input, name='ll_vc'):
"""Build the label learning component for LTM."""
ll_vc = None

# Don't normalize this yet
train_input = normalize_minmax(train_input)
test_input = normalize_minmax(test_input)

if self._hparams.ll_vc_type == 'fc':
ll_vc = LabelLearnerFC()
self._add_sub_component(ll_vc, name)
Expand All @@ -346,6 +357,9 @@ def _build_ll_pc(self, target_output, train_input, test_input, name='ll_pc'):
"""Build the label learning component for PC."""
ll_pc = None

train_input = normalize_minmax(train_input)
test_input = normalize_minmax(test_input)

if self._hparams.ll_pc_type == 'fc':
ll_pc = LabelLearnerFC()
self._add_sub_component(ll_pc, name)
Expand Down Expand Up @@ -480,6 +494,47 @@ def _build_pc(self, input_next, input_next_vis_shape, dg_sparsity):

return pc_output, pc_output_shape

def _build_ll_ensemble(self):
"""Builds ensemble of VC and PC classifiers."""
distributions = []
distribution_mass = []
num_classes = self._label_values.get_shape().as_list()[-1]

aha_mass = 0.495
ltm_mass = 0.495
uniform_mass = 0.01

if aha_mass > 0.0:
aha_prediction = self.get_ll_pc().get_op('preds')
distributions.append(aha_prediction)
distribution_mass.append(aha_mass)

if ltm_mass > 0.0:
ltm_prediction = self.get_ll_vc().get_op('preds')
distributions.append(ltm_prediction)
distribution_mass.append(ltm_mass)

if uniform_mass > 0.0:
uniform = np_uniform(num_classes)
distributions.append(uniform)
distribution_mass.append(uniform_mass)

unseen_sum = 1
unseen_idxs = (0, unseen_sum)

# Build the final distribution, calculate loss
ensemble_preds = tf_build_interpolate_distributions(distributions, distribution_mass, num_classes)

ensemble_correct_preds = tf.equal(tf.argmax(ensemble_preds, 1), tf.argmax(self._label_values, 1))
ensemble_correct_preds = tf.cast(ensemble_correct_preds, tf.float32)

ensemble_accuracy = tf.reduce_mean(ensemble_correct_preds)
ensemble_accuracy_unseen = tf.reduce_mean(ensemble_correct_preds[unseen_idxs[0]:unseen_idxs[1]])

self._dual.set_op('ensemble_preds', ensemble_preds)
self._dual.set_op('ensemble_accuracy', ensemble_accuracy)
self._dual.set_op('ensemble_accuracy_unseen', ensemble_accuracy_unseen)

def build(self, input_values, input_shape, hparams, label_values=None, name='episodic'):
"""Initializes the model parameters.
Expand All @@ -500,40 +555,78 @@ def build(self, input_values, input_shape, hparams, label_values=None, name='epi
self._input_shape = input_shape
self._label_values = label_values

self.set_signal('input', input_values, input_shape)

input_area = np.prod(input_shape[1:])

logging.debug('Input Shape: %s', input_shape)
logging.debug('Input Area: %s', input_area)

with tf.variable_scope(self._name, reuse=tf.AUTO_REUSE):

# Replay mode
# ------------------------------------------------------------------------
replay_mode = 'pixel' # pixel or encoding
replay = self._dual.add('replay', shape=[], default_value=False).add_pl(
default=True, dtype=tf.bool)

# Replace labels during replay
replay_labels = self._dual.add('replay_labels', shape=label_values.shape, default_value=0.0).add_pl(
default=True, dtype=label_values.dtype)

self._label_values = tf.cond(tf.equal(replay, True), lambda: replay_labels, lambda: self._label_values)

# Replay pixel inputs during replay, if using 'pixel' replay mode
if replay_mode == 'pixel':
replay_inputs = self._dual.add('replay_inputs', shape=input_values.shape, default_value=0.0).add_pl(
default=True, dtype=input_values.dtype)

self._input_values = tf.cond(tf.equal(replay, True), lambda: replay_inputs, lambda: self._input_values)

self.set_signal('input', self._input_values, self._input_shape)

# Build the LTM
# ------------------------------------------------------------------------

# Optionally degrade input to VC
degrade_step_pl = self._dual.add('degrade_step', shape=[], # e.g. hidden, input, none
default_value='none').add_pl(default=True, dtype=tf.string)
degrade_random_pl = self._dual.add('degrade_random', shape=[],
default_value=0.0).add_pl(default=True, dtype=tf.float32)
input_values = self.degrader(degrade_step_pl, self._degrade_type, degrade_random_pl, input_values,
input_values = self.degrader(degrade_step_pl, self._degrade_type, degrade_random_pl, self._input_values,
degrade_step='input', name='vc_input_values')

print('vc', 'input', input_values)
self.set_signal('vc_input', input_values, input_shape)

# Build the VC
input_next, input_next_vis_shape = self._build_vc(input_values, input_shape)

vc_encoding = input_next

# Replace the encoding during replay, if using 'encoding' replay mode
if replay_mode == 'encoding':
replay_inputs = self._dual.add('replay_inputs', shape=vc_encoding.shape, default_value=0.0).add_pl(
default=True, dtype=vc_encoding.dtype)

vc_encoding = tf.cond(tf.equal(replay, True), lambda: replay_inputs, lambda: vc_encoding)

self.set_signal('vc', vc_encoding, input_next_vis_shape)
self._dual.set_op('vc_encoding', vc_encoding)

# Build the softmax classifier
if self.is_build_ll_vc() and self._label_values is not None:
self._build_ll_vc(self._label_values, vc_encoding, vc_encoding)

# Build AHA
# ------------------------------------------------------------------------

# Build the DG
dg_sparsity = 0
if self.is_build_dg():
input_next, input_next_vis_shape, dg_sparsity = self._build_dg(input_next, input_next_vis_shape)
dg_encoding = input_next
self.set_signal('dg', dg_encoding, input_next_vis_shape)

# Build the PC
if self.is_build_pc():
# Optionally degrade input to PC

Expand All @@ -554,6 +647,9 @@ def build(self, input_values, input_shape, hparams, label_values=None, name='epi
if self.is_build_ll_pc() and self.is_build_dg() and self._label_values is not None:
self._build_ll_pc(self._label_values, dg_encoding, pc_output)

if self.is_build_ll_ensemble():
self._build_ll_ensemble()

self.reset()

def get_vc(self):
Expand Down Expand Up @@ -641,8 +737,13 @@ def add_fetches(self, fetches, batch_type='training'):
# ------------------------------
# Interest Filter and other
names = []

if self._hparams.use_interest_filter:
names = ['masked_encodings', 'positional_encodings']
names.extend(['masked_encodings', 'positional_encodings'])

if self.is_build_ll_ensemble():
names.extend(['ensemble_preds', 'ensemble_accuracy', 'ensemble_accuracy_unseen'])

# Other
names.extend(['vc_encoding'])

Expand All @@ -666,8 +767,12 @@ def set_fetches(self, fetched, batch_type='training'):
# ----------------------------
# Interest Filter
names = []

if self._hparams.use_interest_filter:
names = ['masked_encodings', 'positional_encodings']
names.extend(['masked_encodings', 'positional_encodings'])

if self.is_build_ll_ensemble():
names.extend(['ensemble_preds', 'ensemble_accuracy', 'ensemble_accuracy_unseen'])

# other
names.extend(['vc_encoding'])
Expand Down
Loading

0 comments on commit f4e9374

Please sign in to comment.