Skip to content

Commit

Permalink
improve logic for special parameter learning rates
Browse files Browse the repository at this point in the history
  • Loading branch information
albertfgu committed Jun 5, 2022
1 parent edb3e8b commit 33ecdf5
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 37 deletions.
6 changes: 6 additions & 0 deletions s4/dss.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,12 @@ class DSSLayer(nn.Module):
l_max: int
decode: bool = False

lr = {
"Lambda_re": 0.1,
"Lambda_im": 0.1,
"log_step": 0.1,
}

def setup(self):
# Learned Parameters
hippo_Lambda_real_initializer, hippo_Lambda_imag_initializer, hippo_p_initializer, hippo_B_initializer = s4.hippo_initializer(self.N)
Expand Down
9 changes: 9 additions & 0 deletions s4/s4.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,15 @@ class S4Layer(nn.Module):
l_max: int
decode: bool = False

# Special parameters with multiplicative factor on lr and no weight decay (handled by main train script)
lr = {
"Lambda_re": 0.1,
"Lambda_im": 0.1,
"P": 0.1,
"B": 0.1,
"log_step": 0.1,
}

def setup(self):
# Learned Parameters (Ct is complex!)
hippo_Lambda_real_initializer, hippo_Lambda_imag_initializer, hippo_p_initializer, hippo_B_initializer = hippo_initializer(self.N)
Expand Down
83 changes: 46 additions & 37 deletions s4/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ def map_fn(nested_dict):


def create_train_state(
model_name,
model_cls,
rng,
in_dim=1,
bsz=128,
seq_len=784,
lr=1e-3,
lr_layer=None,
lr_schedule=False,
total_steps=-1,
):
Expand All @@ -72,49 +72,55 @@ def create_train_state(
params = model.init(
{"params": init_rng, "dropout": dropout_rng},
np.ones((bsz, seq_len, in_dim)),
)[
"params"
].unfreeze() # Note: Added immediate `unfreeze()` to play well w/ Optax. See below!
)
params = params["params"].unfreeze() # Note: Added immediate `unfreeze()` to play well w/ Optax. See below!


# Handle learning rates:
# - LR scheduler
# - Set custom learning rates on some SSM parameters

# Note for Debugging... this is all undocumented and so weird. The following links are helpful...
#
# > Flax "Recommended" interplay w/ Optax (this bridge needs ironing):
# https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md#multi-optimizer
#
# > But... masking doesn't work like the above example suggests!
# Root Explanation: https://github.com/deepmind/optax/issues/159
# Fix: https://github.com/deepmind/optax/discussions/167
#
# > Also... Flax FrozenDict doesn't play well with rest of Jax + Optax...
# https://github.com/deepmind/optax/issues/160#issuecomment-896460796
#
# > Solution: Use Optax.multi_transform!

# Implement LR Schedule (No change for first 30% of training, then decay w/ cubic polynomial to 0 for last 70%)
if lr_schedule:
lr = optax.cosine_onecycle_schedule(
schedule_fn = lambda lr: optax.cosine_onecycle_schedule(
peak_value=lr,
transition_steps=total_steps,
pct_start=0.1,
)
else:
schedule_fn = lambda lr: lr
# lr_layer is a dictionary from parameter name to LR multiplier
if lr_layer is None: lr_layer = {}
optimizers = {
k: optax.adam(learning_rate=schedule_fn(v*lr))
for k, v in lr_layer.items()
}
# Add default optimizer
# Note: it would be better to use a dummy key such as None that can't conflict with parameter names,
# but this causes a hard-to-trace error; it seems that the transforms keys list is being sorted inside optax.multi_transform
# which causes an error since None can't be compared to str
optimizers["__default__"] = optax.adamw(
learning_rate=schedule_fn(lr),
weight_decay=0.01,
)
tx = optax.multi_transform(optimizers, map_nested_fn(lambda k, _: k if k in lr_layer else "__default__"))
# For debugging, this would be the default transform with no scheduler or special params
# tx = optax.adamw(learning_rate=lr, weight_decay=0.01)

# # S4 uses a Fixed LR = 1e-3 with NO weight decay for the S4 Matrices, higher LR elsewhere
if "s4" in model_name or "dss" in model_name or "s4d" in model_name:
# Note for Debugging... this is all undocumented and so weird. The following links are helpful...
#
# > Flax "Recommended" interplay w/ Optax (this bridge needs ironing):
# https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md#multi-optimizer
#
# > But... masking doesn't work like the above example suggests!
# Root Explanation: https://github.com/deepmind/optax/issues/159
# Fix: https://github.com/deepmind/optax/discussions/167
#
# > Also... Flax FrozenDict doesn't play well with rest of Jax + Optax...
# https://github.com/deepmind/optax/issues/160#issuecomment-896460796
#
# > Solution: Use Optax.multi_transform!
s4_fn = map_nested_fn(
lambda k, _: "s4"
if k in ["Lambda", "Lambda_re", "Lambda_im", "p", "B", "W"]
else ("none" if k in [] else "regular")
)
tx = optax.multi_transform(
{
"none": optax.sgd(learning_rate=0.0),
"s4": optax.adam(learning_rate=min(1e-3, lr)),
"regular": optax.adamw(learning_rate=lr, weight_decay=0.01),
},
s4_fn,
)

else:
tx = optax.adamw(learning_rate=lr, weight_decay=0.01)

return train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=tx
Expand Down Expand Up @@ -310,6 +316,9 @@ def example_train(
layer_args = {} if ssm_n is None else {"N": ssm_n}
layer_args["l_max"] = seq_len if classification else seq_len - 1

# Extract custom hyperparameters from model class
lr_layer = getattr(model_cls, "lr", None)


print(f"[*] Starting `{model}` Training on `{dataset}` =>> Initializing...")

Expand All @@ -324,13 +333,13 @@ def example_train(
classification=classification,
)
state = create_train_state(
model,
model_cls,
rng,
in_dim=in_dim,
bsz=bsz,
seq_len=seq_len if classification else seq_len - 1,
lr=lr,
lr_layer=lr_layer,
lr_schedule=lr_schedule,
total_steps=len(trainloader) * epochs,
)
Expand Down

0 comments on commit 33ecdf5

Please sign in to comment.