Skip to content

Commit

Permalink
Update evaluate.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lvwj19 authored Jul 23, 2021
1 parent 95fea05 commit ed6f3b6
Showing 1 changed file with 100 additions and 43 deletions.
143 changes: 100 additions & 43 deletions tools/IPABunny_msg/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,39 @@
from pprnet.data.pointcloud_transforms import PointCloudShuffle, ToTensor, PointCloudJitter
from torchvision import transforms
from torch.utils.data import DataLoader

import pprnet.utils.show3d_balls as show3d_balls
from sklearn.cluster import MeanShift

def extract_vertexes_from_obj(file_name):
with open(file_name, 'r') as f:
vertexes = []
for line in f.readlines():
line = line.strip()
if line.startswith('v'):
words = line.split()[1:]
xyz = [float(w) for w in words]
vertexes.append(xyz)
ori_model_pc = np.array(vertexes)
# center = ( np.max(ori_model_pc, axis=0) + np.min(ori_model_pc, axis=0) ) / 2.0
# ori_model_pc = ori_model_pc - center
return ori_model_pc


#-----------------------GLOBAL SETTINGS START-----------------------
MODEL_OBJ_DIRS = ['/home/dongzhikai/Desktop/iros_competition_code/CAD/SileaneBunny.obj']
MODEL_OBJ_DIRS = './SileaneBunny.obj'
BATCH_SIZE = 1
NUM_TYPE = 1
NUM_POINT = 16384
CHECKPOINT_PATH = "../../logs/IPABunny_msg/log1_batch8_scale3_continue/checkpoint.tar"
DATASET_DIR = 'your-path/Fraunhofer_IPA_Bin-Picking_dataset/h5_dataset/bunny/train/'
CHECKPOINT_PATH = "../../logs/IPABunny_msg/log1_batch8_scale3/checkpoint.tar"
DATASET_DIR = 'your-path-to-Fraunhofer_IPA_Bin-Picking_dataset/h5_dataset/bunny/train/'
TEST_CYCLE_RANGE = [499,500]
TEST_SCENE_RANGE = [1, 81]
#-----------------------GLOBAL SETTINGS END-----------------------

#-----------------------------------------------------------------
model_point_clouds = [ eval_util.extract_vertexes_from_obj(path)*1000.0 for path in MODEL_OBJ_DIRS ]

