From 22fb2b0c25e6f60f95e6ef5d8588e586bdbc06f7 Mon Sep 17 00:00:00 2001
From: Glenn Jocher <glenn.jocher@ultralytics.com>
Date: Fri, 26 Jun 2020 18:56:13 -0700
Subject: [PATCH] refactor dataloader

---
 test.py           | 22 +++-------------------
 train.py          | 35 ++++++-----------------------------
 utils/datasets.py | 22 +++++++++++++++++++++-
 3 files changed, 30 insertions(+), 49 deletions(-)

diff --git a/test.py b/test.py
index 8b94f54e97ee..c0bda5fbd4d3 100644
--- a/test.py
+++ b/test.py
@@ -1,8 +1,6 @@
 import argparse
 import json
 
-from torch.utils.data import DataLoader
-
 from utils import google_utils
 from utils.datasets import *
 from utils.utils import *
@@ -56,30 +54,16 @@ def test(data,
         data = yaml.load(f, Loader=yaml.FullLoader)  # model dict
     nc = 1 if single_cls else int(data['nc'])  # number of classes
     iouv = torch.linspace(0.5, 0.95, 10).to(device)  # iou vector for mAP@0.5:0.95
-    # iouv = iouv[0].view(1)  # comment for mAP@0.5:0.95
     niou = iouv.numel()
 
     # Dataloader
     if dataloader is None:  # not training
+        merge = opt.merge  # use Merge NMS
         img = torch.zeros((1, 3, imgsz, imgsz), device=device)  # init img
         _ = model(img.half() if half else img) if device.type != 'cpu' else None  # run once
-
-        merge = opt.merge  # use Merge NMS
         path = data['test'] if opt.task == 'test' else data['val']  # path to val/test images
-        dataset = LoadImagesAndLabels(path,
-                                      imgsz,
-                                      batch_size,
-                                      rect=True,  # rectangular inference
-                                      single_cls=opt.single_cls,  # single class mode
-                                      stride=int(max(model.stride)),  # model stride
-                                      pad=0.5)  # padding
-        batch_size = min(batch_size, len(dataset))
-        nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
-        dataloader = DataLoader(dataset,
-                                batch_size=batch_size,
-                                num_workers=nw,
-                                pin_memory=True,
-                                collate_fn=dataset.collate_fn)
+        dataloader = create_dataloader(path, imgsz, batch_size, int(max(model.stride)), opt,
+                                       hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0]
 
     seen = 0
     names = model.names if hasattr(model, 'names') else model.module.names
diff --git a/train.py b/train.py
index 94139a033918..4238713fbc68 100644
--- a/train.py
+++ b/train.py
@@ -155,38 +155,15 @@ def train(hyp):
         model = torch.nn.parallel.DistributedDataParallel(model)
         # pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
 
-    # Dataset
-    dataset = LoadImagesAndLabels(train_path, imgsz, batch_size,
-                                  augment=True,
-                                  hyp=hyp,  # augmentation hyperparameters
-                                  rect=opt.rect,  # rectangular training
-                                  cache_images=opt.cache_images,
-                                  single_cls=opt.single_cls,
-                                  stride=gs)
+    # Trainloader
+    dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
+                                            hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
     mlc = np.concatenate(dataset.labels, 0)[:, 0].max()  # max label class
     assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
 
-    # Dataloader
-    batch_size = min(batch_size, len(dataset))
-    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
-    dataloader = torch.utils.data.DataLoader(dataset,
-                                             batch_size=batch_size,
-                                             num_workers=nw,
-                                             shuffle=not opt.rect,  # Shuffle=True unless rectangular training is used
-                                             pin_memory=True,
-                                             collate_fn=dataset.collate_fn)
-
     # Testloader
-    testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, imgsz_test, batch_size,
-                                                                 hyp=hyp,
-                                                                 rect=True,
-                                                                 cache_images=opt.cache_images,
-                                                                 single_cls=opt.single_cls,
-                                                                 stride=gs),
-                                             batch_size=batch_size,
-                                             num_workers=nw,
-                                             pin_memory=True,
-                                             collate_fn=dataset.collate_fn)
+    testloader = create_dataloader(test_path, imgsz_test, batch_size, gs, opt,
+                                            hyp=hyp, augment=False, cache=opt.cache_images, rect=True)[0]
 
     # Model parameters
     hyp['cls'] *= nc / 80.  # scale coco-tuned hyp['cls'] to current dataset
@@ -218,7 +195,7 @@ def train(hyp):
     maps = np.zeros(nc)  # mAP per class
     results = (0, 0, 0, 0, 0, 0, 0)  # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
     print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
-    print('Using %g dataloader workers' % nw)
+    print('Using %g dataloader workers' % dataloader.num_workers)
     print('Starting training for %g epochs...' % epochs)
     # torch.autograd.set_detect_anomaly(True)
     for epoch in range(start_epoch, epochs):  # epoch ------------------------------------------------------------------
diff --git a/utils/datasets.py b/utils/datasets.py
index 37d773efa394..00f23384295a 100755
--- a/utils/datasets.py
+++ b/utils/datasets.py
@@ -41,6 +41,26 @@ def exif_size(img):
     return s
 
 
+def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False):
+    dataset = LoadImagesAndLabels(path, imgsz, batch_size,
+                                  augment=augment,  # augment images
+                                  hyp=hyp,  # augmentation hyperparameters
+                                  rect=rect,  # rectangular training
+                                  cache_images=cache,
+                                  single_cls=opt.single_cls,
+                                  stride=stride,
+                                  pad=pad)
+
+    batch_size = min(batch_size, len(dataset))
+    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 0])  # number of workers
+    dataloader = torch.utils.data.DataLoader(dataset,
+                                             batch_size=batch_size,
+                                             num_workers=nw,
+                                             pin_memory=True,
+                                             collate_fn=LoadImagesAndLabels.collate_fn)
+    return dataloader, dataset
+
+
 class LoadImages:  # for inference
     def __init__(self, path, img_size=416):
         path = str(Path(path))  # os-agnostic
@@ -712,7 +732,7 @@ def random_affine(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10,
         area = w * h
         area0 = (targets[:, 3] - targets[:, 1]) * (targets[:, 4] - targets[:, 2])
         ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16))  # aspect ratio
-        i = (w > 4) & (h > 4) & (area / (area0 * s + 1e-16) > 0.2) & (ar < 10)
+        i = (w > 2) & (h > 2) & (area / (area0 * s + 1e-16) > 0.2) & (ar < 20)
 
         targets = targets[i]
         targets[:, 1:5] = xy[i]