Skip to content

Commit

Permalink
Add inference code
Browse files Browse the repository at this point in the history
  • Loading branch information
abcd40404 committed Apr 7, 2021
1 parent 6dfbbfa commit 6d14e56
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 1 deletion.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ Please refer to [NUSCENES-GUIDE](./NUSCENES-GUIDE.md)

-- For nuScenes dataset, please refer to [NUSCENES-GUIDE](./NUSCENES-GUIDE.md)

## Semantic segmentation demo for a folder of images
```
python demo_folder.py --demo-folder YOUR_FOLDER --save-folder YOUR_SAVE_FOLDER
```

## TODO List
- [x] Release pretrained model for nuScenes.
- [x] Support multiscan semantic segmentation.
Expand Down
27 changes: 26 additions & 1 deletion dataloader/pc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,30 @@ def get_pc_model_class(name):
assert name in REGISTERED_PC_DATASET_CLASSES, f"available class: {REGISTERED_PC_DATASET_CLASSES}"
return REGISTERED_PC_DATASET_CLASSES[name]

@register_dataset
class SemKITTI_demo(data.Dataset):
def __init__(self, data_path, imageset='demo',
return_ref=True, label_mapping="semantic-kitti.yaml", nusc=None):
with open(label_mapping, 'r') as stream:
semkittiyaml = yaml.safe_load(stream)
self.learning_map = semkittiyaml['learning_map']
self.return_ref = return_ref

self.im_idx = []
self.im_idx += absoluteFilePaths(data_path)

def __len__(self):
'Denotes the total number of samples'
return len(self.im_idx)

def __getitem__(self, index):
raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4))
annotated_data = np.expand_dims(np.zeros_like(raw_data[:, 0], dtype=int), axis=1)

data_tuple = (raw_data[:, :3], annotated_data.astype(np.uint8))
if self.return_ref:
data_tuple += (raw_data[:, 3],)
return data_tuple

@register_dataset
class SemKITTI_sk(data.Dataset):
Expand Down Expand Up @@ -58,7 +82,7 @@ def __getitem__(self, index):
annotated_data = np.expand_dims(np.zeros_like(raw_data[:, 0], dtype=int), axis=1)
else:
annotated_data = np.fromfile(self.im_idx[index].replace('velodyne', 'labels')[:-3] + 'label',
dtype=np.int32).reshape((-1, 1))
dtype=np.uint32).reshape((-1, 1))
annotated_data = annotated_data & 0xFFFF # delete high 16 digits binary
annotated_data = np.vectorize(self.learning_map.__getitem__)(annotated_data)

Expand Down Expand Up @@ -108,6 +132,7 @@ def __getitem__(self, index):

def absoluteFilePaths(directory):
for dirpath, _, filenames in os.walk(directory):
filenames.sort()
for f in filenames:
yield os.path.abspath(os.path.join(dirpath, f))

Expand Down
130 changes: 130 additions & 0 deletions demo_folder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# -*- coding:utf-8 -*-
# author: Ptzu
# @file: demo_folder.py

import os
import time
import argparse
import sys
import numpy as np
import torch
import torch.optim as optim
from tqdm import tqdm
import yaml

from utils.metric_util import per_class_iu, fast_hist_crop
from dataloader.pc_dataset import get_SemKITTI_label_name
from builder import data_builder, model_builder, loss_builder
from config.config import load_config_data
from dataloader.dataset_semantickitti import get_model_class, collate_fn_BEV
from dataloader.pc_dataset import get_pc_model_class

from utils.load_save_util import load_checkpoint

import warnings

warnings.filterwarnings("ignore")


def build_dataset(dataset_config,
data_dir,
grid_size=[480, 360, 32]):

label_mapping = dataset_config["label_mapping"]

SemKITTI_demo = get_pc_model_class('SemKITTI_demo')

demo_pt_dataset = SemKITTI_demo(data_dir, imageset="demo",
return_ref=True, label_mapping=label_mapping, nusc=None)

