Skip to content

Commit

Permalink
translation with trimmed mbart
Browse files Browse the repository at this point in the history
  • Loading branch information
ed-fish committed Dec 9, 2024
1 parent bb22ea0 commit 9b89386
Show file tree
Hide file tree
Showing 10 changed files with 305 additions and 235 deletions.
16 changes: 12 additions & 4 deletions data/base_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import decord

class VAT_dataset(Dataset):
def __init__(self, args):
def __init__(self, args, translation_tokenizer=None):
super().__init__()
self.video_decode_backend = args.video_decode_backend
self.num_frames = args.num_frames
Expand All @@ -24,6 +24,14 @@ def __init__(self, args):
self.weight = [0.2, 0.2, 0.2, 0.2] + [0.2 / 8] * 8
self.title = self.text_type == 'raw'
self.data_root = ''
if translation_tokenizer:
print("using translationtokenizerrr")
print("type of translation_tokenizer:", type(translation_tokenizer))
self.translate = True
self.tokenizer = translation_tokenizer
else:
self.tokenizer = get_tokenizer(HF_HUB_PREFIX + self.model, cache_dir=self.cache_dir)

if args.clip_type != 'al':
with open(self.train_data, 'r') as f:
self.id2title_folder_caps = json.load(f)
Expand All @@ -33,7 +41,6 @@ def __init__(self, args):

self.clip_type = args.clip_type

self.tokenizer = get_tokenizer(HF_HUB_PREFIX + self.model, cache_dir=self.cache_dir)
self.video_transform = get_video_transform(args)

def __len__(self):
Expand Down Expand Up @@ -75,6 +82,7 @@ def get_text(self, id):
text = self.id2title_folder_caps[id]['ofa'][ofa_number]
else:
text = self.id2title_folder_caps[id][text_type]

text_output = load_and_transform_text(text, self.tokenizer, title=text_type=='raw')
return text_output, ofa_number

Expand Down Expand Up @@ -113,8 +121,8 @@ def __init__(self, base_dataset, chunk_size=8, stride=4):
self.base_dataset = base_dataset
self.chunk_size = chunk_size
self.stride = stride
self.ids = self.base_dataset.ids

self.ids = base_dataset.ids
def __len__(self):
return len(self.base_dataset)

Expand Down
8 changes: 4 additions & 4 deletions data/build_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def get_VAT_dataset(args):



def get_VAT_batched_dataset(args, chunk_size=8, stride=4):
base_dataset = VAT_dataset(args)
def get_VAT_batched_dataset(args, translation_tokenizer, chunk_size=8, stride=4):
base_dataset = VAT_dataset(args, translation_tokenizer)
dataset = VATBatchedDataset(base_dataset, chunk_size=chunk_size, stride=stride)
num_samples = len(dataset)
sampler = DistributedSampler(dataset) if args.distributed else None
Expand Down Expand Up @@ -174,13 +174,13 @@ def collate_fn(batch):
}


def get_data(args, epoch=0):
def get_data(args, epoch=0, translation_tokenizer=None):
data = {}
if args.do_train:
print(args.train_data)
if args.train_data.endswith(".json"):
if args.use_batched_dataset:
data[f"{args.clip_type}_pt"] = get_VAT_batched_dataset(args)
data[f"{args.clip_type}_pt"] = get_VAT_batched_dataset(args, translation_tokenizer)
else:
data[f"{args.clip_type}_pt"] = get_VAT_dataset(args)
elif args.train_data.endswith(".tar"):
Expand Down
24 changes: 15 additions & 9 deletions data/process_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,23 @@ def clean_youtube(text, is_tags=False):
return text

