Skip to content

Commit

Permalink
add bakedsdf mesh extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
niujinshuchong committed Mar 7, 2023
1 parent 810e8f4 commit e9b0ca7
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 1 deletion.
3 changes: 3 additions & 0 deletions nerfstudio/models/base_surface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ def get_outputs(self, ray_bundle: RayBundle) -> Dict:
"depth": depth,
"normal": normal,
"weights": weights,
"ray_points": self.scene_contraction(
ray_samples.frustums.get_start_positions()
), # used for creating visiblity mask
"directions_norm": ray_bundle.directions_norm, # used to scale z_vals for free space and sdf loss
}
outputs.update(bg_outputs)
Expand Down
49 changes: 49 additions & 0 deletions nerfstudio/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,55 @@ def get_average_eval_image_metrics(self, step: Optional[int] = None):
self.train()
return metrics_dict

@profiler.time_function
def get_visibility_mask(self):
"""Iterate over all the images in the eval dataset and get the average.
Returns:
metrics_dict: dictionary of metrics
"""
self.eval()

coarse_mask = torch.ones((1, 1, 512, 512, 512), requires_grad=True).to(self.device)
coarse_mask.retain_grad()

num_images = len(self.datamanager.fixed_indices_eval_dataloader)
with Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TimeElapsedColumn(),
MofNCompleteColumn(),
transient=True,
) as progress:
task = progress.add_task("[green]Evaluating all eval images...", total=num_images)
for camera_ray_bundle, batch in self.datamanager.fixed_indices_eval_dataloader:
isbasicimages = False
if isinstance(
batch["image"], BasicImages
): # If this is a generalized dataset, we need to get image tensor
isbasicimages = True
batch["image"] = batch["image"].images[0]
camera_ray_bundle = camera_ray_bundle.reshape((*batch["image"].shape[:-1],))
# downsample by factor of 4 to speed up
camera_ray_bundle = camera_ray_bundle[::4, ::4]
height, width = camera_ray_bundle.shape
outputs = self.model.get_outputs_for_camera_ray_bundle(camera_ray_bundle)
ray_points = outputs["ray_points"].reshape(height, width, -1, 3)
weights = outputs["weights"]

valid_points = ray_points.reshape(-1, 3)[weights.reshape(-1) > 0.005]
valid_points = valid_points * 0.5 # normalize from [-2, 2] to [-1, 1]
# update mask based on ray samples
with torch.enable_grad():
out = torch.nn.functional.grid_sample(coarse_mask, valid_points[None, None, None])
out.sum().backward()
progress.advance(task)

coarse_mask = (coarse_mask.grad > 0.0001).float()

self.train()
return coarse_mask

