Skip to content

Commit

Permalink
In the second stage (swa) with different weighting, the original code…
Browse files Browse the repository at this point in the history
… used different loss function as compared with that chosed in the first stage. I modified the default loss function in the second stage as UnversalLoss, which may just be right for my purpose of studies. I would suggest to define a separate function and invoke it in these two stages to keep consistence.
  • Loading branch information
Elysiron authored Jul 6, 2024
1 parent 95ef29e commit fae8de4
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,12 +620,14 @@ def run(args: argparse.Namespace) -> None:
f"Using stochastic weight averaging (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}"
)
else:
loss_fn_energy = modules.WeightedEnergyForcesLoss(
loss_fn_energy = modules.UniversalLoss(
energy_weight=args.swa_energy_weight,
forces_weight=args.swa_forces_weight,
stress_weight=args.swa_stress_weight,
huber_delta=args.huber_delta,
)
logging.info(
f"Using stochastic weight averaging (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}"
f"Using stochastic weight averaging (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}"
)
swa = tools.SWAContainer(
model=AveragedModel(model),
Expand Down

0 comments on commit fae8de4

Please sign in to comment.