Skip to content

[MentorMix] "Beyond Synthetic Noise: Deep Learning on Controlled Noisy Labels" implemented in the PyTorch version.

Notifications You must be signed in to change notification settings

tangminji/MentorMix_pytorch

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

[MentorMix] "Beyond Synthetic Noise: Deep Learning on Controlled Noisy Labels" PyTorch Implementation

This repository implemented paper Beyond Synthetic Noise: Deep Learning on Controlled Noisy Labels in PyTorch version. Official code is here which is implemented by google-research with tensorflow.
Code of this repository provides training method from scratch with dataset CIFAR10/CIFAR100.

Requirements

torch 1.7.1
torchvision 0.8.1
tqdm
argparse

How to run

After you have cloned the repository, you can train each model from scratch with datasets CIFAR10, CIFAR100. Trainable models are ResNet. You can adjust this code if you want to train other kinds of architectures.

  • Using threshold function as MentorNet
python train.py --dataset cifar10
                --StudentNet ResNet34 --MentorNet threshold --MentorNet_type PD
                --optimizer SGD --scheduler StepLR
                --lr 0.1 --batch_size 128 --epoch 500 --wd 2e-4
                --noise_rate 0.2 
                --ema 0.0001
                --gamma_p 0.8 --alpha 2.
                --second_reweight
                --trial 0
                --gpu_id 0
  • Using DNN as MentorNet
    First, train MentorNet(Pre-Defined or Data-Driven).
python3 train_MentorNet.py  --dataset cifar10
                            --StudentNet ResNet34 --MentorNet MentorNet --MentorNet_type PD
                            --optimizer SGD --scheduler CosineAnnealing
                            --lr 0.1 --batch_size 32 --epoch 100 --wd 2e-4
                            --noise_rate 0.
                            --ema 0.05
                            --gamma_p 0.75
                            --train_MentorNet
                            --trial 0
                            --gpu_id 0

(If you train MentorNet in Data-Driven way, noise rate has to be the same as one when training StudentNet later.)

Second, train StudentNet with pre-trained MentorNet.

python train.py --dataset cifar10
                --StudentNet ResNet34 --MentorNet MentorNet --MentorNet_type DD
                --optimizer SGD --scheduler StepLR
                --lr 0.1 --batch_size 128 --epoch 500 --wd 2e-4
                --noise_rate 0.2 
                --ema 0.0001
                --gamma_p 0.8 --alpha 2.
                --second_reweight
                --trial 0
                --gpu_id 0

Implementation Details

Most of the hyperparameters refers to the values mentioned in the paper. However, some hyperparameters such as γp or α refers to the values used in the official code. Those hyperparameters are marked out according to the Noise Level below.

Hyperparameters referred by paper

epoch learning rate weight decay Optimizer Momentum Nesterov scheduler EMA second reweight
400 0.1 0.0002 SGD 0.9 False StepLR(0.9) 0.0001 True

Hyperparameters referred by Official Code

  • γp and α in CIFAR10
Noise Level 0.2 0.4 0.6 0.8
α 2 8 8 4
γp 0.8 0.6 0.6 0.2
second reweight False False True True
  • γp and α in CIFAR100
Noise Level 0.2 0.4 0.6 0.8
α 2 8 4 8
γp 0.7 0.5 0.3 0.1
second reweight False False True True

Accuracy

Below is the result of the test accuracy trained with ResNet34. Results are averaged over 3 repeated experiments of same circumstances.
(All values are percentiles.)

CIFAR10

Noise Level 0.2 0.4 0.6 0.8
Official 95.60 94.20 91.30 81.00
This repo 95.47 93.47 88.88 20.65

CIFAR100

Noise Level 0.2 0.4 0.6 0.8
Official 78.60 71.30 64.60 41.20
This repo 76.30 71.84 38.83 7.20

About

[MentorMix] "Beyond Synthetic Noise: Deep Learning on Controlled Noisy Labels" implemented in the PyTorch version.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%