Skip to content

Commit

Permalink
The building class is divided into Facade and Roof during GANCraft tr…
Browse files Browse the repository at this point in the history
…aining.
  • Loading branch information
hzxie committed Jun 15, 2023
1 parent 0d40c7e commit 5e428f2
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 98 deletions.
14 changes: 7 additions & 7 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# @Author: Haozhe Xie
# @Date: 2023-04-05 20:14:54
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2023-06-09 16:26:19
# @Last Modified at: 2023-06-15 15:12:49
# @Email: [email protected]

from easydict import EasyDict
Expand Down Expand Up @@ -97,16 +97,17 @@
cfg.NETWORK.SAMPLER.TOTAL_STEPS = 256
# GANCraft
cfg.NETWORK.GANCRAFT = EasyDict()
cfg.NETWORK.GANCRAFT.BUILDING_MODE = False
cfg.NETWORK.GANCRAFT.STYLE_DIM = 128
cfg.NETWORK.GANCRAFT.N_SAMPLE_POINTS_PER_RAY = 24
cfg.NETWORK.GANCRAFT.DIST_SCALE = 0.25
cfg.NETWORK.GANCRAFT.ENCODER = "GLOBAL"
cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM = 2
cfg.NETWORK.GANCRAFT.ENCODER = "LOCAL"
cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM = 64 if cfg.NETWORK.GANCRAFT.BUILDING_MODE else 32
cfg.NETWORK.GANCRAFT.GLOBAL_ENCODER_N_BLOCKS = 6
cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM = "GROUP_NORM"
cfg.NETWORK.GANCRAFT.POS_EMD = "HASH_GRID"
cfg.NETWORK.GANCRAFT.POS_EMD = "SIN_COS"
cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES = True
cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS = True
cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS = False
cfg.NETWORK.GANCRAFT.HASH_GRID_N_LEVELS = 16
cfg.NETWORK.GANCRAFT.HASH_GRID_LEVEL_DIM = 8
cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS = 10
Expand All @@ -115,7 +116,6 @@
cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_SIGMA = 1
cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_COLOR = 64
cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE = 128
cfg.NETWORK.GANCRAFT.BUILDING_MODE = False

#
# Train
Expand Down Expand Up @@ -154,7 +154,7 @@
cfg.TRAIN.GANCRAFT.EPS = 1e-7
cfg.TRAIN.GANCRAFT.WEIGHT_DECAY = 0
cfg.TRAIN.GANCRAFT.BETAS = (0., 0.999)
cfg.TRAIN.GANCRAFT.CROP_SIZE = (224, 224)
cfg.TRAIN.GANCRAFT.CROP_SIZE = (192, 192)
cfg.TRAIN.GANCRAFT.PERCEPTUAL_LOSS_MODEL = "vgg19"
cfg.TRAIN.GANCRAFT.PERCEPTUAL_LOSS_LAYERS = ["relu_3_1", "relu_4_1", "relu_5_1"]
cfg.TRAIN.GANCRAFT.PERCEPTUAL_LOSS_WEIGHTS = [0.125, 0.25, 1.0]
Expand Down
42 changes: 24 additions & 18 deletions models/gancraft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
# @Author: Haozhe Xie
# @Date: 2023-04-12 19:53:21
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2023-06-14 18:59:49
# @Last Modified at: 2023-06-15 15:11:46
# @Email: [email protected]
# @Ref: https://github.com/FrozenBurning/SceneDreamer

import numpy as np
import torch
import torch.nn.functional as F
import torchvision.models

