forked from yuweihao/KERN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add code to draw the figures in the paper
- Loading branch information
Showing
4 changed files
with
201 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
[["on", 0.3155507021783991], ["has", 0.1565315839763612], ["wearing", 0.1142129900996459], ["of", 0.08633462072731068], ["in", 0.06088597329393533], ["near", 0.048708457455777614], ["with", 0.029593467211601], ["behind", 0.02916148095807806], ["holding", 0.026519780634489846], ["above", 0.018631615291349837], ["sitting on", 0.012851992516520664], ["wears", 0.011997655390594263], ["under", 0.010738632257650092], ["riding", 0.01017335656530781], ["in front of", 0.009007475449851856], ["standing on", 0.006532788399001132], ["at", 0.0048658674653327015], ["attached to", 0.004318256638376117], ["carrying", 0.00402116572052577], ["walking on", 0.003762616327153307], ["over", 0.003004633012421612], ["belonging to", 0.0024377514232260863], ["for", 0.00242008655784039], ["looking at", 0.002264314563075614], ["watching", 0.002131025124256269], ["hanging from", 0.0020812423218056703], ["parked on", 0.0018660521434708248], ["laying on", 0.001827510618992942], ["eating", 0.0017070683549995583], ["and", 0.0014983017640776933], ["covering", 0.0013296825944869562], ["using", 0.0013200472133674853], ["between", 0.0012285110927325138], ["covered in", 0.001132157281537807], ["along", 0.0010968275507664143], ["part of", 0.0010020796364249524], ["lying on", 0.0009009081346705102], ["on back of", 0.0008784255787250784], ["to", 0.0008334604668342152], ["mounted on", 0.0007804658706771264], ["walking in", 0.0006841120594824195], ["across", 0.0005845464545812223], ["against", 0.0005106751993319469], ["from", 0.0004914044370930055], ["growing on", 0.00047534546856055435], ["painted on", 0.00042235087240346554], ["made of", 0.00027621425875815997], ["playing", 0.00026657887763868927], ["says", 9.956560490119719e-05], ["flying in", 4.817690559735348e-05]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
# -*- coding: UTF-8 -*- | ||
|
||
|
||
import sys | ||
import os | ||
import math | ||
import pickle | ||
import numpy as np | ||
import json | ||
import matplotlib.pyplot as plt | ||
import matplotlib | ||
matplotlib.use('Agg') | ||
|
||
|
||
|
||
|
||
|
||
def draw_scatter(): | ||
|
||
with open('nocontrained/motifnet_predcls_sg_eval_result_mean_recall.pkl', 'rb') as f: | ||
cvpr18 = pickle.load(f) | ||
with open('nocontrained/kern_predcls_sg_eval_result_mean_recall.pkl', 'rb') as f: | ||
our = pickle.load(f) | ||
with open('count_pred_all.json', 'r') as f: | ||
data = json.load(f) | ||
rel_list = [] | ||
num_list = [] | ||
for rel, num in data: | ||
rel_list.append(rel) | ||
num_list.append(num) | ||
num_list = np.array(num_list) * 100 | ||
cvpr18_list = [] | ||
our_list = [] | ||
for r in rel_list: | ||
cvpr18_list.append(cvpr18[r]['R@50']) | ||
our_list.append(our[r]['R@50']) | ||
plt.figure(figsize=(10,5)) | ||
x = np.arange(50)+1 | ||
y1 = np.array(cvpr18_list) * 100 | ||
y2 = np.array(our_list) * 100 | ||
y3 = y2 - y1 | ||
plt.scatter(num_list, y3 ,s=30,color='green',marker='o',alpha=0.5) | ||
|
||
plt.xlabel('Relationship proportion (%)', fontsize=14) | ||
plt.ylabel('Improvement proportion (%)', fontsize=14) | ||
|
||
plt.grid(True, linestyle='--', axis='y') | ||
# plt.legend() | ||
plt.tight_layout() | ||
plt.savefig('scatter.pdf') | ||
|
||
|
||
|
||
def draw_difference_compare(): | ||
|
||
|
||
|
||
with open('nocontrained/motifnet_predcls_sg_eval_result_mean_recall.pkl', 'rb') as f: | ||
cvpr18 = pickle.load(f) | ||
with open('nocontrained/kern_predcls_sg_eval_result_mean_recall.pkl', 'rb') as f: | ||
our = pickle.load(f) | ||
with open('count_pred_all.json', 'r') as f: | ||
data = json.load(f) | ||
rel_list = [] | ||
num_list = [] | ||
for rel, num in data: | ||
rel_list.append(rel) | ||
num_list.append(num) | ||
cvpr18_list = [] | ||
our_list = [] | ||
for r in rel_list: | ||
cvpr18_list.append(cvpr18[r]['R@50']) | ||
our_list.append(our[r]['R@50']) | ||
plt.figure(figsize=(10,5)) | ||
x = np.arange(50)+1 | ||
y1 = np.array(cvpr18_list) * 100 | ||
y2 = np.array(our_list) * 100 | ||
y3 = y2 - y1 | ||
plt.bar(x, y3, alpha=0.9, width = 0.7, facecolor = 'green', edgecolor = 'white', label='SMN', lw=1) | ||
y4 = y3.copy() | ||
y4[y4>0] = 0 | ||
plt.bar(x, y4, alpha=0.9, width = 0.7, facecolor = 'red', edgecolor = 'white', label='SMN', lw=1) | ||
|
||
xticks1=rel_list | ||
plt.xticks(x,xticks1,fontsize=14,rotation=90) | ||
|
||
plt.ylabel('R@50 Improvement (%)', fontsize=14) | ||
|
||
for a,b,c,d in zip(x,y1,y2, y3): | ||
plt.text(a, d if d >= 0 else d - 10, '%+-.2f' % d, ha='center', va= 'bottom',fontsize=14, rotation=90) | ||
plt.xlim(0,51) | ||
plt.ylim(-5, 35) | ||
plt.ylim(-15, 45) | ||
plt.grid(True, linestyle='--', axis='y') | ||
# plt.legend() | ||
plt.tight_layout() | ||
plt.savefig('difference_compare.pdf') | ||
|
||
|
||
|
||
|
||
|
||
|
||
def draw_compare(): | ||
with open('nocontrained/motifnet_predcls_sg_eval_result_mean_recall.pkl', 'rb') as f: | ||
cvpr18 = pickle.load(f) | ||
with open('nocontrained/kern_predcls_sg_eval_result_mean_recall.pkl', 'rb') as f: | ||
our = pickle.load(f) | ||
with open('count_pred_all.json', 'r') as f: | ||
data = json.load(f) | ||
|
||
rel_list = [] | ||
num_list = [] | ||
for rel, num in data: | ||
rel_list.append(rel) | ||
num_list.append(num) | ||
cvpr18_list = [] | ||
our_list = [] | ||
for r in rel_list: | ||
cvpr18_list.append(cvpr18[r]['R@50']) | ||
our_list.append(our[r]['R@50']) | ||
plt.figure(figsize=(10,5)) | ||
x = np.arange(50)+1 | ||
y1 = np.array(cvpr18_list) * 100 | ||
y2 = np.array(our_list) * 100 | ||
y3 = y2 - y1 | ||
plt.bar(x-0.175, y2, alpha=1, width = 0.35, facecolor = 'coral', edgecolor = 'white', label='Ours', lw=1) | ||
plt.bar(x+0.175, y1, alpha=1, width = 0.35, facecolor = 'c', edgecolor = 'white', label='SMN', lw=1) | ||
|
||
y4 = y3.copy() | ||
y4[y4>0] = 0 | ||
|
||
xticks1=rel_list | ||
|
||
plt.xticks(x,xticks1,fontsize=14,rotation=90) | ||
|
||
plt.ylabel('R@50 (%)', fontsize=14) | ||
|
||
for a,b,c,d in zip(x,y1,y2, y3): | ||
plt.text(a, b+0.5 if b>c else c+0.5, '%+-.2f' % d, ha='center', va= 'bottom',fontsize=14, rotation=90) | ||
plt.xlim(0,51) | ||
plt.ylim(0, 119) | ||
plt.grid(True, linestyle='--', axis='y') | ||
plt.legend(fontsize=14) | ||
plt.tight_layout() | ||
plt.savefig('compare.pdf') | ||
|
||
|
||
|
||
|
||
|
||
def draw_bar(file_name): | ||
with open(file_name, 'r') as f: | ||
data = json.load(f) | ||
rel_list = [] | ||
num_list = [] | ||
for rel, num in data: | ||
rel_list.append(rel) | ||
num_list.append(num) | ||
|
||
|
||
|
||
# in CVPR paper, the font size is 14 | ||
|
||
plt.figure(figsize=(10,5)) | ||
|
||
x=np.arange(50)+1 | ||
|
||
y=np.array(num_list) * 100 | ||
xticks1=rel_list | ||
|
||
plt.bar(x,y,width = 0.7,align='center',color = 'lightcoral',alpha=1, edgecolor = 'white') | ||
|
||
plt.tick_params(labelsize=14) | ||
plt.xticks(x,xticks1,fontsize=12,rotation=90) | ||
|
||
# plt.xlabel('Relationship') | ||
plt.ylabel('Proportion (%)', fontsize=14) | ||
|
||
for a,b in zip(x,y): | ||
|
||
plt.text(a, b+0.5, '%.3f' % b, ha='center', va= 'bottom',fontsize=12, rotation=90) | ||
|
||
plt.ylim(0, 42) | ||
plt.xlim(0, 51) | ||
plt.grid(True, linestyle='--', axis='y') | ||
plt.tight_layout() | ||
|
||
plt.savefig('count_pred_all.pdf') | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
|
||
draw_bar('count_pred_all.json') | ||
|
||
draw_compare() | ||
|
||
draw_scatter() | ||
draw_difference_compare() |
Binary file added
BIN
+7.38 KB
draw_figures_in_the_paper/nocontrained/kern_predcls_sg_eval_result_mean_recall.pkl
Binary file not shown.
Binary file added
BIN
+7.38 KB
draw_figures_in_the_paper/nocontrained/motifnet_predcls_sg_eval_result_mean_recall.pkl
Binary file not shown.