Skip to content

Commit

Permalink
updated dwa with a fix on the way tensor scalar items were being upda…
Browse files Browse the repository at this point in the history
…ted. Also added Early Stopping Criterion.
  • Loading branch information
vskadandale committed Feb 9, 2020
1 parent e188eb8 commit 3ad993a
Show file tree
Hide file tree
Showing 23 changed files with 93,336 additions and 110 deletions.
Binary file modified __pycache__/settings.cpython-36.pyc
Binary file not shown.
Binary file modified dataset/__pycache__/dataloaders.cpython-36.pyc
Binary file not shown.
1 change: 1 addition & 0 deletions dataset/dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
import torch
import torch.utils.data
import numpy as np
import random
from settings import *
Expand Down
6 changes: 1 addition & 5 deletions dataset/downsample_gt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import sys

sys.path.append('../')
from pydub import AudioSegment
from settings import *
import pandas as pd
import numpy as np
from utils import create_folder
from utils.utils import create_folder
import librosa
from settings import *

Expand Down
2 changes: 1 addition & 1 deletion dataset/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import partial
import librosa
import torch
from utils import create_folder
from utils.utils import create_folder
import librosa.display
from sklearn.model_selection import train_test_split
import shutil
Expand Down
2 changes: 1 addition & 1 deletion eval/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from settings import *
import pandas as pd
import numpy as np
from utils import create_folder
from utils.utils import create_folder
import librosa
import mir_eval

Expand Down
2 changes: 1 addition & 1 deletion eval/stitch_audio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
sys.path.append('../')
from pydub import AudioSegment
from utils import create_folder
from utils.utils import create_folder
from settings import *

SAMPLING_RATE=TARGET_SAMPLING_RATE
Expand Down
Binary file modified models/__pycache__/wrapper.cpython-36.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion models/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from utils import warpgrid
from utils.utils import warpgrid
import torch.nn.functional as F
from settings import *

Expand Down
51 changes: 13 additions & 38 deletions settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ def set_path(path):
return path


TYPE = '2src' # '4src'
ISOLATED = True
TYPE = '4src' # '4src'
ISOLATED = False
ISOLATED_SOURCE_ID = 0
SOURCES = ['vocals', 'accompaniment', 'drums', 'bass', 'other']
if TYPE == '2src':
Expand All @@ -29,7 +29,7 @@ def set_path(path):
EPOCHS = 60000 # 500
DWA_TEMP = 2
MOMENTUM = 0.9
DROPOUT = 0.05
DROPOUT = 0.1
WEIGHT_DECAY = 0
INITIALIZER = 'xavier'
OPTIMIZER = 'SGD'
Expand All @@ -38,6 +38,7 @@ def set_path(path):
TRACKGRAD = False
ACTIVATION = None
INPUT_CHANNELS = 1
EARLY_STOPPING_PATIENCE = 70

# CUNet Settings
FILTERS_LAYER_1 = 16
Expand All @@ -47,8 +48,6 @@ def set_path(path):
N_CONDITIONS = 1008 # 4064
N_NEURONS = [16, 128, 1024]



##### ENERGY STATS #####
ACC_ENERGY = 687.5261
BAS_ENERGY = 252.7046
Expand All @@ -57,53 +56,29 @@ def set_path(path):
OTH_ENERGY = 216.4932
VOC_ENERGY = 173.4346

###### WEIGHTS #######
PRETRAINED_UNET_CONFIG = '2019-12-01 15:01:42'

#### TENSORBOARD CONFIG #####
PARAMETER_SAVE_FREQUENCY = 100

"""
MUSDB_FOLDER_PATH='/media/venkatesh/slave/dataset/musdb'
INDIAN_SAMPLE_DATA='/media/venkatesh/slave/dataset/Indian_Music/sample/X'
EXPERIMENTS_FOLDER='/media/venkatesh/slave/weights'
DUMPS_FOLDER='/media/venkatesh/slave/dumps'
SEPERATED_OUTPUT_PATH='/media/venkatesh/slave/dataset/output_crumbs/'
OUTPUT_PATH='/media/venkatesh/slave/dataset/output'
RESULTS_PATH='/media/venkatesh/slave/dataset/results'
"""
##### Main Directory Path #####
# MAIN_DIR_PATH = '/media/venkatesh/slave'
MAIN_DIR_PATH = '/mnt/DATA'
# MAIN_DIR_PATH = '/homedtic/vshenoykadandale'

