Skip to content

Commit

Permalink
get training to run
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCai1206 committed Feb 13, 2023
1 parent 14c5c06 commit e46ed13
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 22 deletions.
2 changes: 2 additions & 0 deletions architectures/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,8 @@ def call(self, images, training, ret_list=False):
tokens = self.stem_conv(images)
bsz, h, w, dim = get_shape(tokens)
tokens = self.stem_ln(tf.reshape(tokens, [bsz, h * w, dim]))
# tf.print('tokens', tokens.shape)
# tf.print('images', images.shape)

tokens = tokens + tf.expand_dims(self.vis_pos_emb, 0)
if self.use_cls_token:
Expand Down
11 changes: 6 additions & 5 deletions configs/config_scene_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def get_config(config_str=None):
"""config_str is either empty or contains task,architecture variants."""

task_variant = 'scene_graph_generation@visual_genome'
encoder_variant = 'vit-b' # Set model architecture.
image_size = (640, 640) # Set image size.
encoder_variant = 'resnet' # Set model architecture.
image_size = (480, 480) # Set image size.

tasks_and_datasets = []
for task_and_ds in task_variant.split('+'):
Expand All @@ -41,8 +41,9 @@ def get_config(config_str=None):
vocab_id=10,
image_size=image_size,
quantization_bins=1000,
max_instances_per_image=100,
max_instances_per_image_test=100,
max_instances_per_image=20,
max_instances_per_image_test=20,
max_seq_len=512,
object_order='random',
color_jitter_strength=0.,
jitter_scale_min=0.3,
Expand Down Expand Up @@ -111,7 +112,7 @@ def get_config(config_str=None):
),