demo_dataset = get_model_class(dataset_config['dataset_type'])(
demo_pt_dataset,
grid_size=grid_size,
fixed_volume_space=dataset_config['fixed_volume_space'],
max_volume_space=dataset_config['max_volume_space'],
min_volume_space=dataset_config['min_volume_space'],
ignore_label=dataset_config["ignore_label"],
)
demo_dataset_loader = torch.utils.data.DataLoader(dataset=demo_dataset,
batch_size=1,
collate_fn=collate_fn_BEV,
shuffle=False,
num_workers=4)

return demo_dataset_loader

def main(args):
pytorch_device = torch.device('cuda:0')
config_path = args.config_path
configs = load_config_data(config_path)
dataset_config = configs['dataset_params']
data_dir = args.demo_folder
save_dir = args.save_folder + "/"

demo_batch_size = 1
model_config = configs['model_params']
train_hypers = configs['train_params']

grid_size = model_config['output_shape']
num_class = model_config['num_class']
ignore_label = dataset_config['ignore_label']
model_load_path = train_hypers['model_load_path']

SemKITTI_label_name = get_SemKITTI_label_name(dataset_config["label_mapping"])
unique_label = np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1
unique_label_str = [SemKITTI_label_name[x] for x in unique_label + 1]

my_model = model_builder.build(model_config)
if os.path.exists(model_load_path):
my_model = load_checkpoint(model_load_path, my_model)

my_model.to(pytorch_device)
optimizer = optim.Adam(my_model.parameters(), lr=train_hypers["learning_rate"])

loss_func, lovasz_softmax = loss_builder.build(wce=True, lovasz=True,
num_class=num_class, ignore_label=ignore_label)

demo_dataset_loader = build_dataset(dataset_config, data_dir, grid_size=grid_size)
with open(dataset_config["label_mapping"], 'r') as stream:
semkittiyaml = yaml.safe_load(stream)
inv_learning_map = semkittiyaml['learning_map_inv']

my_model.eval()
hist_list = []
demo_loss_list = []
with torch.no_grad():
for i_iter_demo, (_, demo_vox_label, demo_grid, demo_pt_labs, demo_pt_fea) in enumerate(
demo_dataset_loader):
demo_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in
demo_pt_fea]
demo_grid_ten = [torch.from_numpy(i).to(pytorch_device) for i in demo_grid]
demo_label_tensor = demo_vox_label.type(torch.LongTensor).to(pytorch_device)

predict_labels = my_model(demo_pt_fea_ten, demo_grid_ten, demo_batch_size)
loss = lovasz_softmax(torch.nn.functional.softmax(predict_labels).detach(), demo_label_tensor,
ignore=0) + loss_func(predict_labels.detach(), demo_label_tensor)
predict_labels = torch.argmax(predict_labels, dim=1)
predict_labels = predict_labels.cpu().detach().numpy()
for count, i_demo_grid in enumerate(demo_grid):
hist_list.append(fast_hist_crop(predict_labels[
count, demo_grid[count][:, 0], demo_grid[count][:, 1],
demo_grid[count][:, 2]], demo_pt_labs[count],
unique_label))
inv_labels = np.vectorize(inv_learning_map.__getitem__)(predict_labels[count, demo_grid[count][:, 0], demo_grid[count][:, 1], demo_grid[count][:, 2]])
inv_labels = inv_labels.astype('uint32')
outputPath = save_dir + str(i_iter_demo).zfill(6) + '.label'
inv_labels.tofile(outputPath)
print("save " + outputPath)
demo_loss_list.append(loss.detach().cpu().numpy())

if __name__ == '__main__':
# Training settings
parser = argparse.ArgumentParser(description='')
parser.add_argument('-y', '--config_path', default='config/semantickitti.yaml')
parser.add_argument('--demo-folder', type=str, default='', help='path to the folder containing demo lidar scans', required=True)
parser.add_argument('--save-folder', type=str, default='', help='path to save your result', required=True)
args = parser.parse_args()

print(' '.join(sys.argv))
print(args)
main(args)

0 comments on commit 6d14e56

Please sign in to comment.