MUSDB_FOLDER_PATH = '/mnt/DATA/datasets/musdb/'
INDIAN_SAMPLE_DATA = '/mnt/DATA/datasets/Indian_Music/sample/X'
EXPERIMENTS_FOLDER = '/mnt/DATA/weights'
DUMPS_FOLDER = '/mnt/DATA/dumps'
SEPERATED_OUTPUT_PATH = '/mnt/DATA/datasets/musdb/output_crumbs'
OUTPUT_PATH = '/mnt/DATA/datasets/musdb/output'
RESULTS_PATH = '/mnt/DATA/datasets/musdb/results'

ROOT_DIR = set_path(EXPERIMENTS_FOLDER)
PRETRAINED_UNET_WEIGHTS_PATH = os.path.join(EXPERIMENTS_FOLDER, PRETRAINED_UNET_CONFIG, 'bestcheckpoint.pth')
TEST_UNET_CONFIG = '2020-01-31 19:22:35' # '2020-01-03 11:42:35'#'2020-01-02 19:19:54'#'baseline'#'2020-01-01 20:03:30'#'2019-12-31 14:27:24'#'2019-12-18 18:53:17'

MUSDB_FOLDER_PATH = os.path.join(MAIN_DIR_PATH, 'dataset', 'musdb')
EXPERIMENTS_FOLDER = os.path.join(MAIN_DIR_PATH, 'weights')
DUMPS_FOLDER = os.path.join(MAIN_DIR_PATH, 'dumps')
ROOT_DIR = set_path(EXPERIMENTS_FOLDER)
TEST_UNET_WEIGHTS_PATH = os.path.join(EXPERIMENTS_FOLDER, TEST_UNET_CONFIG, 'bestcheckpoint.pth')
TEST_UNET_REFINED_CONFIG = '2019-12-18 18:53:17'
TEST_UNET_REFINED_WEIGHTS_PATH = os.path.join(EXPERIMENTS_FOLDER, TEST_UNET_REFINED_CONFIG, 'bestcheckpoint.pth')
RAW_MUSDB_PATH = os.path.join(MUSDB_FOLDER_PATH, 'musdb18')
MUSDB_WAVS_FOLDER_PATH = os.path.join(MUSDB_FOLDER_PATH, 'musdb18_wavs')
ENERGY_PROFILE_FOLDER = os.path.join(MUSDB_FOLDER_PATH, 'energy_profile')
MUSDB_SPLITS_PATH = os.path.join(MUSDB_FOLDER_PATH, 'musdbsplit')
MUSDB_SPLITS_AUG_PATH = os.path.join(MUSDB_FOLDER_PATH, 'musdbaug')
CHUNKS_PATH = os.path.join(MUSDB_FOLDER_PATH, 'musdb_chunks')
SPECTROGRAMS_PATH = os.path.join(MUSDB_FOLDER_PATH, 'musdb_spectrograms')
SOURCE_MIX_PATH = os.path.join(MUSDB_FOLDER_PATH, 'musdb_smix')
STEMS_PATH = os.path.join(MUSDB_FOLDER_PATH, 'musdb_stems')
RECON_GT_PATH = os.path.join(MUSDB_FOLDER_PATH, 'musdbGT')
SOURCE_ESTIMATES_PATH = os.path.join(MUSDB_FOLDER_PATH, 'eval')
TEST_DATA_PATH = os.path.join(RAW_MUSDB_PATH, 'test')
TEST_SPEC_DATA_PATH = os.path.join(MUSDB_SPLITS_PATH, 'test')
TEST_MAPPINGS_PATH = os.path.join(os.path.dirname(SEPERATED_OUTPUT_PATH), 'test_mappings.npy')

SOURCES_SUBSET_ID = [SOURCES.index(i) for i in SOURCES_SUBSET]
SAVE_SEGMENTS = True
ENERGY_THRESHOLD = 0

