Skip to content

Commit

Permalink
Fleet: deal with special case: strategy is None (PaddlePaddle#20359)
Browse files Browse the repository at this point in the history
* special case: strategy is None
  • Loading branch information
mapingshuo authored Oct 15, 2019
1 parent 1d82025 commit f55d1c6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/paddle/fluid/incubate/fleet/collective/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class CollectiveOptimizer(DistributedOptimizer):

def __init__(self, optimizer, strategy=DistributedStrategy()):
super(CollectiveOptimizer, self).__init__(optimizer, strategy)
if strategy.forward_recompute:
if strategy is not None and strategy.forward_recompute:
self.forward_recompute = True
self.recompute_checkpoints = strategy.recompute_checkpoints
else:
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/fluid/tests/unittests/test_fleet_api_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from paddle.fluid.incubate.fleet.base.role_maker import Role
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import TranspilerOptimizer
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer


class DistributeTranspilerConfigTest(unittest.TestCase):
Expand Down Expand Up @@ -204,5 +205,11 @@ def testRoleMaker(self):
) # current_id must be less than len(worker_endpoints)


class CollectiveOptimizerTest(unittest.TestCase):
def test_ds_as_None(self):
optimizer = fluid.optimizer.AdamOptimizer()
dist_optimizer = CollectiveOptimizer(optimizer, strategy=None)


if __name__ == '__main__':
unittest.main()

0 comments on commit f55d1c6

Please sign in to comment.