From 243564927c127da4cc3319e7cad12b2bc9da8e98 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Wed, 25 Oct 2017 22:45:09 -0700 Subject: [PATCH] fix version comparison (#200) --- python/nnvm/frontend/mxnet.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/python/nnvm/frontend/mxnet.py b/python/nnvm/frontend/mxnet.py index 6cbd497fd..39a118d24 100644 --- a/python/nnvm/frontend/mxnet.py +++ b/python/nnvm/frontend/mxnet.py @@ -13,14 +13,6 @@ def _get_nnvm_op(op_name): raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name)) return op -def _get_mxnet_version(): - try: - import mxnet as mx - version = mx.__version__ - except ImportError: - version = '0.11.1' - return [int(x) for x in version.split('.')] - def _required_attr(attr, key): assert isinstance(attr, dict) if key not in attr: @@ -127,14 +119,19 @@ def _conv2d_transpose(inputs, attrs): return _get_nnvm_op(op_name)(*inputs, **new_attrs) def _dense(inputs, attrs): + import mxnet as mx op_name, new_attrs = 'dense', {} new_attrs['units'] = _required_attr(attrs, 'num_hidden') new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias') - major, minor, micro = _get_mxnet_version() - if major >= 0 and minor >= 11 and micro >= 1: - use_flatten = _parse_bool_str(attrs, 'flatten', 'True') - if use_flatten: - inputs[0] = _sym.flatten(inputs[0]) + try: + _ = mx.sym.FullyConnected(mx.sym.var('x'), num_hidden=1, flatten=True) + has_flatten = True + except mx.base.MXNetError: + # no flatten attribute in old mxnet + has_flatten = False + use_flatten = _parse_bool_str(attrs, 'flatten', 'True') + if has_flatten and use_flatten: + inputs[0] = _sym.flatten(inputs[0]) return _get_nnvm_op(op_name)(*inputs, **new_attrs) def _dropout(inputs, attrs):