Skip to content

Commit

Permalink
initial heterogeneous DA
Browse files Browse the repository at this point in the history
- {dann,daws}_hda methods have 2 feature extractors
- pass is_target into eval_step() during eval to know which FE to use
- script to test how Python's multiple inheritance works
  • Loading branch information
floft committed Apr 27, 2020
1 parent aba20df commit 241a70c
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 23 deletions.
4 changes: 2 additions & 2 deletions kamiak_config.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
modelFolder="kamiak-models"
logFolder="kamiak-logs"
remotessh="kamiak" # in your .ssh/config file
remotedir="/data/vcea/garrett.wilson/codats/"
localdir="/home/garrett/Documents/Github/codats/"
remotedir="/data/vcea/garrett.wilson/heterogeneous-da/"
localdir="/home/garrett/Documents/Github/heterogeneous-da/"
137 changes: 128 additions & 9 deletions methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def post_data_eval(self, task_y_true, task_y_pred, domain_y_true,
domain_y_pred = tf.nn.softmax(domain_y_pred)
return task_y_true, task_y_pred, domain_y_true, domain_y_pred

def call_model(self, x, **kwargs):
def call_model(self, x, is_target=None, **kwargs):
return self.model(x, **kwargs)

def compute_losses(self, x, task_y_true, domain_y_true, task_y_pred,
Expand Down Expand Up @@ -277,7 +277,7 @@ def train_step(self, data_sources, data_target):
del tape
self.apply_gradients(gradients)

def eval_step(self, data):
def eval_step(self, data, is_target):
""" Evaluate a batch of source or target data, called in metrics.py.
This preprocesses the data to have x, y, domain always be lists so
we can use the same compiled tf.function code in eval_step_list() for
Expand All @@ -291,15 +291,16 @@ def eval_step(self, data):
if not isinstance(domain, list):
domain = [domain]

return self.eval_step_list((x, y, domain))
return self.eval_step_list((x, y, domain), is_target)

#@tf.function # faster not to compile
def eval_step_list(self, data):
def eval_step_list(self, data, is_target):
""" Override preparation in prepare_data_eval() """
x, task_y_true, domain_y_true = self.prepare_data_eval(data)

# Run through model
task_y_pred, domain_y_pred, fe_output = self.call_model(x, training=False)
task_y_pred, domain_y_pred, fe_output = self.call_model(x,
is_target=is_target, training=False)

# Calculate losses
losses = self.compute_losses(x, task_y_true, domain_y_true,
Expand All @@ -313,7 +314,7 @@ def eval_step_list(self, data):
return task_y_true, task_y_pred, domain_y_true, domain_y_pred, losses

#
# Domain adaptation
# Homogeneous domain adaptation
#

# The base method class performs no adaptation
Expand Down Expand Up @@ -420,7 +421,8 @@ def domain_label(self, index, is_target):
@register_method("dann_smooth")
class MethodDannSmooth(MethodDannGS):
""" MDAN Smooth method based on MethodDannGS since we want binary source = 1,
target = 0 for the domain labels """
target = 0 for the domain labels, very similar to HeterogeneousBase()
code except this has multiple DC's not multiple FE's """
def create_model(self):
self.model = models.DannSmoothModel(
self.num_classes, self.domain_outputs, # Note: domain_outputs=2
Expand Down Expand Up @@ -468,7 +470,7 @@ def prepare_data_eval(self, data):

return x, y, domain

def call_model(self, x, **kwargs):
def call_model(self, x, is_target=None, **kwargs):
""" Run each source-target pair through model separately, using the
corresponding domain classifier. """
task_y_pred = []
Expand All @@ -482,7 +484,7 @@ def call_model(self, x, **kwargs):

for i in range(len(x)):
i_task_y_pred, i_domain_y_pred, i_fe_output = \
self.model(x[i], domain_classifier=i, **kwargs)
self.model(x[i], which_dc=i, **kwargs)
task_y_pred.append(i_task_y_pred)
domain_y_pred.append(i_domain_y_pred)
fe_output.append(i_fe_output)
Expand Down Expand Up @@ -699,6 +701,123 @@ def compute_gradients(self, tape, losses):
return super().compute_gradients(tape, [total_loss, task_loss, d_loss])


#
# Heterogeneous domain adaptation
#

class HeterogeneousBase:
""" Handle multiple feature extractors, very similar to MethodDannSmooth()
code except this has multiple FE's not multiple DC's """
def __init__(self, *args, **kwargs):
# Otherwise, with multiple inheritance, the other init's aren't called.
super().__init__(*args, **kwargs)

def create_model(self):
# For now we assume all sources have the same feature space. So, we need
# two feature extractors -- one for source and one for target.
num_feature_extractors = 2

self.model = models.HeterogeneousDannModel(
self.num_classes, self.domain_outputs,
self.global_step, self.total_steps,
num_feature_extractors)

def prepare_data(self, data_sources, data_target):
""" Prepare a batch of all source(s) data and target data separately,
so we run through the source/target feature extractors separately """
assert data_target is not None, \
"cannot run Heterogeneous DA without target"
x_a, y_a, domain_a = data_sources
x_b, y_b, domain_b = data_target

# Note: x_b, etc. isn't a list so doesn't need concat
x = [tf.concat(x_a, axis=0), x_b]
task_y_true = [tf.concat(y_a, axis=0), y_b]
domain_y_true = [tf.concat(domain_a, axis=0), domain_b]

return x, task_y_true, domain_y_true

def prepare_data_eval(self, data):
""" Don't concatenate elements of the list like in the base class since
we want to handle the source/target domains separately, to pass to the
right feature extractors."""
x, y, domain = data

assert isinstance(x, list), \
"Must pass x=[...] even if only one domain for tf.function consistency"
assert isinstance(y, list), \
"Must pass y=[...] even if only one domain for tf.function consistency"
assert isinstance(domain, list), \
"Must pass domain=[...] even if only one domain for tf.function consistency"

return x, y, domain

def call_model(self, x, is_target=None, training=None, **kwargs):
""" Run each source/target through appropriate feature extractor.
If is_target=None, then this is training. If is_target=True, then this
is evaluation of target data, and if is_target=False, then this is
evaluation of source data. """
task_y_pred = []
domain_y_pred = []
fe_output = []

# Should be 2 for source/target or 1 during evaluation for just one
assert (training is True and is_target is None and len(x) == 2) \
or (training is False and (is_target is True or is_target is False)
and len(x) == 1), \
"is_target=None and len(x)=2 during training but " \
"is_target=True/False and len(x)=1 during testing"

for i in range(len(x)):
# At test time, we set source/target explicitly -- use appropriate
# feature extractor: sources = 0, target = 1 (see ordering in
# prepare_data)
if is_target is not None:
which_fe = 1 if is_target else 0
else:
which_fe = i

i_task_y_pred, i_domain_y_pred, i_fe_output = \
self.model(x[i], which_fe=which_fe, training=training, **kwargs)
task_y_pred.append(i_task_y_pred)
domain_y_pred.append(i_domain_y_pred)
fe_output.append(i_fe_output)

return task_y_pred, domain_y_pred, fe_output

def compute_losses(self, x, task_y_true, domain_y_true, task_y_pred,
domain_y_pred, fe_output, training):
""" Concatenate, then parent class's loss (e.g. DANN or DA-WS) """
x = tf.concat(x, axis=0)
task_y_true = tf.concat(task_y_true, axis=0)
domain_y_true = tf.concat(domain_y_true, axis=0)
task_y_pred = tf.concat(task_y_pred, axis=0)
domain_y_pred = tf.concat(domain_y_pred, axis=0)
fe_output = tf.concat(fe_output, axis=0)
super().compute_losses(x, task_y_true, domain_y_true, task_y_pred,
domain_y_pred, fe_output, training)

def post_data_eval(self, task_y_true, task_y_pred, domain_y_true,
domain_y_pred):
""" Concatenate, then parent class's post_data_eval """
task_y_true = tf.concat(task_y_true, axis=0)
task_y_pred = tf.concat(task_y_pred, axis=0)
domain_y_true = tf.concat(domain_y_true, axis=0)
domain_y_pred = tf.concat(domain_y_pred, axis=0)
return super().post_data_eval(task_y_true, task_y_pred, domain_y_true,
domain_y_pred)


@register_method("dann_hda")
class MethodHeterogeneousDann(HeterogeneousBase, MethodDann):
pass


@register_method("daws_hda")
class MethodHeterogeneousDaws(HeterogeneousBase, MethodDaws):
pass


#
# Domain generalization
#
Expand Down
6 changes: 5 additions & 1 deletion metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ def _run_single_batch(self, data, dataset_name, domain_name):
assert dataset_name in self.datasets, "unknown dataset "+str(dataset_name)
assert domain_name in self.domains, "unknown domain "+str(domain_name)

results = self.method.eval_step(data)
is_target = domain_name == "target"
results = self.method.eval_step(data, is_target)

classifier = "task" # Which classifier's task_y_pred are we looking at?
self._process_batch(results, classifier, domain_name, dataset_name)
Expand Down Expand Up @@ -409,6 +410,9 @@ def plots(self, global_step):
first_time = step == 1

# Generate plots
#
# TODO probably broken with heterogeneous DA models since they have more
# than one feature extractor, so this will error
t = time.time()
plots = generate_plots(data_a, data_b,
self.method.model.feature_extractor, first_time)
Expand Down
75 changes: 64 additions & 11 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,13 @@ class ModelBase(tf.keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@property
def trainable_variables_fe(self):
return self.feature_extractor.trainable_variables

@property
def trainable_variables_task(self):
return self.feature_extractor.trainable_variables \
return self.trainable_variables_fe \
+ self.task_classifier.trainable_variables

@property
Expand All @@ -75,7 +79,7 @@ def trainable_variables_domain(self):

@property
def trainable_variables_task_domain(self):
return self.feature_extractor.trainable_variables \
return self.trainable_variables_fe \
+ self.task_classifier.trainable_variables \
+ self.trainable_variables_domain

Expand Down Expand Up @@ -171,6 +175,52 @@ def call_domain_classifier(self, fe, task, **kwargs):
return self.domain_classifier(grl_output, **kwargs)


class HeterogeneousDannModel(DannModel):
""" Heterogeneous DANN model has multiple feature extractors,
very similar to DannSmoothModel() code except this has multiple FE's
not multiple DC's """
def __init__(self, *args, num_feature_extractors, **kwargs):
super().__init__(*args, **kwargs)

# Requires multiple feature extractors
new_feature_extractor = [self.feature_extractor]

# Start at 1 since we already have one
for i in range(1, num_feature_extractors):
new_feature_extractor.append(
tf.keras.models.clone_model(self.feature_extractor))

self.feature_extractor = new_feature_extractor

@property
def trainable_variables_fe(self):
# We have multiple feature extractors, so get all variables
fe_vars = []

for fe in self.feature_extractor:
fe_vars += fe.trainable_variables

return fe_vars

def call_feature_extractor(self, inputs, which_fe=None, **kwargs):
# Override so we don't pass which_fe argument to model
return self.feature_extractor(inputs)

assert which_fe is not None, \
"must specify which feature extractor to use"
return self.feature_extractor[which_fe](inputs, **kwargs)

def call_task_classifier(self, fe, which_fe=None, **kwargs):
# Override so we don't pass which_fe argument to model
return self.task_classifier(fe, **kwargs)

def call_domain_classifier(self, fe, task, which_fe=None, **kwargs):
# Override so we don't pass which_fe argument to model
# Copy of the DANN version only with above arg change
grl_output = self.flip_gradient(fe, **kwargs)
return self.domain_classifier(grl_output, **kwargs)


class SleepModel(DannModel):
""" Sleep model is DANN but concatenating task classifier output (with stop
gradient) with feature extractor output when fed to the domain classifier """
Expand All @@ -187,6 +237,9 @@ def call_domain_classifier(self, fe, task, **kwargs):


class DannSmoothModel(DannModel):
""" DANN Smooth model hs multiple domain classifiers,
very similar to HeterogeneousDannModel() code except this has multiple DC's
not multiple FE's """
def __init__(self, *args, num_domain_classifiers, **kwargs):
# For MDAN Smooth, it's binary classification but we have a separate
# discriminator for each source-target pair.
Expand All @@ -213,20 +266,20 @@ def trainable_variables_domain(self):

return domain_vars

def call_feature_extractor(self, inputs, **kwargs):
# Override so we don't pass domain_classifier argument to model
return self.feature_extractor(inputs)
def call_feature_extractor(self, inputs, which_dc=None, **kwargs):
# Override so we don't pass which_dc argument to model
return self.feature_extractor(inputs, **kwargs)

def call_task_classifier(self, fe, **kwargs):
# Override so we don't pass domain_classifier argument to model
return self.task_classifier(fe)
def call_task_classifier(self, fe, which_dc=None, **kwargs):
# Override so we don't pass which_dc argument to model
return self.task_classifier(fe, **kwargs)

def call_domain_classifier(self, fe, task, domain_classifier=None, **kwargs):
assert domain_classifier is not None, \
def call_domain_classifier(self, fe, task, which_dc=None, **kwargs):
assert which_dc is not None, \
"must specify which domain classifier to use with method Smooth"
grl_output = self.flip_gradient(fe, **kwargs)
# 0 = source domain 1 with target, 1 = source domain 2 with target, etc.
return self.domain_classifier[domain_classifier](grl_output, **kwargs)
return self.domain_classifier[which_dc](grl_output, **kwargs)


class VradaFeatureExtractor(tf.keras.Model):
Expand Down
48 changes: 48 additions & 0 deletions multiple_inheritance_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python3
"""
Test heterogeneous multiple inheritance
"""


class MethodBase:
def __init__(self):
print("MethodBase")


class MethodDann(MethodBase):
def __init__(self):
super().__init__()
print("MethodDann")


class MethodDaws(MethodBase):
def __init__(self):
super().__init__()
print("MethodDaws")


class HeterogeneousBase:
def __init__(self):
super().__init__()
print("HeterogeneousBase")


class HeterogeneousDann(HeterogeneousBase, MethodDann):
pass
# def __init__(self):
# super().__init__()
# print("HeterogeneousDann")
# print(HeterogeneousDann.__mro__)


class HeterogeneousDaws(HeterogeneousBase, MethodDaws):
pass
# def __init__(self):
# super().__init__()
# print("HeterogeneousDaws")
# print(HeterogeneousDaws.__mro__)


if __name__ == "__main__":
a = HeterogeneousDann()
# b = HeterogeneousDaws()

0 comments on commit 241a70c

Please sign in to comment.