Skip to content

Commit

Permalink
crete scene graph task
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCai1206 committed Feb 5, 2023
1 parent 0da2302 commit 6170222
Show file tree
Hide file tree
Showing 12 changed files with 1,219 additions and 432 deletions.
136 changes: 136 additions & 0 deletions configs/config_scene_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# coding=utf-8
# Copyright 2022 The Pix2Seq Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Config file for object detection fine-tuning and evaluation."""

import copy

from configs import dataset_configs
from configs.config_base import architecture_config_map
from configs.config_base import D

# pylint: disable=invalid-name,line-too-long,missing-docstring


def get_config(config_str=None):
"""config_str is either empty or contains task,architecture variants."""

task_variant = 'scene_graph_generation@visual_genome'
encoder_variant = 'vit-b' # Set model architecture.
image_size = (640, 640) # Set image size.

tasks_and_datasets = []
for task_and_ds in task_variant.split('+'):
tasks_and_datasets.append(task_and_ds.split('@'))

task_config_map = {
'scene_graph_generation': D(
name='scene_graph_generation',
vocab_id=10,
image_size=image_size,
quantization_bins=1000,
max_instances_per_image=100,
max_instances_per_image_test=100,
object_order='random',
color_jitter_strength=0.,
jitter_scale_min=0.3,
jitter_scale_max=2.0,
# Train on both ground-truth and (augmented) noisy objects.
noise_bbox_weight=1.0,
eos_token_weight=0.1,
# Train on just ground-truth objects (with an ending token).
# noise_bbox_weight=0.0,
# eos_token_weight=0.1,
class_label_corruption='rand_n_fake_cls',
top_k=0,
top_p=0.4,
temperature=1.0,
weight=1.0,
),
}

task_d_list = []
dataset_list = []
for tv, ds_name in tasks_and_datasets:
task_d_list.append(task_config_map[tv])
dataset_config = copy.deepcopy(dataset_configs.dataset_configs[ds_name])
dataset_list.append(dataset_config)

config = D(
dataset=dataset_list[0],
datasets=dataset_list,

task=task_d_list[0],
tasks=task_d_list,

model=D(
name='encoder_ar_decoder',
image_size=image_size,
max_seq_len=512,
vocab_size=3000, # Note: should be large enough for 100 + num_classes + quantization_bins + (optional) text
coord_vocab_shift=1000, # Note: make sure num_class <= coord_vocab_shift - 100
text_vocab_shift=3000, # Note: make sure coord_vocab_shift + quantization_bins <= text_vocab_shift
use_cls_token=False,
shared_decoder_embedding=True,
decoder_output_bias=True,
patch_size=16,
drop_path=0.1,
drop_units=0.1,
drop_att=0.0,
dec_proj_mode='mlp',
pos_encoding='sin_cos',
pos_encoding_dec='learned',
# pretrained_ckpt=get_obj365_pretrained_checkpoint(encoder_variant),
),

optimization=D(
optimizer='adamw',
learning_rate=3e-5,
end_lr_factor=0.01,
warmup_epochs=2,
warmup_steps=0, # set to >0 to override warmup_epochs.
weight_decay=0.05,
global_clipnorm=-1,
beta1=0.9,
beta2=0.95,
eps=1e-8,
learning_rate_schedule='linear',
learning_rate_scaling='none',
),

train=D(
batch_size=32,
epochs=40,
steps=0, # set to >0 to override epochs.
checkpoint_epochs=1,
checkpoint_steps=0, # set to >0 to override checkpoint_epochs.
keep_checkpoint_max=5,
loss_type='xent',
),

eval=D(
tag='eval',
checkpoint_dir='', # checkpoint_dir will be model_dir if not set.
# checkpoint_dir=get_coco_finetuned_checkpoint(encoder_variant, image_size[0]),
batch_size=8, # needs to be divisible by total eval examples.
steps=0, # 0 means eval over full validation set.
),
)

# Update model with architecture variant.
for key, value in architecture_config_map[encoder_variant].items():
config.model[key] = value

return config
19 changes: 19 additions & 0 deletions configs/dataset_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
COCO_TRAIN_TFRECORD_PATTERN = 'gs://pix2seq/multi_task/data/coco/tfrecord/train*'
COCO_VAL_TFRECORD_PATTERN = '/data/hulab/zcai75/pix2seq/coco/val*'

VG_TRAIN_TFRECORD_PATTERN = '/data/hulab/zcai75/pix2seq/vg/train*'
VG_VAL_TFRECORD_PATTERN = '/data/hulab/zcai75/pix2seq/vg/val*'

