Skip to content

Commit

Permalink
Update info about ResNeXt-14 (32x2d) model for IN1K
Browse files Browse the repository at this point in the history
  • Loading branch information
osmr committed Jul 13, 2019
1 parent cb0ecda commit d62810e
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 104 deletions.
42 changes: 21 additions & 21 deletions chainer_/chainercv2/models/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,13 @@

def get_model_name_suffix_data(model_name):
if model_name not in _model_sha1:
raise ValueError('Pretrained model for {name} is not available.'.format(name=model_name))
raise ValueError("Pretrained model for {name} is not available.".format(name=model_name))
error, sha1_hash, repo_release_tag = _model_sha1[model_name]
return error, sha1_hash, repo_release_tag


def get_model_file(model_name,
local_model_store_dir_path=os.path.join('~', '.chainer', 'models')):
local_model_store_dir_path=os.path.join("~", ".chainer", "models")):
"""
Return location for the pretrained on local file system. This function will download from online model zoo when
model cannot be found or has mismatch. The root directory will be created if it doesn't exist.
Expand All @@ -425,7 +425,7 @@ def get_model_file(model_name,
"""
error, sha1_hash, repo_release_tag = get_model_name_suffix_data(model_name)
short_sha1 = sha1_hash[:8]
file_name = '{name}-{error}-{short_sha1}.npz'.format(
file_name = "{name}-{error}-{short_sha1}.npz".format(
name=model_name,
error=error,
short_sha1=short_sha1)
Expand All @@ -435,16 +435,16 @@ def get_model_file(model_name,
if _check_sha1(file_path, sha1_hash):
return file_path
else:
logging.warning('Mismatch in the content of model file detected. Downloading again.')
logging.warning("Mismatch in the content of model file detected. Downloading again.")
else:
logging.info('Model file not found. Downloading to {}.'.format(file_path))
logging.info("Model file not found. Downloading to {}.".format(file_path))

if not os.path.exists(local_model_store_dir_path):
os.makedirs(local_model_store_dir_path)

zip_file_path = file_path + '.zip'
zip_file_path = file_path + ".zip"
_download(
url='{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip'.format(
url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format(
repo_url=imgclsmob_repo_url,
repo_release_tag=repo_release_tag,
file_name=file_name),
Expand All @@ -457,7 +457,7 @@ def get_model_file(model_name,
if _check_sha1(file_path, sha1_hash):
return file_path
else:
raise ValueError('Downloaded file has different hash. Please try again.')
raise ValueError("Downloaded file has different hash. Please try again.")


def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
Expand Down Expand Up @@ -494,21 +494,21 @@ class requests_failed_to_import(object):
requests = requests_failed_to_import

if path is None:
fname = url.split('/')[-1]
fname = url.split("/")[-1]
# Empty filenames are invalid
assert fname, 'Can\'t construct file-name from this URL. Please set the `path` option manually.'
assert fname, "Can't construct file-name from this URL. Please set the `path` option manually."
else:
path = os.path.expanduser(path)
if os.path.isdir(path):
fname = os.path.join(path, url.split('/')[-1])
fname = os.path.join(path, url.split("/")[-1])
else:
fname = path
assert retries >= 0, "Number of retries should be at least 0"

if not verify_ssl:
warnings.warn(
'Unverified HTTPS request is being made (verify_ssl=False). '
'Adding certificate verification is strongly advised.')
"Unverified HTTPS request is being made (verify_ssl=False). "
"Adding certificate verification is strongly advised.")

if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
Expand All @@ -518,27 +518,27 @@ class requests_failed_to_import(object):
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
print('Downloading {} from {}...'.format(fname, url))
print("Downloading {} from {}...".format(fname, url))
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
raise RuntimeError("Failed downloading url {}".format(url))
with open(fname, 'wb') as f:
with open(fname, "wb") as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
if sha1_hash and not _check_sha1(fname, sha1_hash):
raise UserWarning('File {} is downloaded but the content hash does not match.'
' The repo may be outdated or download may be incomplete. '
'If the "repo_url" is overridden, consider switching to '
'the default repo.'.format(fname))
raise UserWarning("File {} is downloaded but the content hash does not match."
" The repo may be outdated or download may be incomplete. "
"If the 'repo_url' is overridden, consider switching to "
"the default repo.".format(fname))
break
except Exception as e:
retries -= 1
if retries <= 0:
raise e
else:
print("download failed, retrying, {} attempt{} left"
.format(retries, 's' if retries > 1 else ''))
.format(retries, "s" if retries > 1 else ""))

return fname

Expand All @@ -559,7 +559,7 @@ def _check_sha1(filename, sha1_hash):
Whether the file content matches the expected hash.
"""
sha1 = hashlib.sha1()
with open(filename, 'rb') as f:
with open(filename, "rb") as f:
while True:
data = f.read(1048576)
if not data:
Expand Down
16 changes: 8 additions & 8 deletions gluon/gluoncv2/models/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,13 @@

def get_model_name_suffix_data(model_name):
if model_name not in _model_sha1:
raise ValueError('Pretrained model for {name} is not available.'.format(name=model_name))
raise ValueError("Pretrained model for {name} is not available.".format(name=model_name))
error, sha1_hash, repo_release_tag = _model_sha1[model_name]
return error, sha1_hash, repo_release_tag


def get_model_file(model_name,
local_model_store_dir_path=os.path.join('~', '.mxnet', 'models')):
local_model_store_dir_path=os.path.join("~", ".mxnet", "models")):
"""
Return location for the pretrained on local file system. This function will download from online model zoo when
model cannot be found or has mismatch. The root directory will be created if it doesn't exist.
Expand All @@ -492,7 +492,7 @@ def get_model_file(model_name,
"""
error, sha1_hash, repo_release_tag = get_model_name_suffix_data(model_name)
short_sha1 = sha1_hash[:8]
file_name = '{name}-{error}-{short_sha1}.params'.format(
file_name = "{name}-{error}-{short_sha1}.params".format(
name=model_name,
error=error,
short_sha1=short_sha1)
Expand All @@ -502,16 +502,16 @@ def get_model_file(model_name,
if check_sha1(file_path, sha1_hash):
return file_path
else:
logging.warning('Mismatch in the content of model file detected. Downloading again.')
logging.warning("Mismatch in the content of model file detected. Downloading again.")
else:
logging.info('Model file not found. Downloading to {}.'.format(file_path))
logging.info("Model file not found. Downloading to {}.".format(file_path))

if not os.path.exists(local_model_store_dir_path):
os.makedirs(local_model_store_dir_path)

zip_file_path = file_path + '.zip'
zip_file_path = file_path + ".zip"
download(
url='{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip'.format(
url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format(
repo_url=imgclsmob_repo_url,
repo_release_tag=repo_release_tag,
file_name=file_name),
Expand All @@ -524,4 +524,4 @@ def get_model_file(model_name,
if check_sha1(file_path, sha1_hash):
return file_path
else:
raise ValueError('Downloaded file has different hash. Please try again.')
raise ValueError("Downloaded file has different hash. Please try again.")
56 changes: 28 additions & 28 deletions keras_/kerascv/models/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def get_model_name_suffix_data(model_name):


def get_model_file(model_name,
local_model_store_dir_path=os.path.join('~', '.keras', 'models')):
local_model_store_dir_path=os.path.join("~", ".keras", "models")):
"""
Return location for the pretrained on local file system. This function will download from online model zoo when
model cannot be found or has mismatch. The root directory will be created if it doesn't exist.
Expand Down Expand Up @@ -252,9 +252,9 @@ class requests_failed_to_import(object):
requests = requests_failed_to_import

if path is None:
fname = url.split('/')[-1]
fname = url.split("/")[-1]
# Empty filenames are invalid
assert fname, "Can\'t construct file-name from this URL. Please set the `path` option manually."
assert fname, "Can't construct file-name from this URL. Please set the `path` option manually."
else:
path = os.path.expanduser(path)
if os.path.isdir(path):
Expand Down Expand Up @@ -295,7 +295,7 @@ class requests_failed_to_import(object):
raise e
else:
print("download failed, retrying, {} attempt{} left"
.format(retries, 's' if retries > 1 else ''))
.format(retries, "s" if retries > 1 else ""))

return fname

Expand All @@ -316,7 +316,7 @@ def _check_sha1(filename, sha1_hash):
Whether the file content matches the expected hash.
"""
sha1 = hashlib.sha1()
with open(filename, 'rb') as f:
with open(filename, "rb") as f:
while True:
data = f.read(1048576)
if not data:
Expand Down Expand Up @@ -372,41 +372,41 @@ def _load_weights_from_hdf5_group(f,
if weights:
filtered_layers.append(layer)

layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
layer_names = load_attributes_from_hdf5_group(f, "layer_names")
filtered_layer_names = []
for name in layer_names:
g = f[name]
weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
weight_names = load_attributes_from_hdf5_group(g, "weight_names")
if weight_names:
filtered_layer_names.append(name)
layer_names = filtered_layer_names
if len(layer_names) != len(filtered_layers):
raise ValueError('You are trying to load a weight file '
'containing ' + str(len(layer_names)) +
' layers into a model with ' +
str(len(filtered_layers)) + ' layers.')
raise ValueError("You are trying to load a weight file "
"containing " + str(len(layer_names)) +
" layers into a model with " +
str(len(filtered_layers)) + " layers.")

weight_value_tuples = []
for k, name in enumerate(layer_names):
g = f[name]
weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
weight_names = load_attributes_from_hdf5_group(g, "weight_names")
weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
layer = filtered_layers[k]
symbolic_weights = layer.weights
weight_values = _preprocess_weights_for_loading(
layer=layer,
weights=weight_values)
if len(weight_values) != len(symbolic_weights):
raise ValueError('Layer #' + str(k) +
' (named "' + layer.name +
'" in the current model) was found to '
'correspond to layer ' + name +
' in the save file. '
'However the new layer ' + layer.name +
' expects ' + str(len(symbolic_weights)) +
' weights, but the saved weights have ' +
raise ValueError("Layer #" + str(k) +
" (named `" + layer.name +
"` in the current model) was found to "
"correspond to layer " + name +
" in the save file. "
"However the new layer " + layer.name +
" expects " + str(len(symbolic_weights)) +
" weights, but the saved weights have " +
str(len(weight_values)) +
' elements.')
" elements.")
weight_value_tuples += zip(symbolic_weights, weight_values)
K.batch_set_value(weight_value_tuples)

Expand All @@ -424,7 +424,7 @@ def _load_weights_from_hdf5_group_by_name(f,
List of target layers.
"""
# New file format.
layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
layer_names = load_attributes_from_hdf5_group(f, "layer_names")

# Reverse index of layer name to list of layers with name.
index = {}
Expand All @@ -435,7 +435,7 @@ def _load_weights_from_hdf5_group_by_name(f,
weight_value_tuples = []
for k, name in enumerate(layer_names):
g = f[name]
weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
weight_names = load_attributes_from_hdf5_group(g, "weight_names")
weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]

for layer in index.get(name, []):
Expand All @@ -444,15 +444,15 @@ def _load_weights_from_hdf5_group_by_name(f,
layer=layer,
weights=weight_values)
if len(weight_values) != len(symbolic_weights):
warnings.warn('Skipping loading of weights for layer {} due to mismatch in number of weights ({} vs'
' {}).'.format(layer, len(symbolic_weights), len(weight_values)))
warnings.warn("Skipping loading of weights for layer {} due to mismatch in number of weights ({} vs"
" {}).".format(layer, len(symbolic_weights), len(weight_values)))
continue
# Set values.
for i in range(len(weight_values)):
symbolic_shape = K.int_shape(symbolic_weights[i])
if symbolic_shape != weight_values[i].shape:
warnings.warn('Skipping loading of weights for layer {} due to mismatch in shape ({} vs'
' {}).'.format(layer, symbolic_weights[i].shape, weight_values[i].shape))
warnings.warn("Skipping loading of weights for layer {} due to mismatch in shape ({} vs"
" {}).".format(layer, symbolic_weights[i].shape, weight_values[i].shape))
continue
else:
weight_value_tuples.append((symbolic_weights[i],
Expand Down Expand Up @@ -498,7 +498,7 @@ def load_model(net,

def download_model(net,
model_name,
local_model_store_dir_path=os.path.join('~', '.keras', 'models')):
local_model_store_dir_path=os.path.join("~", ".keras", "models")):
"""
Load model state dictionary from a file with downloading it if necessary.
Expand Down
Loading

0 comments on commit d62810e

Please sign in to comment.