Skip to content

Commit

Permalink
Adjusted color mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
nzantout committed Jul 10, 2024
1 parent bfb690e commit b7be8d0
Show file tree
Hide file tree
Showing 16 changed files with 188 additions and 161 deletions.
38 changes: 18 additions & 20 deletions 3d_data_preprocess/3rscan/3rscan_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@
from utils.glb_to_pointcloud import (
load_meshes,
subdivide_mesh,
sample_semantic_pointcloud_from_uv_mesh,
SEED,
DEVICE
sample_semantic_pointcloud_from_uv_mesh
)
from utils.bbox_utils import calculate_bbox, calculate_bbox_hull, calculate_axis_aligned_bbox
from utils.pointcloud_utils import sort_pointcloud, write_ply_file
from utils.pointcloud_utils import save_pointcloud, sort_pointcloud, write_ply_file
from utils.freespace_generation_new import generate_free_space
from utils.dominant_colors_new_lab import judge_color, generate_color_anchors
from utils.headers import REGION_HEADER, OBJECT_HEADER
Expand All @@ -37,8 +35,8 @@ def __init__(
sampling_density=None,
num_region_samples=70000,
filter_objs_less_than=10,
seed=SEED,
device=DEVICE
device='cuda',
seed=42,
):

# Color Tree
Expand All @@ -52,13 +50,13 @@ def __init__(

self.scan_directory = scan_directory
self.category_mappings_path = category_mappings_path
self.output_directory = output_directory
self.output_directory = os.path.join(output_directory, '3RScan')
self.num_pointcloud_samples = num_pointcloud_samples
self.num_region_samples = num_region_samples
self.filter_objs_less_than = filter_objs_less_than
self.sampling_density = sampling_density
self.seed = seed
self.device = device
self.device = 'cuda' if (device=='cuda' and torch.cuda.is_available()) else 'cpu'

with open('skipped_scans.json', 'r') as f:
skipped_scans = json.load(f)
Expand Down Expand Up @@ -214,7 +212,8 @@ def create_3rscan(self):
self.mesh_objects, self.semantic_mesh_objects = load_meshes(
str(color_mesh_path),
str(semantic_mesh_path),
['objectId', 'globalId', 'red', 'green', 'blue'])
['objectId', 'globalId', 'red', 'green', 'blue'],
device=self.device)

# self.mesh_objects = subdivide_mesh(self.mesh_objects, 0.1)

Expand All @@ -224,21 +223,20 @@ def create_3rscan(self):
n=self.num_pointcloud_samples,
sampling_density=self.sampling_density,
seed=self.seed,
device=self.device
)

filtered_points = self.create_object_csv(points, scan_name)
self.create_region_csv(filtered_points, scan_name)

vertex = torch.cat([
filtered_points[:, :7],
torch.zeros((filtered_points.shape[0], 1), device=DEVICE)
torch.zeros((filtered_points.shape[0], 1), device=self.device)
], dim=1)

vertex, region_indices_out, object_indices_out = sort_pointcloud(vertex)
save_pointcloud(vertex, region_indices_out, object_indices_out, output_path, scan_name)

torch.save(region_indices_out, output_path / f'{scan_name}_region_split.npy')
torch.save(object_indices_out, output_path / f'{scan_name}_object_split.npy')
write_ply_file(vertex[:, :6], output_path / f'{scan_name}_pc_result.ply')



Expand All @@ -250,24 +248,24 @@ def create_3rscan(self):
default='/media/navigation/easystore/Original_dataset/3RScan/scans')
parser.add_argument('--category_mappings_path', default='3rscan_full_mapping.csv')
parser.add_argument('--output_directory',
default='/home/navigation/Dataset/VLA_Dataset_3rscan')
default='/home/navigation/Dataset/VLA_Dataset')
parser.add_argument('--num_points', type=int, default=5000000)
parser.add_argument('--sampling_density', type=float, default=1e-4)
parser.add_argument('--sampling_density', type=float, default=2e-4)
parser.add_argument('--num_region_points', type=int, default=500000)
parser.add_argument('--filter_objs_less_than', type=int, default=10)
# parser.add_argument('--device', default='cuda:0')
parser.add_argument('--device', default='cuda')
parser.add_argument('--seed', type=int, default=42)

args = parser.parse_args()

print(f'Device: {DEVICE}')

ThreeRScanPreprocessor(
args.scan_directory,
args.category_mappings_path,
args.output_directory,
args.num_points,
args.sampling_density,
args.num_region_points,
args.filter_objs_less_than
# args.device
args.filter_objs_less_than,
args.device,
args.seed
).create_3rscan()
75 changes: 38 additions & 37 deletions 3d_data_preprocess/arkit/arkit_data_generation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from time import perf_counter
from scipy.spatial import KDTree
import os
import os.path as osp
Expand All @@ -17,7 +18,7 @@
from utils.freespace_generation_new import generate_free_space
from utils.dominant_colors_new_lab import judge_color, generate_color_anchors
from utils.bbox_utils import calculate_bbox_hull
from utils.pointcloud_utils import get_regions, get_objects, sort_pointcloud, write_ply_file
from utils.pointcloud_utils import get_regions, get_objects, save_pointcloud, sort_pointcloud, write_ply_file
from utils.headers import OBJECT_HEADER, REGION_HEADER

