Skip to content

Commit

Permalink
update visualization; add ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
d1ngn1gefe1 committed Aug 21, 2022
1 parent c0a0247 commit b6fc02f
Show file tree
Hide file tree
Showing 7 changed files with 697 additions and 81 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
*.ipynb
#*.ipynb

# IPython
profile_default/
Expand Down
48 changes: 32 additions & 16 deletions momaapi/analysis/stat_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,51 @@
import os.path as osp
from pprint import pprint
import seaborn as sns
import tempfile


class StatVisualizer:
def __init__(self, moma, dir_vis):
def __init__(self, moma, dir_vis=None):
if dir_vis is None:
dir_vis = tempfile.mkdtemp()

self.moma = moma
self.dir_vis = dir_vis

def show(self, with_split):
os.makedirs(osp.join(self.dir_vis, "stats"), exist_ok=True)

keys = [
x for x in self.moma.statistics["all"].keys() if x != "raw" and x != "hoi"
]
if with_split:
distributions, hues = {}, {}
for key in self.moma.distributions_train:
for key in keys:
distributions[key] = (
self.moma.distributions_train[key]
+ self.moma.distributions_val[key]
self.moma.statistics["standard_train"][key]["distribution"]
+ self.moma.statistics["standard_val"][key]["distribution"]
+ self.moma.statistics["standard_test"][key]["distribution"]
)
hues[key] = (
["train"]
* len(self.moma.statistics["standard_train"][key]["distribution"])
+ ["val"]
* len(self.moma.statistics["standard_val"][key]["distribution"])
+ ["test"]
* len(self.moma.statistics["standard_test"][key]["distribution"])
)
hues[key] = ["train"] * len(self.moma.distributions_train[key]) + [
"val"
] * len(self.moma.distributions_val[key])
pprint(self.moma.statistics_train, sort_dicts=False)
pprint(self.moma.statistics_val, sort_dicts=False)

else:
distributions = self.moma.distributions
hues = {key: None for key in self.moma.distributions}
pprint(self.moma.statistics, sort_dicts=False)
distributions = {
key: self.moma.statistics["all"][key]["distribution"] for key in keys
}
hues = {key: None for key in keys}

paths = {}
for key in distributions:
counts = distributions[key]
cnames = (
self.moma.get_taxonomy(key) + self.moma.get_taxonomy(key)
if with_split
else self.moma.get_taxonomy(key)
self.moma.taxonomy[key] * 3 if with_split else self.moma.taxonomy[key]
)
if isinstance(cnames[0], tuple):
cnames = [cname[0] for cname in cnames]
Expand Down Expand Up @@ -64,4 +75,9 @@ def show(self, with_split):
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.set_ylim(bottom=1)
plt.tight_layout()
plt.savefig(osp.join(self.dir_vis, "stats", fname))

path = osp.join(self.dir_vis, "stats", fname)
paths[key] = path
plt.savefig(path)

return paths
File renamed without changes.
21 changes: 21 additions & 0 deletions scripts/save_gifs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import argparse

from momaapi import MOMA, AnnVisualizer


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dir-moma", type=str, default="/media/hdd/moma-lrg")
parser.add_argument("-d", "--dir-vis", type=str, default="/media/hdd/moma-lrg/vis")
args = parser.parse_args()

moma = MOMA(args.dir_moma)

visualizer = AnnVisualizer(moma, args.dir_vis)
ids_sact = moma.get_ids_sact()
for id_sact in ids_sact:
visualizer.show_sact(id_sact, vstack=False)


if __name__ == "__main__":
main()
197 changes: 197 additions & 0 deletions scripts/visualize_anns.ipynb

Large diffs are not rendered by default.

446 changes: 446 additions & 0 deletions scripts/visualize_stats.ipynb

Large diffs are not rendered by default.

64 changes: 0 additions & 64 deletions tests/run_visualize.py

This file was deleted.

0 comments on commit b6fc02f

Please sign in to comment.