FILTERED_SAMPLE_PATHS = os.path.join(MUSDB_FOLDER_PATH, TYPE + '_filtered')
10 changes: 4 additions & 6 deletions test/baseline.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import sys
sys.path.append('..')
import shutil

from dataset.dataloaders import UnetInput
from flerken import pytorchfw
from flerken.models import UNet
from flerken.framework.pytorchframework import set_training, config, ctx_iter, \
classitems,checkpoint_on_key,assert_workdir
from flerken.framework import train, val
from flerken.framework.pytorchframework import set_training, config, ctx_iter
from flerken.framework import val
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import *
from utils.utils import *
from models.wrapper import Wrapper
from tqdm import tqdm
from loss.losses import *
from collections import OrderedDict
from collections import OrderedDict
from settings import *


Expand Down
7 changes: 3 additions & 4 deletions test/dwa.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import sys
sys.path.append('..')
import shutil

from dataset.dataloaders import UnetInput
from flerken import pytorchfw
from flerken.models import UNet
from flerken.framework.pytorchframework import set_training, config, ctx_iter, \
classitems,checkpoint_on_key,assert_workdir
from flerken.framework import train, val
classitems
from flerken.framework import val
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import *
from utils.utils import *
from models.wrapper import Wrapper
from tqdm import tqdm
from loss.losses import *
Expand Down
7 changes: 3 additions & 4 deletions test/energy_based.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import sys
sys.path.append('..')
import shutil

from dataset.dataloaders import UnetInput
from flerken import pytorchfw
from flerken.models import UNet
from flerken.framework.pytorchframework import set_training, config, ctx_iter, \
classitems,checkpoint_on_key,assert_workdir
from flerken.framework import train, val
classitems
from flerken.framework import val
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import *
from utils.utils import *
from models.wrapper import Wrapper
from tqdm import tqdm
from loss.losses import *
Expand Down
2 changes: 1 addition & 1 deletion test/unit_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flerken.framework.pytorchframework import set_training, config, ctx_iter, classitems
from flerken.framework import val
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import *
from utils.utils import *
from models.wrapper import Wrapper
from tqdm import tqdm
from loss.losses import *
Expand Down
15 changes: 12 additions & 3 deletions train/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from dataset.dataloaders import UnetInput
from flerken import pytorchfw
from flerken.models import UNet
from flerken.framework.pytorchframework import set_training, config, ctx_iter, classitems
from flerken.framework.pytorchframework import set_training, config, ctx_iter
from flerken.framework import train, val
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import *
from utils.utils import *
from utils.EarlyStopping import EarlyStopping
from models.wrapper import Wrapper
from tqdm import tqdm
from loss.losses import *
Expand All @@ -22,6 +23,7 @@ def __init__(self, model, rootdir, workname, main_device=0, trackgrad=False):
self.visual_dumps_path = os.path.join(DUMPS_FOLDER, 'visuals')
self.grid_unwarp = torch.from_numpy(
warpgrid(BATCH_SIZE, NFFT // 2 + 1, STFT_WIDTH, warp=False)).to('cuda')
self.EarlyStopChecker = EarlyStopping(patience=EARLY_STOPPING_PATIENCE)
self.val_iterations = 0

def print_args(self):
Expand Down Expand Up @@ -95,6 +97,13 @@ def train(self):
with val(self):
self.run_epoch()
self.__update_db__()
stop = self.EarlyStopChecker.check_improvement(self.loss_.data.tuple['val'].epoch_array.val,
self.epoch)
if stop:
print('Early Stopping Epoch : [{0}], '
'Best Checkpoint Epoch : [{1}]'.format(self.epoch,
self.EarlyStopChecker.best_epoch))
break

def train_epoch(self, logger):
j = 0
Expand Down Expand Up @@ -200,7 +209,7 @@ def main():
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

# SET MODEL
u_net = UNet([16, 32, 64, 128, 256, 512], 1, None, dropout=DROPOUT, verbose=False, useBN=True)
u_net = UNet([32, 64, 128, 256, 512, 1024, 2048], K, None, verbose=False, useBN=True, dropout=DROPOUT)
model = Wrapper(u_net)

if not os.path.exists(ROOT_DIR):
Expand Down
Loading

0 comments on commit 3ad993a

Please sign in to comment.