-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
@yang-yucheng03
committed
May 27, 2023
0 parents
commit 2a703ca
Showing
480 changed files
with
72,989 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Contents | ||
|
||
- [Contents](#contents) | ||
- [CMT Description](#cmt-description) | ||
- [Model architecture](#model-architecture) | ||
- [Dataset](#dataset) | ||
- [Environment Requirements](#environment-requirements) | ||
- [Script description](#script-description) | ||
- [Script and sample code](#script-and-sample-code) | ||
- [Eval process](#eval-process) | ||
- [Usage](#usage) | ||
- [Launch](#launch) | ||
- [Result](#result) | ||
- [Description of Random Situation](#description-of-random-situation) | ||
- [ModelZoo Homepage](#modelzoo-homepage) | ||
|
||
## [CMT Description](#contents) | ||
|
||
This paper aims to develop a network that can outperform not only the canonical transformers, but also the high-performance convolutional models. We propose a new transformer based hybrid network by taking advantage of transformers to capture long-range dependencies, and of CNNs to model local features. Furthermore, we scale it to obtain a family of models, called CMTs, obtaining much better accuracy and efficiency than previous convolution and transformer based models. | ||
|
||
[Paper](https://arxiv.org/pdf/2107.06263.pdf): Jianyuan Guo, Kai Han, Han Wu, Chang Xu, Yehui Tang, Chunjing Xu, Yunhe Wang. CMT: Convolutional Neural Networks Meet Vision Transformers. Accepted in CVPR 2022. | ||
|
||
## [Model architecture](#contents) | ||
|
||
A block of CMT is shown below: | ||
|
||
![image-20211026160438718](./fig/CMT.PNG) | ||
|
||
## [Dataset](#contents) | ||
|
||
Dataset used: [ImageNet2012] | ||
|
||
- Dataset size 224*224 colorful images in 1000 classes | ||
- Train:1,281,167 images | ||
- Test: 50,000 images | ||
- Data format:jpeg | ||
- Note:Data will be processed in dataset.py | ||
|
||
## [Environment Requirements](#contents) | ||
|
||
- Hardware(Ascend/GPU) | ||
- Prepare hardware environment with Ascend or GPU. | ||
- Framework | ||
- [MindSpore](https://www.mindspore.cn/install/en) | ||
- For more information, please check the resources below£º | ||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html) | ||
- [MindSpore Python API](https://www.mindspore.cn/docs/en/master/index.html) | ||
|
||
## [Script description](#contents) | ||
|
||
### [Script and sample code](#contents) | ||
|
||
```bash | ||
CMT | ||
├── eval.py # inference entry | ||
├── fig | ||
│ └── CMT.PNG # the illustration of CMT network | ||
├── readme.md # Readme | ||
└── src | ||
├── dataset.py # dataset loader | ||
└── cmt.py # CMT network | ||
``` | ||
|
||
## [Eval process](#contents) | ||
|
||
### Usage | ||
|
||
After installing MindSpore via the official website, you can start evaluation as follows: | ||
|
||
### Launch | ||
|
||
```bash | ||
# CMT infer example | ||
GPU: python eval.py --model cmt --dataset_path dataset_path --platform GPU --checkpoint_path [CHECKPOINT_PATH] | ||
``` | ||
|
||
> checkpoint can be downloaded at https://download.mindspore.cn/model_zoo/. | ||
### Result | ||
|
||
```bash | ||
result: {'acc': 0.832} ckpt= ./cmt_s_ms.ckpt | ||
``` | ||
|
||
## [Description of Random Situation](#contents) | ||
|
||
In dataset.py, we set the seed inside "create_dataset" function. We also use random seed in train.py. | ||
|
||
## [ModelZoo Homepage](#contents) | ||
|
||
Please check the official [homepage](https://gitee.com/mindspore/models). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright 2022 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
""" | ||
eval. | ||
""" | ||
import os | ||
import argparse | ||
from mindspore import context | ||
from mindspore import nn | ||
from mindspore.train.model import Model | ||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||
from src.dataset import create_dataset | ||
from src.cmt import cmt_s | ||
|
||
parser = argparse.ArgumentParser(description='Image classification') | ||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | ||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | ||
parser.add_argument('--platform', type=str, default='Ascend', help='run platform') | ||
parser.add_argument('--model', type=str, default='cmt', help='eval model') | ||
args_opt = parser.parse_args() | ||
|
||
|
||
if __name__ == '__main__': | ||
config_platform = None | ||
if args_opt.platform == "Ascend": | ||
device_id = int(os.getenv('DEVICE_ID')) | ||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", | ||
device_id=device_id, save_graphs=False) | ||
elif args_opt.platform == "GPU": | ||
context.set_context(mode=context.GRAPH_MODE, | ||
device_target=args_opt.platform, save_graphs=False) | ||
else: | ||
raise ValueError("Unsupported platform.") | ||
|
||
if args_opt.model == 'cmt': | ||
net = cmt_s() | ||
else: | ||
raise ValueError("Unsupported model.") | ||
|
||
if args_opt.checkpoint_path: | ||
param_dict = load_checkpoint(args_opt.checkpoint_path) | ||
load_param_into_net(net, param_dict) | ||
net.set_train(False) | ||
|
||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | ||
|
||
dataset = create_dataset(args_opt.dataset_path, do_train=False, batch_size=128) | ||
|
||
model = Model(net, loss_fn=loss, metrics={'acc'}) | ||
res = model.eval(dataset, dataset_sink_mode=False) | ||
print("result:", res, "ckpt=", args_opt.checkpoint_path) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#!/bin/bash | ||
# Copyright 2022 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
if [ $# -lt 3 ] | ||
then | ||
echo "Usage: bash ./scripts/run_cmt_eval.sh [DATA_PATH] [PLATFORM] [CHECKPOINT_PATH]" | ||
exit 1 | ||
fi | ||
|
||
DATA_PATH=$1 | ||
PLATFORM=$2 | ||
CHECKPOINT_PATH=$3 | ||
|
||
rm -rf evaluation | ||
mkdir ./evaluation | ||
cd ./evaluation || exit | ||
echo "start training for device id $DEVICE_ID" | ||
env > env.log | ||
python eval.py --model cmt --dataset_path=$DATA_PATH --platform=$PLATFORM --checkpoint_path=$CHECKPOINT_PATH > eval.log 2>&1 & | ||
cd ../ |
Oops, something went wrong.