Skip to content

Commit

Permalink
update, added discriminative training for RAT-SPNs
Browse files Browse the repository at this point in the history
  • Loading branch information
R. Peharz committed Aug 1, 2019
1 parent 9831956 commit a11ab87
Showing 18 changed files with 879 additions and 24 deletions.
27 changes: 22 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# RAT-SPN
Code for UAI'19: Random Sum-Product Networks: A Simple and Effective Approach to Probabilistic Deep Learning

# V0.1
# V0.2
* RAT-SPN model
* Experiments for generative learning using EM
* Experiments for generative learning of RAT-SPNs using EM
* Experiments for discriminative learning of RAT-SPNs using Adam

# Quick Start
# Setup
git clone https://github.com/cambridge-mlg/RAT-SPN

cd RAT-SPN
@@ -16,6 +17,22 @@ source ratspn_venv/bin/activate

python download_preprocess_data.py

python quick_generative_rat_spn.py
# Quick Run for Generative Experiments
This will simply train a single RAT-SPN (no crossvalidation).

python quick_run_rat_spn_generative.py

python quick_eval_rat_spn_generative.py

# Quick Run for Discriminative Training on MNIST
This will simply train a single RAT-SPN for each depth.

python quick_run_rat_spn_mnist.py

quick_eval_rat_spn_discriminative.py

# Full Training
See the run_*.py and eval_*.py files



python eval_quick_generative_rat_spn.py
1 change: 1 addition & 0 deletions configurations.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"worker_time_limit": 42900}
8 changes: 4 additions & 4 deletions download_preprocess_data.py
Original file line number Diff line number Diff line change
@@ -63,7 +63,7 @@ def maybe_download_fashion_mnist():


def maybe_download_DEBD():
if os.path.isfile('data/DEBD'):
if os.path.isdir('data/DEBD'):
print('DEBD already exists')
return
subprocess.run(['git', 'clone', 'https://github.com/arranger1044/DEBD', 'data/DEBD'])
@@ -160,7 +160,7 @@ def process_imdb(out_path='data/imdb',
"""Adopted from keras/datasets/imdb/
"""

out_file = os.path.join(out_path, 'imdb-dense-nmf-{}.pklz'.format(max_topics))
out_file = os.path.join(out_path, 'imdb-dense-nmf-{}.pkl'.format(max_topics))
if os.path.isfile(out_file):
print('Already exists: {}'.format(out_file))
return
@@ -295,9 +295,9 @@ def process_imdb(out_path='data/imdb',
x_test.shape))

# saving to pickle
with gzip.open(out_file, 'wb') as f:
with open(out_file, 'wb') as f:
pickle.dump((x_train, y_train, x_valid, y_valid, x_test, y_test), f)
print('Saved to gzipped pickle to {}'.format(out_file))
print('Saved to pickle to {}'.format(out_file))

# return x_train, y_train, x_valid, y_valid, x_test, y_test

4 changes: 0 additions & 4 deletions eval_quick_generative_rat_spn.py

This file was deleted.

63 changes: 63 additions & 0 deletions eval_rat_spn_discriminative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np
import pickle
import os

datasets = ['mnist', 'fashion-mnist', 'wine', 'theorem', 'higgs', 'imdb']
result_basefolder = 'results/ratspn/'


def evaluate():

ls = os.listdir(result_basefolder)

for dataset in datasets:
print()
if dataset not in ls:
print('Results for {} not found.'.format(dataset))
continue

ls2 = os.listdir(os.path.join(result_basefolder, dataset))

best_valid_acc = -np.inf
best_test_acc = -np.inf
best_model = None
best_epoch = None
test_accs = []

for result_folder in ls2:

argdict = {}
for a in result_folder.split('__'):
last_ = a.rfind('_')
argdict[a[:last_]] = float(a[last_ + 1:])

try:
results = pickle.load(open('{}/{}/{}/results.pkl'.format(
result_basefolder,
dataset,
result_folder), "rb"))
except:
print()
print("can't load")
print(result_folder)
continue

valid_acc = results['best_valid_acc']
test_acc = results['test_ACC'][results['epoch_best_valid_acc']]
test_accs.append(test_acc)

if valid_acc > best_valid_acc:
best_valid_acc = valid_acc
best_test_acc = test_acc
best_model = argdict
best_epoch = results['epoch_best_valid_acc']

print('Test accuracy: {}'.format(best_test_acc))
print('Achieved by configuration:')
print(best_model)
print('in epoch {} with validation accuracy {}'.format(best_epoch, best_valid_acc))


