diff --git a/methods.py b/methods.py index 6a27e6d..1d3ebda 100644 --- a/methods.py +++ b/methods.py @@ -114,7 +114,7 @@ def create_optimizers(self): self.checkpoint_variables["opt"] = self.opt def create_model(self): - self.model = models.FcnModelBase(self.num_classes, self.domain_outputs) + self.model = models.BasicModel(self.num_classes, self.domain_outputs) def create_losses(self): self.task_loss = make_loss() @@ -1028,7 +1028,7 @@ def mle_for_p_d_given_y(self): self.p_d_given_y = p_d_given_y def create_model(self): - self.model = models.FcnModelBase(self.num_classes, self.domain_outputs) + self.model = models.BasicModel(self.num_classes, self.domain_outputs) def compute_losses(self, x, task_y_true, domain_y_true, task_y_pred, domain_y_pred, fe_output, training): diff --git a/models.py b/models.py index 41e6516..43033e9 100644 --- a/models.py +++ b/models.py @@ -162,8 +162,18 @@ def __init__(self, num_classes, num_domains, **kwargs): ]) -class DannModel(FcnModelBase): - """ DANN adds a gradient reversal layer before the domain classifier """ +class BasicModel(FcnModelBase): + """ Model without adaptation (i.e. no DANN) """ + pass + + +class DannModelBase: + """ DANN adds a gradient reversal layer before the domain classifier + + Note: we don't inherit from FcnModelBase or any other specific model because + we want to support either FcnModelBase, RnnModelBase, etc. with multiple + inheritance. + """ def __init__(self, num_classes, num_domains, global_step, total_steps, **kwargs): super().__init__(num_classes, num_domains, **kwargs) @@ -175,7 +185,12 @@ def call_domain_classifier(self, fe, task, **kwargs): return self.domain_classifier(grl_output, **kwargs) -class HeterogeneousDannModel(DannModel): +class DannModel(DannModelBase, FcnModelBase): + """ Model with adaptation (i.e. with DANN) """ + pass + + +class HeterogeneousDannModel(DannModelBase, FcnModelBase): """ Heterogeneous DANN model has multiple feature extractors, very similar to DannSmoothModel() code except this has multiple FE's not multiple DC's """ @@ -219,7 +234,7 @@ def call_domain_classifier(self, fe, task, which_fe=None, **kwargs): return self.domain_classifier(grl_output, **kwargs) -class SleepModel(DannModel): +class SleepModel(DannModelBase, FcnModelBase): """ Sleep model is DANN but concatenating task classifier output (with stop gradient) with feature extractor output when fed to the domain classifier """ def __init__(self, *args, **kwargs): @@ -234,7 +249,7 @@ def call_domain_classifier(self, fe, task, **kwargs): return self.domain_classifier(domain_input, **kwargs) -class DannSmoothModel(DannModel): +class DannSmoothModel(DannModelBase, FcnModelBase): """ DANN Smooth model hs multiple domain classifiers, very similar to HeterogeneousDannModel() code except this has multiple DC's not multiple FE's """ @@ -319,7 +334,7 @@ def call(self, inputs, **kwargs): class RnnModelBase(ModelBase): """ RNN-based model - for R-DANN and VRADA """ - def __init__(self, vrada, num_classes, num_domains, **kwargs): + def __init__(self, num_classes, num_domains, vrada, **kwargs): super().__init__(**kwargs) self.num_classes = num_classes self.num_domains = num_domains @@ -348,26 +363,11 @@ def call(self, inputs, training=None, **kwargs): return task, domain, fe -class DannRnnModel(RnnModelBase): - """ DannModel but for RnnModelBase not FcnModelBase """ - def __init__(self, vrada, num_classes, num_domains, global_step, - total_steps, **kwargs): - super().__init__(vrada, num_classes, num_domains, **kwargs) - grl_schedule = DannGrlSchedule(total_steps) - self.flip_gradient = FlipGradient(global_step, grl_schedule) - - def call_domain_classifier(self, fe, task, **kwargs): - grl_output = self.flip_gradient(fe, **kwargs) - return self.domain_classifier(grl_output, **kwargs) - - -class VradaModel(DannRnnModel): +class VradaModel(DannModelBase, RnnModelBase): def __init__(self, *args, **kwargs): - vrada = True - super().__init__(vrada, *args, **kwargs) + super().__init__(*args, vrada=True, **kwargs) -class RDannModel(DannRnnModel): +class RDannModel(DannModelBase, RnnModelBase): def __init__(self, *args, **kwargs): - vrada = False - super().__init__(vrada, *args, **kwargs) + super().__init__(*args, vrada=False, **kwargs)