import warnings
Expand All @@ -32,7 +33,8 @@ def __init__(
floor_height: float, # the floor height for generating free space
color_standard: str, # colors standars to use for domain color calculation (css21, css3, html4, css2)
generate_freespace=False,
skipped_scenes = []
skipped_scenes = [],
device='cuda'
):
self.input_folder = input_folder
self.mapping_file = mapping_file
Expand All @@ -41,6 +43,7 @@ def __init__(
self.anchor_colors_array, self.anchor_colors_array_hsv, self.anchor_colors_name = generate_color_anchors(color_standard)
self.tree = KDTree(self.anchor_colors_array_hsv)
self.skipped_scenes = skipped_scenes
self.device = 'cuda' if (device=='cuda' and torch.cuda.is_available()) else 'cpu'

self.generate_freespace = generate_freespace

Expand All @@ -54,26 +57,26 @@ def inside_test(self, points, cube3d):
b1, b2, b3, b4, t1, t2, t3, t4 = cube3d

dir1 = (t1-b1)
size1 = np.linalg.norm(dir1)
size1 = torch.norm(dir1)
dir1 = dir1 / size1

dir2 = (b2-b1)
size2 = np.linalg.norm(dir2)
size2 = torch.norm(dir2)
dir2 = dir2 / size2

dir3 = (b4-b1)
size3 = np.linalg.norm(dir3)
size3 = torch.norm(dir3)
dir3 = dir3 / size3

cube3d_center = (b1 + t3)/2.0

dir_vec = points - cube3d_center

res1 = np.where((np.absolute(dir_vec @ dir1) * 2) <= size1)[0]
res2 = np.where((np.absolute(dir_vec @ dir2) * 2) <= size2)[0]
res3 = np.where((np.absolute(dir_vec @ dir3) * 2) <= size3)[0]
res1 = torch.abs(dir_vec @ dir1) <= size1/2
res2 = torch.abs(dir_vec @ dir2) <= size2/2
res3 = torch.abs(dir_vec @ dir3) <= size3/2

return list(set(res1) & set(res2) & set(res3))
return res1 & res2 & res3

def get_bbox(self, center, size, R): # function for calculating the corner points of a bbox based on its center, size and a 3x3 rotation matrix

Expand All @@ -94,7 +97,7 @@ def get_bbox(self, center, size, R): # function for calculating the corner point
corners_3d[0,:] += center[0]
corners_3d[1,:] += center[1]
corners_3d[2,:] += center[2]
return np.transpose(corners_3d)
return torch.from_numpy(corners_3d).T.to(self.device)


def crop_pc(self, xyz, bbox_center, bbox_length, bbox_rotation): # based on bbox information, crop the points inside it
Expand All @@ -108,7 +111,7 @@ def crop_pc(self, xyz, bbox_center, bbox_length, bbox_rotation): # based on bbox
bbox = o3d.geometry.LineSet()
bbox.lines = o3d.utility.Vector2iVector(bbox_lines)
bbox.colors = o3d.utility.Vector3dVector(colors)
bbox.points = o3d.utility.Vector3dVector(cube3d)
bbox.points = o3d.utility.Vector3dVector(cube3d.cpu().numpy())
return index, bbox


Expand Down Expand Up @@ -140,15 +143,15 @@ def create_arkit(self):
x = np.asarray(pc.elements[0].data['x'])
y = np.asarray(pc.elements[0].data['y'])
z = np.asarray(pc.elements[0].data['z'])
xyz = np.vstack((x,y,z)).transpose()
xyz = torch.from_numpy(np.vstack((x,y,z)).transpose()).to(self.device)
region_id = 0
# region_ids = np.repeat(np.array([[region_id]]), len(xyz), axis = 0)
unlabeled_obj_filter = np.ones(xyz.shape[0], dtype=bool)
rgba = np.vstack((r,g,b,a)).transpose()
unlabeled_obj_filter = torch.ones(xyz.shape[0], dtype=bool)
rgba = torch.from_numpy(np.vstack((r,g,b,a)).transpose()).to(self.device)
obj_pcs, obj_rgbs, obj_ids = [], [], []

region_center = (np.max(xyz, axis=0) + np.min(xyz, axis=0))/2
region_size = np.max(xyz, axis=0) - np.min(xyz, axis=0)
region_center = (torch.max(xyz, dim=0)[0] + torch.min(xyz, dim=0)[0])/2
region_size = torch.max(xyz, dim=0)[0] - torch.min(xyz, dim=0)[0]

object_file_name = scan_name + '_object_result.csv'
with open(object_file_name, 'w', newline='') as f:
Expand All @@ -174,7 +177,7 @@ def create_arkit(self):
# draw.append(bbox_to_draw)
object_pc = xyz[index, :]
# obj_vertex = np.vstack((obj_vertex, object_pc))
center, size, heading = calculate_bbox_hull(object_pc.astype(np.float64))
center, size, heading = calculate_bbox_hull(object_pc.to(torch.float64).cpu().numpy())

front_heading = ['_']

Expand All @@ -185,10 +188,10 @@ def create_arkit(self):

object_colors = rgba[index, 0:3]/255
# obj_colors = np.vstack((obj_colors, object_colors))
color_3 = judge_color(object_colors, self.tree, self.anchor_colors_array, self.anchor_colors_name)
color_3 = judge_color(object_colors.cpu().numpy(), self.tree, self.anchor_colors_array, self.anchor_colors_name)

unlabeled_obj_filter[index] = False
obj_ids.append(np.ones(object_pc.shape[0]) * object_id)
obj_ids.append(torch.ones(object_pc.shape[0], device=self.device) * object_id)
obj_pcs.append(object_pc)
obj_rgbs.append(rgba[index, 0:3])

Expand Down Expand Up @@ -217,21 +220,21 @@ def create_arkit(self):
region_info = []
region_info.append(region_id)
region_info.append(region_label)
region_info += list(region_center)
region_info += list(region_size)
region_info += region_center.cpu().tolist()
region_info += region_size.cpu().tolist()
region_info.append(region_heading)
region_writer.writerow(region_info)

unlabeled_obj_pc = xyz[unlabeled_obj_filter]
unlabeled_obj_color = rgba[unlabeled_obj_filter, :3]
obj_ids.append(np.ones(unlabeled_obj_pc.shape[0]) * -1)
obj_ids.append(torch.ones(unlabeled_obj_pc.shape[0], device=self.device) * -1)
obj_pcs.append(unlabeled_obj_pc)
obj_rgbs.append(unlabeled_obj_color)

obj_ids = np.concatenate(obj_ids)
obj_pcs = np.vstack(obj_pcs)
obj_rgbs = np.vstack(obj_rgbs)
region_ids = np.zeros_like(obj_ids)
obj_ids = torch.concatenate(obj_ids)
obj_pcs = torch.vstack(obj_pcs)
obj_rgbs = torch.vstack(obj_rgbs)
region_ids = torch.zeros_like(obj_ids)


# region_folder = osp.join(output_scan_folder, 'regions')
Expand All @@ -255,18 +258,14 @@ def create_arkit(self):
# o3d.t.io.write_point_cloud(ply_file_name, pcd)

vertex = torch.cat([
torch.from_numpy(obj_pcs),
torch.from_numpy(obj_rgbs),
torch.from_numpy(obj_ids)[:, None],
torch.from_numpy(region_ids)[:, None]
obj_pcs,
obj_rgbs,
obj_ids[:, None],
region_ids[:, None]
], dim=1)

vertex, region_indices_out, object_indices_out = sort_pointcloud(vertex)

torch.save(region_indices_out, f'{scan_name}_region_split.npy')
torch.save(object_indices_out, f'{scan_name}_object_split.npy')
write_ply_file(vertex[:, :6], f'{scan_name}_pc_result.ply')

save_pointcloud(vertex, region_indices_out, object_indices_out, '', scan_name)

if self.generate_freespace:
floor_center, floor_size, floor_heading = calculate_bbox_hull(xyz)
Expand All @@ -284,13 +283,14 @@ def create_arkit(self):
help="Input file of the mesh")
parser.add_argument('--mapping_folder', default='./arkit_cat_mapping.csv',
help="Input folder of the category mapping")
parser.add_argument('--output_folder', default='/home/navigation/Dataset/VLA_Dataset_more',
parser.add_argument('--output_folder', default='/home/navigation/Dataset/VLA_Dataset/ARKitScenes',
help="Output PLY file to save")
parser.add_argument('--floor_height', default=0.35,
help="floor heigh for generating free space")
parser.add_argument('--color_standard', default='css3',
help="color standard, chosen from css2, css21, css3, html4")
parser.add_argument('--generate_freespace', action='store_true', help='Generate free spaces')
parser.add_argument('--device', default='cuda')

args = parser.parse_args()

Expand All @@ -310,7 +310,8 @@ def create_arkit(self):
args.output_folder,
args.floor_height,
args.color_standard,
args.generate_freespace
args.generate_freespace,
device=args.device
).create_arkit()
print('====================End processing arkit training set====================')

Expand Down
Loading

0 comments on commit b7be8d0

Please sign in to comment.