Skip to content

Commit

Permalink
55
Browse files Browse the repository at this point in the history
  • Loading branch information
@yang-yucheng03 committed May 27, 2023
0 parents commit 2a703ca
Show file tree
Hide file tree
Showing 480 changed files with 72,989 additions and 0 deletions.
91 changes: 91 additions & 0 deletions CMT/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Contents

- [Contents](#contents)
- [CMT Description](#cmt-description)
- [Model architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script description](#script-description)
- [Script and sample code](#script-and-sample-code)
- [Eval process](#eval-process)
- [Usage](#usage)
- [Launch](#launch)
- [Result](#result)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)

## [CMT Description](#contents)

This paper aims to develop a network that can outperform not only the canonical transformers, but also the high-performance convolutional models. We propose a new transformer based hybrid network by taking advantage of transformers to capture long-range dependencies, and of CNNs to model local features. Furthermore, we scale it to obtain a family of models, called CMTs, obtaining much better accuracy and efficiency than previous convolution and transformer based models.

[Paper](https://arxiv.org/pdf/2107.06263.pdf): Jianyuan Guo, Kai Han, Han Wu, Chang Xu, Yehui Tang, Chunjing Xu, Yunhe Wang. CMT: Convolutional Neural Networks Meet Vision Transformers. Accepted in CVPR 2022.

## [Model architecture](#contents)

A block of CMT is shown below:

![image-20211026160438718](./fig/CMT.PNG)

## [Dataset](#contents)

Dataset used: [ImageNet2012]

- Dataset size 224*224 colorful images in 1000 classes
- Train:1,281,167 images
- Test: 50,000 images
- Data format:jpeg
- Note:Data will be processed in dataset.py

## [Environment Requirements](#contents)

- Hardware(Ascend/GPU)
- Prepare hardware environment with Ascend or GPU.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below£º
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/en/master/index.html)

## [Script description](#contents)

### [Script and sample code](#contents)

```bash
CMT
├── eval.py # inference entry
├── fig
│ └── CMT.PNG # the illustration of CMT network
├── readme.md # Readme
└── src
├── dataset.py # dataset loader
└── cmt.py # CMT network
```

## [Eval process](#contents)

### Usage

After installing MindSpore via the official website, you can start evaluation as follows:

### Launch

```bash
# CMT infer example
GPU: python eval.py --model cmt --dataset_path dataset_path --platform GPU --checkpoint_path [CHECKPOINT_PATH]
```

> checkpoint can be downloaded at https://download.mindspore.cn/model_zoo/.
### Result

```bash
result: {'acc': 0.832} ckpt= ./cmt_s_ms.ckpt
```

## [Description of Random Situation](#contents)

In dataset.py, we set the seed inside "create_dataset" function. We also use random seed in train.py.

## [ModelZoo Homepage](#contents)

Please check the official [homepage](https://gitee.com/mindspore/models).
63 changes: 63 additions & 0 deletions CMT/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
eval.
"""
import os
import argparse
from mindspore import context
from mindspore import nn
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset import create_dataset
from src.cmt import cmt_s

parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--platform', type=str, default='Ascend', help='run platform')
parser.add_argument('--model', type=str, default='cmt', help='eval model')
args_opt = parser.parse_args()


if __name__ == '__main__':
config_platform = None
if args_opt.platform == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
device_id=device_id, save_graphs=False)
elif args_opt.platform == "GPU":
context.set_context(mode=context.GRAPH_MODE,
device_target=args_opt.platform, save_graphs=False)
else:
raise ValueError("Unsupported platform.")

if args_opt.model == 'cmt':
net = cmt_s()
else:
raise ValueError("Unsupported model.")

if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)

loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

dataset = create_dataset(args_opt.dataset_path, do_train=False, batch_size=128)

model = Model(net, loss_fn=loss, metrics={'acc'})
res = model.eval(dataset, dataset_sink_mode=False)
print("result:", res, "ckpt=", args_opt.checkpoint_path)
Binary file added CMT/fig/CMT.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
32 changes: 32 additions & 0 deletions CMT/scripts/run_cmt_eval.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 3 ]
then
echo "Usage: bash ./scripts/run_cmt_eval.sh [DATA_PATH] [PLATFORM] [CHECKPOINT_PATH]"
exit 1
fi

DATA_PATH=$1
PLATFORM=$2
CHECKPOINT_PATH=$3

rm -rf evaluation
mkdir ./evaluation
cd ./evaluation || exit
echo "start training for device id $DEVICE_ID"
env > env.log
python eval.py --model cmt --dataset_path=$DATA_PATH --platform=$PLATFORM --checkpoint_path=$CHECKPOINT_PATH > eval.log 2>&1 &
cd ../
Loading

0 comments on commit 2a703ca

Please sign in to comment.