model_pointcloud = extract_vertexes_from_obj(MODEL_OBJ_DIRS)
model_pointcloud *= 1000.0
transforms = transforms.Compose(
[
PointCloudShuffle(NUM_POINT),
Expand All @@ -63,58 +80,98 @@



def eval_one_epoch():
def eval_one_epoch(loader):
net.eval() #

for batch_idx, batch_samples in enumerate(test_loader):
# labels = {
# 'rot_label':batch_samples['rot_label'].to(device),
# 'trans_label':batch_samples['trans_label'].to(device),
# }
input_point = batch_samples['point_clouds'][0].cpu().numpy().copy()
for batch_idx, batch_samples in enumerate(loader):
input_point_ori = batch_samples['point_clouds'].numpy()[0]
inputs = {
'point_clouds': batch_samples['point_clouds'].to(device),
'labels': None
}

# Forward pass
with torch.no_grad():
time_start = time.time()
pred_results, _ = net(inputs)
print("Forward time:", time.time()-time_start)

pred_trans_val = pred_results[0][0].cpu().numpy()
pred_mat_val = pred_results[1][0].cpu().numpy()
if pred_results[3] is not None:
pred_cls_val = pred_results[3][0].cpu().numpy()
pred_cls_val = np.argmax(pred_cls_val, -1)
else:
pred_cls_val = np.zeros(pred_trans_val.shape[0])


pc_list = []
trans_list = []
rot_list = []
cls_list = []
for k in range(NUM_TYPE):
cls_k_idx = np.where(pred_cls_val==k)[0]
if len(cls_k_idx) == 0:
continue
cls_k_points = input_point[cls_k_idx]
cls_k_pred_trans = pred_trans_val[cls_k_idx]
cls_k_pred_mat = pred_mat_val[cls_k_idx]

meanshift_args = {'bandwidth':40, 'bin_seeding':True, 'cluster_all':False, 'min_bin_freq':40}
n_cluster_cls_k, ins_label_cls_k, centroid_cls_k, rot_mat_cls_k, pc_segments_cls_k = \
eval_util.cluster_and_average(cls_k_points, cls_k_pred_trans, cls_k_pred_mat, meanshift_args)

pc_list += pc_segments_cls_k
trans_list += centroid_cls_k
rot_list += rot_mat_cls_k
cls_list += [k]*len(pc_segments_cls_k)
pred_vis_val = pred_results[2][0].cpu().numpy()


vs_picked_idx = pred_vis_val > 0.45

input_point = input_point_ori[vs_picked_idx]
pred_trans_val = pred_trans_val[vs_picked_idx]
pred_mat_val = pred_mat_val[vs_picked_idx]

# print('pred_trans_val', pred_trans_val.shape)
# print('pred_mat_val', pred_mat_val.shape)
# pred_trans_val = pred_trans_val[0]
# pred_mat_val = pred_mat_val

ms = MeanShift(bandwidth=40, bin_seeding=True, cluster_all=False, min_bin_freq=40)
ms.fit(pred_trans_val)
labels = ms.labels_
cluster_centers = ms.cluster_centers_


# # Number of clusters in labels, ignoring noise if present.
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
# print(n_clusters)


color_cluster = [np.array([random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)]) for i in range(n_clusters)]
color_per_point = np.ones([pred_trans_val.shape[0], pred_trans_val.shape[1]]) * 255
for idx in range(color_per_point.shape[0]):
if labels[idx] != -1:
color_per_point[idx, :] = color_cluster[labels[idx]]

visualize_util.show_points([input_point], radius=5)
visualize_util.show_points(pc_list, color_array='random', radius=5)
visualize_util.show_models(model_point_clouds, trans_list, rot_list, cls_list, color_array='random', radius=5)

pred_trans_cluster = [[] for _ in range(n_clusters)]
pred_mat_cluster = [[] for _ in range(n_clusters)]
for idx in range(pred_trans_val.shape[0]):
if labels[idx] != -1:
pred_trans_cluster[labels[idx]].append(np.reshape(pred_trans_val[idx], [1, 3]))
pred_mat_cluster[labels[idx]].append(np.reshape(pred_mat_val[idx], [1, 3, 3]))
pred_trans_cluster = [np.concatenate(cluster, axis=0) for cluster in pred_trans_cluster]
pred_mat_cluster = [np.concatenate(cluster, axis=0) for cluster in pred_mat_cluster]

cluster_center_pred = [ np.mean(cluster, axis=0) for cluster in pred_trans_cluster]


cluster_mat_pred = []
for mat_cluster in pred_mat_cluster:
# print(mat_cluster)
# print(mat_cluster.shape)
all_quat = np.zeros([mat_cluster.shape[0], 4])
for idx in range(mat_cluster.shape[0]):
all_quat[idx] = eulerangles.mat2quat(mat_cluster[idx])
quat = eulerangles.average_quat(all_quat)
# print(ea.shape)
# print(ea.shape)
cluster_mat_pred.append( eulerangles.quat2mat(quat) )


all_model_point = np.zeros([model_pointcloud.shape[0]*n_clusters, 3])
all_model_color = np.zeros([model_pointcloud.shape[0]*n_clusters, 3])
for cluster_idx in range(n_clusters):
begin_idx = cluster_idx * model_pointcloud.shape[0]
end_idx = (cluster_idx+1) * model_pointcloud.shape[0]
all_model_color[begin_idx:end_idx, :] = color_cluster[cluster_idx]
all_model_point[begin_idx:end_idx, :] = np.dot(model_pointcloud, cluster_mat_pred[cluster_idx].T) + \
np.tile(np.reshape(cluster_center_pred[cluster_idx], [1, 3]), [model_pointcloud.shape[0], 1])


show3d_balls.showpoints(input_point_ori, ballradius=5)
show3d_balls.showpoints(pred_trans_val, c_gt=color_per_point, ballradius=5)
show3d_balls.showpoints(input_point, c_gt=color_per_point, ballradius=5)
show3d_balls.showpoints(all_model_point, c_gt=all_model_color, ballradius=5)




if __name__ == "__main__":
eval_one_epoch()
eval_one_epoch(test_loader)

0 comments on commit ed6f3b6

Please sign in to comment.