Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose multimask_output in EfficientSam.forward #30

Open
MichaelFishmanBDAII opened this issue Dec 19, 2023 · 3 comments
Open

Expose multimask_output in EfficientSam.forward #30

MichaelFishmanBDAII opened this issue Dec 19, 2023 · 3 comments

Comments

@MichaelFishmanBDAII
Copy link

MichaelFishmanBDAII commented Dec 19, 2023

This would involve adding the optional parameter to EfficientSam.forward, and then passing it to EfficientSam.predict_masks.

    def forward(
        self,
        batched_images: torch.Tensor,
        batched_points: torch.Tensor,
        batched_point_labels: torch.Tensor,
        scale_to_original_image_size: bool = True,
        multimask_output: bool = True
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Predicts masks end-to-end from provided images and prompts.
        If prompts are not known in advance, using SamPredictor is
        recommended over calling the model directly.

        Arguments:
          batched_images: A tensor of shape [B, 3, H, W]
          batched_points: A tensor of shape [B, num_queries, max_num_pts, 2]
          batched_point_labels: A tensor of shape [B, num_queries, max_num_pts]
          multimask_output: If True, generate multiple masks for each query. Otherwise, generate one mask per query.

        Returns:
          A list tuples of two tensors where the ith element is by considering the first i+1 points.
            low_res_mask: A tensor of shape [B, 256, 256] of predicted masks
            iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores
        """
        batch_size, _, input_h, input_w = batched_images.shape
        image_embeddings = self.get_image_embeddings(batched_images)
        return self.predict_masks(
            image_embeddings,
            batched_points,
            batched_point_labels,
            multimask_output=multimask_output,
            input_h=input_h,
            input_w=input_w,
            output_h=input_h if scale_to_original_image_size else -1,
            output_w=input_w if scale_to_original_image_size else -1,
        )

I'm happy to make a PR for this, but I figure it may be easier to just throw this in as part of the ongoing updates you all are making.

Thanks for releasing this code and updating it so frequently!

Edit: I tried the code above, and using multimask_output=False seems to be giving me broken masks, so I'm probably missing something and this may be more involved than I'd thought. The bug could also be in my postprocessing code.

For the dog example image and points, this is what I get with, and without multimask:
with multimask

without multimask

@yformer
Copy link
Owner

yformer commented Dec 20, 2023

@MichaelFishmanBDAII, thanks for your interest! We will take a look.

@MichaelFishmanBDAII
Copy link
Author

The predicted IOU for the single-mask mode masks tends to be very low (1e-5), so I don't think the problem is how I'm unpacking the masks.

I also tried using bounding box prompts instead of point prompts, since the original SAM paper says single mask mode was designed for multi prompt mask generation, but this still gave me the diffuse, grid-like masks that the point prompts gave me.

@feivellau
Copy link

I'm experiencing the same problem. When I set multimask_output=False to get a single mask, the mask results are very poor!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants