Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to update the original model parameters after calling make_functional? #280

Open
trenta3 opened this issue Nov 19, 2021 · 10 comments
Open
Labels
actionable It is clear what should be done for this issue

Comments

@trenta3
Copy link

trenta3 commented Nov 19, 2021

As per the title, I find that updating the tensors pointed by the params returned by make_functional does not update the real parameters in the original model.
Is there a way to do this? I find that it would be extremely useful to implement optimization algorithms in a way that is more similar to their mathematical description.

To provide more context I add an example script of what standard Gradient Descent should look like in this way:

import torch
from torch import nn
from functorch import make_functional

learning_rate = 0.1

def optstep(params, jacobians):
    with torch.no_grad():
        for i, param in enumerate(params):
            param.add_(jacobians[i], alpha=-learning_rate)

if __name__ == '__main__':
    model = nn.Linear(3, 5)
    x, targets = torch.randn(2, 3), torch.randn(2, 5)
    criterion = nn.MSELoss()

    print("INITIAL LOSS:", criterion(model(x), targets).item())
    # Render the model functional and compute the jacobian                                                           
    func_model, params = make_functional(model)
    def f(*params):
        out = func_model(params, x)
        return criterion(out, targets)
    jacobian = torch.autograd.functional.jacobian(f, params)

    # Ideally would train on the current input                                                                       
    optstep(params, jacobian)
    # Now compute the new loss                                                                                       
    print("NEW LOSS:", criterion(model(x), targets).item())

Executing the script shows that the parameters are not updated since the loss doesn't change

INITIAL LOSS: 1.2894147634506226
NEW LOSS: 1.2894147634506226
@trenta3
Copy link
Author

trenta3 commented Nov 19, 2021

After looking a bit in the source code I've found functorch._src.make_functional.extract_weights and load_weights which allow me to do exactly what I wanted to do.
Maybe those methods can be exposed and documented to allow the suggested use case?

@zou3519
Copy link
Contributor

zou3519 commented Nov 19, 2021

Couldn't you do

def optstep(model, jacobians):
    with torch.no_grad():
        for i, param in enumerate(model.parameters()):
            param.add_(jacobians[i], alpha=-learning_rate)

?

(Also, you might want to try functorch.jacrev instead of torch.autograd.functional.jacobian -- it may be faster)

@trenta3
Copy link
Author

trenta3 commented Nov 20, 2021 via email

@zou3519
Copy link
Contributor

zou3519 commented Nov 22, 2021

Is model.parameters() guaranteed to return parameters in the same order of
make_functional?

Yes

If this is the case then I can surely do this, however I would like to ask
that it is documented as proper behaviour on which one can rely on.

Yes, we should document this

@trenta3
Copy link
Author

trenta3 commented Nov 22, 2021

Thank you very much again for all this work.
I think the issue can be closed as soon as the behaviour is documented.

@zou3519
Copy link
Contributor

zou3519 commented Nov 29, 2021

@trenta3 out of curiosity, what are you using make_functional for? Are you using any of the other functorch APIs?

@trenta3
Copy link
Author

trenta3 commented Nov 30, 2021

I'm currently using make_functional as well as other functorch APIs, in particular jvp and jacrev to easily write more complex optimizers that need to consider also second order information of a neural network, which is unmanageable to do in pytorch.
Earlier this year I had the need to compute eigenvectors of the linearizations of some neural networks, and the ability to obtain gradients for each example separately was crucial.

If I must say it, a thing that I miss is the ability to "lazily" compute parts of the hessian, like extracting its diagonal, without using the full memory (and compute) requirement to calculate the whole hessian.
More generally the ability for a pytorch user to manipulate "lazy tensors" (i.e. a thunk of computation depending on some data, but which is not eagerly executed) would be extremely useful to compute the diagonal of the hessian, as well as a lot of computations on kernel methods (like pyKeops does), but I sincerly don't know how much this can be made efficient.

@zou3519 zou3519 added the actionable It is clear what should be done for this issue label Dec 2, 2021
@kxhit
Copy link

kxhit commented Apr 13, 2022

Hi! Thanks a lot for building this awesome functorch!

I have the same issue. I'm using fmodel, params, buffers = combine_state_for_ensemble(models) to stack models and optimizing the params in a training loop. After this, I wish to update each origin model's state_dict(). I can't find a nice way to achieve this. Actually what I am doing is

with torch.no_grad():
    for idx, model in enumerate(models):
        for i, param in enumerate(model.parameters()):
            param.set_(params[i][idx])

Hope I can get a nicer way to achieve this with a good tutorial. Thanks!

@zou3519
Copy link
Contributor

zou3519 commented Apr 13, 2022

@kxhit thank you for your feedback. Could you give a little more context about why you want to update each original model's state_dict?

@kxhit
Copy link

kxhit commented Apr 13, 2022

@zou3519 Hi, thanks for your quick reply.

In my case, I'm training many tiny networks and need to use the up-to-date network's weights every a few steps. So I need to assign batch weights back to the original models frequently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable It is clear what should be done for this issue
Projects
None yet
Development

No branches or pull requests

3 participants