-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
aditya sanghi
committed
May 2, 2022
1 parent
4e19280
commit 60e62bd
Showing
32 changed files
with
13,208 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
exps/* | ||
.ipynb_checkpoints/ | ||
nohup.out | ||
*pyc | ||
__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,96 @@ | ||
# Clip-Forge: Code Coming Soon | ||
|
||
|
||
## CLIP-Forge: Towards Zero-Shot Text-to-Shape Generation (CVPR 2022) | ||
|
||
![CLIP](/images/main.png) | ||
|
||
Generating shapes using natural language can enable new ways of imagining and creating the things around us. While significant recent progress has been made in text-to-image generation, text-to-shape generation remains a challenging problem due to the unavailability of paired text and shape data at a large scale. We present a simple yet effective method for zero-shot text-to-shape generation that circumvents such data scarcity. Our proposed method, named CLIP-Forge, is based on a two-stage training process, which only depends on an unlabelled shape dataset and a pre-trained image-text network such as CLIP. Our method has the benefits of avoiding expensive inference time optimization, as well as the ability to generate multiple shapes for a given text. We not only demonstrate promising zero-shot generalization of the CLIP-Forge model qualitatively and quantitatively, but also provide extensive comparative evaluations to better understand its behavior. | ||
|
||
Paper Link: [[Paper]](https://arxiv.org/pdf/2110.02624.pdf) | ||
|
||
If you find our code or paper useful, you can cite at: | ||
|
||
@article{sanghi2021clip, | ||
title={Clip-forge: Towards zero-shot text-to-shape generation}, | ||
author={Sanghi, Aditya and Chu, Hang and Lambourne, Joseph G and Wang, Ye and Cheng, Chin-Yi and Fumero, Marco}, | ||
journal={arXiv preprint arXiv:2110.02624}, | ||
year={2021} | ||
} | ||
|
||
## Installation | ||
|
||
First create an anaconda environment called `clip_forge` using | ||
``` | ||
conda env create -f environment.yaml | ||
conda activate clip_forge | ||
``` | ||
|
||
Then, [install PyTorch 1.7.1](https://pytorch.org/get-started/locally/) (or later) and torchvision. Please change the CUDA version based on your requirements. | ||
|
||
```bash | ||
conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0 | ||
pip install git+https://github.com/openai/CLIP.git | ||
pip install sklearn | ||
``` | ||
|
||
Choose a folder to download the data, classifier and model: | ||
``` | ||
wget https://clip-forge-pretrained.s3.us-west-2.amazonaws.com/exps.zip | ||
unzip exps.zip | ||
``` | ||
|
||
## Training | ||
|
||
For training, first you need to setup the dataset. We use the data prepared from occupancy networks (https://github.com/autonomousvision/occupancy_networks). | ||
``` | ||
## Stage 1 | ||
python train_autoencoder.py --dataset_path /path/to/dataset/ | ||
## Stage 2 | ||
python train_post_clip.py --dataset_path /path/to/dataset/ --checkpoint best_iou --num_views 1 --text_query "a chair" "a limo" "a jet plane" | ||
``` | ||
|
||
## Inference | ||
|
||
To generate shape renderings based on text query: | ||
``` | ||
python test_post_clip.py --checkpoint_dir_base "./exps/models/autoencoder" --checkpoint best_iou --checkpoint_nf best --experiment_mode save_voxel_on_query --checkpoint_dir_prior "./exps/models/prior" --text_query "a truck" "a round chair" "a limo" --threshold 0.1 --output_dir "./exps/hello_world" | ||
``` | ||
|
||
The image rendering of the shapes will be present in output_dir. | ||
|
||
To calculate Accuracy, please make sure you have the classifier model. | ||
``` | ||
python test_post_clip.py --checkpoint_dir_base "./exps/models/autoencoder/" --checkpoint best_iou --checkpoint_nf best --experiment_mode cls_cal_category --checkpoint_dir_prior "./exps/models/prior/" --threshold 0.05 --classifier_checkpoint "./exps/classifier/checkpoints/best.pt" | ||
``` | ||
To calculate FID, please make sure you have the classifier model and data loaded. | ||
``` | ||
python test_post_clip.py --checkpoint_dir_base "./exps/models/autoencoder/" --checkpoint best_iou --checkpoint_nf best --experiment_mode fid_cal --dataset_path /path/to/dataset/ --checkpoint_dir_prior "./exps/models/prior/" --threshold 0.05 --classifier_checkpoint "./exps/classifier/checkpoints/best.pt" | ||
``` | ||
|
||
## Inference Tips | ||
|
||
To get the optimal results use different threshold values as controlled by the argument `threshold` as shown in Figure 10 in the paper. We also recommend using world synonyms and text augmentation for best results. As the network is trained on Shapenet, we would recommend limiting the queries across the 13 categories present in ShapeNet. Note, we believe this method scales with data, but unfortunately public 3D data is limited. | ||
|
||
|
||
|
||
## Releasing Soon | ||
|
||
- [ ] Pointcloud code | ||
- [ ] Pretrained models for pointcloud experiments | ||
|
||
|
||
|
||
## Other interesting ideas | ||
|
||
- ClipMatrix (https://arxiv.org/pdf/2109.12922.pdf) | ||
- Text2Mesh (https://threedle.github.io/text2mesh/) | ||
- DreamFields (https://arxiv.org/pdf/2112.01455.pdf) | ||
- https://arxiv.org/pdf/2203.13333.pdf | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.utils.data | ||
from tqdm import tqdm | ||
|
||
|
||
class VoxelEncoderBN(nn.Module): | ||
def __init__(self, dim=3, c_dim=128): | ||
super().__init__() | ||
self.actvn = F.relu | ||
self.conv_in = nn.Conv3d(1, 32, 3, padding=1) | ||
self.conv_0 = nn.Conv3d(32, 64, 3, padding=1, stride=2) | ||
self.conv_1 = nn.Conv3d(64, 128, 3, padding=1, stride=2) | ||
self.conv_2 = nn.Conv3d(128, 256, 3, padding=1, stride=2) | ||
self.conv_3 = nn.Conv3d(256, 512, 3, padding=1, stride=2) | ||
self.fc = nn.Linear(512 * 2 * 2 * 2, c_dim) | ||
|
||
self.conv0_bn = nn.BatchNorm3d(32) | ||
self.conv1_bn = nn.BatchNorm3d(64) | ||
self.conv2_bn = nn.BatchNorm3d(128) | ||
self.conv3_bn = nn.BatchNorm3d(256) | ||
|
||
def forward(self, x): | ||
batch_size = x.size(0) | ||
x = x.unsqueeze(1) | ||
net = self.conv_in(x) | ||
net = self.conv_0(self.actvn(self.conv0_bn(net))) | ||
net = self.conv_1(self.actvn(self.conv1_bn(net))) | ||
net = self.conv_2(self.actvn(self.conv2_bn(net))) | ||
net = self.conv_3(self.actvn(self.conv3_bn(net))) | ||
hidden = net.view(batch_size, 512 * 2 * 2 * 2) | ||
x = self.fc(self.actvn(hidden)) | ||
return x | ||
|
||
|
||
|
||
class classifier_32(nn.Module): | ||
def __init__(self, encoder_type, num_classes, dropout=0.5): | ||
super(classifier_32, self).__init__() | ||
|
||
self.encoder_head = VoxelEncoderBN(c_dim=512) | ||
|
||
self.projection = nn.Sequential( | ||
nn.BatchNorm1d(512), | ||
nn.ReLU(), | ||
nn.Dropout(p=dropout), | ||
nn.Linear(512, 512), | ||
nn.BatchNorm1d(512), | ||
nn.ReLU(), | ||
nn.Dropout(p=dropout), | ||
nn.Linear(512, num_classes) | ||
) | ||
|
||
|
||
def forward(self, x): | ||
z = self.encoder_head(x) | ||
x = self.projection(z) | ||
return x, z | ||
|
||
def get_activations(datapoints, model, args): | ||
model.eval() | ||
|
||
dataset = torch.utils.data.TensorDataset(torch.from_numpy(datapoints).squeeze()) | ||
loader = torch.utils.data.DataLoader(dataset, batch_size=32) | ||
|
||
all_activation = [] | ||
all_labels = [] | ||
softmax = nn.Softmax(dim=-1).to(args.device) | ||
|
||
with torch.no_grad(): | ||
for data in tqdm(loader): | ||
try: | ||
data_mod = data[0].type(torch.FloatTensor).to(args.device) | ||
out, embeddings = model(data_mod) | ||
pred_label = softmax(out) | ||
_, pred_label = torch.max(pred_label, dim=-1) | ||
except: | ||
print("Some Error happened") | ||
print(data[0]) | ||
raise "err" | ||
continue | ||
all_activation.append(embeddings.detach().cpu().numpy()) | ||
all_labels.append(pred_label.detach().cpu().numpy()) | ||
|
||
all_activation = np.concatenate(all_activation) | ||
all_labels = np.concatenate(all_labels) | ||
return all_activation, all_labels | ||
|
Oops, something went wrong.