train=D(
batch_size=32,
batch_size=16,
epochs=40,
steps=0, # set to >0 to override epochs.
checkpoint_epochs=1,
Expand Down
6 changes: 4 additions & 2 deletions configs/dataset_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@
vg_dataset_config = D(
train_file_pattern=VG_TRAIN_TFRECORD_PATTERN,
val_file_pattern=VG_VAL_TFRECORD_PATTERN,
train_num_examples=108073-12928,
eval_num_examples=108073-81627,
# train_num_examples=108073-12928,
# eval_num_examples=108073-81627,
train_num_examples=16,
eval_num_examples=14,
train_split='train',
eval_split='validation',
**_shared_dataset_config
Expand Down
3 changes: 1 addition & 2 deletions data/scripts/create_vg_tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def create_anno_iter(img_data, label_data, ann):

yield feature_dict, skipped
skipped = 0
# break

break

def create_example(feature_dict, skipped):
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
Expand Down
2 changes: 2 additions & 0 deletions models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def train_step(self, examples, tasks, strategy):
strategy: tensorflow strategy such as `TPUStrategy` or `MirroredStrategy`.
"""
logging.info('train_step begins...')
# for key in examples[0][1]:
# tf.print(key, examples[0][1][key].shape, examples[0][1][key][0])
preprocessed_outputs = [
t.preprocess_batched(e, training=True) for e, t in zip(examples, tasks)]

Expand Down
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
# pylint: enable=unused-import
from tasks import task as task_lib
import tensorflow as tf

tf.get_logger().setLevel('ERROR')

TRAIN = 'train'
EVAL = 'eval'
Expand Down
2 changes: 1 addition & 1 deletion run_scripts/create_vg_tfrecord.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export PYTHONPATH=$PYTHONPATH:~/Github/pix2seq
train=false
train=true
if $train; then
split=train
else
Expand Down
8 changes: 6 additions & 2 deletions run_scripts/train.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
config=../configs/config_scene_graph.py
model_dir=/data/hulab/zcai75/checkpoints/pix2seq/scene_graph
export TF_CPP_MIN_LOG_LEVEL=1
# export NCCL_DEBUG=INFO
export CUDA_VISIBLE_DEVICES=1
# export AUTOGRAPH_VERBOSITY=0

python ../run.py \
--mode=train \
--model_dir=$model_dir \
--config=$config \
--config.train.batch_size=32 \
--config.train.epochs=20 \
--config.train.batch_size=12 \
--config.train.epochs=2 \
--config.optimization.learning_rate=3e-5 \
--run_eagerly > train.out
42 changes: 33 additions & 9 deletions tasks/scene_graph_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self,
super().__init__(config)

if config.task.get('max_seq_len', 'auto') == 'auto':
self.config.task.max_seq_len = config.task.max_instances_per_image * 5
self.config.task.max_seq_len = config.task.max_instances_per_image * 15
# self._category_names = task_utils.get_category_names(
# config.dataset.get('category_names_path'))
metric_config = config.task.get('metric')
Expand Down Expand Up @@ -135,7 +135,9 @@ def preprocess_batched(self, batched_examples, training):

# Create input/target seq.
ret = build_response_seq_from_bbox(
labels['bbox'], labels['label'], config.quantization_bins,
labels['box1'], labels['box2'],
labels['box1_id'], labels['box2_id'],
labels['pred'], config.quantization_bins,
config.noise_bbox_weight, mconfig.coord_vocab_shift,
class_label_corruption=config.class_label_corruption)
response_seq, response_seq_cm, token_weights = ret
Expand All @@ -145,6 +147,8 @@ def preprocess_batched(self, batched_examples, training):
target_seq = tf.concat([prompt_seq, response_seq], -1)

# Pad sequence to a unified maximum length.
# tf.print(prompt_seq.shape, response_seq_cm.shape, response_seq.shape, target_seq.shape, token_weights.shape)
# tf.print(config.max_seq_len)
assert input_seq.shape[-1] <= config.max_seq_len + 1
input_seq = utils.pad_to_max_len(input_seq, config.max_seq_len + 1, -1)
target_seq = utils.pad_to_max_len(target_seq, config.max_seq_len + 1, -1)
Expand All @@ -156,6 +160,7 @@ def preprocess_batched(self, batched_examples, training):
target_seq == vocab.PADDING_TOKEN,
tf.zeros_like(token_weights) + config.eos_token_weight, token_weights)

# tf.print(features['image'].shape)
if training:
return features['image'], input_seq, target_seq, token_weights
else:
Expand Down Expand Up @@ -412,7 +417,7 @@ def build_response_seq_from_bbox(box1, box2,
qbox1 = qbox1 + coord_vocab_shift
qbox1 = tf.where(is_padding, tf.zeros_like(qbox1), qbox1)

is_padding = tf.expand_dims(tf.equal(box1_id, 0), -1)
is_padding = tf.expand_dims(tf.equal(box2_id, 0), -1)
qbox2 = utils.quantize(box2, quantization_bins)
qbox2 = qbox2 + coord_vocab_shift
qbox2 = tf.where(is_padding, tf.zeros_like(qbox2), qbox1)
Expand All @@ -426,7 +431,14 @@ def build_response_seq_from_bbox(box1, box2,
lb_shape = tf.shape(new_label2)

# Bbox and label serialization.
response_seq = tf.concat([qbox1, new_label1, vocab.SUB, pred, vocab.PRED, qbox2, new_label2, vocab.OBJ], axis=-1)
# tf.print('qbox1', qbox1.shape)
# tf.print('new_label1', new_label1.shape)
# tf.print(tf.broadcast_to(tf.cast(vocab.SUB, tf.int64), new_label1.shape).shape)
# tf.print('pred', tf.expand_dims(pred, -1).shape)
response_seq = tf.concat([
qbox1, new_label1, tf.broadcast_to(tf.cast(vocab.SUB, tf.int64), new_label1.shape),
tf.expand_dims(pred, -1), tf.broadcast_to(tf.cast(vocab.PRED, tf.int64), new_label1.shape),
qbox2, new_label2, tf.broadcast_to(tf.cast(vocab.OBJ, tf.int64), new_label1.shape)], axis=-1)
response_seq = utils.flatten_non_batch_dims(response_seq, 2)
rand_cls = vocab.BASE_VOCAB_SHIFT + tf.random.uniform(
lb_shape,
Expand All @@ -447,14 +459,26 @@ def build_response_seq_from_bbox(box1, box2,
'real_n_rand_n_fake_cls': real_n_rand_n_fake_cls}
new_label_m1 = label_mapping[class_label_corruption]
new_label_m1 = tf.where(is_padding, tf.zeros_like(new_label_m1), new_label_m1)
response_seq = tf.concat([qbox1, new_label_m1, vocab.SUB, pred, vocab.PRED, qbox2, new_label_m2, vocab.OBJ], axis=-1)
new_label_m2 = label_mapping[class_label_corruption]
new_label_m2 = tf.where(is_padding, tf.zeros_like(new_label_m2), new_label_m2)
response_seq_class_m = tf.concat([
qbox1, new_label_m1, tf.broadcast_to(tf.cast(vocab.SUB, tf.int64), new_label_m1.shape),
tf.expand_dims(pred, -1), tf.broadcast_to(tf.cast(vocab.PRED, tf.int64), new_label_m1.shape),
qbox2, new_label_m2, tf.broadcast_to(tf.cast(vocab.OBJ, tf.int64), new_label_m2.shape)], axis=-1)
response_seq_class_m = utils.flatten_non_batch_dims(response_seq_class_m, 2)

# Get token weights.
is_real = tf.cast(tf.not_equal(new_label, vocab.FAKE_CLASS_TOKEN), tf.float32)
bbox_weight = tf.tile(is_real, [1, 1, 4])
label_weight = is_real + (1. - is_real) * noise_bbox_weight
token_weights = tf.concat([bbox_weight, label_weight], -1)
is_real1 = tf.cast(tf.not_equal(new_label1, vocab.FAKE_CLASS_TOKEN), tf.float32)
bbox1_weight = tf.tile(is_real1, [1, 1, 4])
label1_weight = is_real1 + (1. - is_real1) * noise_bbox_weight
is_real2 = tf.cast(tf.not_equal(new_label2, vocab.FAKE_CLASS_TOKEN), tf.float32)
bbox2_weight = tf.tile(is_real2, [1, 1, 4])
label2_weight = is_real2 + (1. - is_real2) * noise_bbox_weight
token_weights = tf.concat([
bbox1_weight, label1_weight, tf.broadcast_to(0., new_label_m1.shape),
tf.expand_dims(tf.broadcast_to(1., pred.shape), -1), tf.broadcast_to(0., new_label_m1.shape),
bbox2_weight, label2_weight, tf.broadcast_to(0., new_label_m2.shape)
], -1)
token_weights = utils.flatten_non_batch_dims(token_weights, 2)

return response_seq, response_seq_class_m, token_weights
Expand Down

0 comments on commit e46ed13

Please sign in to comment.