Skip to content

Commit

Permalink
Update test_batched_inference_image_captioning_conditioned (hugging…
Browse files Browse the repository at this point in the history
…face#23391)

* fix

* fix

* fix test + add more docs

---------

Co-authored-by: ydshieh <[email protected]>
Co-authored-by: younesbelkada <[email protected]>
  • Loading branch information
3 people authored May 16, 2023
1 parent d765717 commit 21741e8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/model_doc/pix2struct.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Tips:
Pix2Struct has been fine tuned on a variety of tasks and datasets, ranging from image captioning, visual question answering (VQA) over different inputs (books, charts, science diagrams), captioning UI components etc. The full list can be found in Table 1 of the paper.
We therefore advise you to use these models for the tasks they have been fine tuned on. For instance, if you want to use Pix2Struct for UI captioning, you should use the model fine tuned on the UI dataset. If you want to use Pix2Struct for image captioning, you should use the model fine tuned on the natural images captioning dataset and so on.

If you want to use the model to perform conditional text captioning, make sure to use the processor with `add_special_tokens=False`.

This model was contributed by [ybelkada](https://huggingface.co/ybelkada).
The original code can be found [here](https://github.com/google-research/pix2struct).

Expand Down
9 changes: 6 additions & 3 deletions tests/models/pix2struct/test_modeling_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,17 +749,20 @@ def test_batched_inference_image_captioning_conditioned(self):
texts = ["A picture of", "An photography of"]

# image only
inputs = processor(images=[image_1, image_2], text=texts, return_tensors="pt").to(torch_device)
inputs = processor(images=[image_1, image_2], text=texts, return_tensors="pt", add_special_tokens=False).to(
torch_device
)

predictions = model.generate(**inputs)

self.assertEqual(
processor.decode(predictions[0], skip_special_tokens=True), "A picture of a stop sign that says yes."
processor.decode(predictions[0], skip_special_tokens=True),
"A picture of a stop sign with a red stop sign on it.",
)

self.assertEqual(
processor.decode(predictions[1], skip_special_tokens=True),
"An photography of the Temple Bar and a few other places.",
"An photography of the Temple Bar and the Temple Bar.",
)

def test_vqa_model(self):
Expand Down

0 comments on commit 21741e8

Please sign in to comment.