Skip to content

Commit

Permalink
Make testsets unique in config.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhengPeng7 committed Oct 6, 2024
1 parent 31879ac commit 1785774
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 44 deletions.
36 changes: 23 additions & 13 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ def __init__(self) -> None:

# TASK settings
self.task = ['DIS5K', 'COD', 'HRSOD', 'General', 'General-2K', 'Matting'][0]
self.validation_set = {
'DIS5K': [],
'COD': [],
'HRSOD': [],
'General': ['DIS-VD', 'TE-P3M-500-NP'],
'General-2K': ['DIS-VD', 'TE-P3M-500-NP'],
'Matting': ['TE-P3M-500-NP'],
self.testsets = {
# Benchmarks
'DIS5K': ','.join(['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4']),
'COD': ','.join(['CHAMELEON', 'NC4K', 'TE-CAMO', 'TE-COD10K']),
'HRSOD': ','.join(['DAVIS-S', 'TE-HRSOD', 'TE-UHRSD', 'DUT-OMRON', 'TE-DUTS']),
# Practical use
'General': ','.join(['DIS-VD', 'TE-P3M-500-NP']),
'General-2K': ','.join(['DIS-VD', 'TE-P3M-500-NP']),
'Matting': ','.join(['TE-P3M-500-NP', 'TE-AM-2k']),
}[self.task]
datasets_all = '+'.join([ds for ds in (os.listdir(os.path.join(self.data_root_dir, self.task)) if os.path.isdir(os.path.join(self.data_root_dir, self.task)) else []) if ds not in self.validation_set])
datasets_all = '+'.join([ds for ds in (os.listdir(os.path.join(self.data_root_dir, self.task)) if os.path.isdir(os.path.join(self.data_root_dir, self.task)) else []) if ds not in self.testsets.split(',')])
self.training_set = {
'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0],
'COD': 'TR-COD10K+TR-CAMO',
Expand Down Expand Up @@ -184,11 +186,19 @@ def __init__(self) -> None:
self.save_last = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0])
self.save_step = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'step=' in l][0].split('step=')[-1].split()[0])

def print_task(self) -> None:
# Return task for choosing settings in shell scripts.
print(self.task)

# Return task for choosing settings in shell scripts.
if __name__ == '__main__':
import argparse


parser = argparse.ArgumentParser(description='Only choose one argument to activate.')
parser.add_argument('--print_task', action='store_true', help='print task name')
parser.add_argument('--print_testsets', action='store_true', help='print validation set')
args = parser.parse_args()

config = Config()
config.print_task()

for arg_name, arg_value in args._get_kwargs():
if arg_value:
print(config.__getattribute__(arg_name[len('print_'):]))

9 changes: 1 addition & 8 deletions eval_existingOnes.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,7 @@ def do_eval(args):
default='./e_preds')
parser.add_argument(
'--data_lst', type=str, help='test dataset',
default={
'DIS5K': '+'.join(['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4'][:]),
'COD': '+'.join(['TE-COD10K', 'NC4K', 'TE-CAMO', 'CHAMELEON'][:]),
'HRSOD': '+'.join(['DAVIS-S', 'TE-HRSOD', 'TE-UHRSD', 'TE-DUTS', 'DUT-OMRON'][:]),
'General': '+'.join(['DIS-VD', 'TE-P3M-500-NP'][:]),
'General-2K': '+'.join(['DIS-VD', 'TE-P3M-500-NP'][:]),
'Matting': '+'.join(['TE-AM-2k'][:]),
}[config.task])
default=config.testsets.replace(',', '+'))
parser.add_argument(
'--save_dir', type=str, help='candidate competitors',
default='e_results')
Expand Down
14 changes: 2 additions & 12 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,9 @@ def main(args):
parser.add_argument('--ckpt_folder', default=sorted(glob(os.path.join('ckpt', '*')))[-1], type=str, help='model folder')
parser.add_argument('--pred_root', default='e_preds', type=str, help='Output folder')
parser.add_argument('--testsets',
default={
'DIS5K': 'DIS-VD+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4',
'COD': 'TE-COD10K+NC4K+TE-CAMO+CHAMELEON',
'HRSOD': 'DAVIS-S+TE-HRSOD+TE-UHRSD+TE-DUTS+DUT-OMRON',
'General': 'DIS-VD+TE-P3M-500-NP',
'General-2K': 'DIS-VD+TE-P3M-500-NP',
'Matting': 'TE-AM-2k',
'DIS5K-': 'DIS-VD',
'COD-': 'TE-COD10K',
'SOD-': 'DAVIS-S+TE-HRSOD+TE-UHRSD',
}[config.task + ''],
default=config.testsets.replace(',', '+'),
type=str,
help="Test all sets: , 'DIS-VD+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'")
help="Test all sets: DIS5K -> 'DIS-VD+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'")

args = parser.parse_args()

Expand Down
12 changes: 3 additions & 9 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,9 @@ echo Inference finished at $(date)
# Evaluation
log_dir=e_logs && mkdir ${log_dir}

task=$(python3 config.py)
case "${task}" in
"DIS5K") testsets='DIS-VD,DIS-TE1,DIS-TE2,DIS-TE3,DIS-TE4' ;;
"COD") testsets='CHAMELEON,NC4K,TE-CAMO,TE-COD10K' ;;
"HRSOD") testsets='DAVIS-S,TE-HRSOD,TE-UHRSD,DUT-OMRON,TE-DUTS' ;;
"General") testsets='DIS-VD,TE-P3M-500-NP' ;;
"General-2K") testsets='DIS-VD,TE-P3M-500-NP' ;;
"Matting") testsets='TE-AM-2k' ;;
esac
task=$(python3 config.py --print_task)
testsets=$(python3 config.py --print_testsets)

testsets=(`echo ${testsets} | tr ',' ' '`) && testsets=${testsets[@]}

for testset in ${testsets}; do
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _train_batch(self, batch):
gts = batch[1].to(device)
class_labels = batch[2].to(device)
if config.use_fp16:
with amp.autocast(enabled=config.use_fp16):
with amp.autocast(enabled=config.use_fp16, dtype=(torch.float16, torch.bfloat16)[0]):
scaled_preds, class_preds_lst = self.model(inputs)
if config.out_ref:
(outs_gdt_pred, outs_gdt_label), scaled_preds = scaled_preds
Expand Down
2 changes: 1 addition & 1 deletion train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Run script
# Settings of training & test for different tasks.
method="$1"
task=$(python3 config.py)
task=$(python3 config.py --print_task)
case "${task}" in
"DIS5K") epochs=600 && val_last=50 && step=5 ;;
"COD") epochs=150 && val_last=50 && step=5 ;;
Expand Down

0 comments on commit 1785774

Please sign in to comment.