if __name__ == '__main__':

evaluate()
File renamed without changes.
4 changes: 4 additions & 0 deletions quick_eval_rat_spn_discriminative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import eval_rat_spn_discriminative

eval_rat_spn_discriminative.result_basefolder = 'quick_results/ratspn/'
eval_rat_spn_discriminative.evaluate()
4 changes: 4 additions & 0 deletions quick_eval_rat_spn_generative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import eval_rat_spn_generative

eval_rat_spn_generative.result_basefolder = 'quick_results/ratspn/debd/'
eval_rat_spn_generative.evaluate()
9 changes: 0 additions & 9 deletions quick_generative_rat_spn.py

This file was deleted.

9 changes: 9 additions & 0 deletions quick_run_rat_spn_generative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import run_rat_spn_generative

run_rat_spn_generative.structure_dict = {}
run_rat_spn_generative.structure_dict[2] = [
{'num_recursive_splits': 10, 'num_input_distributions': 8, 'num_sums': 8}]
run_rat_spn_generative.base_result_path = "quick_results/ratspn/debd/"
run_rat_spn_generative.num_epochs = 20

run_rat_spn_generative.run()
16 changes: 16 additions & 0 deletions quick_run_rat_spn_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import run_rat_spn_mnist

run_rat_spn_mnist.structure_dict = {}
# depth 1
run_rat_spn_mnist.structure_dict[1] = [{'num_recursive_splits': 14, 'num_input_distributions': 15, 'num_sums': 10}]
# depth 2
run_rat_spn_mnist.structure_dict[2] = [{'num_recursive_splits': 12, 'num_input_distributions': 15, 'num_sums': 15}]
# depth 3
run_rat_spn_mnist.structure_dict[3] = [{'num_recursive_splits': 12, 'num_input_distributions': 14, 'num_sums': 12}]
# depth 4
run_rat_spn_mnist.structure_dict[4] = [{'num_recursive_splits': 10, 'num_input_distributions': 15, 'num_sums': 10}]
run_rat_spn_mnist.base_result_path = "quick_results/ratspn/mnist/"
run_rat_spn_mnist.param_configs = [{'dropout_rate_input': 0.5, 'dropout_rate_sums': 0.5}]
run_rat_spn_mnist.num_epochs = 100

run_rat_spn_mnist.run()
130 changes: 130 additions & 0 deletions run_rat_spn_fashion_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os
import filelock
import utils
import sys
import subprocess
import time
import json

print("")
print("Discriminative Training of RAT-SPNs on fashion-mnist")
print("")

with open('configurations.json') as f:
configs = json.loads(f.read())

start_time = time.time()
time_limit_seconds = configs['worker_time_limit']
dont_start_if_less_than_seconds = 600.0
base_result_path = "results/ratspn/fashion-mnist/"

structure_dict = {}

# depth 1
structure_dict[1] = [
{'num_recursive_splits': 9, 'num_input_distributions': 10, 'num_sums': 10},
{'num_recursive_splits': 14, 'num_input_distributions': 15, 'num_sums': 10},
{'num_recursive_splits': 19, 'num_input_distributions': 20, 'num_sums': 10},
{'num_recursive_splits': 29, 'num_input_distributions': 25, 'num_sums': 10},
{'num_recursive_splits': 40, 'num_input_distributions': 33, 'num_sums': 10}]

# depth 2
structure_dict[2] = [
{'num_recursive_splits': 8, 'num_input_distributions': 10, 'num_sums': 10},
{'num_recursive_splits': 12, 'num_input_distributions': 15, 'num_sums': 15},
{'num_recursive_splits': 19, 'num_input_distributions': 20, 'num_sums': 18},
{'num_recursive_splits': 30, 'num_input_distributions': 25, 'num_sums': 25},
{'num_recursive_splits': 40, 'num_input_distributions': 37, 'num_sums': 35}]

# depth 3
structure_dict[3] = [
{'num_recursive_splits': 10, 'num_input_distributions': 8, 'num_sums': 8},
{'num_recursive_splits': 12, 'num_input_distributions': 14, 'num_sums': 12},
{'num_recursive_splits': 15, 'num_input_distributions': 20, 'num_sums': 18},
{'num_recursive_splits': 30, 'num_input_distributions': 25, 'num_sums': 20},
{'num_recursive_splits': 40, 'num_input_distributions': 35, 'num_sums': 30}]

