Skip to content

Commit

Permalink
Stripped newline from text and convert image to RGB.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ngaima Sandiman committed Oct 12, 2024
1 parent 55e60da commit 5ba557a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 19 deletions.
22 changes: 5 additions & 17 deletions src/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,27 +163,15 @@ def validation_step(self, batch, batch_idx):
candidates = []
references = []

columns = [
"prediction",
"references",
]
log_data = []

for pred, captions in zip(predictions, labels):
predicted_text = (
parts[1] if len(parts := pred.split("\n", 1)) > 1 else pred
)
# predicted_text = (
# parts[1] if len(parts := pred.split("\n", 1)) > 1 else pred
# )
predicted_text = pred.replace("caption en", "").replace("\n", "")
captions = [caption.replace("\n", "") for caption in captions]
candidates.append(predicted_text)
references.append(captions)

if self.config.train_wandb_logger is not None:
log_data.append([predicted_text, captions])

if len(log_data) > 0:
self.config.train_wandb_logger.log_text(
key=f"log_samples_{batch_idx}", columns=columns, data=log_data
)

corpus_scores = self.calculate_corpus_scores(references, candidates)

bleu_1_scores.append(corpus_scores["bleu_1"].item())
Expand Down
4 changes: 2 additions & 2 deletions src/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def train_collate_fn(examples, processor, device):
images = [example["image"] for example in examples]
images = [example["image"].convert("RGB") for example in examples]
texts = [PROMPT for _ in examples]
captions = [example["caption"][0] for example in examples]

Expand Down Expand Up @@ -41,7 +41,7 @@ def train_collate_fn(examples, processor, device):


def eval_collate_fn(examples, processor, device):
images = [example["image"] for example in examples]
images = [example["image"].convert("RGB") for example in examples]
texts = [PROMPT for _ in examples]
captions = [example["caption"] for example in examples]

Expand Down

0 comments on commit 5ba557a

Please sign in to comment.