diff --git a/tools/helper/dist_helper.py b/tools/helper/dist_helper.py index 5307aca..4c74874 100644 --- a/tools/helper/dist_helper.py +++ b/tools/helper/dist_helper.py @@ -56,6 +56,7 @@ def all_gather(data): Returns: list[data]: list of data gathered from each rank """ + DistHelper.synchronize() world_size = DistHelper.get_world_size() if world_size == 1: return [data]