diff --git a/README.md b/README.md index de3e4d0..ba001c5 100644 --- a/README.md +++ b/README.md @@ -3,11 +3,14 @@
------------------------------------------ +π£π£π£ **[*SUSTech1K*](https://lidargait.github.io) relseased, pls checking the [tutorial](datasets/SUSTech1K/README.md).** π£π£π£ + πππ **[*OpenGait*](https://openaccess.thecvf.com/content/CVPR2023/papers/Fan_OpenGait_Revisiting_Gait_Recognition_Towards_Better_Practicality_CVPR_2023_paper.pdf) has been accpected by CVPR2023 as a highlight paperοΌ** πππ OpenGait is a flexible and extensible gait recognition project provided by the [Shiqi Yu Group](https://faculty.sustech.edu.cn/yusq/) and supported in part by [WATRIX.AI](http://www.watrix.ai). ## What's New +- **[July 2023]** [SUSTech1K](datasets/SUSTech1K/README.md) is released and supported by OpenGait. - **[May 2023]** A real gait recognition system [All-in-One-Gait](https://github.com/jdyjjj/All-in-One-Gait) provided by [Dongyang Jin](https://github.com/jdyjjj) is avaliable. - [Apr 2023] [CASIA-E](datasets/CASIA-E/README.md) is supported by OpenGait. - [Feb 2023] [HID 2023 competition](https://hid2023.iapr-tc4.org/) is open, welcome to participate. Additionally, tutorial for the competition has been updated in [datasets/HID/](./datasets/HID). @@ -50,7 +53,7 @@ Results and models are available in the [model zoo](docs/1.model_zoo.md). ## Authors: **Open Gait Team (OGT)** - [Chao Fan (ζ¨θΆ )](https://chaofan996.github.io), 12131100@mail.sustech.edu.cn -- [Chuanfu Shen (ζ²ε·η¦)](https://faculty.sustech.edu.cn/?p=95396&tagid=yusq&cat=2&iscss=1&snapid=1&orderby=date), 11950016@mail.sustech.edu.cn +- [Chuanfu Shen (ζ²ε·η¦)](https://chuanfushen.github.io), 11950016@mail.sustech.edu.cn - [Junhao Liang (ζ’ε³»θ±ͺ)](https://faculty.sustech.edu.cn/?p=95401&tagid=yusq&cat=2&iscss=1&snapid=1&orderby=date), 12132342@mail.sustech.edu.cn ## Acknowledgement diff --git a/configs/lidargait/lidargait_sustech1k.yaml b/configs/lidargait/lidargait_sustech1k.yaml new file mode 100644 index 0000000..d1c73b9 --- /dev/null +++ b/configs/lidargait/lidargait_sustech1k.yaml @@ -0,0 +1,101 @@ +data_cfg: + dataset_name: SUSTech1K + dataset_root: your_path_of_SUSTech1K-Released-pkl + dataset_partition: ./datasets/SUSTech1K/SUSTech1K.json + num_workers: 4 + data_in_use: [false, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false] + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: SUSTech1K + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 40000 + save_name: LidarGait + eval_func: evaluate_indoor_dataset #evaluate_Gait3D + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilTransform + +loss_cfg: + - loss_term_weight: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weight: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: Baseline + backbone_cfg: + type: ResNet9 + in_channel: 3 + block: BasicBlock + channels: # Layers configuration for automatically model construction + - 64 + - 128 + - 256 + - 512 + layers: + - 1 + - 1 + - 1 + - 1 + strides: + - 1 + - 2 + - 2 + - 1 + maxpool: false + SeparateFCs: + in_channels: 512 + out_channels: 256 + parts_num: 16 + SeparateBNNecks: + class_num: 250 + in_channels: 256 + parts_num: 16 + bin_num: + - 16 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 20000 + - 30000 + scheduler: MultiStepLR +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + with_test: true #true + log_iter: 100 + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 5000 + save_name: LidarGait + sync_BN: true + total_iter: 40000 + sampler: + batch_shuffle: true + batch_size: + - 8 # TripletSampler, batch_size[0] indicates Number of Identity + - 8 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 10 # fixed frames number for training + sample_type: fixed_unordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: BaseSilTransform \ No newline at end of file diff --git a/datasets/SUSTech1K/README.md b/datasets/SUSTech1K/README.md new file mode 100644 index 0000000..925c5fa --- /dev/null +++ b/datasets/SUSTech1K/README.md @@ -0,0 +1,33 @@ +# Tutorial for [SUSTech1K](https://lidargait.github.io) + +## Download the SUSTech1K dataset +Download the dataset from the [link](https://lidargait.github.io). +decompress these two file by following command: +```shell +unzip -P password SUSTech1K-pkl.zip | xargs -n1 tar xzvf +``` +password should be obtained by signing [agreement](https://lidargait.github.io/static/resources/SUSTech1KAgreement.pdf) and sending to email (shencf2019@mail.sustech.edu.cn) + +## Train the dataset +Modify the `dataset_root` in `configs/lidargait/lidargait_sustech1k.yaml`, and then run this command: +```shell +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 opengait/main.py --cfgs configs/lidargait/lidargait_sustech1k.yaml --phase train +``` + + +## Process from RAW dataset + +### Preprocess the dataset (Optional) +Download the raw dataset from the [official link](https://lidargait.github.io). You will get two compressed files, i.e. `DATASET_DOWNLOAD.md5`, `SUSTeck1K-RAW.zip`, and `SUSTeck1K-pkl.zip`. +We recommend using our provided pickle files for convenience, or process raw dataset into pickle by this command: +```shell +python datasets/SUSTech1K/pretreatment_SUSTech1K.py -i SUSTech1K-Released-2023 -o SUSTech1K-pkl -n 8 +``` + +### Projecting PointCloud into Depth image (Optional) +You can use our processed depth images, or you can process via the command: +```shell +python datasets/SUSTech1K/point2depth.py -i SUSTech1K-Released-2023/ -o SUSTech1K-Released-2023/ -n 8 +``` +We recommend using our provided depth images for convenience. + diff --git a/datasets/SUSTech1K/SUSTech1K.json b/datasets/SUSTech1K/SUSTech1K.json new file mode 100644 index 0000000..dd80292 --- /dev/null +++ b/datasets/SUSTech1K/SUSTech1K.json @@ -0,0 +1,1056 @@ +{ + "TRAIN_SET": [ + "0000", + "0002", + "0005", + "0018", + "0021", + "0026", + "0033", + "0039", + "0040", + "0044", + "0046", + "0047", + "0050", + "0052", + "0060", + "0062", + "0066", + "0073", + "0075", + "0078", + "0080", + "0092", + "0093", + "0096", + "0097", + "0102", + "0103", + "0116", + "0118", + "0126", + "0128", + "0144", + "0149", + "0151", + "0153", + "0154", + "0156", + "0157", + "0158", + "0164", + "0165", + "0168", + "0169", + "0174", + "0180", + "0183", + "0193", + "0194", + "0196", + "0199", + "0201", + "0203", + "0204", + "0208", + "0212", + "0220", + "0226", + "0227", + "0231", + "0232", + "0243", + "0253", + "0257", + "0267", + "0268", + "0282", + "0284", + "0287", + "0289", + "0291", + "0293", + "0295", + "0297", + "0301", + "0306", + "0311", + "0315", + "0318", + "0320", + "0321", + "0323", + "0325", + "0327", + "0341", + "0347", + "0351", + "0356", + "0357", + "0360", + "0367", + "0372", + "0379", + "0380", + "0395", + "0402", + "0412", + "0421", + "0425", + "0434", + "0438", + "0440", + "0441", + "0444", + "0452", + "0457", + "0458", + "0464", + "0468", + "0473", + "0474", + "0478", + "0488", + "0493", + "0495", + "0496", + "0497", + "0502", + "0514", + "0522", + "0521", + "0525", + "0533", + "0535", + "0537", + "0540", + "0544", + "0545", + "0546", + "0547", + "0551", + "0552", + "0553", + "0555", + "0557", + "0559", + "0577", + "0581", + "0583", + "0584", + "0585", + "0591", + "0597", + "0600", + "0605", + "0610", + "0611", + "0612", + "0616", + "0631", + "0632", + "0634", + "0636", + "0637", + "0641", + "0649", + "0653", + "0655", + "0664", + "0665", + "0671", + "0675", + "0677", + "0687", + "0695", + "0701", + "0702", + "0707", + "0717", + "0720", + "0721", + "0723", + "0726", + "0731", + "0756", + "0757", + "0759", + "0760", + "0767", + "0770", + "0773", + "0779", + "0780", + "0783", + "0791", + "0792", + "0796", + "0805", + "0810", + "0811", + "0823", + "0828", + "0830", + "0839", + "0841", + "0844", + "0845", + "0846", + "0850", + "0853", + "0860", + "0862", + "0863", + "0865", + "0868", + "0869", + "0876", + "0883", + "0884", + "0888", + "0897", + "0904", + "0906", + "0907", + "0908", + "0918", + "0922", + "0923", + "0925", + "0933", + "0938", + "0942", + "0943", + "0944", + "0948", + "0951", + "0959", + "0965", + "0966", + "0969", + "0970", + "0973", + "0978", + "0979", + "0982", + "0996", + "0997", + "1002", + "1004", + "1011", + "1013", + "1015", + "1019", + "1024", + "1026", + "1027", + "1036", + "1038", + "1046", + "1056", + "1057" + ], + "TEST_SET": [ + "0001", + "0003", + "0004", + "0006", + "0007", + "0008", + "0009", + "0010", + "0011", + "0012", + "0013", + "0014", + "0015", + "0016", + "0017", + "0019", + "0020", + "0022", + "0023", + "0024", + "0025", + "0027", + "0028", + "0029", + "0030", + "0031", + "0032", + "0034", + "0035", + "0036", + "0037", + "0038", + "0041", + "0042", + "0043", + "0045", + "0048", + "0049", + "0051", + "0053", + "0054", + "0055", + "0056", + "0057", + "0058", + "0059", + "0061", + "0063", + "0064", + "0065", + "0067", + "0068", + "0069", + "0070", + "0071", + "0072", + "0074", + "0076", + "0077", + "0079", + "0081", + "0082", + "0083", + "0084", + "0085", + "0086", + "0087", + "0088", + "0089", + "0090", + "0091", + "0094", + "0095", + "0098", + "0099", + "0100", + "0101", + "0104", + "0105", + "0106", + "0107", + "0108", + "0109", + "0110", + "0111", + "0112", + "0113", + "0114", + "0115", + "0117", + "0119", + "0120", + "0121", + "0122", + "0123", + "0124", + "0125", + "0127", + "0129", + "0130", + "0131", + "0132", + "0133", + "0134", + "0135", + "0136", + "0137", + "0138", + "0139", + "0140", + "0141", + "0142", + "0143", + "0145", + "0146", + "0147", + "0148", + "0150", + "0152", + "0155", + "0159", + "0160", + "0161", + "0162", + "0163", + "0166", + "0167", + "0170", + "0171", + "0172", + "0173", + "0175", + "0176", + "0177", + "0178", + "0179", + "0181", + "0182", + "0184", + "0185", + "0186", + "0187", + "0188", + "0189", + "0190", + "0191", + "0192", + "0195", + "0197", + "0198", + "0200", + "0202", + "0205", + "0206", + "0207", + "0209", + "0210", + "0211", + "0213", + "0214", + "0215", + "0216", + "0217", + "0218", + "0219", + "0221", + "0222", + "0223", + "0224", + "0225", + "0228", + "0229", + "0230", + "0233", + "0234", + "0235", + "0236", + "0237", + "0238", + "0239", + "0240", + "0241", + "0242", + "0244", + "0245", + "0246", + "0247", + "0248", + "0249", + "0250", + "0251", + "0252", + "0254", + "0255", + "0256", + "0258", + "0259", + "0260", + "0261", + "0262", + "0263", + "0264", + "0265", + "0266", + "0269", + "0270", + "0271", + "0272", + "0273", + "0274", + "0275", + "0276", + "0277", + "0278", + "0279", + "0280", + "0281", + "0283", + "0285", + "0286", + "0288", + "0290", + "0292", + "0294", + "0296", + "0298", + "0299", + "0300", + "0302", + "0303", + "0304", + "0305", + "0307", + "0308", + "0309", + "0310", + "0312", + "0313", + "0314", + "0316", + "0317", + "0319", + "0322", + "0324", + "0326", + "0328", + "0329", + "0330", + "0331", + "0332", + "0333", + "0334", + "0335", + "0336", + "0337", + "0338", + "0339", + "0340", + "0342", + "0343", + "0344", + "0345", + "0346", + "0348", + "0349", + "0350", + "0353", + "0354", + "0355", + "0358", + "0359", + "0361", + "0362", + "0363", + "0364", + "0365", + "0366", + "0368", + "0369", + "0370", + "0371", + "0373", + "0374", + "0375", + "0376", + "0377", + "0378", + "0381", + "0382", + "0383", + "0384", + "0385", + "0386", + "0387", + "0388", + "0389", + "0390", + "0391", + "0392", + "0393", + "0394", + "0396", + "0397", + "0398", + "0399", + "0400", + "0401", + "0403", + "0404", + "0405", + "0406", + "0407", + "0408", + "0409", + "0410", + "0411", + "0413", + "0414", + "0415", + "0416", + "0417", + "0418", + "0419", + "0420", + "0422", + "0423", + "0424", + "0426", + "0427", + "0428", + "0429", + "0430", + "0431", + "0432", + "0433", + "0435", + "0436", + "0437", + "0439", + "0442", + "0443", + "0445", + "0446", + "0447", + "0448", + "0449", + "0450", + "0451", + "0453", + "0454", + "0455", + "0456", + "0459", + "0460", + "0461", + "0462", + "0463", + "0465", + "0466", + "0467", + "0469", + "0470", + "0471", + "0472", + "0475", + "0476", + "0477", + "0479", + "0480", + "0481", + "0482", + "0483", + "0484", + "0485", + "0486", + "0487", + "0489", + "0490", + "0491", + "0492", + "0494", + "0498", + "0499", + "0500", + "0501", + "0503", + "0504", + "0505", + "0506", + "0507", + "0508", + "0509", + "0510", + "0511", + "0512", + "0513", + "0515", + "0516", + "0517", + "0518", + "0519", + "0520", + "0523", + "0524", + "0526", + "0527", + "0528", + "0529", + "0530", + "0531", + "0532", + "0534", + "0536", + "0538", + "0539", + "0541", + "0542", + "0543", + "0548", + "0549", + "0550", + "0554", + "0556", + "0558", + "0560", + "0561", + "0562", + "0563", + "0564", + "0565", + "0566", + "0567", + "0568", + "0569", + "0570", + "0571", + "0572", + "0573", + "0574", + "0575", + "0576", + "0578", + "0579", + "0580", + "0582", + "0586", + "0587", + "0588", + "0589", + "0590", + "0592", + "0593", + "0594", + "0595", + "0596", + "0598", + "0599", + "0601", + "0602", + "0603", + "0604", + "0606", + "0607", + "0608", + "0609", + "0613", + "0614", + "0615", + "0617", + "0618", + "0619", + "0620", + "0621", + "0622", + "0623", + "0624", + "0625", + "0626", + "0627", + "0628", + "0629", + "0630", + "0633", + "0635", + "0638", + "0639", + "0640", + "0642", + "0643", + "0644", + "0645", + "0646", + "0647", + "0648", + "0650", + "0651", + "0652", + "0654", + "0656", + "0657", + "0658", + "0659", + "0660", + "0661", + "0662", + "0663", + "0666", + "0667", + "0668", + "0669", + "0670", + "0672", + "0673", + "0674", + "0676", + "0678", + "0679", + "0680", + "0681", + "0682", + "0683", + "0684", + "0685", + "0686", + "0688", + "0689", + "0690", + "0691", + "0692", + "0693", + "0694", + "0696", + "0697", + "0698", + "0699", + "0700", + "0703", + "0704", + "0705", + "0706", + "0708", + "0709", + "0710", + "0711", + "0712", + "0713", + "0714", + "0715", + "0716", + "0718", + "0719", + "0722", + "0724", + "0725", + "0727", + "0728", + "0729", + "0730", + "0732", + "0733", + "0734", + "0735", + "0736", + "0737", + "0738", + "0739", + "0740", + "0741", + "0742", + "0743", + "0744", + "0745", + "0746", + "0747", + "0748", + "0749", + "0750", + "0751", + "0752", + "0753", + "0754", + "0755", + "0758", + "0761", + "0762", + "0763", + "0764", + "0765", + "0766", + "0768", + "0769", + "0771", + "0772", + "0774", + "0775", + "0776", + "0777", + "0778", + "0781", + "0782", + "0784", + "0785", + "0786", + "0787", + "0788", + "0789", + "0790", + "0793", + "0794", + "0795", + "0797", + "0798", + "0799", + "0800", + "0801", + "0802", + "0803", + "0804", + "0806", + "0807", + "0808", + "0809", + "0812", + "0813", + "0814", + "0815", + "0816", + "0817", + "0818", + "0819", + "0820", + "0821", + "0822", + "0824", + "0825", + "0826", + "0827", + "0829", + "0831", + "0832", + "0833", + "0834", + "0835", + "0836", + "0837", + "0838", + "0840", + "0842", + "0843", + "0847", + "0848", + "0849", + "0851", + "0852", + "0854", + "0855", + "0856", + "0857", + "0858", + "0859", + "0861", + "0864", + "0866", + "0867", + "0870", + "0871", + "0872", + "0873", + "0874", + "0875", + "0877", + "0878", + "0879", + "0880", + "0881", + "0882", + "0885", + "0886", + "0887", + "0889", + "0890", + "0891", + "0892", + "0893", + "0894", + "0895", + "0896", + "0898", + "0899", + "0900", + "0901", + "0902", + "0903", + "0905", + "0909", + "0910", + "0911", + "0912", + "0913", + "0914", + "0915", + "0916", + "0917", + "0919", + "0920", + "0921", + "0924", + "0926", + "0927", + "0928", + "0929", + "0930", + "0931", + "0932", + "0934", + "0935", + "0936", + "0937", + "0939", + "0940", + "0941", + "0945", + "0946", + "0947", + "0949", + "0950", + "0952", + "0953", + "0954", + "0955", + "0956", + "0957", + "0958", + "0960", + "0961", + "0962", + "0963", + "0964", + "0967", + "0968", + "0971", + "0972", + "0974", + "0975", + "0976", + "0977", + "0980", + "0981", + "0983", + "0984", + "0985", + "0986", + "0987", + "0988", + "0989", + "0990", + "0991", + "0992", + "0993", + "0994", + "0995", + "0998", + "0999", + "1000", + "1001", + "1003", + "1005", + "1006", + "1007", + "1008", + "1009", + "1010", + "1012", + "1014", + "1016", + "1017", + "1018", + "1020", + "1021", + "1022", + "1023", + "1025", + "1028", + "1029", + "1030", + "1031", + "1032", + "1033", + "1034", + "1035", + "1037", + "1039", + "1040", + "1041", + "1042", + "1043", + "1044", + "1045", + "1049", + "1055" + ] +} \ No newline at end of file diff --git a/datasets/SUSTech1K/point2depth.py b/datasets/SUSTech1K/point2depth.py new file mode 100644 index 0000000..c685a0a --- /dev/null +++ b/datasets/SUSTech1K/point2depth.py @@ -0,0 +1,279 @@ +import matplotlib.pyplot as plt + +import open3d as o3d +# This source is based on https://github.com/AbnerHqC/GaitSet/blob/master/pretreatment.py +import argparse +import logging +import multiprocessing as mp +import os +import pickle +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Tuple + +import cv2 +import numpy as np +from tqdm import tqdm + +def align_img(img: np.ndarray, img_size: int = 64) -> np.ndarray: + """Aligns the image to the center. + Args: + img (np.ndarray): Image to align. + img_size (int, optional): Image resizing size. Defaults to 64. + Returns: + np.ndarray: Aligned image. + """ + if img.sum() <= 10000: + y_top = 0 + y_btm = img.shape[0] + else: + # Get the upper and lower points + # img.sum + y_sum = img.sum(axis=2).sum(axis=1) + y_top = (y_sum != 0).argmax(axis=0) + y_btm = (y_sum != 0).cumsum(axis=0).argmax(axis=0) + + img = img[y_top: y_btm, :,:] + + # As the height of a person is larger than the width, + # use the height to calculate resize ratio. + ratio = img.shape[1] / img.shape[0] + img = cv2.resize(img, (int(img_size * ratio), img_size), interpolation=cv2.INTER_CUBIC) + + # Get the median of the x-axis and take it as the person's x-center. + x_csum = img.sum(axis=2).sum(axis=0).cumsum() + x_center = img.shape[1] // 2 + for idx, csum in enumerate(x_csum): + if csum > img.sum() / 2: + x_center = idx + break + + # if not x_center: + # logging.warning(f'{img_file} has no center.') + # continue + + # Get the left and right points + half_width = img_size // 2 + left = x_center - half_width + right = x_center + half_width + if left <= 0 or right >= img.shape[1]: + left += half_width + right += half_width + # _ = np.zeros((img.shape[0], half_width,3)) + # img = np.concatenate([_, img, _], axis=1) + + img = img[:, left: right,:].astype('uint8') + return img + + + + + +def lidar_to_2d_front_view(points, + v_res, + h_res, + v_fov, + val="depth", + cmap="jet", + saveto=None, + y_fudge=0.0 + ): + """ Takes points in 3D space from LIDAR data and projects them to a 2D + "front view" image, and saves that image. + + Args: + points: (np array) + The numpy array containing the lidar points. + The shape should be Nx4 + - Where N is the number of points, and + - each point is specified by 4 values (x, y, z, reflectance) + v_res: (float) + vertical resolution of the lidar sensor used. + h_res: (float) + horizontal resolution of the lidar sensor used. + v_fov: (tuple of two floats) + (minimum_negative_angle, max_positive_angle) + val: (str) + What value to use to encode the points that get plotted. + One of {"depth", "height", "reflectance"} + cmap: (str) + Color map to use to color code the `val` values. + NOTE: Must be a value accepted by matplotlib's scatter function + Examples: "jet", "gray" + saveto: (str or None) + If a string is provided, it saves the image as this filename. + If None, then it just shows the image. + y_fudge: (float) + A hacky fudge factor to use if the theoretical calculations of + vertical range do not match the actual data. + + For a Velodyne HDL 64E, set this value to 5. + """ + + # DUMMY PROOFING + assert len(v_fov) ==2, "v_fov must be list/tuple of length 2" + assert v_fov[0] <= 0, "first element in v_fov must be 0 or negative" + assert val in {"depth", "height", "reflectance"}, \ + 'val must be one of {"depth", "height", "reflectance"}' + + + x_lidar = - points[:, 0] + y_lidar = - points[:, 1] + z_lidar = points[:, 2] + # Distance relative to origin when looked from top + d_lidar = np.sqrt(x_lidar ** 2 + y_lidar ** 2) + # Absolute distance relative to origin + # d_lidar = np.sqrt(x_lidar ** 2 + y_lidar ** 2, z_lidar ** 2) + + v_fov_total = -v_fov[0] + v_fov[1] + + # Convert to Radians + v_res_rad = v_res * (np.pi/180) + h_res_rad = h_res * (np.pi/180) + + # PROJECT INTO IMAGE COORDINATES + x_img = np.arctan2(-y_lidar, x_lidar)/ h_res_rad + y_img = np.arctan2(z_lidar, d_lidar)/ v_res_rad + + # SHIFT COORDINATES TO MAKE 0,0 THE MINIMUM + x_min = -360.0 / h_res / 2 # Theoretical min x value based on sensor specs + x_img -= x_min # Shift + x_max = 360.0 / h_res # Theoretical max x value after shifting + + y_min = v_fov[0] / v_res # theoretical min y value based on sensor specs + y_img -= y_min # Shift + y_max = v_fov_total / v_res # Theoretical max x value after shifting + + y_max += y_fudge # Fudge factor if the calculations based on + # spec sheet do not match the range of + # angles collected by in the data. + + # WHAT DATA TO USE TO ENCODE THE VALUE FOR EACH PIXEL + if val == "reflectance": + pass + elif val == "height": + pixel_values = z_lidar + else: + pixel_values = -d_lidar + # pixel_values = 'w' + + # PLOT THE IMAGE + cmap = "jet" # Color map to use + dpi = 100 # Image resolution + fig, ax = plt.subplots(figsize=(x_max/dpi, y_max/dpi), dpi=dpi) + ax.scatter(x_img,y_img, s=1, c=pixel_values, linewidths=0, alpha=1, cmap=cmap) + ax.set_facecolor((0, 0, 0)) # Set regions with no points to black + ax.axis('scaled') # {equal, scaled} + ax.xaxis.set_visible(False) # Do not draw axis tick marks + ax.yaxis.set_visible(False) # Do not draw axis tick marks + plt.xlim([0, x_max]) # prevent drawing empty space outside of horizontal FOV + plt.ylim([0, y_max]) # prevent drawing empty space outside of vertical FOV + + saveto = saveto.replace('.pcd','.png') + fig.savefig(saveto, dpi=dpi, bbox_inches='tight', pad_inches=0.0) + plt.close() + img = cv2.imread(saveto) + img = align_img(img) + + aligned_path = saveto.replace('offline','aligned') + os.makedirs(os.path.dirname(aligned_path), exist_ok=True) + cv2.imwrite(aligned_path, img) + # fig, ax = plt.subplots(figsize=(x_max/dpi, y_max/dpi), dpi=dpi) + # ax.scatter(x_img,y_img, s=1, c='white', linewidths=0, alpha=1) + # ax.set_facecolor((0, 0, 0)) # Set regions with no points to black + # ax.axis('scaled') # {equal, scaled} + # ax.xaxis.set_visible(False) # Do not draw axis tick marks + # ax.yaxis.set_visible(False) # Do not draw axis tick marks + # plt.xlim([0, x_max]) # prevent drawing empty space outside of horizontal FOV + # plt.ylim([0, y_max]) # prevent drawing empty space outside of vertical FOV + + # fig.savefig(saveto.replace('depth','sils'), dpi=dpi, bbox_inches='tight', pad_inches=0.0) + # plt.close() + + +def pcd2depth(img_groups: Tuple, output_path: Path, img_size: int = 64, verbose: bool = False, dataset='CASIAB') -> None: + """Reads a group of images and saves the data in pickle format. + Args: + img_groups (Tuple): Tuple of (sid, seq, view) and list of image paths. + output_path (Path): Output path. + img_size (int, optional): Image resizing size. Defaults to 64. + verbose (bool, optional): Display debug info. Defaults to False. + """ + sinfo = img_groups[0] + img_paths = img_groups[1] + for img_file in sorted(img_paths): + pcd_name = img_file.split('/')[-1] + pcd = o3d.io.read_point_cloud(img_file) + points = np.asarray(pcd.points) + HRES = 0.19188 # horizontal resolution (assuming 20Hz setting) + VRES = 0.2 + VFOV = (-25.0, 15.0) # Field of view (-ve, +ve) along vertical axis + Y_FUDGE = 0 # y fudge factor for velodyne HDL 64E + dst_path = os.path.join(output_path, *sinfo) + os.makedirs(dst_path, exist_ok=True) + dst_path = os.path.join(dst_path,pcd_name) + lidar_to_2d_front_view(points, v_res=VRES, h_res=HRES, v_fov=VFOV, val="depth", + saveto=dst_path, y_fudge=Y_FUDGE) + # if len(points) == 0: + # print(img_file) + # to_pickle.append(points) + # dst_path = os.path.join(output_path, *sinfo) + # os.makedirs(dst_path, exist_ok=True) + # pkl_path = os.path.join(dst_path, f'pcd-{sinfo[2]}.pkl') + # pickle.dump(to_pickle, open(pkl_path, 'wb')) + # if len(to_pickle) < 5: + # logging.warning(f'{sinfo} has less than 5 valid data.') + + + +def pretreat(input_path: Path, output_path: Path, img_size: int = 64, workers: int = 4, verbose: bool = False, dataset: str = 'CASIAB') -> None: + """Reads a dataset and saves the data in pickle format. + Args: + input_path (Path): Dataset root path. + output_path (Path): Output path. + img_size (int, optional): Image resizing size. Defaults to 64. + workers (int, optional): Number of thread workers. Defaults to 4. + verbose (bool, optional): Display debug info. Defaults to False. + """ + img_groups = defaultdict(list) + logging.info(f'Listing {input_path}') + total_files = 0 + for sid in tqdm(sorted(os.listdir(input_path))): + for seq in os.listdir(os.path.join(input_path,sid)): + for view in os.listdir(os.path.join(input_path,sid,seq)): + for img_path in os.listdir(os.path.join(input_path,sid,seq,view,'PCDs')): + img_groups[(sid, seq, view,'PCDs_offline_depths')].append(os.path.join(input_path,sid,seq,view, 'PCDs',img_path)) + total_files += 1 + + logging.info(f'Total files listed: {total_files}') + + progress = tqdm(total=len(img_groups), desc='Pretreating', unit='folder') + + with mp.Pool(workers) as pool: + logging.info(f'Start pretreating {input_path}') + for _ in pool.imap_unordered(partial(pcd2depth, output_path=output_path, img_size=img_size, verbose=verbose, dataset=dataset), img_groups.items()): + progress.update(1) + logging.info('Done') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='OpenGait dataset pretreatment module.') + parser.add_argument('-i', '--input_path', default='', type=str, help='Root path of raw dataset.') + parser.add_argument('-o', '--output_path', default='', type=str, help='Output path of pickled dataset.') + parser.add_argument('-l', '--log_file', default='./pretreatment.log', type=str, help='Log file path. Default: ./pretreatment.log') + parser.add_argument('-n', '--n_workers', default=4, type=int, help='Number of thread workers. Default: 4') + parser.add_argument('-r', '--img_size', default=64, type=int, help='Image resizing size. Default 64') + parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.') + parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.') + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s') + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + logging.info('Verbose mode is on.') + for k, v in args.__dict__.items(): + logging.debug(f'{k}: {v}') + + pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset) diff --git a/datasets/SUSTech1K/pretreatment_SUSTech1K.py b/datasets/SUSTech1K/pretreatment_SUSTech1K.py new file mode 100644 index 0000000..e379ffd --- /dev/null +++ b/datasets/SUSTech1K/pretreatment_SUSTech1K.py @@ -0,0 +1,221 @@ +# This source is based on https://github.com/AbnerHqC/GaitSet/blob/master/pretreatment.py +import argparse +import logging +import multiprocessing as mp +import os +import pickle +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Tuple + +import cv2 +import numpy as np +from tqdm import tqdm + +import json +import open3d as o3d + +def compare_pcd_rgb_timestamp(pcd_file,rgb_file): + pcd_time = float(pcd_file.split('/')[-1].replace('.pcd','')) + 0.05 + rgb_time = float(rgb_file.split('/')[-1].replace('.jpg','')[:10] + '.' + rgb_file.split('/')[-1].replace('.jpg','')[10:]) + return pcd_time, rgb_time + + + +def imgs2pickle(img_groups: Tuple, output_path: Path, img_size: int = 64, verbose: bool = False, dataset='CASIAB') -> None: + """Reads a group of images and saves the data in pickle format. + + Args: + img_groups (Tuple): Tuple of (sid, seq, view) and list of image paths. + output_path (Path): Output path. + img_size (int, optional): Image resizing size. Defaults to 64. + verbose (bool, optional): Display debug info. Defaults to False. + """ + sinfo = img_groups[0] + img_paths = img_groups[1] # path with modality name + to_pickle = [] + cnt = 0 + pcd_list = [] + rgb_list = [] + + threshold = 0.020 # 20 ms + + for index, modality_files in enumerate(img_paths): + data_files = modality_files[1] + modality = modality_files[0] + if modality == 'PCDs': + data = [np.asarray(o3d.io.read_point_cloud(points).points) for points in data_files] + pcd_list = data_files + elif modality == 'RGB_raw': + imgs = [cv2.imread(rgb) for rgb in data_files] + rgb_list = data_files + imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs] + HWs = [img.shape[:2] for img in imgs] + # transpose to (C, H W) + data = [cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for img in imgs] + imgs = [img.transpose(2, 0, 1) for img in imgs] + data = np.asarray(data) + HWs = np.asarray(HWs) + elif modality == 'Sils_raw': + sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files] + data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils] + data = np.asarray(data) + elif modality == 'Sils_aligned': + sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files] + data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils] + data = np.asarray(data) + elif modality == 'Pose': + data = [json.load(open(pose)) for pose in data_files] + data = np.asarray(data) + elif modality == 'PCDs_depths': + imgs = [cv2.imread(rgb) for rgb in data_files] + imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs] + data = [img.transpose(2, 0, 1) for img in imgs] + data = np.asarray(data) + elif modality == 'PCDs_sils': + data = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files] + data = np.asarray(data) + + dst_path = os.path.join(output_path, *sinfo) + os.makedirs(dst_path, exist_ok=True) + if modality == 'RGB_raw': + pkl_path = os.path.join(dst_path, f'{cnt:02d}-{sinfo[2]}-Camera-Ratios-HW.pkl') + pickle.dump(HWs, open(pkl_path, 'wb')) + cnt += 1 + + if 'PCDs' in modality: + pkl_path = os.path.join(dst_path, f'{cnt:02d}-{sinfo[2]}-LiDAR-{modality}.pkl') + pickle.dump(data, open(pkl_path, 'wb')) + else: + pkl_path = os.path.join(dst_path, f'{cnt:02d}-{sinfo[2]}-Camera-{modality}.pkl') + pickle.dump(data, open(pkl_path, 'wb')) + cnt += 1 + + pcd_indexs = [] + rgb_indexs = [] + # print(pcd_list) + for pcd_index in range(len(pcd_list)): + time_diff = 1 + tmp = pcd_index, 0 + for rgb_index in range(len(rgb_list)): + pcd_t, rgb_t = compare_pcd_rgb_timestamp(pcd_list[pcd_index], rgb_list[rgb_index]) + diff = abs(pcd_t - rgb_t) + if diff < time_diff: + tmp = pcd_index, rgb_index + time_diff = diff + if time_diff <= threshold: + pcd_indexs.append(tmp[0]) + rgb_indexs.append(tmp[1]) + + if len(set(pcd_indexs)) != len(pcd_indexs): + print(img_groups[0], pcd_indexs, rgb_indexs, len(pcd_indexs) == len(pcd_indexs)) + + for index, modality_files in enumerate(img_paths): + modality = modality_files[0] + data_files = modality_files[1] + data_files = [data_files[index] for index in pcd_indexs] if 'PCDs' in modality else [data_files[index] for index in rgb_indexs] + + if modality == 'PCDs': + data = [np.asarray(o3d.io.read_point_cloud(points).points) for points in data_files] + pcd_list = data_files + elif modality == 'RGB_raw': + imgs = [cv2.imread(rgb) for rgb in data_files] + rgb_list = data_files + imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs] + HWs = [img.shape[:2] for img in imgs] + # transpose to (C, H W) + data = [cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for img in imgs] + imgs = [img.transpose(2, 0, 1) for img in imgs] + data = np.asarray(data) + HWs = np.asarray(HWs) + elif modality == 'Sils_raw': + sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files] + data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils] + data = np.asarray(data) + elif modality == 'Sils_aligned': + sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files] + data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils] + data = np.asarray(data) + elif modality == 'Pose': + data = [json.load(open(pose)) for pose in data_files] + data = np.asarray(data) + elif modality == 'PCDs_depths': + imgs = [cv2.imread(rgb) for rgb in data_files] + imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs] + data = [img.transpose(2, 0, 1) for img in imgs] + data = np.asarray(data) + elif modality == 'PCDs_sils': + data = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files] + data = np.asarray(data) + + dst_path = os.path.join(output_path, *sinfo) + os.makedirs(dst_path, exist_ok=True) + if modality == 'RGB_raw': + pkl_path = os.path.join(dst_path, f'{cnt:02d}-sync-{sinfo[2]}-Camera-Ratios-HW.pkl') + pickle.dump(HWs, open(pkl_path, 'wb')) + cnt += 1 + + if 'PCDs' in modality: + pkl_path = os.path.join(dst_path, f'{cnt:02d}-sync-{sinfo[2]}-LiDAR-{modality}.pkl') + pickle.dump(data, open(pkl_path, 'wb')) + else: + pkl_path = os.path.join(dst_path, f'{cnt:02d}-sync-{sinfo[2]}-Camera-{modality}.pkl') + pickle.dump(data, open(pkl_path, 'wb')) + cnt += 1 + + +def pretreat(input_path: Path, output_path: Path, img_size: int = 64, workers: int = 4, verbose: bool = False, dataset: str = 'CASIAB') -> None: + """Reads a dataset and saves the data in pickle format. + + Args: + input_path (Path): Dataset root path. + output_path (Path): Output path. + img_size (int, optional): Image resizing size. Defaults to 64. + workers (int, optional): Number of thread workers. Defaults to 4. + verbose (bool, optional): Display debug info. Defaults to False. + """ + img_groups = defaultdict(list) + logging.info(f'Listing {input_path}') + total_files = 0 + for id_ in tqdm(sorted(os.listdir(input_path))): + for type_ in os.listdir(os.path.join(input_path,id_)): + for view_ in os.listdir(os.path.join(input_path,id_,type_)): + for modality in sorted(os.listdir(os.path.join(input_path,id_,type_,view_))): + modality_path = os.path.join(input_path,id_,type_,view_,modality) + file_names = sorted(os.listdir(modality_path)) + file_names = [os.path.join(modality_path, file_name) for file_name in file_names] + img_groups[(id_, type_, view_)].append((modality, file_names)) + total_files += 1 + + logging.info(f'Total files listed: {total_files}') + + progress = tqdm(total=len(img_groups), desc='Pretreating', unit='folder') + + with mp.Pool(workers) as pool: + logging.info(f'Start pretreating {input_path}') + for _ in pool.imap_unordered(partial(imgs2pickle, output_path=output_path, img_size=img_size, verbose=verbose, dataset=dataset), img_groups.items()): + progress.update(1) + logging.info('Done') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='OpenGait dataset pretreatment module.') + parser.add_argument('-i', '--input_path', default='', type=str, help='Root path of raw dataset.') + parser.add_argument('-o', '--output_path', default='', type=str, help='Output path of pickled dataset.') + parser.add_argument('-l', '--log_file', default='./pretreatment.log', type=str, help='Log file path. Default: ./pretreatment.log') + parser.add_argument('-n', '--n_workers', default=4, type=int, help='Number of thread workers. Default: 4') + parser.add_argument('-r', '--img_size', default=64, type=int, help='Image resizing size. Default 64') + parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.') + parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.') + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s') + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + logging.info('Verbose mode is on.') + for k, v in args.__dict__.items(): + logging.debug(f'{k}: {v}') + + pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset) \ No newline at end of file diff --git a/opengait/evaluation/evaluator.py b/opengait/evaluation/evaluator.py index 896e4ae..3546531 100644 --- a/opengait/evaluation/evaluator.py +++ b/opengait/evaluation/evaluator.py @@ -74,46 +74,59 @@ def single_view_gallery_evaluation(feature, label, seq_type, view, dataset, metr 'CASIA-E': {'NM': ['H-scene2-nm-1', 'H-scene2-nm-2', 'L-scene2-nm-1', 'L-scene2-nm-2', 'H-scene3-nm-1', 'H-scene3-nm-2', 'L-scene3-nm-1', 'L-scene3-nm-2', 'H-scene3_s-nm-1', 'H-scene3_s-nm-2', 'L-scene3_s-nm-1', 'L-scene3_s-nm-2', ], 'BG': ['H-scene2-bg-1', 'H-scene2-bg-2', 'L-scene2-bg-1', 'L-scene2-bg-2', 'H-scene3-bg-1', 'H-scene3-bg-2', 'L-scene3-bg-1', 'L-scene3-bg-2', 'H-scene3_s-bg-1', 'H-scene3_s-bg-2', 'L-scene3_s-bg-1', 'L-scene3_s-bg-2'], 'CL': ['H-scene2-cl-1', 'H-scene2-cl-2', 'L-scene2-cl-1', 'L-scene2-cl-2', 'H-scene3-cl-1', 'H-scene3-cl-2', 'L-scene3-cl-1', 'L-scene3-cl-2', 'H-scene3_s-cl-1', 'H-scene3_s-cl-2', 'L-scene3_s-cl-1', 'L-scene3_s-cl-2'] - } - + }, + 'SUSTech1K': {'Normal': ['01-nm'], 'Bag': ['bg'], 'Clothing': ['cl'], 'Carrying':['cr'], 'Umberalla': ['ub'], 'Uniform': ['uf'], 'Occlusion': ['oc'],'Night': ['nt'], 'Overall': ['01','02','03','04']} } gallery_seq_dict = {'CASIA-B': ['nm-01', 'nm-02', 'nm-03', 'nm-04'], 'OUMVLP': ['01'], - 'CASIA-E': ['H-scene1-nm-1', 'H-scene1-nm-2', 'L-scene1-nm-1', 'L-scene1-nm-2']} + 'CASIA-E': ['H-scene1-nm-1', 'H-scene1-nm-2', 'L-scene1-nm-1', 'L-scene1-nm-2'], + 'SUSTech1K': ['00-nm'],} msg_mgr = get_msg_mgr() acc = {} view_list = sorted(np.unique(view)) + num_rank = 1 if dataset == 'CASIA-E': view_list.remove("270") + if dataset == 'SUSTech1K': + num_rank = 5 view_num = len(view_list) - num_rank = 1 + for (type_, probe_seq) in probe_seq_dict[dataset].items(): - acc[type_] = np.zeros((view_num, view_num)) - 1. + acc[type_] = np.zeros((view_num, view_num, num_rank)) - 1. for (v1, probe_view) in enumerate(view_list): pseq_mask = np.isin(seq_type, probe_seq) & np.isin( view, probe_view) + pseq_mask = pseq_mask if 'SUSTech1K' not in dataset else np.any(np.asarray( + [np.char.find(seq_type, probe)>=0 for probe in probe_seq]), axis=0 + ) & np.isin(view, probe_view) # For SUSTech1K only probe_x = feature[pseq_mask, :] probe_y = label[pseq_mask] for (v2, gallery_view) in enumerate(view_list): gseq_mask = np.isin(seq_type, gallery_seq_dict[dataset]) & np.isin( view, [gallery_view]) + gseq_mask = gseq_mask if 'SUSTech1K' not in dataset else np.any(np.asarray( + [np.char.find(seq_type, gallery)>=0 for gallery in gallery_seq_dict[dataset]]), axis=0 + ) & np.isin(view, [gallery_view]) # For SUSTech1K only gallery_y = label[gseq_mask] gallery_x = feature[gseq_mask, :] dist = cuda_dist(probe_x, gallery_x, metric) idx = dist.topk(num_rank, largest=False)[1].cpu().numpy() - acc[type_][v1, v2] = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx], 1) > 0, + acc[type_][v1, v2, :] = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0, 0) * 100 / dist.shape[0], 2) result_dict = {} msg_mgr.log_info('===Rank-1 (Exclude identical-view cases)===') out_str = "" - for type_ in probe_seq_dict[dataset].keys(): - sub_acc = de_diag(acc[type_], each_angle=True) - msg_mgr.log_info(f'{type_}: {sub_acc}') - result_dict[f'scalar/test_accuracy/{type_}'] = np.mean(sub_acc) - out_str += f"{type_}: {np.mean(sub_acc):.2f}%\t" - msg_mgr.log_info(out_str) + for rank in range(num_rank): + out_str = "" + for type_ in probe_seq_dict[dataset].keys(): + sub_acc = de_diag(acc[type_][:,:,rank], each_angle=True) + if rank == 0: + msg_mgr.log_info(f'{type_}@R{rank+1}: {sub_acc}') + result_dict[f'scalar/test_accuracy/{type_}@R{rank+1}'] = np.mean(sub_acc) + out_str += f"{type_}@R{rank+1}: {np.mean(sub_acc):.2f}%\t" + msg_mgr.log_info(out_str) return result_dict @@ -122,7 +135,7 @@ def evaluate_indoor_dataset(data, dataset, metric='euc', cross_view_gallery=Fals label = np.array(label) view = np.array(view) - if dataset not in ('CASIA-B', 'OUMVLP', 'CASIA-E'): + if dataset not in ('CASIA-B', 'OUMVLP', 'CASIA-E', 'SUSTech1K'): raise KeyError("DataSet %s hasn't been supported !" % dataset) if cross_view_gallery: return cross_view_gallery_evaluation( diff --git a/opengait/modeling/models/baseline.py b/opengait/modeling/models/baseline.py index 4e1c72f..ba130d1 100644 --- a/opengait/modeling/models/baseline.py +++ b/opengait/modeling/models/baseline.py @@ -3,6 +3,7 @@ from ..base_model import BaseModel from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks +from einops import rearrange class Baseline(BaseModel): @@ -20,6 +21,8 @@ def forward(self, inputs): sils = ipts[0] if len(sils.size()) == 4: sils = sils.unsqueeze(1) + else: + sils = rearrange(sils, 'n s c h w -> n c s h w') del ipts outs = self.Backbone(sils) # [n, c, s, h, w] @@ -33,17 +36,16 @@ def forward(self, inputs): embed_2, logits = self.BNNecks(embed_1) # [n, c, p] embed = embed_1 - n, _, s, h, w = sils.size() retval = { 'training_feat': { 'triplet': {'embeddings': embed_1, 'labels': labs}, 'softmax': {'logits': logits, 'labels': labs} }, 'visual_summary': { - 'image/sils': sils.view(n*s, 1, h, w) + 'image/sils': rearrange(sils,'n c s h w -> (n s) c h w') }, 'inference_feat': { 'embeddings': embed } } - return retval + return retval \ No newline at end of file