Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Disallow scalar parameters in Dirichlet and Categorical (#11589)
Summary: This adds a small check in `Dirichlet` and `Categorical` `__init__` methods to ensure that scalar parameters are not admissible. **Motivation** Currently, `Dirichlet` throws no error when provided with a scalar parameter, but if we `expand` a scalar instance, it inherits the empty event shape from the original instance and gives unexpected results. The alternative to this check is to promote `event_shape` to be `torch.Size((1,))` if the original instance was a scalar, but that seems to add a bit more complexity (and changes the behavior of `expand` in that it would affect the `event_shape` as well as the `batch_shape` now). Does this seem reasonable? cc. alicanb, fritzo. ```python In [4]: d = dist.Dirichlet(torch.tensor(1.)) In [5]: d.sample() Out[5]: tensor(1.0000) In [6]: d.log_prob(d.sample()) Out[6]: tensor(0.) In [7]: e = d.expand([3]) In [8]: e.sample() Out[8]: tensor([0.3953, 0.1797, 0.4250]) # interpreted as events In [9]: e.log_prob(e.sample()) Out[9]: tensor(0.6931) # wrongly summed out In [10]: e.batch_shape Out[10]: torch.Size([3]) In [11]: e.event_shape Out[11]: torch.Size([]) # cannot be empty ``` Additionally, based on review comments, this removes `real_vector` constraint. This was only being used in `MultivariateNormal`, but I am happy to revert this if we want to keep it around for backwards compatibility. Pull Request resolved: pytorch/pytorch#11589 Differential Revision: D9818271 Pulled By: soumith fbshipit-source-id: f9bbba90ed6f04e0b5bdfa169e70ca20b280fc74
- Loading branch information