Skip to content

Commit

Permalink
reduce duplicate model code, better name
Browse files Browse the repository at this point in the history
- Multiple inheritance so we don't need DannRnnModel which only changed
  what it inherited from.
- Create BasicModel not FcnModelBase in methods.py
  • Loading branch information
floft committed Apr 29, 2020
1 parent 7370b74 commit 24465b7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 27 deletions.
4 changes: 2 additions & 2 deletions methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def create_optimizers(self):
self.checkpoint_variables["opt"] = self.opt

def create_model(self):
self.model = models.FcnModelBase(self.num_classes, self.domain_outputs)
self.model = models.BasicModel(self.num_classes, self.domain_outputs)

def create_losses(self):
self.task_loss = make_loss()
Expand Down Expand Up @@ -1028,7 +1028,7 @@ def mle_for_p_d_given_y(self):
self.p_d_given_y = p_d_given_y

def create_model(self):
self.model = models.FcnModelBase(self.num_classes, self.domain_outputs)
self.model = models.BasicModel(self.num_classes, self.domain_outputs)

def compute_losses(self, x, task_y_true, domain_y_true, task_y_pred,
domain_y_pred, fe_output, training):
Expand Down
50 changes: 25 additions & 25 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,18 @@ def __init__(self, num_classes, num_domains, **kwargs):
])


class DannModel(FcnModelBase):
""" DANN adds a gradient reversal layer before the domain classifier """
class BasicModel(FcnModelBase):
""" Model without adaptation (i.e. no DANN) """
pass


class DannModelBase:
""" DANN adds a gradient reversal layer before the domain classifier
Note: we don't inherit from FcnModelBase or any other specific model because
we want to support either FcnModelBase, RnnModelBase, etc. with multiple
inheritance.
"""
def __init__(self, num_classes, num_domains, global_step,
total_steps, **kwargs):
super().__init__(num_classes, num_domains, **kwargs)
Expand All @@ -175,7 +185,12 @@ def call_domain_classifier(self, fe, task, **kwargs):
return self.domain_classifier(grl_output, **kwargs)


class HeterogeneousDannModel(DannModel):
class DannModel(DannModelBase, FcnModelBase):
""" Model with adaptation (i.e. with DANN) """
pass


class HeterogeneousDannModel(DannModelBase, FcnModelBase):
""" Heterogeneous DANN model has multiple feature extractors,
very similar to DannSmoothModel() code except this has multiple FE's
not multiple DC's """
Expand Down Expand Up @@ -219,7 +234,7 @@ def call_domain_classifier(self, fe, task, which_fe=None, **kwargs):
return self.domain_classifier(grl_output, **kwargs)


class SleepModel(DannModel):
class SleepModel(DannModelBase, FcnModelBase):
""" Sleep model is DANN but concatenating task classifier output (with stop
gradient) with feature extractor output when fed to the domain classifier """
def __init__(self, *args, **kwargs):
Expand All @@ -234,7 +249,7 @@ def call_domain_classifier(self, fe, task, **kwargs):
return self.domain_classifier(domain_input, **kwargs)


class DannSmoothModel(DannModel):
class DannSmoothModel(DannModelBase, FcnModelBase):
""" DANN Smooth model hs multiple domain classifiers,
very similar to HeterogeneousDannModel() code except this has multiple DC's
not multiple FE's """
Expand Down Expand Up @@ -319,7 +334,7 @@ def call(self, inputs, **kwargs):

class RnnModelBase(ModelBase):
""" RNN-based model - for R-DANN and VRADA """
def __init__(self, vrada, num_classes, num_domains, **kwargs):
def __init__(self, num_classes, num_domains, vrada, **kwargs):
super().__init__(**kwargs)
self.num_classes = num_classes
self.num_domains = num_domains
Expand Down Expand Up @@ -348,26 +363,11 @@ def call(self, inputs, training=None, **kwargs):
return task, domain, fe


class DannRnnModel(RnnModelBase):
""" DannModel but for RnnModelBase not FcnModelBase """
def __init__(self, vrada, num_classes, num_domains, global_step,
total_steps, **kwargs):
super().__init__(vrada, num_classes, num_domains, **kwargs)
grl_schedule = DannGrlSchedule(total_steps)
self.flip_gradient = FlipGradient(global_step, grl_schedule)

def call_domain_classifier(self, fe, task, **kwargs):
grl_output = self.flip_gradient(fe, **kwargs)
return self.domain_classifier(grl_output, **kwargs)


class VradaModel(DannRnnModel):
class VradaModel(DannModelBase, RnnModelBase):
def __init__(self, *args, **kwargs):
vrada = True
super().__init__(vrada, *args, **kwargs)
super().__init__(*args, vrada=True, **kwargs)


class RDannModel(DannRnnModel):
class RDannModel(DannModelBase, RnnModelBase):
def __init__(self, *args, **kwargs):
vrada = False
super().__init__(vrada, *args, **kwargs)
super().__init__(*args, vrada=False, **kwargs)

0 comments on commit 24465b7

Please sign in to comment.