Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an Ensemble Module that is constructed from a list of Modules and encapsulates the necessary state #992

Open
sinking-point opened this issue Aug 1, 2022 · 6 comments

Comments

@sinking-point
Copy link

sinking-point commented Aug 1, 2022

Most of the examples I've seen use hmap at the top level, to create an 'outer' ensemble of models, or to factor out the batch dimension. However, my use case is 'inner' ensembles of modules within a larger model. This means I have to register the parameters and buffers from combine_state_for_ensemble with the parent module, which is annoying and messy.

An obvious solution is to create an Ensemble module which internally calls combine_state_for_ensemble and vmap along with storing the necessary state:

self.ens = Ensemble(my_modules, in_dims=(0, 0, 2), out_dims=(0, 0, 2))
...
x = ens(x)

Even if registering the state weren't an issue, I still think this would be a popular feature. It's more intuitive than the current method of creating ensembles.

@sinking-point
Copy link
Author

Something like this, perhaps:

class Ensemble(nn.Module):
    def __init__(self, modules, **kwargs):
        super().__init__()
        
        fmodel, self.params, self.buffers = combine_state_for_ensemble(modules)
        
        self.vmap_model = vmap(fmodel, **kwargs)
        
        for i, param in enumerate(self.params):
            self.register_parameter('param_' + str(i), nn.Parameter(param))
        
        for i, buffer in enumerate(self.buffers):
            self.register_buffer('buffer_' + str(i), nn.Buffer(buffer))
            
    def forward(self, *args, **kwargs):
        return self.vmap_model(self.params, self.buffers, *args, **kwargs)

@zou3519
Copy link
Contributor

zou3519 commented Aug 2, 2022

This seems convenient to have. I am not sure if this would go into functorch or in torch.nn in the long-term state, but we can certainly toss something like this into functorch to start. cc @samdow who is thinking about functional modules. Also curious to hear @jbschlosser and @albanD's opinions as torch.nn maintainers.

@albanD
Copy link
Contributor

albanD commented Aug 2, 2022

This would need to be part of a bigger plan to move things like combine_state_for_ensemble as well?
Also this seems to be very vmap specific?

@zou3519
Copy link
Contributor

zou3519 commented Aug 2, 2022

Also this seems to be very vmap specific?

Are you suggesting that we should put the nn.Ensemble API into functorch because it is vmap specific?

@sinking-point
Copy link
Author

sinking-point commented Aug 3, 2022

I did wonder about this because my suggestion is not really functional. It doesn't fit with the theme of this package. However, this is the only place it can go since torch can't have functorch as a dependency. Unless we create a new package for this, I guess.

@albanD
Copy link
Contributor

albanD commented Aug 3, 2022

Are you suggesting that we should put the nn.Ensemble API into functorch because it is vmap specific?

Not necessarily but it does sound much "higher level" than things currently in torch.nn. So not sure where it should live.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants