forked from google-research/pix2seq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0da2302
commit 6170222
Showing
12 changed files
with
1,219 additions
and
432 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
# coding=utf-8 | ||
# Copyright 2022 The Pix2Seq Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Config file for object detection fine-tuning and evaluation.""" | ||
|
||
import copy | ||
|
||
from configs import dataset_configs | ||
from configs.config_base import architecture_config_map | ||
from configs.config_base import D | ||
|
||
# pylint: disable=invalid-name,line-too-long,missing-docstring | ||
|
||
|
||
def get_config(config_str=None): | ||
"""config_str is either empty or contains task,architecture variants.""" | ||
|
||
task_variant = 'scene_graph_generation@visual_genome' | ||
encoder_variant = 'vit-b' # Set model architecture. | ||
image_size = (640, 640) # Set image size. | ||
|
||
tasks_and_datasets = [] | ||
for task_and_ds in task_variant.split('+'): | ||
tasks_and_datasets.append(task_and_ds.split('@')) | ||
|
||
task_config_map = { | ||
'scene_graph_generation': D( | ||
name='scene_graph_generation', | ||
vocab_id=10, | ||
image_size=image_size, | ||
quantization_bins=1000, | ||
max_instances_per_image=100, | ||
max_instances_per_image_test=100, | ||
object_order='random', | ||
color_jitter_strength=0., | ||
jitter_scale_min=0.3, | ||
jitter_scale_max=2.0, | ||
# Train on both ground-truth and (augmented) noisy objects. | ||
noise_bbox_weight=1.0, | ||
eos_token_weight=0.1, | ||
# Train on just ground-truth objects (with an ending token). | ||
# noise_bbox_weight=0.0, | ||
# eos_token_weight=0.1, | ||
class_label_corruption='rand_n_fake_cls', | ||
top_k=0, | ||
top_p=0.4, | ||
temperature=1.0, | ||
weight=1.0, | ||
), | ||
} | ||
|
||
task_d_list = [] | ||
dataset_list = [] | ||
for tv, ds_name in tasks_and_datasets: | ||
task_d_list.append(task_config_map[tv]) | ||
dataset_config = copy.deepcopy(dataset_configs.dataset_configs[ds_name]) | ||
dataset_list.append(dataset_config) | ||
|
||
config = D( | ||
dataset=dataset_list[0], | ||
datasets=dataset_list, | ||
|
||
task=task_d_list[0], | ||
tasks=task_d_list, | ||
|
||
model=D( | ||
name='encoder_ar_decoder', | ||
image_size=image_size, | ||
max_seq_len=512, | ||
vocab_size=3000, # Note: should be large enough for 100 + num_classes + quantization_bins + (optional) text | ||
coord_vocab_shift=1000, # Note: make sure num_class <= coord_vocab_shift - 100 | ||
text_vocab_shift=3000, # Note: make sure coord_vocab_shift + quantization_bins <= text_vocab_shift | ||
use_cls_token=False, | ||
shared_decoder_embedding=True, | ||
decoder_output_bias=True, | ||
patch_size=16, | ||
drop_path=0.1, | ||
drop_units=0.1, | ||
drop_att=0.0, | ||
dec_proj_mode='mlp', | ||
pos_encoding='sin_cos', | ||
pos_encoding_dec='learned', | ||
# pretrained_ckpt=get_obj365_pretrained_checkpoint(encoder_variant), | ||
), | ||
|
||
optimization=D( | ||
optimizer='adamw', | ||
learning_rate=3e-5, | ||
end_lr_factor=0.01, | ||
warmup_epochs=2, | ||
warmup_steps=0, # set to >0 to override warmup_epochs. | ||
weight_decay=0.05, | ||
global_clipnorm=-1, | ||
beta1=0.9, | ||
beta2=0.95, | ||
eps=1e-8, | ||
learning_rate_schedule='linear', | ||
learning_rate_scaling='none', | ||
), | ||
|
||
train=D( | ||
batch_size=32, | ||
epochs=40, | ||
steps=0, # set to >0 to override epochs. | ||
checkpoint_epochs=1, | ||
checkpoint_steps=0, # set to >0 to override checkpoint_epochs. | ||
keep_checkpoint_max=5, | ||
loss_type='xent', | ||
), | ||
|
||
eval=D( | ||
tag='eval', | ||
checkpoint_dir='', # checkpoint_dir will be model_dir if not set. | ||
# checkpoint_dir=get_coco_finetuned_checkpoint(encoder_variant, image_size[0]), | ||
batch_size=8, # needs to be divisible by total eval examples. | ||
steps=0, # 0 means eval over full validation set. | ||
), | ||
) | ||
|
||
# Update model with architecture variant. | ||
for key, value in architecture_config_map[encoder_variant].items(): | ||
config.model[key] = value | ||
|
||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import code | ||
import collections | ||
import json | ||
import os | ||
|
||
from absl import app | ||
from absl import flags | ||
from absl import logging | ||
import numpy as np | ||
from tqdm import tqdm | ||
import vocab | ||
from data.scripts import tfrecord_lib | ||
import tensorflow as tf | ||
import h5py | ||
|
||
flags.DEFINE_string('vg_image_dir', '', 'Directory containing images.') | ||
flags.DEFINE_string('vg_ann_file', '', 'Instance annotation file.') | ||
flags.DEFINE_string('image_data_file', '', 'Image data file.') | ||
flags.DEFINE_string('vg_ann_label_file', '', 'Json containing label information') | ||
flags.DEFINE_bool('train', False, 'Is train split.') | ||
flags.DEFINE_string('output_dir', '', 'Output directory') | ||
flags.DEFINE_integer('num_shards', 32, 'Number of shards for output file.') | ||
FLAGS = flags.FLAGS | ||
|
||
def create_anno_iter(img_data, label_data, ann): | ||
for j, img in enumerate(tqdm(img_data)): | ||
with open(os.path.join(FLAGS.vg_image_dir, str(img['image_id'])) + '.jpg', 'rb') as fid: | ||
encoded_jpg = fid.read() | ||
feature_dict = tfrecord_lib.image_info_to_feature_dict( | ||
img['height'], img['width'], f"{img['image_id']}.JPG", img['image_id'], encoded_jpg, 'jpg' | ||
) | ||
|
||
# if img['image_id'] > 108072: | ||
# yield None, True | ||
# continue | ||
|
||
split = 0 if FLAGS.train else 2 | ||
if ann['split'][j] != split: | ||
yield None, True | ||
continue | ||
|
||
first_rel = ann['img_to_first_rel'][j] | ||
last_rel = ann['img_to_last_rel'][j] | ||
img_rels = ann['relationships'][first_rel : last_rel+1] | ||
if len(img_rels) == 0: | ||
yield None, True | ||
continue | ||
|
||
box1_ids = img_rels[:, 0] | ||
box2_ids = img_rels[:, 1] | ||
pred_ids = ann['predicates'][first_rel : last_rel+1] | ||
box1 = [ann['boxes_1024'][i] for i in box1_ids] | ||
box2 = [ann['boxes_1024'][i] for i in box2_ids] | ||
pred_label = [label_data['idx_to_predicate'][str(i[0])].encode('utf-8') for i in pred_ids] | ||
box1_label = [label_data['idx_to_label'][str(ann['labels'][i][0])].encode('utf-8') for i in box1_ids] | ||
box2_label = [label_data['idx_to_label'][str(ann['labels'][i][0])].encode('utf-8') for i in box2_ids] | ||
|
||
feature_dict.update({ | ||
'box1': tfrecord_lib.convert_to_feature(box1, 'int64_list'), | ||
'pred': tfrecord_lib.convert_to_feature(pred_ids, 'int64_list'), | ||
'box2': tfrecord_lib.convert_to_feature(box2, 'int64_list'), | ||
'pred_label': tfrecord_lib.convert_to_feature(pred_label, 'bytes_list'), | ||
'box1_label': tfrecord_lib.convert_to_feature(box1_label, 'bytes_list'), | ||
'box2_label': tfrecord_lib.convert_to_feature(box2_label, 'bytes_list') | ||
}) | ||
|
||
yield feature_dict, False | ||
|
||
|
||
def create_example(feature_dict, skipped): | ||
example = tf.train.Example(features=tf.train.Features(feature=feature_dict)) | ||
return example, 0 if not skipped else 1 | ||
|
||
def main(_): | ||
logging.info('Building instance index.') | ||
|
||
directory = os.path.dirname(FLAGS.output_dir) | ||
if not os.path.isdir(directory): | ||
os.makedirs(directory, exist_ok=True) | ||
|
||
ann = h5py.File(FLAGS.vg_ann_file, 'r') | ||
with open(FLAGS.image_data_file, 'r') as img_data_f, open(FLAGS.vg_ann_label_file, 'r') as label_f: | ||
img_data = json.load(img_data_f) | ||
label_data = json.load(label_f) | ||
print(ann.keys()) | ||
print(img_data[0].keys()) | ||
print(label_data.keys()) | ||
|
||
anno_iter = create_anno_iter(img_data, label_data, ann) | ||
|
||
tfrecord_lib.write_tf_record_dataset( | ||
output_path=FLAGS.output_dir, | ||
annotation_iterator=anno_iter, | ||
process_func=create_example, | ||
num_shards=FLAGS.num_shards, | ||
multiple_processes=8) | ||
|
||
def run_main(): | ||
app.run(main) | ||
|
||
if __name__ == '__main__': | ||
run_main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import numpy as np | ||
import utils | ||
import vocab | ||
from data import dataset as dataset_lib | ||
from data import decode_utils | ||
import tensorflow as tf | ||
|
||
@dataset_lib.DatasetRegistry.register('visual_genome') | ||
class VisualGenomeTFRecordDataset(dataset_lib.TFRecordDataset): | ||
def get_feature_map(self): | ||
image_feature_map = decode_utils.get_feature_map_for_image() | ||
vg_feature_map = { | ||
'box1': tf.io.VarLenFeature(tf.float32), | ||
'pred': tf.io.VarLenFeature(tf.int64), | ||
'box2': tf.io.VarLenFeature(tf.float32), | ||
'pred_label': tf.io.VarLenFeature(tf.int64), | ||
'box1_label': tf.io.VarLenFeature(tf.int64), | ||
'box2_label': tf.io.VarLenFeature(tf.int64) | ||
} | ||
|
||
def filter_example(self, example, training): | ||
if training: | ||
return tf.shape(example['box1'])[0] > 0 | ||
else: | ||
return True | ||
|
||
def extract(self, example, training): | ||
features = { | ||
'image': decode_utils.decode_image(example), | ||
'image/id': tf.strings.to_number(example['image/source_id'], tf.int64), | ||
} | ||
|
||
scale = 1. / utils.tf_float32(tf.shape(features['image'])[:2]) | ||
box1 = utils.scale_points(example['box1']) | ||
box2 = utils.scale_points(example['box2']) | ||
|
||
labels = {k: v for k, v in example.items() if 'image' not in k} | ||
|
||
return features, labels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
export PYTHONPATH=$PYTHONPATH:~/Github/pix2seq | ||
split=train | ||
|
||
python ../data/scripts/create_vg_tfrecord.py \ | ||
--vg_image_dir /data/hulab/zcai75/visual_genome/VG_100K \ | ||
--vg_ann_file /data/hulab/zcai75/visual_genome/VG-SGG-with-attri.h5 \ | ||
--image_data_file /data/hulab/zcai75/visual_genome/image_data.json \ | ||
--vg_ann_label_file /data/hulab/zcai75/visual_genome/vg_motif_anno/VG-SGG-dicts-with-attri.json \ | ||
--train False \ | ||
--output_dir /data/hulab/zcai75/pix2seq/vg/${split} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
python ../run.py \ | ||
--model_dir /data/hulab/zcai75/checkpoints/pix2seq/coco_det_finetune \ | ||
--mode eval | ||
|
||
config=configs/config_multi_task.py:object_detection@coco/2017_object_detection,vit-b | ||
model_dir=/tmp/pix2seq_eval_det | ||
config=../configs/config_multi_task.py:object_detection@coco/2017_object_detection,vit-b | ||
model_dir=/data/hulab/zcai75/checkpoints/pix2seq/multi_task/vit_b_640x640 | ||
# Path to save the detected boxes for evaluating other tasks. | ||
boxes_json_path=$model_dir/boxes.json | ||
python3 run.py --config=$config --model_dir=$model_dir --mode=eval --config.task.eval_outputs_json_path=$boxes_json_path | ||
python ../run.py \ | ||
--config=$config \ | ||
--model_dir=$model_dir \ | ||
--mode=eval \ | ||
--config.task.eval_outputs_json_path=$boxes_json_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
config=../configs/config_scene_graph.py | ||
model_dir=/data/hulab/zcai75/checkpoints/pix2seq/scene_graph | ||
|
||
python ../run.py \ | ||
--mode=train \ | ||
--model_dir=$model_dir \ | ||
--config=$config \ | ||
--config.train.batch_size=32 \ | ||
--config.train.epochs=20 \ | ||
--config.optimization.learning_rate=3e-5 |
Oops, something went wrong.