From 241a70c53e251ea766b7223fb7c27c3b494aecfe Mon Sep 17 00:00:00 2001 From: Garrett Wilson Date: Mon, 27 Apr 2020 16:44:25 -0700 Subject: [PATCH] initial heterogeneous DA - {dann,daws}_hda methods have 2 feature extractors - pass is_target into eval_step() during eval to know which FE to use - script to test how Python's multiple inheritance works --- kamiak_config.sh | 4 +- methods.py | 137 +++++++++++++++++++++++++++++++--- metrics.py | 6 +- models.py | 75 ++++++++++++++++--- multiple_inheritance_check.py | 48 ++++++++++++ 5 files changed, 247 insertions(+), 23 deletions(-) create mode 100755 multiple_inheritance_check.py diff --git a/kamiak_config.sh b/kamiak_config.sh index 7939932..a22332b 100644 --- a/kamiak_config.sh +++ b/kamiak_config.sh @@ -5,5 +5,5 @@ modelFolder="kamiak-models" logFolder="kamiak-logs" remotessh="kamiak" # in your .ssh/config file -remotedir="/data/vcea/garrett.wilson/codats/" -localdir="/home/garrett/Documents/Github/codats/" +remotedir="/data/vcea/garrett.wilson/heterogeneous-da/" +localdir="/home/garrett/Documents/Github/heterogeneous-da/" diff --git a/methods.py b/methods.py index e46151f..9494fe8 100644 --- a/methods.py +++ b/methods.py @@ -244,7 +244,7 @@ def post_data_eval(self, task_y_true, task_y_pred, domain_y_true, domain_y_pred = tf.nn.softmax(domain_y_pred) return task_y_true, task_y_pred, domain_y_true, domain_y_pred - def call_model(self, x, **kwargs): + def call_model(self, x, is_target=None, **kwargs): return self.model(x, **kwargs) def compute_losses(self, x, task_y_true, domain_y_true, task_y_pred, @@ -277,7 +277,7 @@ def train_step(self, data_sources, data_target): del tape self.apply_gradients(gradients) - def eval_step(self, data): + def eval_step(self, data, is_target): """ Evaluate a batch of source or target data, called in metrics.py. This preprocesses the data to have x, y, domain always be lists so we can use the same compiled tf.function code in eval_step_list() for @@ -291,15 +291,16 @@ def eval_step(self, data): if not isinstance(domain, list): domain = [domain] - return self.eval_step_list((x, y, domain)) + return self.eval_step_list((x, y, domain), is_target) #@tf.function # faster not to compile - def eval_step_list(self, data): + def eval_step_list(self, data, is_target): """ Override preparation in prepare_data_eval() """ x, task_y_true, domain_y_true = self.prepare_data_eval(data) # Run through model - task_y_pred, domain_y_pred, fe_output = self.call_model(x, training=False) + task_y_pred, domain_y_pred, fe_output = self.call_model(x, + is_target=is_target, training=False) # Calculate losses losses = self.compute_losses(x, task_y_true, domain_y_true, @@ -313,7 +314,7 @@ def eval_step_list(self, data): return task_y_true, task_y_pred, domain_y_true, domain_y_pred, losses # -# Domain adaptation +# Homogeneous domain adaptation # # The base method class performs no adaptation @@ -420,7 +421,8 @@ def domain_label(self, index, is_target): @register_method("dann_smooth") class MethodDannSmooth(MethodDannGS): """ MDAN Smooth method based on MethodDannGS since we want binary source = 1, - target = 0 for the domain labels """ + target = 0 for the domain labels, very similar to HeterogeneousBase() + code except this has multiple DC's not multiple FE's """ def create_model(self): self.model = models.DannSmoothModel( self.num_classes, self.domain_outputs, # Note: domain_outputs=2 @@ -468,7 +470,7 @@ def prepare_data_eval(self, data): return x, y, domain - def call_model(self, x, **kwargs): + def call_model(self, x, is_target=None, **kwargs): """ Run each source-target pair through model separately, using the corresponding domain classifier. """ task_y_pred = [] @@ -482,7 +484,7 @@ def call_model(self, x, **kwargs): for i in range(len(x)): i_task_y_pred, i_domain_y_pred, i_fe_output = \ - self.model(x[i], domain_classifier=i, **kwargs) + self.model(x[i], which_dc=i, **kwargs) task_y_pred.append(i_task_y_pred) domain_y_pred.append(i_domain_y_pred) fe_output.append(i_fe_output) @@ -699,6 +701,123 @@ def compute_gradients(self, tape, losses): return super().compute_gradients(tape, [total_loss, task_loss, d_loss]) +# +# Heterogeneous domain adaptation +# + +class HeterogeneousBase: + """ Handle multiple feature extractors, very similar to MethodDannSmooth() + code except this has multiple FE's not multiple DC's """ + def __init__(self, *args, **kwargs): + # Otherwise, with multiple inheritance, the other init's aren't called. + super().__init__(*args, **kwargs) + + def create_model(self): + # For now we assume all sources have the same feature space. So, we need + # two feature extractors -- one for source and one for target. + num_feature_extractors = 2 + + self.model = models.HeterogeneousDannModel( + self.num_classes, self.domain_outputs, + self.global_step, self.total_steps, + num_feature_extractors) + + def prepare_data(self, data_sources, data_target): + """ Prepare a batch of all source(s) data and target data separately, + so we run through the source/target feature extractors separately """ + assert data_target is not None, \ + "cannot run Heterogeneous DA without target" + x_a, y_a, domain_a = data_sources + x_b, y_b, domain_b = data_target + + # Note: x_b, etc. isn't a list so doesn't need concat + x = [tf.concat(x_a, axis=0), x_b] + task_y_true = [tf.concat(y_a, axis=0), y_b] + domain_y_true = [tf.concat(domain_a, axis=0), domain_b] + + return x, task_y_true, domain_y_true + + def prepare_data_eval(self, data): + """ Don't concatenate elements of the list like in the base class since + we want to handle the source/target domains separately, to pass to the + right feature extractors.""" + x, y, domain = data + + assert isinstance(x, list), \ + "Must pass x=[...] even if only one domain for tf.function consistency" + assert isinstance(y, list), \ + "Must pass y=[...] even if only one domain for tf.function consistency" + assert isinstance(domain, list), \ + "Must pass domain=[...] even if only one domain for tf.function consistency" + + return x, y, domain + + def call_model(self, x, is_target=None, training=None, **kwargs): + """ Run each source/target through appropriate feature extractor. + If is_target=None, then this is training. If is_target=True, then this + is evaluation of target data, and if is_target=False, then this is + evaluation of source data. """ + task_y_pred = [] + domain_y_pred = [] + fe_output = [] + + # Should be 2 for source/target or 1 during evaluation for just one + assert (training is True and is_target is None and len(x) == 2) \ + or (training is False and (is_target is True or is_target is False) + and len(x) == 1), \ + "is_target=None and len(x)=2 during training but " \ + "is_target=True/False and len(x)=1 during testing" + + for i in range(len(x)): + # At test time, we set source/target explicitly -- use appropriate + # feature extractor: sources = 0, target = 1 (see ordering in + # prepare_data) + if is_target is not None: + which_fe = 1 if is_target else 0 + else: + which_fe = i + + i_task_y_pred, i_domain_y_pred, i_fe_output = \ + self.model(x[i], which_fe=which_fe, training=training, **kwargs) + task_y_pred.append(i_task_y_pred) + domain_y_pred.append(i_domain_y_pred) + fe_output.append(i_fe_output) + + return task_y_pred, domain_y_pred, fe_output + + def compute_losses(self, x, task_y_true, domain_y_true, task_y_pred, + domain_y_pred, fe_output, training): + """ Concatenate, then parent class's loss (e.g. DANN or DA-WS) """ + x = tf.concat(x, axis=0) + task_y_true = tf.concat(task_y_true, axis=0) + domain_y_true = tf.concat(domain_y_true, axis=0) + task_y_pred = tf.concat(task_y_pred, axis=0) + domain_y_pred = tf.concat(domain_y_pred, axis=0) + fe_output = tf.concat(fe_output, axis=0) + super().compute_losses(x, task_y_true, domain_y_true, task_y_pred, + domain_y_pred, fe_output, training) + + def post_data_eval(self, task_y_true, task_y_pred, domain_y_true, + domain_y_pred): + """ Concatenate, then parent class's post_data_eval """ + task_y_true = tf.concat(task_y_true, axis=0) + task_y_pred = tf.concat(task_y_pred, axis=0) + domain_y_true = tf.concat(domain_y_true, axis=0) + domain_y_pred = tf.concat(domain_y_pred, axis=0) + return super().post_data_eval(task_y_true, task_y_pred, domain_y_true, + domain_y_pred) + + +@register_method("dann_hda") +class MethodHeterogeneousDann(HeterogeneousBase, MethodDann): + pass + + +@register_method("daws_hda") +class MethodHeterogeneousDaws(HeterogeneousBase, MethodDaws): + pass + + # # Domain generalization # diff --git a/metrics.py b/metrics.py index 2732cd8..c362ba7 100644 --- a/metrics.py +++ b/metrics.py @@ -308,7 +308,8 @@ def _run_single_batch(self, data, dataset_name, domain_name): assert dataset_name in self.datasets, "unknown dataset "+str(dataset_name) assert domain_name in self.domains, "unknown domain "+str(domain_name) - results = self.method.eval_step(data) + is_target = domain_name == "target" + results = self.method.eval_step(data, is_target) classifier = "task" # Which classifier's task_y_pred are we looking at? self._process_batch(results, classifier, domain_name, dataset_name) @@ -409,6 +410,9 @@ def plots(self, global_step): first_time = step == 1 # Generate plots + # + # TODO probably broken with heterogeneous DA models since they have more + # than one feature extractor, so this will error t = time.time() plots = generate_plots(data_a, data_b, self.method.model.feature_extractor, first_time) diff --git a/models.py b/models.py index bc0b3bc..0688b14 100644 --- a/models.py +++ b/models.py @@ -64,9 +64,13 @@ class ModelBase(tf.keras.Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + @property + def trainable_variables_fe(self): + return self.feature_extractor.trainable_variables + @property def trainable_variables_task(self): - return self.feature_extractor.trainable_variables \ + return self.trainable_variables_fe \ + self.task_classifier.trainable_variables @property @@ -75,7 +79,7 @@ def trainable_variables_domain(self): @property def trainable_variables_task_domain(self): - return self.feature_extractor.trainable_variables \ + return self.trainable_variables_fe \ + self.task_classifier.trainable_variables \ + self.trainable_variables_domain @@ -171,6 +175,52 @@ def call_domain_classifier(self, fe, task, **kwargs): return self.domain_classifier(grl_output, **kwargs) +class HeterogeneousDannModel(DannModel): + """ Heterogeneous DANN model has multiple feature extractors, + very similar to DannSmoothModel() code except this has multiple FE's + not multiple DC's """ + def __init__(self, *args, num_feature_extractors, **kwargs): + super().__init__(*args, **kwargs) + + # Requires multiple feature extractors + new_feature_extractor = [self.feature_extractor] + + # Start at 1 since we already have one + for i in range(1, num_feature_extractors): + new_feature_extractor.append( + tf.keras.models.clone_model(self.feature_extractor)) + + self.feature_extractor = new_feature_extractor + + @property + def trainable_variables_fe(self): + # We have multiple feature extractors, so get all variables + fe_vars = [] + + for fe in self.feature_extractor: + fe_vars += fe.trainable_variables + + return fe_vars + + def call_feature_extractor(self, inputs, which_fe=None, **kwargs): + # Override so we don't pass which_fe argument to model + return self.feature_extractor(inputs) + + assert which_fe is not None, \ + "must specify which feature extractor to use" + return self.feature_extractor[which_fe](inputs, **kwargs) + + def call_task_classifier(self, fe, which_fe=None, **kwargs): + # Override so we don't pass which_fe argument to model + return self.task_classifier(fe, **kwargs) + + def call_domain_classifier(self, fe, task, which_fe=None, **kwargs): + # Override so we don't pass which_fe argument to model + # Copy of the DANN version only with above arg change + grl_output = self.flip_gradient(fe, **kwargs) + return self.domain_classifier(grl_output, **kwargs) + + class SleepModel(DannModel): """ Sleep model is DANN but concatenating task classifier output (with stop gradient) with feature extractor output when fed to the domain classifier """ @@ -187,6 +237,9 @@ def call_domain_classifier(self, fe, task, **kwargs): class DannSmoothModel(DannModel): + """ DANN Smooth model hs multiple domain classifiers, + very similar to HeterogeneousDannModel() code except this has multiple DC's + not multiple FE's """ def __init__(self, *args, num_domain_classifiers, **kwargs): # For MDAN Smooth, it's binary classification but we have a separate # discriminator for each source-target pair. @@ -213,20 +266,20 @@ def trainable_variables_domain(self): return domain_vars - def call_feature_extractor(self, inputs, **kwargs): - # Override so we don't pass domain_classifier argument to model - return self.feature_extractor(inputs) + def call_feature_extractor(self, inputs, which_dc=None, **kwargs): + # Override so we don't pass which_dc argument to model + return self.feature_extractor(inputs, **kwargs) - def call_task_classifier(self, fe, **kwargs): - # Override so we don't pass domain_classifier argument to model - return self.task_classifier(fe) + def call_task_classifier(self, fe, which_dc=None, **kwargs): + # Override so we don't pass which_dc argument to model + return self.task_classifier(fe, **kwargs) - def call_domain_classifier(self, fe, task, domain_classifier=None, **kwargs): - assert domain_classifier is not None, \ + def call_domain_classifier(self, fe, task, which_dc=None, **kwargs): + assert which_dc is not None, \ "must specify which domain classifier to use with method Smooth" grl_output = self.flip_gradient(fe, **kwargs) # 0 = source domain 1 with target, 1 = source domain 2 with target, etc. - return self.domain_classifier[domain_classifier](grl_output, **kwargs) + return self.domain_classifier[which_dc](grl_output, **kwargs) class VradaFeatureExtractor(tf.keras.Model): diff --git a/multiple_inheritance_check.py b/multiple_inheritance_check.py new file mode 100755 index 0000000..c59badc --- /dev/null +++ b/multiple_inheritance_check.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +""" +Test heterogeneous multiple inheritance +""" + + +class MethodBase: + def __init__(self): + print("MethodBase") + + +class MethodDann(MethodBase): + def __init__(self): + super().__init__() + print("MethodDann") + + +class MethodDaws(MethodBase): + def __init__(self): + super().__init__() + print("MethodDaws") + + +class HeterogeneousBase: + def __init__(self): + super().__init__() + print("HeterogeneousBase") + + +class HeterogeneousDann(HeterogeneousBase, MethodDann): + pass + # def __init__(self): + # super().__init__() + # print("HeterogeneousDann") + # print(HeterogeneousDann.__mro__) + + +class HeterogeneousDaws(HeterogeneousBase, MethodDaws): + pass + # def __init__(self): + # super().__init__() + # print("HeterogeneousDaws") + # print(HeterogeneousDaws.__mro__) + + +if __name__ == "__main__": + a = HeterogeneousDann() + # b = HeterogeneousDaws()