Skip to content

Commit

Permalink
修改当batch size大于样本总数量的时候data provider递归超限的bug
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Dec 12, 2018
1 parent c8327fd commit 6d2a963
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions data_provider/lanenet_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def next_batch(self, batch_size):
idx_start = batch_size * self._next_batch_loop_count
idx_end = batch_size * self._next_batch_loop_count + batch_size

if idx_start == 0 and idx_end > len(self._gt_label_binary_list):
raise ValueError('Batch size不能大于样本的总数量')

if idx_end > len(self._gt_label_binary_list):
self._random_dataset()
self._next_batch_loop_count = 0
Expand Down Expand Up @@ -121,9 +124,7 @@ def next_batch(self, batch_size):


if __name__ == '__main__':
val = DataSet('/home/baidu/DataBase/Semantic_Segmentation/Kitti_Vision/data_road/lanenet_training/train.txt')
a1, a2, a3 = val.next_batch(1)
cv2.imwrite('test_binary_label.png', a2[0] * 255)
val = DataSet('/media/baidu/Data/Semantic_Segmentation/TUSimple_Lane_Detection/training/val.txt')
b1, b2, b3 = val.next_batch(50)
c1, c2, c3 = val.next_batch(50)
dd, d2, d3 = val.next_batch(50)

0 comments on commit 6d2a963

Please sign in to comment.