# depth 4
structure_dict[4] = [
{'num_recursive_splits': 5, 'num_input_distributions': 10, 'num_sums': 9},
{'num_recursive_splits': 10, 'num_input_distributions': 15, 'num_sums': 10},
{'num_recursive_splits': 14, 'num_input_distributions': 20, 'num_sums': 14},
{'num_recursive_splits': 28, 'num_input_distributions': 20, 'num_sums': 20},
{'num_recursive_splits': 40, 'num_input_distributions': 30, 'num_sums': 26}]

param_configs = [
{'dropout_rate_input': 1.0, 'dropout_rate_sums': 1.0},
{'dropout_rate_input': 1.0, 'dropout_rate_sums': 0.75},
{'dropout_rate_input': 1.0, 'dropout_rate_sums': 0.5},
{'dropout_rate_input': 1.0, 'dropout_rate_sums': 0.25},
{'dropout_rate_input': 0.75, 'dropout_rate_sums': 1.0},
{'dropout_rate_input': 0.75, 'dropout_rate_sums': 0.75},
{'dropout_rate_input': 0.75, 'dropout_rate_sums': 0.5},
{'dropout_rate_input': 0.75, 'dropout_rate_sums': 0.25},
{'dropout_rate_input': 0.5, 'dropout_rate_sums': 1.0},
{'dropout_rate_input': 0.5, 'dropout_rate_sums': 0.75},
{'dropout_rate_input': 0.5, 'dropout_rate_sums': 0.5},
{'dropout_rate_input': 0.5, 'dropout_rate_sums': 0.25},
{'dropout_rate_input': 0.25, 'dropout_rate_sums': 1.0},
{'dropout_rate_input': 0.25, 'dropout_rate_sums': 0.75},
{'dropout_rate_input': 0.25, 'dropout_rate_sums': 0.5},
{'dropout_rate_input': 0.25, 'dropout_rate_sums': 0.25}]

num_epochs = 200


def run():
for split_depth in structure_dict:
for structure_config in structure_dict[split_depth]:
for config_dict in param_configs:

remaining_time = time_limit_seconds - (time.time() - start_time)
if remaining_time < dont_start_if_less_than_seconds:
print("Only {} seconds remaining, stop worker".format(remaining_time))
sys.exit(0)

cmd = "python train_rat_spn.py --store_best_valid_loss --store_best_valid_acc --num_epochs {}".format(num_epochs)
cmd += " --timeout_seconds {}".format(remaining_time)
cmd += " --split_depth {}".format(split_depth)
cmd += " --data_path data/fashion-mnist/"

for key in sorted(structure_config.keys()):
cmd += " --{} {}".format(key, structure_config[key])
for key in sorted(config_dict.keys()):
cmd += " --{} {}".format(key, config_dict[key])

comb_string = ""
comb_string += "split_depth_{}".format(split_depth)
for key in sorted(structure_config.keys()):
comb_string += "__{}_{}".format(key, structure_config[key])
for key in sorted(config_dict.keys()):
comb_string += "__{}_{}".format(key, config_dict[key])

result_path = base_result_path + comb_string
cmd += " --result_path " + result_path

###
print(cmd)

utils.mkdir_p(result_path)
lock_file = result_path + "/file.lock"
done_file = result_path + "/file.done"
lock = filelock.FileLock(lock_file)
try:
lock.acquire(timeout=0.1)
if os.path.isfile(done_file):
print(" already done -> skip")
else:
sys.stdout.flush()
ret_val = subprocess.call(cmd, shell=True)
if ret_val == 7:
lock.release()
print("Task timed out, stop worker")
sys.exit(0)
os.system("touch {}".format(done_file))
lock.release()
except filelock.Timeout:
print(" locked -> skip")

if __name__ == '__main__':
run()
8 changes: 6 additions & 2 deletions generative_rat_spn.py → run_rat_spn_generative.py
Original file line number Diff line number Diff line change
@@ -5,14 +5,17 @@
import subprocess
import time
import utils
import json

print("")
print("Generative Training of RAT-SPNs on 20 binary datasets")
print("")

with open('configurations.json') as f:
configs = json.loads(f.read())

start_time = time.time()
time_limit_seconds = 30758400.
# time_limit_seconds = 42900.0
time_limit_seconds = configs['worker_time_limit']
dont_start_if_less_than_seconds = 600.0

optimizer = "em"
@@ -52,6 +55,7 @@

num_epochs = 100


def run():
for dataset in datasets.DEBD:
for split_depth in structure_dict:
Loading

0 comments on commit a11ab87

Please sign in to comment.