Skip to content

Commit

Permalink
bugs fixed in train.py and update README
Browse files Browse the repository at this point in the history
  • Loading branch information
qqwweee committed May 4, 2018
1 parent ed8fbaa commit 3f93a89
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ python yolo.py OR python yolo_video.py
Box format: x_min,y_min,x_max,y_max,class_id (no space).
For VOC dataset, try `python voc_annotation.py`

2. Modify train.py and start training.
2. Make sure you have run `python convert.py yolov3.cfg yolov3.weights model_data/yolo.h5`
A file model_data/yolo_weights.h5 will be generated when you run train.py for the first time.
The file is used to load pretrained weights.

3. Modify train.py and start training.
`python train.py`
You will get the trained model model_data/my_yolo.h5.
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def create_model(input_shape, anchors, num_classes, load_pretrained=True, freeze
if not os.path.exists(weights_path):
print("CREATING WEIGHTS FILE" + weights_path)
yolo_path = os.path.join('model_data', 'yolo.h5')
model_body = load_model(yolo_path, compile=False)
model_body.save_weights(weights_path)
orig_model = load_model(yolo_path, compile=False)
orig_model.save_weights(weights_path)
model_body.load_weights(weights_path, by_name=True, skip_mismatch=True)
if freeze_body:
# Do not freeze 3 output layers.
Expand Down
4 changes: 2 additions & 2 deletions yolo3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def letterbox_image(image, size):
'''resize image with unchanged aspect ratio using padding'''
image_w, image_h = image.size
w, h = size
new_w = int(image_w * min(w/image_w, h/image_h))
new_h = int(image_h * min(w/image_w, h/image_h))
new_w = int(image_w * min(w*1.0/image_w, h*1.0/image_h))
new_h = int(image_h * min(w*1.0/image_w, h*1.0/image_h))
resized_image = image.resize((new_w,new_h), Image.BICUBIC)

boxed_image = Image.new('RGB', size, (128,128,128))
Expand Down

0 comments on commit 3f93a89

Please sign in to comment.