# Download from gs://pix2seq/multi_task/data/coco/json
COCO_ANNOTATIONS_DIR = '/data/hulab/zcai75/coco/annotations/'

Expand All @@ -47,6 +50,16 @@
**_shared_dataset_config
)

vg_dataset_config = D(
train_file_patter=VG_TRAIN_TFRECORD_PATTERN,
val_file_pattern=VG_VAL_TFRECORD_PATTERN,
train_num_examples=118287,
eval_num_examples=5000,
train_split='train',
eval_split='validation',
**_shared_dataset_config
)

dataset_configs = {
'coco/2017_object_detection':
D(
Expand Down Expand Up @@ -83,4 +96,10 @@
train_filename_for_metrics='captions_train2017_eval_compatible.json',
val_filename_for_metrics='captions_val2017_eval_compatible.json',
**_shared_coco_dataset_config),

'visual_genome':
D(
name='visual_genome',
**vg_dataset_config
),
}
1 change: 1 addition & 0 deletions data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
"""All registered datasets."""

from data import coco # pylint: disable=unused-import
from data import visual_genome
102 changes: 102 additions & 0 deletions data/scripts/create_vg_tfrecord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import code
import collections
import json
import os

from absl import app
from absl import flags
from absl import logging
import numpy as np
from tqdm import tqdm
import vocab
from data.scripts import tfrecord_lib
import tensorflow as tf
import h5py

flags.DEFINE_string('vg_image_dir', '', 'Directory containing images.')
flags.DEFINE_string('vg_ann_file', '', 'Instance annotation file.')
flags.DEFINE_string('image_data_file', '', 'Image data file.')
flags.DEFINE_string('vg_ann_label_file', '', 'Json containing label information')
flags.DEFINE_bool('train', False, 'Is train split.')
flags.DEFINE_string('output_dir', '', 'Output directory')
flags.DEFINE_integer('num_shards', 32, 'Number of shards for output file.')
FLAGS = flags.FLAGS

def create_anno_iter(img_data, label_data, ann):
for j, img in enumerate(tqdm(img_data)):
with open(os.path.join(FLAGS.vg_image_dir, str(img['image_id'])) + '.jpg', 'rb') as fid:
encoded_jpg = fid.read()
feature_dict = tfrecord_lib.image_info_to_feature_dict(
img['height'], img['width'], f"{img['image_id']}.JPG", img['image_id'], encoded_jpg, 'jpg'
)

# if img['image_id'] > 108072:
# yield None, True
# continue

split = 0 if FLAGS.train else 2
if ann['split'][j] != split:
yield None, True
continue

first_rel = ann['img_to_first_rel'][j]
last_rel = ann['img_to_last_rel'][j]
img_rels = ann['relationships'][first_rel : last_rel+1]
if len(img_rels) == 0:
yield None, True
continue

box1_ids = img_rels[:, 0]
box2_ids = img_rels[:, 1]
pred_ids = ann['predicates'][first_rel : last_rel+1]
box1 = [ann['boxes_1024'][i] for i in box1_ids]
box2 = [ann['boxes_1024'][i] for i in box2_ids]
pred_label = [label_data['idx_to_predicate'][str(i[0])].encode('utf-8') for i in pred_ids]
box1_label = [label_data['idx_to_label'][str(ann['labels'][i][0])].encode('utf-8') for i in box1_ids]
box2_label = [label_data['idx_to_label'][str(ann['labels'][i][0])].encode('utf-8') for i in box2_ids]

feature_dict.update({
'box1': tfrecord_lib.convert_to_feature(box1, 'int64_list'),
'pred': tfrecord_lib.convert_to_feature(pred_ids, 'int64_list'),
'box2': tfrecord_lib.convert_to_feature(box2, 'int64_list'),
'pred_label': tfrecord_lib.convert_to_feature(pred_label, 'bytes_list'),
'box1_label': tfrecord_lib.convert_to_feature(box1_label, 'bytes_list'),
'box2_label': tfrecord_lib.convert_to_feature(box2_label, 'bytes_list')
})

yield feature_dict, False


def create_example(feature_dict, skipped):
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
return example, 0 if not skipped else 1

def main(_):
logging.info('Building instance index.')

directory = os.path.dirname(FLAGS.output_dir)
if not os.path.isdir(directory):
os.makedirs(directory, exist_ok=True)

