Skip to content

Commit

Permalink
Update tf.py
Browse files Browse the repository at this point in the history
Fix Bug of exporting tflite model.
  • Loading branch information
hukaixuan19970627 authored Jan 22, 2022
1 parent bcffd1f commit 7e84de0
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions models/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detec
super().__init__()
self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor
self.no = nc + 185 # number of outputs per anchor
self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [tf.zeros(1)] * self.nl # init grid
Expand Down Expand Up @@ -272,7 +272,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
no = na * (nc + 185) # number of outputs = anchors * (classes + 5)

layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
Expand Down

0 comments on commit 7e84de0

Please sign in to comment.