Skip to content

Commit 354ebb4

Browse files
committed
Fix Trust NCG import
1 parent cfa54f8 commit 354ebb4

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

transfer_model/optimizers/optim_factory.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch
2222
import torch.optim as optim
2323
from loguru import logger
24-
from torchtrustncg import TrustRegionNewtonCG
24+
from torchtrustncg import TrustRegion
2525

2626
Tensor = NewType('Tensor', torch.Tensor)
2727

@@ -46,7 +46,7 @@ def build_optimizer(parameters: List[Tensor],
4646
optimizer = optim.LBFGS(parameters, **optim_cfg.get('lbfgs', {}))
4747
create_graph = False
4848
elif optim_type == 'trust_ncg' or optim_type == 'trust-ncg':
49-
optimizer = TrustRegionNewtonCG(
49+
optimizer = TrustRegion(
5050
parameters, **optim_cfg.get('trust_ncg', {}))
5151
create_graph = True
5252
elif optim_type == 'rmsprop':

0 commit comments

Comments
 (0)