Skip to content

Commit

Permalink
Add inferencing codes.
Browse files Browse the repository at this point in the history
  • Loading branch information
xlwangDev committed Oct 27, 2023
1 parent ea173cb commit 415860f
Show file tree
Hide file tree
Showing 20 changed files with 2,445 additions and 44 deletions.
62 changes: 57 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ We introduce a novel approach to fine-grained cross-view geo-localization. Our m

## 🔥 News

- [2023-10-01] We release the code for implementing the spherical transform. For usage instructions, please refer to [Spherical_transform.ipynb](https://github.com/xlwangDev/HC-Net/blob/main/Spherical_transform.ipynb).
- [2023-10-27] We release the inferencing codes with [checkpoints](https://drive.google.com/drive/folders/1EL6RISnR5lOgz0WtWUYtFhKGcU_nROX9?usp=sharing) as well as the [demo script](https://github.com/xlwangDev/HC-Net/blob/main/demo_gradio.py). You can test HC-Net with your own machines.
- [2023-10-01] We release the code for implementing the spherical transform. For usage instructions, please refer to [Spherical_transform.ipynb](https://github.com/xlwangDev/HC-Net/blob/main/demo/Spherical_transform.ipynb).
- [2023-09-21] HC-Net has been accepted by NeurIPS 2023! 🔥🔥🔥
- [2023-08-30] We release the [paper](https://arxiv.org/abs/2308.16906) of HC-Net and an online gradio [demo](http://101.230.144.196:7860).

Expand All @@ -35,11 +36,62 @@ You can test our model using the data from the **'same_area_balanced_test.txt'**

<img src="./figure/Demo.png" alt="image-20230831204530724" style="zoom: 80%;" />

## 📦 Inferencing

### Installation

We test our codes under the following environment:

- Ubuntu 18.04
- CUDA 12.0
- Python 3.8.16
- PyTorch 1.13.0

To get started, follow these steps:

1. Clone this repository.

```bash
git clone https://github.com/xlwangDev/HC-Net.git
cd HC-Net
```

2. Install the required packages.

```bash
conda create -n hcnet python=3.8 -y
conda activate hcnet
pip install -r requirements.txt
```

### Evaluation

To evaluate the HC-Net model, follow these steps:

1. Download the [VIGOR](https://github.com/Jeff-Zilence/VIGOR) dataset and set its path to '/home/< usr >/Data/VIGOR'.
2. Download the [pretrained models](https://drive.google.com/drive/folders/1EL6RISnR5lOgz0WtWUYtFhKGcU_nROX9?usp=sharing) and place them in the './checkpoints/VIGOR '.
3. Run the following command:

````bash
chmod +x val.sh
# Usage: val.sh [same|cross]
# For same-area in VIGOR
./val.sh same 0
# For cross-area in VIGOR
./val.sh cross 0
````

4. You can also observe the visualization results of the model through a demo based on gradio. Use the following command to start the demo, and open the local URL: [http://0.0.0.0:7860](http://0.0.0.0:7860/).

```bash
python demo_gradio.py
```

## 🏷️ Label Correction for [VIGOR](https://github.com/Jeff-Zilence/VIGOR) Dataset

<img src="./figure/VIGOR_label.png" alt="image-20230831204530724" style="zoom: 60%;" />

We propose the use of [Mercator projection](https://en.wikipedia.org/wiki/Web_Mercator_projection#References) to directly compute the pixel coordinates of ground images on specified satellite images using the GPS information provided in the dataset.
We propose the use of [Mercator projection](https://en.wikipedia.org/wiki/Web_Mercator_projection#References) to directly compute the pixel coordinates of ground images on specified satellite images using the GPS information provided in the dataset. You can find the specific code at [Mercator.py](https://github.com/xlwangDev/HC-Net/blob/main/models/utils/Mercator.py).

To use our corrected label, you can add the following content to the `__getitem__` method of the `VIGORDataset` class in `datasets.py` file in the [CCVPE](https://github.com/tudelft-iv/CCVPE) project:

Expand Down Expand Up @@ -84,9 +136,9 @@ Our projection process is implemented entirely in PyTorch, which means **our pro

## 📝 TODO List

- [ ] Add data preparation codes.
- [ ] Add inferencing and serving codes with checkpoints.
- [ ] Add evaluation codes.
- [x] Add data preparation codes.
- [x] Add inferencing and serving codes with checkpoints.
- [x] Add evaluation codes.
- [ ] Add training codes.

## 🔗 Citation
Expand Down
144 changes: 144 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, random_split
import torch.utils.data as data
import cv2
from models.utils.utils import get_BEV_tensor, get_BEV_projection
from models.utils.augment import train_transform

class VIGOR(Dataset):
def __init__(self, args, split='train', root = 'Data/VIGOR/', same_area=True):
usr = os.getcwd().split('/')[2]
root = os.path.join('/home',usr,root)
same_area = not args.cross_area

self.image_size = args.image_size
label_root = 'splits' # 'splits' splits__corrected
if same_area:
self.train_city_list =['NewYork', 'Seattle', 'SanFrancisco', 'Chicago'] # ['NewYork', 'Seattle', 'SanFrancisco', 'Chicago'] ['Seattle']
self.test_city_list = ['NewYork', 'Seattle', 'SanFrancisco', 'Chicago']
else:
self.train_city_list = ['NewYork', 'Seattle']
self.test_city_list = ['SanFrancisco', 'Chicago']

pano_list = []
pano_label = []
sat_delta = []

if split == 'train':
for city in self.train_city_list:
label_fname = os.path.join(root, label_root, city, 'same_area_balanced_train.txt'
if same_area else 'pano_label_balanced.txt')
with open(label_fname, 'r') as file:
for line in file.readlines():
data = np.array(line.split(' '))
label = []
for i in [1, 4, 7, 10]:
label.append(os.path.join(root, city, 'satellite', data[i]))
delta = np.array([data[2:4], data[5:7], data[8:10], data[11:13]]).astype(float)
pano_list.append(os.path.join(root, city, 'panorama', data[0]))
pano_label.append(label)
sat_delta.append(delta)
else:
for city in self.test_city_list:
label_fname = os.path.join(root, label_root, city, 'same_area_balanced_test.txt'
if same_area else 'pano_label_balanced.txt')
with open(label_fname, 'r') as file:
for line in file.readlines():
data = np.array(line.split(' '))
label = []
for i in [1, 4, 7, 10]:
label.append(os.path.join(root, city, 'satellite', data[i]))
delta = np.array([data[2:4], data[5:7], data[8:10], data[11:13]]).astype(float)
pano_list.append(os.path.join(root, city, 'panorama', data[0]))
pano_label.append(label)
sat_delta.append(delta)

self.pano_list = pano_list
self.pano_label = pano_label
self.sat_delta = sat_delta

self.split = split
self.transform = train_transform(0) if 'augment' in args and args.augment else None
self.center = [(self.image_size/2,self.image_size/2), (self.image_size/2,self.image_size/2-self.image_size/8)] \
if 'orien' in args and args.orien else [(self.image_size//2.0, self.image_size//2.0),]
pona_path = self.pano_list[0]
pona = cv2.imread(pona_path, 1)[:,:,::-1] # BGR ==> RGB
self.out = get_BEV_projection(pona,self.image_size,self.image_size,Fov = 85*2, dty = 0, dy = 0)
self.ori_noise = args.ori_noise
# self.out = None

def __len__(self):
return len(self.pano_list)

def __getitem__(self, idx):
patch_size = self.image_size
pona_path = self.pano_list[idx]
select_ = 0 #random.randint(0,3)
sat_path = self.pano_label[idx][select_]
pano_gps = np.array(pona_path[:-5].split(',')[-2:]).astype(float)
sat_gps = np.array(sat_path[:-4].split('_')[-2:]).astype(float)

# =================== read satellite map ===================================
sat = cv2.imread(sat_path, 1)[:,:,::-1]
sat = cv2.resize(sat, (patch_size, patch_size))

# =================== read ground map ===================================
pona = cv2.imread(pona_path, 1)[:,:,::-1]

rotation_range = self.ori_noise
random_ori = np.random.uniform(-1, 1) * rotation_range/360
ori_angle = random_ori * 360
pona = np.roll(pona,int(random_ori*pona.shape[1]), axis=1)


if self.split == 'train' and self.transform is not None:
pona_bev = get_BEV_tensor(pona,patch_size,patch_size,dty = 0, dy = 0, dataset=False, out = self.out).numpy().astype(np.uint8) # dataset=False get numpy HWC, else get tensor CHW
transformed = self.transform(image=pona_bev, keypoints=self.center)
pona_bev = transformed["image"]
try:
transformed_center = [transformed['keypoints'][0],transformed['keypoints'][1] ] \
if len(self.center) ==2 else transformed['keypoints']
except IndexError :
# if transformed_center is invisible, skip
print('\033[1;93m'+f"Skipping data at index {idx} due to invisible"+'. \033[0m')
return None
img1 = torch.from_numpy(pona_bev).float().permute(2, 0, 1)
else:
pona_bev = get_BEV_tensor(pona,500,500,dty = 0, dy = 0, out = self.out).numpy().astype(np.uint8)
pona_bev = cv2.resize(pona_bev, (patch_size, patch_size))
img1 = torch.from_numpy(pona_bev).float().permute(2, 0, 1)

img2 = torch.from_numpy(sat).float().permute(2, 0, 1)
pano_gps = torch.from_numpy(pano_gps) # [batch, 2]
sat_gps = torch.from_numpy(sat_gps)

sat_delta_init = torch.from_numpy(self.sat_delta[idx][select_]*patch_size/640.0).float()
sat_delta = torch.zeros(2)
sat_delta[1] = sat_delta_init[0] + patch_size/2.0
sat_delta[0] = patch_size/2.0 - sat_delta_init[1] # 从 [y, x] To [x, y], so fit the coord of model out
if self.split == 'train':
transformed_center = torch.tensor(transformed_center).float() if self.transform is not None else torch.tensor(self.center).float()
transformed_center = transformed_center.permute(1, 0)
return img1, img2, pano_gps, sat_gps, transformed_center, sat_delta, torch.tensor(ori_angle)
else:
return img1, img2, pano_gps, sat_gps, torch.tensor(ori_angle), sat_delta

def fetch_dataloader(args, split='train'):

train_dataset = VIGOR(args, split)
print('Training with %d image pairs' % len(train_dataset))

if split == 'train':
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
print("using {} images for training, {} images for validation.".format(train_size, val_size))
return train_dataset, val_dataset
else:
nw = min([os.cpu_count(), args.batch_size if args.batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
test_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=True, shuffle=True, num_workers=nw, drop_last=False)
return test_loader
25 changes: 15 additions & 10 deletions Spherical_transform.ipynb → demo/Spherical_transform.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"current_directory = os.getcwd()\n",
"if current_directory.split('/')[-1] == 'demo':\n",
" os.chdir('..')\n",
" \n",
"import numpy as np\n",
"import cv2\n",
"import matplotlib.pyplot as plt\n",
Expand All @@ -15,16 +20,16 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7ff8df1a08e0>"
"<matplotlib.image.AxesImage at 0x7f969b952ee0>"
]
},
"execution_count": 6,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
},
Expand Down Expand Up @@ -55,16 +60,16 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7ff8df9ad8e0>"
"<matplotlib.image.AxesImage at 0x7f969ad5aa60>"
]
},
"execution_count": 8,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
Expand All @@ -87,16 +92,16 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7ff8df89f5e0>"
"<matplotlib.image.AxesImage at 0x7f969ada0c40>"
]
},
"execution_count": 11,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
Expand Down
Loading

0 comments on commit 415860f

Please sign in to comment.