Skip to content

Commit 36bb03d

Browse files
committed
Add tests for training
1 parent c49f518 commit 36bb03d

File tree

6 files changed

+165
-32
lines changed

6 files changed

+165
-32
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,8 @@ summary/
3939
#Built graphs
4040
built_graph/
4141

42+
#Training checkpoints
43+
ckpt/*
44+
4245
#pytest cache
4346
.cache/

test/test_darkflow.py

+74-32
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,36 @@
1212
# locally if you don't want this happening!)
1313

1414
#Settings
15-
imgWidth = 640
16-
imgHeight = 424
1715
buildPath = os.environ.get("TRAVIS_BUILD_DIR")
1816

1917
if buildPath is None:
2018
print()
2119
print("TRAVIS_BUILD_DIR environment variable was not found - is this running on TravisCI?")
2220
print("If you want to test this locally, set TRAVIS_BUILD_DIR to the base directory of the cloned darkflow repository.")
2321
exit()
24-
testImgPath = os.path.join(buildPath, "sample_img", "sample_person.jpg")
25-
expectedDetectedObjectsV1 = [{"label": "dog","confidence": 0.46,"topleft": {"x": 84, "y": 249},"bottomright": {"x": 208,"y": 367}},
26-
{"label": "person","confidence": 0.60,"topleft": {"x": 159, "y": 102},"bottomright": {"x": 304,"y": 365}}]
2722

28-
expectedDetectedObjectsV2 = [{"label":"person","confidence":0.82,"topleft":{"x":189,"y":96},"bottomright":{"x":271,"y":380}},
29-
{"label":"dog","confidence":0.79,"topleft":{"x":69,"y":258},"bottomright":{"x":209,"y":354}},
30-
{"label":"horse","confidence":0.89,"topleft":{"x":397,"y":127},"bottomright":{"x":605,"y":352}}]
23+
testImg = {"path": os.path.join(buildPath, "sample_img", "sample_person.jpg"), "width": 640, "height": 424,
24+
"expected-objects": {"yolo-small": [{"label": "dog", "confidence": 0.46, "topleft": {"x": 84, "y": 249}, "bottomright": {"x": 208, "y": 367}},
25+
{"label": "person", "confidence": 0.60, "topleft": {"x": 159, "y": 102}, "bottomright": {"x": 304, "y": 365}}],
26+
"yolo": [{"label": "person", "confidence": 0.82, "topleft": {"x": 189, "y": 96}, "bottomright": {"x": 271, "y": 380}},
27+
{"label": "dog", "confidence": 0.79, "topleft": {"x": 69, "y": 258}, "bottomright": {"x": 209, "y": 354}},
28+
{"label": "horse", "confidence": 0.89, "topleft": {"x": 397, "y": 127}, "bottomright": {"x": 605, "y": 352}}]}}
29+
30+
trainImgBikePerson = {"path": os.path.join(buildPath, "test", "training", "images", "1.jpg"), "width": 500, "height": 375,
31+
"expected-objects": {"tiny-yolo-voc-TRAINED": [{"label": "bicycle", "confidence": 0.46, "topleft": {"x": 121, "y": 126}, "bottomright": {"x": 234, "y": 244}},
32+
{"label": "cow", "confidence": 0.54, "topleft": {"x": 262, "y": 218}, "bottomright": {"x": 385, "y": 311}},
33+
{"label": "person", "confidence": 0.70, "topleft": {"x": 132, "y": 34}, "bottomright": {"x": 232, "y": 167}}]}}
34+
35+
trainImgHorsePerson = {"path": os.path.join(buildPath, "test", "training", "images", "2.jpg"), "width": 500, "height": 332,
36+
"expected-objects": {"tiny-yolo-voc-TRAINED": [{"label": "horse", "confidence": 0.97, "topleft": {"x": 157, "y": 95}, "bottomright": {"x": 420, "y": 304}},
37+
{"label": "person", "confidence": 0.89, "topleft": {"x": 258, "y": 53}, "bottomright": {"x": 300, "y": 218}}]}}
38+
39+
3140
posCompareThreshold = 0.05 #Comparisons must match be within 5% of width/height when compared to expected value
3241
threshCompareThreshold = 0.1 #Comparisons must match within 0.1 of expected threshold for each prediction
33-
yoloDownloadV1 = "https://pjreddie.com/media/files/yolo-small.weights"
34-
yoloDownloadV2 = "https://pjreddie.com/media/files/yolo.weights"
42+
yolo_small_Download = "https://pjreddie.com/media/files/yolo-small.weights" #YOLOv1
43+
yolo_Download = "https://pjreddie.com/media/files/yolo.weights" #YOLOv2
44+
tiny_yolo_voc_Download = "https://pjreddie.com/media/files/tiny-yolo-voc.weights" #YOLOv2
3545

3646
def download_file(url, savePath):
3747
fileName = savePath.split("/")[-1]
@@ -47,19 +57,23 @@ def download_file(url, savePath):
4757
else:
4858
print("Found existing " + fileName + " file.")
4959

50-
yoloWeightPathV1 = os.path.join(buildPath, "bin", yoloDownloadV1.split("/")[-1])
51-
yoloCfgPathV1 = os.path.join(buildPath, "cfg", "v1", "{0}.cfg".format(os.path.splitext(os.path.basename(yoloWeightPathV1))[0]))
60+
yolo_small_WeightPath = os.path.join(buildPath, "bin", yolo_small_Download.split("/")[-1])
61+
yolo_small_CfgPath = os.path.join(buildPath, "cfg", "v1", "{0}.cfg".format(os.path.splitext(os.path.basename(yolo_small_WeightPath))[0]))
5262

53-
yoloWeightPathV2 = os.path.join(buildPath, "bin", yoloDownloadV2.split("/")[-1])
54-
yoloCfgPathV2 = os.path.join(buildPath, "cfg", "{0}.cfg".format(os.path.splitext(os.path.basename(yoloWeightPathV2))[0]))
63+
yolo_WeightPath = os.path.join(buildPath, "bin", yolo_Download.split("/")[-1])
64+
yolo_CfgPath = os.path.join(buildPath, "cfg", "{0}.cfg".format(os.path.splitext(os.path.basename(yolo_WeightPath))[0]))
5565

56-
pbPath = os.path.join(buildPath, "built_graph", os.path.splitext(os.path.basename(yoloWeightPathV2))[0] + ".pb")
57-
metaPath = os.path.join(buildPath, "built_graph", os.path.splitext(os.path.basename(yoloWeightPathV2))[0] + ".meta")
66+
tiny_yolo_voc_WeightPath = os.path.join(buildPath, "bin", tiny_yolo_voc_Download.split("/")[-1])
67+
tiny_yolo_voc_CfgPath = os.path.join(buildPath, "cfg", "{0}.cfg".format(os.path.splitext(os.path.basename(tiny_yolo_voc_WeightPath))[0]))
68+
69+
pbPath = os.path.join(buildPath, "built_graph", os.path.splitext(os.path.basename(yolo_WeightPath))[0] + ".pb")
70+
metaPath = os.path.join(buildPath, "built_graph", os.path.splitext(os.path.basename(yolo_WeightPath))[0] + ".meta")
5871

5972
generalConfigPath = os.path.join(buildPath, "cfg")
6073

61-
download_file(yoloDownloadV1, yoloWeightPathV1) #Check if we need to download (and if so download) the YOLOv1 weights
62-
download_file(yoloDownloadV2, yoloWeightPathV2) #Check if we need to download (and if so download) the YOLOv2 weights
74+
download_file(yolo_small_Download, yolo_small_WeightPath) #Check if we need to download (and if so download) the yolo-small weights (YOLOv1)
75+
download_file(yolo_Download, yolo_WeightPath) #Check if we need to download (and if so download) the yolo weights (YOLOv2)
76+
download_file(tiny_yolo_voc_Download, tiny_yolo_voc_WeightPath) #Check if we need to download (and if so download) the tiny-yolo-voc weights (YOLOv2)
6377

6478
def executeCLI(commandString):
6579
print()
@@ -95,43 +109,45 @@ def compareObjectData(defaultObjects, newObjects, width, height):
95109
return True
96110

97111
#Delete all images that won't be tested on so forwarding the whole folder doesn't take forever
98-
filelist = [f for f in os.listdir(os.path.dirname(testImgPath)) if os.path.isfile(os.path.join(os.path.dirname(testImgPath), f)) and f != os.path.basename(testImgPath)]
112+
filelist = [f for f in os.listdir(os.path.dirname(testImg["path"])) if os.path.isfile(os.path.join(os.path.dirname(testImg["path"]), f)) and f != os.path.basename(testImg["path"])]
99113
for f in filelist:
100-
os.remove(os.path.join(os.path.dirname(testImgPath), f))
114+
os.remove(os.path.join(os.path.dirname(testImg["path"]), f))
115+
101116

117+
#TESTS FOR INFERENCE
102118
def test_CLI_IMG_YOLOv2():
103119
#Test predictions outputted to an image using the YOLOv2 model through CLI
104120
#NOTE: This test currently does not verify anything about the image created (i.e. proper labeling, proper positioning of prediction boxes, etc.)
105121
# it simply verifies that the code executes properly and that the expected output image is indeed created in ./test/img/out
106122

107-
testString = "flow --imgdir {0} --model {1} --load {2} --config {3} --threshold 0.4".format(os.path.dirname(testImgPath), yoloCfgPathV2, yoloWeightPathV2, generalConfigPath)
123+
testString = "flow --imgdir {0} --model {1} --load {2} --config {3} --threshold 0.4".format(os.path.dirname(testImg["path"]), yolo_CfgPath, yolo_WeightPath, generalConfigPath)
108124
executeCLI(testString)
109125

110-
outputImgPath = os.path.join(os.path.dirname(testImgPath), "out", os.path.basename(testImgPath))
126+
outputImgPath = os.path.join(os.path.dirname(testImg["path"]), "out", os.path.basename(testImg["path"]))
111127
assert os.path.exists(outputImgPath), "Expected output image: {0} was not found.".format(outputImgPath)
112128

113129
def test_CLI_JSON_YOLOv2():
114130
#Test predictions outputted to a JSON file using the YOLOv2 model through CLI
115131
#NOTE: This test verifies that the code executes properly, the JSON file is created properly and the predictions generated are within a certain
116132
# margin of error when compared to the expected predictions.
117133

118-
testString = "flow --imgdir {0} --model {1} --load {2} --config {3} --threshold 0.4 --json".format(os.path.dirname(testImgPath), yoloCfgPathV2, yoloWeightPathV2, generalConfigPath)
134+
testString = "flow --imgdir {0} --model {1} --load {2} --config {3} --threshold 0.4 --json".format(os.path.dirname(testImg["path"]), yolo_CfgPath, yolo_WeightPath, generalConfigPath)
119135
executeCLI(testString)
120136

121-
outputJSONPath = os.path.join(os.path.dirname(testImgPath), "out", os.path.splitext(os.path.basename(testImgPath))[0] + ".json")
137+
outputJSONPath = os.path.join(os.path.dirname(testImg["path"]), "out", os.path.splitext(os.path.basename(testImg["path"]))[0] + ".json")
122138
assert os.path.exists(outputJSONPath), "Expected output JSON file: {0} was not found.".format(outputJSONPath)
123139

124140
with open(outputJSONPath) as json_file:
125141
loadedPredictions = json.load(json_file)
126142

127-
assert compareObjectData(expectedDetectedObjectsV2, loadedPredictions, imgWidth, imgHeight), "Generated object predictions from JSON were not within margin of error compared to expected values."
143+
assert compareObjectData(testImg["expected-objects"]["yolo"], loadedPredictions, testImg["width"], testImg["height"]), "Generated object predictions from JSON were not within margin of error compared to expected values."
128144

129145
def test_CLI_SAVEPB_YOLOv2():
130146
#Save .pb and .meta as generated from the YOLOv2 model through CLI
131-
#NOTE: This test verifies that the code executes properly, and the .pb and .meta files are successfully created. A subsequent test will verify the
147+
#NOTE: This test verifies that the code executes properly, and the .pb and .meta files are successfully created. The subsequent test will verify the
132148
# contents of those files.
133149

134-
testString = "flow --model {0} --load {1} --config {2} --threshold 0.4 --savepb".format(yoloCfgPathV2, yoloWeightPathV2, generalConfigPath)
150+
testString = "flow --model {0} --load {1} --config {2} --threshold 0.4 --savepb".format(yolo_CfgPath, yolo_WeightPath, generalConfigPath)
135151

136152
with pytest.raises(SystemExit):
137153
executeCLI(testString)
@@ -146,18 +162,44 @@ def test_RETURNPREDICT_PBLOAD_YOLOv2():
146162

147163
options = {"pbLoad": pbPath, "metaLoad": metaPath, "threshold": 0.4}
148164
tfnet = TFNet(options)
149-
imgcv = cv2.imread(testImgPath)
165+
imgcv = cv2.imread(testImg["path"])
150166
loadedPredictions = tfnet.return_predict(imgcv)
151167

152-
assert compareObjectData(expectedDetectedObjectsV2, loadedPredictions, imgWidth, imgHeight), "Generated object predictions from return_predict() were not within margin of error compared to expected values."
168+
assert compareObjectData(testImg["expected-objects"]["yolo"], loadedPredictions, testImg["width"], testImg["height"]), "Generated object predictions from return_predict() were not within margin of error compared to expected values."
153169

154170
def test_RETURNPREDICT_YOLOv1():
155171
#Test YOLOv1 using normal .weights and .cfg
156172
#NOTE: This test verifies that the code executes properly, and that the predictions generated are within the accepted margin of error to the expected predictions.
157173

158-
options = {"model": yoloCfgPathV1, "load": yoloWeightPathV1, "config": generalConfigPath, "threshold": 0.4}
174+
options = {"model": yolo_small_CfgPath, "load": yolo_small_WeightPath, "config": generalConfigPath, "threshold": 0.4}
175+
tfnet = TFNet(options)
176+
imgcv = cv2.imread(testImg["path"])
177+
loadedPredictions = tfnet.return_predict(imgcv)
178+
179+
assert compareObjectData(testImg["expected-objects"]["yolo-small"], loadedPredictions, testImg["width"], testImg["height"]), "Generated object predictions from return_predict() were not within margin of error compared to expected values."
180+
181+
#TESTS FOR TRAINING
182+
def test_TRAIN_FROM_WEIGHTS_CLI__LOAD_CHECKPOINT_RETURNPREDICT_YOLOv2():
183+
#Test training using pre-generated weights for tiny-yolo-voc
184+
#NOTE: This test verifies that the code executes properly, and that the expected checkpoint file (tiny-yolo-voc-20.meta in this case) is generated.
185+
# In addition, predictions are generated using the checkpoint file to verify that training completed successfully.
186+
187+
testString = "flow --model {0} --load {1} --train --dataset {2} --annotation {3} --epoch 20".format(tiny_yolo_voc_CfgPath, tiny_yolo_voc_WeightPath, os.path.join(buildPath, "test", "training", "images"), os.path.join(buildPath, "test", "training", "annotations"))
188+
with pytest.raises(SystemExit):
189+
executeCLI(testString)
190+
191+
checkpointPath = os.path.join(buildPath, "ckpt", "tiny-yolo-voc-20.meta")
192+
assert os.path.exists(checkpointPath), "Expected output checkpoint file: {0} was not found.".format(checkpointPath)
193+
194+
options = {"model": tiny_yolo_voc_CfgPath, "load": 20, "config": generalConfigPath, "threshold": 0.4}
159195
tfnet = TFNet(options)
160-
imgcv = cv2.imread(testImgPath)
196+
197+
#Make sure predictions match the expected values for image with bike and person
198+
imgcv = cv2.imread(trainImgBikePerson["path"])
161199
loadedPredictions = tfnet.return_predict(imgcv)
200+
assert compareObjectData(trainImgBikePerson["expected-objects"]["tiny-yolo-voc-TRAINED"], loadedPredictions, trainImgBikePerson["width"], trainImgBikePerson["height"]), "Generated object predictions from training were not within margin of error compared to expected values for the image with the bike and person.\nTraining may not have completed successfully."
162201

163-
assert compareObjectData(expectedDetectedObjectsV1, loadedPredictions, imgWidth, imgHeight), "Generated object predictions from return_predict() were not within margin of error compared to expected values."
202+
#Make sure predictions match the expected values for image with horse and person
203+
imgcv = cv2.imread(trainImgHorsePerson["path"])
204+
loadedPredictions = tfnet.return_predict(imgcv)
205+
assert compareObjectData(trainImgHorsePerson["expected-objects"]["tiny-yolo-voc-TRAINED"], loadedPredictions, trainImgHorsePerson["width"], trainImgHorsePerson["height"]), "Generated object predictions from training were not within margin of error compared to expected values for the image with the bike and person.\nTraining may not have completed successfully."

test/training/annotations/1.xml

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
<annotation>
2+
<folder>VOC2007</folder>
3+
<filename>1.jpg</filename>
4+
<source>
5+
<database>The VOC2007 Database</database>
6+
<annotation>PASCAL VOC2007</annotation>
7+
<image>flickr</image>
8+
<flickrid>336426776</flickrid>
9+
</source>
10+
<owner>
11+
<flickrid>Elder Timothy Chaves</flickrid>
12+
<name>Tim Chaves</name>
13+
</owner>
14+
<size>
15+
<width>500</width>
16+
<height>375</height>
17+
<depth>3</depth>
18+
</size>
19+
<segmented>0</segmented>
20+
<object>
21+
<name>person</name>
22+
<pose>Left</pose>
23+
<truncated>0</truncated>
24+
<difficult>0</difficult>
25+
<bndbox>
26+
<xmin>135</xmin>
27+
<ymin>25</ymin>
28+
<xmax>236</xmax>
29+
<ymax>188</ymax>
30+
</bndbox>
31+
</object>
32+
<object>
33+
<name>bicycle</name>
34+
<pose>Left</pose>
35+
<truncated>0</truncated>
36+
<difficult>0</difficult>
37+
<bndbox>
38+
<xmin>95</xmin>
39+
<ymin>85</ymin>
40+
<xmax>232</xmax>
41+
<ymax>253</ymax>
42+
</bndbox>
43+
</object>
44+
</annotation>

test/training/annotations/2.xml

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
<annotation>
2+
<folder>VOC2007</folder>
3+
<filename>2.jpg</filename>
4+
<source>
5+
<database>The VOC2007 Database</database>
6+
<annotation>PASCAL VOC2007</annotation>
7+
<image>flickr</image>
8+
<flickrid>329950741</flickrid>
9+
</source>
10+
<owner>
11+
<flickrid>Lothar Lenz</flickrid>
12+
<name>Lothar Lenz</name>
13+
</owner>
14+
<size>
15+
<width>500</width>
16+
<height>332</height>
17+
<depth>3</depth>
18+
</size>
19+
<segmented>0</segmented>
20+
<object>
21+
<name>person</name>
22+
<pose>Left</pose>
23+
<truncated>0</truncated>
24+
<difficult>0</difficult>
25+
<bndbox>
26+
<xmin>235</xmin>
27+
<ymin>51</ymin>
28+
<xmax>309</xmax>
29+
<ymax>222</ymax>
30+
</bndbox>
31+
</object>
32+
<object>
33+
<name>horse</name>
34+
<pose>Left</pose>
35+
<truncated>0</truncated>
36+
<difficult>0</difficult>
37+
<bndbox>
38+
<xmin>157</xmin>
39+
<ymin>106</ymin>
40+
<xmax>426</xmax>
41+
<ymax>294</ymax>
42+
</bndbox>
43+
</object>
44+
</annotation>

test/training/images/1.jpg

104 KB
Loading

test/training/images/2.jpg

90.1 KB
Loading

0 commit comments

Comments
 (0)