diff --git a/utils.py b/utils.py index 888e3f1..b227031 100644 --- a/utils.py +++ b/utils.py @@ -98,7 +98,7 @@ def partition_data(dataset, datadir, logdir, partition, n_parties, beta=0.4): n_train = y_train.shape[0] - elif partition == "homo" or partition == "iid": + if partition == "homo" or partition == "iid": idxs = np.random.permutation(n_train) batch_idxs = np.array_split(idxs, n_parties) net_dataidx_map = {i: batch_idxs[i] for i in range(n_parties)}