Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of MIL-NCE loss #2

Open
ChenYutongTHU opened this issue Jan 7, 2022 · 0 comments
Open

Implementation of MIL-NCE loss #2

ChenYutongTHU opened this issue Jan 7, 2022 · 0 comments

Comments

@ChenYutongTHU
Copy link

Hi, thanks for your interesting work.
I found the implementation of NCE loss somehow different from what is described in your paper.

  1. The dictionary videos have many redundant entries, as a class label can appear in multiple videos and be collected in multiple batches. I notice that all pairs of bsl-1k and dictionary features sharing the same class label are sampled as positive pairs even when they belong to two different batches, suggesting that some positive pairs can be included multiple times in the numerator.

    bsldict/loss/loss.py

    Lines 149 to 154 in eea308a

    for i, t in enumerate(num_unique_dicts):
    # find the set of pairs with the current dictionary class label
    curr_dict = targets_dict == t
    # find the bsl1k embeddings that share the same class label
    curr_bsl1k = match_multi[:, curr_dict][:, 0]
  2. Pairs from different batches sharing same labels are excluded in the denominator yet included in the numerator.

    bsldict/loss/loss.py

    Lines 169 to 172 in eea308a

    # Account for matches that occur in different batches
    pos_neg_mask = (curr_dict.unsqueeze(0) | curr_bsl1k.unsqueeze(1))
    pos_neg_mask *= ~diff_batch_match
    where_mask = torch.where(pos_neg_mask)
  3. At last, for each batch, the log ratio is computed by iterating over the duplicated dictionary entries (num_unique_dicts). According to this paper, it seems more reasonable to iterate over BSL-1k videos.

    bsldict/loss/loss.py

    Lines 177 to 179 in eea308a

    for i, t in enumerate(num_unique_dicts):
    numerator[i] = torch.logsumexp(distances[pos_mask_list[i]], dim=0)
    denominator[i] = torch.logsumexp(distances[mask_list[i]], dim=0)

Moreover, the difference between using BSL-1k video or dictionary video as anchors for contrastive sampling is not reflected in the code implementation. Is it just a concept for better explaining construction of positive/negative pairs?

Thanks a lot for your help~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant