Skip to content

Commit

Permalink
Fix type bug
Browse files Browse the repository at this point in the history
Signed-off-by: cuiyanx <[email protected]>
  • Loading branch information
cuiyanx committed Mar 19, 2020
1 parent 9a4a38b commit 11afc8d
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions src/cts_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,21 +528,21 @@ def DumpJSTest(model, example, js_fd):

# set input and output types
for t in model.GetTypes():
if t.scale == 0.0 and t.zeroPoint == 0 and t.extraParams is None:
if t.type in ["FLOAT32", "INT32", "UINT32"]:
typeDef = " let %s = {type: nn.%s};"%(t, t.type)
else :
typeDef = " let %s = {type: nn.%s, dimensions: [%s]};\n let %s_length = product(%s.dimensions);"%(
t, t.type, t.GetDimensionsString()[1:-1], t, t)
if t.type in ["FLOAT32", "INT32", "FLOAT16", "UINT32"]:
typeDef = " let %s = {type: nn.%s};"%(t, t.type)
elif t.type in ["TENSOR_INT32", "TENSOR_FLOAT16", "TENSOR_FLOAT32"]:
typeDef = " let %s = {type: nn.%s, dimensions: [%s]};\n let %s_length = product(%s.dimensions);" \
%(t, t.type, t.GetDimensionsString()[1:-1], t, t)
elif t.type in ["TENSOR_QUANT8_ASYMM", "TENSOR_QUANT8_ASYMM_SIGNED"]:
typeDef = " let %s = {type: nn.%s, dimensions: [%s], scale: %s, zeroPoint: %d};\n let %s_length = product(%s.dimensions);" \
%(t, t.type, t.GetDimensionsString()[1:-1], tg.PrettyPrintAsFloat(t.scale)[:-1], t.zeroPoint, t, t)
elif t.type in ["TENSOR_QUANT8_SYMM_PER_CHANNEL"]:
typeDef = " let %s = {type: nn.%s, dimensions: [%s]};\n let %s_length = product(%s.dimensions);" \
%(t, t.type, t.GetDimensionsString()[1:-1], t, t)
per_channel_types[str(t)] = t.extraParams.GetJSConstructor()
else:
if t.extraParams is None or t.extraParams.hide:
typeDef = " let %s = {type: nn.%s, dimensions: [%s], scale: %s, zeroPoint: %d};\n let %s_length = product(%s.dimensions);"%(
t, t.type, t.GetDimensionsString()[1:-1], tg.PrettyPrintAsFloat(t.scale)[:-1], t.zeroPoint, t, t)
else:
typeDef = " let %s = {type: nn.%s, dimensions: [%s]};\n let %s_length = product(%s.dimensions);"%(
t, t.type, t.GetDimensionsString()[1:-1], t, t)

per_channel_types[str(t)] = t.extraParams.GetJSConstructor()
traceback.print_exc()
sys.exit("Cannot support tensor of type: {}".format(t.type))

print (typeDef, file = js_fd)
print ("", file = js_fd)
Expand Down

0 comments on commit 11afc8d

Please sign in to comment.