def load_and_transform_text(text, tokenizer, title=True):
if title:
title_hashtags = text.split('#')
title, hashtags = title_hashtags[0], '#' + '#'.join(title_hashtags[1:])
title = clean_youtube(title)
hashtags = clean_youtube(hashtags, is_tags=True)
text = title + ', ' + hashtags
if text == '' or text.isspace():
raise ValueError('text is empty')
input_ids, attention_mask = tokenizer(text)
encoding = tokenizer(text, truncation=True, padding='max_length', max_length=77, return_tensors='pt')
# print("DEBUG tokenization:", encoding)
# encoding should be a dictionary with 'input_ids' and 'attention_mask' as tensors.
input_ids, attention_mask = encoding['input_ids'], encoding['attention_mask']
# print("DEBUG tokenized input_ids shape:", input_ids.shape)
return {'input_ids': input_ids.squeeze(), 'attention_mask': attention_mask.squeeze()}

# if title:
# title_hashtags = text.split('#')
# title, hashtags = title_hashtags[0], '#' + '#'.join(title_hashtags[1:])
# title = clean_youtube(title)
# hashtags = clean_youtube(hashtags, is_tags=True)
# text = title + ', ' + hashtags
# if text == '' or text.isspace():
# raise ValueError('text is empty')
# input_ids, attention_mask = tokenizer(text)
# return {'input_ids': input_ids.squeeze(), 'attention_mask': attention_mask.squeeze()}


if __name__ == '__main__':
Expand Down
47 changes: 47 additions & 0 deletions get_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@

import json
from transformers import MBartForConditionalGeneration, MBartTokenizer, MBartConfig
from hftrim.ModelTrimmers import MBartTrimmer
from hftrim.TokenizerTrimmer import TokenizerTrimmer

# Load your JSON file
with open('data/phoenix/phoenix_train.json', 'r') as f:
raw_data = json.load(f)

# Extract "polish_mplug" values
data = []

for key, value in raw_data.items():
polish_mplug = value.get('polish_mplug', '').strip()
if polish_mplug:
data.append(polish_mplug)

# Initialize tokenizer and model
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-cc25", src_lang="de_DE", tgt_lang="de_DE")
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
configuration = model.config

# Trim tokenizer
tt = TokenizerTrimmer(tokenizer)
tt.make_vocab(data)
tt.make_tokenizer()

# Trim model
mt = MBartTrimmer(model, configuration, tt.trimmed_tokenizer)
mt.make_weights(tt.trimmed_vocab_ids)
mt.make_model()

new_tokenizer = tt.trimmed_tokenizer
new_model = mt.trimmed_model

# Save the trimmed tokenizer and model
new_tokenizer.save_pretrained('pretrain_models/MBart_trimmed')
new_model.save_pretrained('pretrain_models/MBart_trimmed')

# Configure and save the MyTran model
configuration = MBartConfig.from_pretrained('pretrain_models/MBart_trimmed/config.json')
configuration.vocab_size = new_model.config.vocab_size
mytran_model = MBartForConditionalGeneration._from_config(config=configuration)
mytran_model.model.shared = new_model.model.shared

mytran_model.save_pretrained('pretrain_models/mytran/')
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def get_latest_checkpoint(path: str, remote: bool):
def SET_GLOBAL_VALUE(k, v):
set_global_value(k, v)


def main(args):
args = parse_args(args)

Expand Down
2 changes: 0 additions & 2 deletions model/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
from torch import nn
from transformers import AutoConfig, CLIPPreTrainedModel


from model.base_model import CLIPModel
from model.process_clip import add_time_attn_block, convert_model_to_lora, set_global_value, resize_pos
from open_clip import convert_weights_to_lp
from open_clip.transformer import PatchDropout
from training.distributed import is_master


def SET_GLOBAL_VALUE(k, v):
set_global_value(k, v)

Expand Down
192 changes: 135 additions & 57 deletions model/translation_model.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,153 @@
from torch import Tensor
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch import nn
from transformers import MBartForConditionalGeneration, AutoConfig, PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
import torchvision
from transformers import MBartForConditionalGeneration, MBartConfig
from transformers.models.mbart.modeling_mbart import shift_tokens_right
import numpy as np
from pathlib import Path

