diff --git a/src/model/train.py b/src/model/train.py index f6524ae..726e5cb 100644 --- a/src/model/train.py +++ b/src/model/train.py @@ -163,27 +163,15 @@ def validation_step(self, batch, batch_idx): candidates = [] references = [] - columns = [ - "prediction", - "references", - ] - log_data = [] - for pred, captions in zip(predictions, labels): - predicted_text = ( - parts[1] if len(parts := pred.split("\n", 1)) > 1 else pred - ) + # predicted_text = ( + # parts[1] if len(parts := pred.split("\n", 1)) > 1 else pred + # ) + predicted_text = pred.replace("caption en", "").replace("\n", "") + captions = [caption.replace("\n", "") for caption in captions] candidates.append(predicted_text) references.append(captions) - if self.config.train_wandb_logger is not None: - log_data.append([predicted_text, captions]) - - if len(log_data) > 0: - self.config.train_wandb_logger.log_text( - key=f"log_samples_{batch_idx}", columns=columns, data=log_data - ) - corpus_scores = self.calculate_corpus_scores(references, candidates) bleu_1_scores.append(corpus_scores["bleu_1"].item()) diff --git a/src/utils/train_utils.py b/src/utils/train_utils.py index 8926967..82befa6 100644 --- a/src/utils/train_utils.py +++ b/src/utils/train_utils.py @@ -10,7 +10,7 @@ def train_collate_fn(examples, processor, device): - images = [example["image"] for example in examples] + images = [example["image"].convert("RGB") for example in examples] texts = [PROMPT for _ in examples] captions = [example["caption"][0] for example in examples] @@ -41,7 +41,7 @@ def train_collate_fn(examples, processor, device): def eval_collate_fn(examples, processor, device): - images = [example["image"] for example in examples] + images = [example["image"].convert("RGB") for example in examples] texts = [PROMPT for _ in examples] captions = [example["caption"] for example in examples]