ann = h5py.File(FLAGS.vg_ann_file, 'r')
with open(FLAGS.image_data_file, 'r') as img_data_f, open(FLAGS.vg_ann_label_file, 'r') as label_f:
img_data = json.load(img_data_f)
label_data = json.load(label_f)
print(ann.keys())
print(img_data[0].keys())
print(label_data.keys())

anno_iter = create_anno_iter(img_data, label_data, ann)

tfrecord_lib.write_tf_record_dataset(
output_path=FLAGS.output_dir,
annotation_iterator=anno_iter,
process_func=create_example,
num_shards=FLAGS.num_shards,
multiple_processes=8)

def run_main():
app.run(main)

if __name__ == '__main__':
run_main()
39 changes: 39 additions & 0 deletions data/visual_genome.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np
import utils
import vocab
from data import dataset as dataset_lib
from data import decode_utils
import tensorflow as tf

@dataset_lib.DatasetRegistry.register('visual_genome')
class VisualGenomeTFRecordDataset(dataset_lib.TFRecordDataset):
def get_feature_map(self):
image_feature_map = decode_utils.get_feature_map_for_image()
vg_feature_map = {
'box1': tf.io.VarLenFeature(tf.float32),
'pred': tf.io.VarLenFeature(tf.int64),
'box2': tf.io.VarLenFeature(tf.float32),
'pred_label': tf.io.VarLenFeature(tf.int64),
'box1_label': tf.io.VarLenFeature(tf.int64),
'box2_label': tf.io.VarLenFeature(tf.int64)
}

def filter_example(self, example, training):
if training:
return tf.shape(example['box1'])[0] > 0
else:
return True

def extract(self, example, training):
features = {
'image': decode_utils.decode_image(example),
'image/id': tf.strings.to_number(example['image/source_id'], tf.int64),
}

scale = 1. / utils.tf_float32(tf.shape(features['image'])[:2])
box1 = utils.scale_points(example['box1'])
box2 = utils.scale_points(example['box2'])

labels = {k: v for k, v in example.items() if 'image' not in k}

return features, labels
1 change: 1 addition & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from tasks import instance_segmentation
from tasks import keypoint_detection
from tasks import object_detection
from tasks import scene_graph_generation
# pylint: enable=unused-import
from tasks import task as task_lib
import tensorflow as tf
Expand Down
File renamed without changes.
10 changes: 10 additions & 0 deletions run_scripts/create_vg_tfrecord.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
export PYTHONPATH=$PYTHONPATH:~/Github/pix2seq
split=train

python ../data/scripts/create_vg_tfrecord.py \
--vg_image_dir /data/hulab/zcai75/visual_genome/VG_100K \
--vg_ann_file /data/hulab/zcai75/visual_genome/VG-SGG-with-attri.h5 \
--image_data_file /data/hulab/zcai75/visual_genome/image_data.json \
--vg_ann_label_file /data/hulab/zcai75/visual_genome/vg_motif_anno/VG-SGG-dicts-with-attri.json \
--train False \
--output_dir /data/hulab/zcai75/pix2seq/vg/${split}
14 changes: 7 additions & 7 deletions run_scripts/eval.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
python ../run.py \
--model_dir /data/hulab/zcai75/checkpoints/pix2seq/coco_det_finetune \
--mode eval

config=configs/config_multi_task.py:object_detection@coco/2017_object_detection,vit-b
model_dir=/tmp/pix2seq_eval_det
config=../configs/config_multi_task.py:object_detection@coco/2017_object_detection,vit-b
model_dir=/data/hulab/zcai75/checkpoints/pix2seq/multi_task/vit_b_640x640
# Path to save the detected boxes for evaluating other tasks.
boxes_json_path=$model_dir/boxes.json
python3 run.py --config=$config --model_dir=$model_dir --mode=eval --config.task.eval_outputs_json_path=$boxes_json_path
python ../run.py \
--config=$config \
--model_dir=$model_dir \
--mode=eval \
--config.task.eval_outputs_json_path=$boxes_json_path
10 changes: 10 additions & 0 deletions run_scripts/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
config=../configs/config_scene_graph.py
model_dir=/data/hulab/zcai75/checkpoints/pix2seq/scene_graph

python ../run.py \
--mode=train \
--model_dir=$model_dir \
--config=$config \
--config.train.batch_size=32 \
--config.train.epochs=20 \
--config.optimization.learning_rate=3e-5
Loading

0 comments on commit 6170222

Please sign in to comment.