def load_pipeline(self, loaded_state: Dict[str, Any]) -> None:
"""Load the checkpoint from the given path
Expand Down
121 changes: 121 additions & 0 deletions nerfstudio/utils/marching_cubes.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,124 @@ def evaluate(points):
meshexport.export(str(output_path))
else:
print("=================================================no surface skip")


def get_surface_sliding_with_contraction(
sdf,
resolution=512,
bounding_box_min=(-1.0, -1.0, -1.0),
bounding_box_max=(1.0, 1.0, 1.0),
return_mesh=False,
level=0,
coarse_mask=None,
output_path: Path = Path("test.ply"),
simplify_mesh=True,
inv_contraction=None,
max_range=32.0,
):
assert resolution % 512 == 0

resN = resolution
cropN = 512
level = 0
N = resN // cropN

grid_min = bounding_box_min
grid_max = bounding_box_max
xs = np.linspace(grid_min[0], grid_max[0], N + 1)
ys = np.linspace(grid_min[1], grid_max[1], N + 1)
zs = np.linspace(grid_min[2], grid_max[2], N + 1)

# print(xs)
# print(ys)
# print(zs)
meshes = []
for i in range(N):
for j in range(N):
for k in range(N):
print(i, j, k)
x_min, x_max = xs[i], xs[i + 1]
y_min, y_max = ys[j], ys[j + 1]
z_min, z_max = zs[k], zs[k + 1]

x = np.linspace(x_min, x_max, cropN)
y = np.linspace(y_min, y_max, cropN)
z = np.linspace(z_min, z_max, cropN)

xx, yy, zz = np.meshgrid(x, y, z, indexing="ij")
points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda()

@torch.no_grad()
def evaluate(points):
z = []
for _, pnts in enumerate(torch.split(points, 100000, dim=0)):
z.append(sdf(pnts))
z = torch.cat(z, axis=0)
return z

# construct point pyramids
points = points.reshape(cropN, cropN, cropN, 3)

# query coarse grids
points_tmp = points[None].cuda() * 0.5 # normalize from [-2, 2] to [-1, 1]
current_mask = torch.nn.functional.grid_sample(coarse_mask, points_tmp, mode="nearest")

points = points.reshape(-1, 3)
valid_mask = current_mask.reshape(-1) > 0
pts_to_eval = points[valid_mask]
print(current_mask.float().mean())

# breakpoint()
pts_sdf = torch.ones_like(points[..., 0])
print(pts_sdf.shape, pts_to_eval.shape, points.shape)
if pts_to_eval.shape[0] > 0:
pts_sdf_eval = evaluate(pts_to_eval.contiguous())
pts_sdf[valid_mask.reshape(-1)] = pts_sdf_eval

z = pts_sdf.detach().cpu().numpy()

current_mask = (current_mask > 0.0).cpu().numpy()[0, 0]
# skip if no surface found
if current_mask is not None:
valid_z = z.reshape(cropN, cropN, cropN)[current_mask]
if valid_z.shape[0] <= 0 or (np.min(valid_z) > level or np.max(valid_z) < level):
continue

if not (np.min(z) > level or np.max(z) < level):
z = z.astype(np.float32)
verts, faces, normals, _ = measure.marching_cubes(
volume=z.reshape(cropN, cropN, cropN), # .transpose([1, 0, 2]),
level=level,
spacing=(
(x_max - x_min) / (cropN - 1),
(y_max - y_min) / (cropN - 1),
(z_max - z_min) / (cropN - 1),
),
mask=current_mask,
)
verts = verts + np.array([x_min, y_min, z_min])

meshcrop = trimesh.Trimesh(verts, faces, normals)
meshes.append(meshcrop)

combined = trimesh.util.concatenate(meshes)

# inverse contraction and clipping the points range
if inv_contraction is not None:
combined.vertices = inv_contraction(torch.from_numpy(combined.vertices)).numpy()
combined.vertices = np.clip(combined.vertices, -max_range, max_range)

if return_mesh:
return combined
else:
filename = str(output_path)
filename_simplify = str(output_path).replace(".ply", "-simplify.ply")

combined.export(filename)
if simplify_mesh:
ms = pymeshlab.MeshSet()
ms.load_new_mesh(filename)

print("simply mesh")
ms.meshing_decimation_quadric_edge_collapse(targetfacenum=2000000)
ms.save_current_mesh(filename_simplify, save_face_color=False)
53 changes: 52 additions & 1 deletion scripts/extract_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,20 @@
import tyro
from rich.console import Console

from nerfstudio.model_components.ray_samplers import save_points
from nerfstudio.utils.eval_utils import eval_setup
from nerfstudio.utils.marching_cubes import get_surface_occupancy, get_surface_sliding
from nerfstudio.utils.marching_cubes import (
get_surface_occupancy,
get_surface_sliding,
get_surface_sliding_with_contraction,
)

CONSOLE = Console(width=120)

# speedup for when input size to model doesn't change (much)
torch.backends.cudnn.benchmark = True # type: ignore
torch.set_float32_matmul_precision("high")


@dataclass
class ExtractMesh:
Expand All @@ -36,6 +45,12 @@ class ExtractMesh:
bounding_box_min: Tuple[float, float, float] = (-1.0, -1.0, -1.0)
"""Maximum of the bounding box."""
bounding_box_max: Tuple[float, float, float] = (1.0, 1.0, 1.0)
"""marching cube threshold"""
marching_cube_threshold: float = 0.0
"""create visibility mask"""
create_visibility_mask: bool = False
"""save visibility grid"""
save_visibility_grid: bool = False

def main(self) -> None:
"""Main function."""
Expand All @@ -46,6 +61,42 @@ def main(self) -> None:

CONSOLE.print("Extract mesh with marching cubes and may take a while")

if self.create_visibility_mask:
assert self.resolution % 512 == 0

coarse_mask = pipeline.get_visibility_mask()

# TODO reading contraction type from pipeline
def inv_contract(x, order=float("inf")):
mag = torch.linalg.norm(x, ord=order, dim=-1)
mask = mag >= 1
x_new = x.clone()
x_new[mask] = (1 / (2 - mag[mask][..., None])) * (x[mask] / mag[mask][..., None])
return x_new

if self.save_visibility_grid:
offset = torch.linspace(-2.0, 2.0, 512)
x, y, z = torch.meshgrid(offset, offset, offset, indexing="ij")
offset_cube = torch.stack([x, y, z], dim=-1).reshape(-1, 3).to(coarse_mask.device)
points = offset_cube[coarse_mask.reshape(-1) > 0]
points = inv_contract(points)
save_points("mask.ply", points.cpu().numpy())
torch.save(coarse_mask, "coarse_mask.pt")

get_surface_sliding_with_contraction(
sdf=lambda x: (
pipeline.model.field.forward_geonetwork(x)[:, 0] - self.marching_cube_threshold
).contiguous(),
resolution=self.resolution,
bounding_box_min=self.bounding_box_min,
bounding_box_max=self.bounding_box_max,
coarse_mask=coarse_mask,
output_path=self.output_path,
simplify_mesh=self.simplify_mesh,
inv_contraction=inv_contract,
)
return

if self.is_occupancy:
# for unisurf
get_surface_occupancy(
Expand Down

0 comments on commit e9b0ca7

Please sign in to comment.