class VideoTranslationModel(PreTrainedModel):
def __init__(self, clip_model, mbart_model_name_or_path):
# Initialize with the MBart configuration
config = AutoConfig.from_pretrained(mbart_model_name_or_path)
super().__init__(config)
def make_resnet(name='resnet18'):
if name == 'resnet18':
model = torchvision.models.resnet18(pretrained=True)
elif name == 'resnet34':
model = torchvision.models.resnet34(pretrained=True)
elif name == 'resnet50':
model = torchvision.models.resnet50(pretrained=True)
elif name == 'resnet101':
model = torchvision.models.resnet101(pretrained=True)
else:
raise Exception('Unsupported resnet model.')
inchannel = model.fc.in_features
model.fc = nn.Identity()
return model

self.clip_model = clip_model # Your CLIP video encoder
self.mbart_model = MBartForConditionalGeneration.from_pretrained(mbart_model_name_or_path)

# If encoder and MBart hidden sizes differ, add a projection layer
if self.clip_model.config.projection_dim != self.mbart_model.config.d_model:
self.encoder_projection = nn.Linear(
self.clip_model.config.projection_dim, self.mbart_model.config.d_model
)
else:
self.encoder_projection = nn.Identity()
# Scaling factor for embeddings (optional)
self.embed_scale = math.sqrt(self.mbart_model.config.d_model)
class TemporalConv(nn.Module):
def __init__(self, input_size, hidden_size, conv_type=2):
super(TemporalConv, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.conv_type = conv_type

if self.conv_type == 0:
self.kernel_size = ['K3']
elif self.conv_type == 1:
self.kernel_size = ['K5', "P2"]
elif self.conv_type == 2:
self.kernel_size = ['K5', "P2", 'K5', 'P2']

def forward(self, pixel_values, labels=None, decoder_input_ids=None, decoder_attention_mask=None):
# Obtain encoder outputs from the CLIP model
with torch.no_grad():
encoder_hidden_states = self.clip_model.encode_image(pixel_values)
# encoder_hidden_states shape: (batch_size, seq_len, hidden_size)
modules = []
for layer_idx, ks in enumerate(self.kernel_size):
input_sz = self.input_size if layer_idx == 0 else self.hidden_size
if ks[0] == 'P':
modules.append(nn.MaxPool1d(kernel_size=int(ks[1]), ceil_mode=False))
elif ks[0] == 'K':
modules.append(nn.Conv1d(input_sz, self.hidden_size, kernel_size=int(ks[1]), stride=1, padding=0))
modules.append(nn.BatchNorm1d(self.hidden_size))
modules.append(nn.ReLU(inplace=True))
self.temporal_conv = nn.Sequential(*modules)

def forward(self, x):
x = self.temporal_conv(x.permute(0,2,1))
return x.permute(0,2,1)

# Project encoder outputs to match MBart's hidden size
encoder_hidden_states = self.encoder_projection(encoder_hidden_states)
# Apply scaling
encoder_hidden_states = self.embed_scale * encoder_hidden_states
class FeatureExtracter(nn.Module):
def __init__(self, frozen=False):
super(FeatureExtracter, self).__init__()
self.conv_2d = make_resnet(name='resnet18')
self.conv_1d = TemporalConv(input_size=512, hidden_size=1024, conv_type=2)
if frozen:
for param in self.conv_2d.parameters():
param.requires_grad = False

# Create an attention mask for the encoder (assuming no padding)
encoder_attention_mask = torch.ones(
encoder_hidden_states.size()[:-1], dtype=torch.long, device=encoder_hidden_states.device
def forward(self, src: Tensor, src_length_batch):
src = self.conv_2d(src)
x_batch = []
start = 0
for length in src_length_batch:
end = start + length
x_batch.append(src[start:end])
start = end
x = nn.utils.rnn.pad_sequence(x_batch, batch_first=True)
x = self.conv_1d(x)
return x

class V_encoder(nn.Module):
def __init__(self, emb_size, feature_size):
super(V_encoder, self).__init__()
self.src_emb = nn.Linear(feature_size, emb_size)
self.bn_ac = nn.Sequential(
nn.BatchNorm1d(emb_size),
nn.ReLU(inplace=True)
)
print(encoder_hidden_states.shape)
for m in self.modules():
if isinstance(m, (nn.Conv1d,nn.Linear)):
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def forward(self, src: Tensor):
src = self.src_emb(src)
return src

def config_decoder():
from transformers import AutoConfig
return MBartForConditionalGeneration.from_pretrained(
"pretrain_models/MBart_trimmed/",
ignore_mismatched_sizes=True,
config=AutoConfig.from_pretrained(Path('pretrain_models/MBart_trimmed/config.json'))
)

# Prepare encoder outputs
encoder_outputs = BaseModelOutput(last_hidden_state=encoder_hidden_states)
class gloss_free_model(nn.Module):
def __init__(self, embed_dim=1024, pretrain=None, embed_layer=True):
super(gloss_free_model, self).__init__()
self.mbart = config_decoder()
if embed_layer:
self.sign_emb = V_encoder(emb_size=embed_dim, feature_size=768)
self.embed_scale = math.sqrt(embed_dim)
else:
self.sign_emb = nn.Identity()
self.embed_scale = 1.0

# If decoder_input_ids are not provided, generate them from labels
if decoder_input_ids is None and labels is not None:
decoder_input_ids = self.mbart_model.prepare_decoder_input_ids_from_labels(labels)
def forward(self, input_embeds, attention_mask, tgt_input):
# DEBUG PRINTS:
# print("DEBUG FORWARD in gloss_free_model")
# print("input_embeds.shape:", input_embeds.shape)
# print("attention_mask.shape:", attention_mask.shape)
# print("tgt_input['input_ids'].shape:", tgt_input["input_ids"].shape)
# print("tgt_input['input_ids'] min/max:", tgt_input["input_ids"].min().item(), tgt_input["input_ids"].max().item())

# Pass encoder outputs and masks to the MBart model
outputs = self.mbart_model(
encoder_outputs=encoder_outputs,
input_embeds = self.embed_scale * self.sign_emb(input_embeds)

decoder_input_ids = shift_tokens_right(
tgt_input["input_ids"], pad_token_id=self.mbart.config.pad_token_id
)
decoder_attention_mask = (decoder_input_ids != self.mbart.config.pad_token_id).long()

# print("decoder_input_ids.shape:", decoder_input_ids.shape)
# print("decoder_input_ids min/max:", decoder_input_ids.min().item(), decoder_input_ids.max().item())
# print("decoder_attention_mask.shape:", decoder_attention_mask.shape)

out = self.mbart(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
labels=labels,
use_cache=False,
labels=tgt_input["input_ids"],
return_dict=True,
)
return outputs
return out["loss"], out["logits"]

def generate(self, pixel_values, max_length=50, num_beams=5):
encoder_hidden_states = self.clip_model.encode_image(pixel_values)
encoder_hidden_states = self.encoder_projection(encoder_hidden_states)
encoder_hidden_states = self.embed_scale * encoder_hidden_states
encoder_attention_mask = torch.ones(
encoder_hidden_states.size()[:-1], dtype=torch.long, device=encoder_hidden_states.device
)
encoder_outputs = BaseModelOutput(last_hidden_state=encoder_hidden_states)
generated_tokens = self.mbart_model.generate(
encoder_outputs=encoder_outputs,
attention_mask=encoder_attention_mask,
max_length=max_length,
def generate(self, src_input, max_new_tokens, num_beams, decoder_start_token_id):
inputs_embeds, attention_mask = self.share_forward(src_input)
out = self.mbart.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
decoder_start_token_id=decoder_start_token_id
)
return generated_tokens
return out
Loading

0 comments on commit 9b89386

Please sign in to comment.