import extensions.grid_encoder
import extensions.voxlib
Expand Down Expand Up @@ -127,25 +126,30 @@ def _forward_perpix(
building_stats,
)
# Generate per-sample segmentation label
mc_masks = torch.gather(voxel_id, -2, new_idx)
# print(mc_masks.size()) # torch.Size([N, H, W, n_samples + 1, 1])
mc_masks = mc_masks.long()
mc_masks_onehot = torch.zeros(
seg_map_bev = torch.gather(voxel_id, -2, new_idx)
# print(seg_map_bev.size()) # torch.Size([N, H, W, n_samples + 1, 1])
# In Building Mode, the one more channel is used for building roofs
n_seg_map_classes = (
self.cfg.DATASETS.OSM_LAYOUT.N_CLASSES + 1
if self.cfg.NETWORK.GANCRAFT.BUILDING_MODE
else self.cfg.DATASETS.OSM_LAYOUT.N_CLASSES
)
seg_map_bev_onehot = torch.zeros(
[
mc_masks.size(0),
mc_masks.size(1),
mc_masks.size(2),
mc_masks.size(3),
self.cfg.DATASETS.OSM_LAYOUT.N_CLASSES,
seg_map_bev.size(0),
seg_map_bev.size(1),
seg_map_bev.size(2),
seg_map_bev.size(3),
n_seg_map_classes,
],
dtype=torch.float,
device=voxel_id.device,
)
# print(mc_masks_onehot.size()) # torch.Size([N, H, W, n_samples + 1, 1])
mc_masks_onehot.scatter_(-1, mc_masks, 1.0)
# print(seg_map_bev_onehot.size()) # torch.Size([N, H, W, n_samples + 1, 1])
seg_map_bev_onehot.scatter_(-1, seg_map_bev.long(), 1.0)

net_out_s, net_out_c = self._forward_perpix_sub(
features, normalized_cord, z, mc_masks_onehot
features, normalized_cord, z, seg_map_bev_onehot
)
# Blending
weights = self._volum_rendering_relu(
Expand Down Expand Up @@ -340,14 +344,14 @@ def _cumsum_exclusive(self, tensor, dim):
)
return cumsum

def _forward_perpix_sub(self, features, normalized_cord, z, mc_masks_onehot):
def _forward_perpix_sub(self, features, normalized_cord, z, seg_map_bev_onehot):
r"""Forwarding the MLP.
Args:
features (N x C1 x ...? tensor): Local features determined by the current pixel.
normalized_coord (N x H x W x L x 3 tensor): 3D world coordinates of sampled points. L is number of samples; N is batch size, always 1.
z (N x C3 tensor): Intermediate style vectors.
mc_masks_onehot (N x H x W x L x C4): One-hot segmentation maps.
seg_map_bev_onehot (N x H x W x L x C4): One-hot segmentation maps.
Returns:
net_out_s (N x H x W x L x 1 tensor): Opacities.
net_out_c (N x H x W x L x C5 tensor): Color embeddings.
Expand Down Expand Up @@ -421,7 +425,7 @@ def _forward_perpix_sub(self, features, normalized_cord, z, mc_masks_onehot):
elif self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES:
feature_in = feature_in

net_out_s, net_out_c = self.render_net(feature_in, z, mc_masks_onehot)
net_out_s, net_out_c = self.render_net(feature_in, z, seg_map_bev_onehot)
return net_out_s, net_out_c

def _forward_global(self, net_out, z):
Expand Down Expand Up @@ -593,7 +597,9 @@ def __init__(self, cfg):
in_dim = f_dim

self.fc_m_a = torch.nn.Linear(
cfg.DATASETS.OSM_LAYOUT.N_CLASSES,
cfg.DATASETS.OSM_LAYOUT.N_CLASSES + 1
if cfg.NETWORK.GANCRAFT.BUILDING_MODE
else cfg.DATASETS.OSM_LAYOUT.N_CLASSES,
cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM,
bias=False,
)
Expand Down
102 changes: 67 additions & 35 deletions scripts/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# @Author: Haozhe Xie
# @Date: 2023-03-31 15:04:25
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2023-06-15 09:10:16
# @Last Modified at: 2023-06-15 14:04:17
# @Email: [email protected]

import argparse
Expand Down Expand Up @@ -33,6 +33,29 @@
import extensions.voxlib as voxlib
from extensions.extrude_tensor import TensorExtruder

# Global constants
HEIGHTS = {
"ROAD": 4,
"GREEN_LANDS": 8,
"CONSTRUCTION": 10,
"COAST_ZONES": 0,
}
CLASSES = {
"NULL": 0,
"ROAD": 1,
"BLD_FACADE": 2,
"GREEN_LANDS": 3,
"CONSTRUCTION": 4,
"COAST_ZONES": 5,
"OTHERS": 6,
"BLD_ROOF": 7,
}
# NOTE: ID > 10 are reserved for building instances.
# Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1.
CONSTANTS = {
"BLD_INS_LABEL_MIN": 10,
}


def _tag_equals(tags, key, values=None):
if key not in tags:
Expand All @@ -44,11 +67,14 @@ def _tag_equals(tags, key, values=None):

def _get_highway_color(map_name, highway_tags):
if map_name == "height_field":
return 0
return HEIGHTS["ROAD"]
elif map_name == "seg_map":
# Ignore underground highways
return 0 if "layer" in highway_tags and highway_tags["layer"] < 0 else 1
# return 1
return (
CLASSES["NULL"]
if "layer" in highway_tags and highway_tags["layer"] < 0
else CLASSES["ROAD"]
)
else:
raise Exception("Unknown map name: %s" % map_name)

Expand All @@ -65,20 +91,20 @@ def _get_footprint_color(map_name, footprint_tags):
return None
elif _tag_equals(footprint_tags, "landuse", ["construction"]):
# "building" in footprint_tags and footprint_tags["building"] == "construction"
return 10
return HEIGHTS["CONSTRUCTION"]
else:
raise Exception("Unknown height for tag: %s" % footprint_tags)
elif map_name == "seg_map":
if _tag_equals(footprint_tags, "role", ["inner"]):
return 2
return CLASSES["BLD_FACADE"]
elif _tag_equals(footprint_tags, "building") or _tag_equals(
footprint_tags, "building:part"
):
return 2
return CLASSES["BLD_FACADE"]
elif _tag_equals(footprint_tags, "landuse", ["construction"]):
return 4
return CLASSES["CONSTRUCTION"]
else:
return 0
return CLASSES["NULL"]
elif map_name == "footprint_contour":
if _tag_equals(footprint_tags, "building"):
return 1
Expand Down Expand Up @@ -154,9 +180,9 @@ def get_osm_images(osm_file_path, osm_tile_img_path, zoom_level):
)
else:
green_lands = get_green_lands(osm_tile_img_path, seg_map)
seg_map[green_lands != 0] = 3
seg_map[green_lands != 0] = CLASSES["GREEN_LANDS"]
coast_zones = get_coast_zones(osm_tile_img_path, seg_map.shape)
seg_map[coast_zones != 0] = 5
seg_map[coast_zones != 0] = CLASSES["COAST_ZONES"]
# Plot footprint at the end to make building masks more complete
seg_map = utils.osm_helper.plot_footprints(
"seg_map",
Expand All @@ -167,7 +193,7 @@ def get_osm_images(osm_file_path, osm_tile_img_path, zoom_level):
xy_bounds,
)
# Assign ID=6 to unlabelled pixels (regarded as ground)
seg_map[seg_map == 0] = 6
seg_map[seg_map == 0] = CLASSES["OTHERS"]

# Generate the contours of footprints
logging.debug("Generating footprint contours ...")
Expand Down Expand Up @@ -196,11 +222,11 @@ def get_osm_images(osm_file_path, osm_tile_img_path, zoom_level):
# xy_bounds,
# resolution,
# )
height_field[height_field == 0] = 4
height_field[height_field == 0] = HEIGHTS["ROAD"]
if coast_zones is not None:
height_field[coast_zones != 0] = 0
height_field[coast_zones != 0] = HEIGHTS["COAST_ZONES"]
if green_lands is not None:
height_field[green_lands != 0] = 8
height_field[green_lands != 0] = HEIGHTS["GREEN_LANDS"]
# Follow the order in plotting seg maps
height_field = utils.osm_helper.plot_footprints(
"height_field",
Expand Down Expand Up @@ -337,28 +363,26 @@ def _get_img_patch(img, cx, cy, patch_size):


def _get_instance_seg_map(seg_map, contours, use_contours=False):
BULIDING_MASK_ID = 2
BLD_INS_LABEL_MIN = 10
N_PIXELS_THRES = 16
if use_contours:
_, labels, stats, _ = cv2.connectedComponentsWithStats(
(1 - contours).astype(np.uint8), connectivity=4
)
else:
_, labels, stats, _ = cv2.connectedComponentsWithStats(
(seg_map == BULIDING_MASK_ID).astype(np.uint8), connectivity=4
(seg_map == CLASSES["BLD_FACADE"]).astype(np.uint8), connectivity=4
)

# Remove non-building instance masks
labels[seg_map != BULIDING_MASK_ID] = 0
labels[seg_map != CLASSES["BLD_FACADE"]] = 0
# Building instance mask
building_mask = labels != 0

# Building Instance Mask starts from 10 (labels + 10)
seg_map[seg_map == BULIDING_MASK_ID] = 0
seg_map = (
seg_map * (1 - building_mask) + (labels + BLD_INS_LABEL_MIN) * building_mask
)
# Make building instance IDs are even numbers and start from 10
# Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1.
labels = (labels + CONSTANTS["BLD_INS_LABEL_MIN"]) * 2

seg_map[seg_map == CLASSES["BLD_FACADE"]] = 0
seg_map = seg_map * (1 - building_mask) + labels * building_mask
assert np.max(labels) < 2147483648
return seg_map.astype(np.int32), stats[:, :4]

Expand Down Expand Up @@ -447,13 +471,11 @@ def get_google_earth_aligned_seg_maps(
patch_size,
).astype(np.int32)
# Recalculate the center offsets of buildings
BLD_INS_LABEL_MIN = 10
buildings = (
np.unique(part_seg_map[part_seg_map > BLD_INS_LABEL_MIN]) - BLD_INS_LABEL_MIN
)
buildings = np.unique(part_seg_map[part_seg_map > CONSTANTS["BLD_INS_LABEL_MIN"]])
part_building_stats = {}
for bid in buildings:
_stats = building_stats[bid].copy().astype(np.float32)
_bid = bid // 2 - CONSTANTS["BLD_INS_LABEL_MIN"]
_stats = building_stats[_bid].copy().astype(np.float32)
# NOTE: assert building_stats.shape[1] == 4, represents x, y, w, h of the components.
# Convert x and y to dx and dy, where dx and dy denote the offsets to the center.
_stats[0] = _stats[0] - tr_cx + _stats[2] / 2
Expand All @@ -466,10 +488,21 @@ def get_google_earth_aligned_seg_maps(
torch.from_numpy(part_hf[None, None, ...]).cuda(),
).squeeze()
logging.debug("The shape of SegVolume: %s" % (seg_volume.size(),))
# Convert camera position to the voxel coordinate system
vol_cx, vol_cy = ((patch_size - 1) // 2, (patch_size - 1) // 2)
# Change the top-level voxel of the "Building Facade" to "Building Roof"
roof_seg_map = part_seg_map.copy()
non_roof_msk = part_seg_map <= CONSTANTS["BLD_INS_LABEL_MIN"]
# Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1.
roof_seg_map = roof_seg_map - 1
roof_seg_map[non_roof_msk] = part_seg_map[non_roof_msk]
seg_volume = seg_volume.scatter_(
dim=2,
index=torch.from_numpy(part_hf[..., None]).long().cuda(),
src=torch.from_numpy(roof_seg_map[..., None]).cuda(),
)

seg_maps = []
# Convert camera position to the voxel coordinate system
vol_cx, vol_cy = ((patch_size - 1) // 2, (patch_size - 1) // 2)
for gcp in tqdm(ge_camera_poses["poses"], desc="Project: %s" % ge_project_name):
x, y = utils.osm_helper.lnglat2xy(
gcp["coordinate"]["longitude"],
Expand Down Expand Up @@ -560,10 +593,9 @@ def get_google_earth_aligned_seg_maps(


def get_ambiguous_seg_mask(voxel_id, est_seg_map):
BULIDING_MASK_ID = 2
BLD_INS_LABEL_MIN = 10
seg_map = voxel_id.squeeze()[..., 0]
seg_map[seg_map >= BLD_INS_LABEL_MIN] = BULIDING_MASK_ID
# All facade and roof instances are mapped into BLD_FACADE
seg_map[seg_map >= CONSTANTS["BLD_INS_LABEL_MIN"]] = CLASSES["BLD_FACADE"]
est_seg_map = np.array(est_seg_map.convert("P"))
return seg_map == est_seg_map

Expand Down
Loading

0 comments on commit 5e428f2

Please sign in to comment.