diff --git a/scripts/print_table.py b/scripts/print_table.py new file mode 100644 index 0000000..9ce40a0 --- /dev/null +++ b/scripts/print_table.py @@ -0,0 +1,164 @@ +import pandas as pd +from collections import OrderedDict +import argparse +import numpy as np +import re + + +def get_parser(): + parser = argparse.ArgumentParser( + "Prints a table from metrics files\n" + ) + parser.add_argument( + '--config', '-c', type=str, default='config.csv', + help='Path to the config csv with `name` and `path` columns. ' + '`name` is a model name, and ' + '`path` is a path to metrics file`' + ) + parser.add_argument( + '--extension', '-e', type=str, + choices=['html', 'latex', 'csv'], + default='csv', + help='Format of a table' + ) + parser.add_argument( + '--output', '-o', type=str, + default='output.csv', + help='Path to the output table' + ) + parser.add_argument( + '--precision', '-p', type=int, + default=4, help='Precision in final table' + ) + return parser + + +if __name__ == "__main__": + parser = get_parser() + config, unknown = parser.parse_known_args() + if len(unknown) != 0: + raise ValueError("Unknown argument " + unknown[0]) + + metrics = OrderedDict() + models = pd.read_csv(config.config) + for path, name in zip(models['path'], models['name']): + metrics[name] = pd.read_csv(path, header=None) + metrics[name] = {x[1][0]: x[1][1] + for x in metrics[name].iterrows()} + metrics[name]['Model'] = name + metrics = pd.DataFrame(metrics).T + metrics = metrics.rename(columns={'valid': 'Valid', + 'unique@1000': 'Unique@1k', + 'unique@10000': 'Unique@10k'}) + targets = ['Model', 'Valid', 'Unique@1k', + 'Unique@10k', 'FCD/Test', 'FCD/TestSF', + 'SNN/Test', 'SNN/TestSF', 'Frag/Test', + 'Frag/TestSF', 'Scaf/Test', 'Scaf/TestSF', + 'IntDiv', 'IntDiv2', 'Filters'] + directions = [2, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1] + metrics = metrics[targets] + + bf_pattern = { + 'csv': '{}', + 'html': '{}', + 'latex': r'!bf1! {} !bf2!' + }[config.extension] + it_pattern = { + 'csv': '{}', + 'html': '{}', + 'latex': r'!it1! {} !it2!' + }[config.extension] + + arrow = { + 'csv': [' (↓)', ' (↑)', ''], + 'html': [' (↓)', ' (↑)', ''], + 'latex': [r' ($\downarrow$)', r' ($\uparrow$)', ''] + }[config.extension] + + for col, d in zip(targets[1:], directions[1:]): + metrics[col] = metrics[col] \ + .astype(float) \ + .round(config.precision) + max_val = (2 * d - 1) * np.max( + [(2 * d - 1) * m for m, n in zip(metrics[col], + metrics['Model']) + if n != 'Train']) + metrics[col] = [str(x) if x != max_val or n == 'Train' + else bf_pattern.format(x) + for x, n in zip(metrics[col], + metrics['Model'])] + for col in targets[::-1]: + metrics[col] = [it_pattern.format(x) + if n == 'Train' else x + for x, n in zip(metrics[col], + metrics['Model'])] + + metrics = metrics.round(config.precision) + if config.extension == 'csv': + metrics.to_csv(config.output, index=None) + elif config.extension == 'html': + html = metrics.to_html(index=None) + html = re.sub('<', '<', html) + html = re.sub('>', '>', html) + header, footer = html.split('') + header += '' + header = header.split('\n') + values = [x.strip()[4:-5] + for x in header[3:-2]] + spans = ['rowspan' if '/' not in x else 'colspan' + for x in values] + first_header = [x.split('/')[0] for x in values] + second_header = [x.split('/')[1] for x in values + if '/' in x] + new_header = header[:3] + i = 0 + total = 0 + while i < len(first_header): + h = first_header[i] + new_header.append( + ' ' * 6 + '