Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

function(Jacobian)-dot-vector and vector-Jacobian-vector function #1056

Open
veya2ztn opened this issue Oct 28, 2022 · 0 comments
Open

function(Jacobian)-dot-vector and vector-Jacobian-vector function #1056

veya2ztn opened this issue Oct 28, 2022 · 0 comments

Comments

@veya2ztn
Copy link

Hi,

I'd like to use functorh to realize following loss:

Question demonstrate

assume the

  • the dimension of output tensor is $O$ and we will use $y^\gamma$ mark each element.
  • the dimension of input tensor (primal) is $I$ and we use $x_\alpha$ mark each element.
  • we have a pytorch model $f$ with parameter marked as $W$ to map the input to output $f:\vec{x}(R^I) \rightarrow \vec{y}(R^O)$

there exists the Jacobian matrix $(O\times I)$ marked $J_\alpha^\gamma=\frac{\partial y^{\gamma}}{\partial x_\alpha}$

I am want to calculate two term

$$ L1=\sum_\gamma(\sum_\alpha J_\alpha^{\gamma}-1)^2 $$

$$ L2 =\sum_\gamma [\sum_\alpha (J_\alpha^{\gamma})^2-1]^2 $$

as well as there gradient of $W$, $\frac{\partial L1}{\partial W}$ and $\frac{\partial L2}{\partial W}$ for the gradient decent update.

This is easier to realize with the help of functorch , I post a toy example below

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import torch
import torch.nn.functional as F
import functorch
from functorch import jacrev,jacfwd
from functorch import make_functional, vmap, grad
B=200
I=100
O=300
class MyModel(torch.nn.Module):
    def __init__(self, in_chan, out_chan):
        super().__init__()
        self.backbone = torch.nn.Linear(in_chan, out_chan,bias=False)
    def forward(self,x):
        return self.backbone(x)**2
model= MyModel(I, O).cuda()
x    = torch.randn(B, I).cuda()
cotangents = torch.ones(B,I).cuda()
func_model, params = make_functional(model)

### ---> to calculate the dL1/dw term
def Normlization_Term_1(params,x):
        return ((functorch.jvp(lambda x:func_model(params,x), (x,), (cotangents,)
            )[1]-1)**2).mean()
Derivation_Term_1 = jacrev(Normlization_Term_1, argnums=0)(params, x)

### ---> to calculate the dL2/dw term
Normlization_Term_2= lambda params,x:(
    (vmap(jacrev(func_model, argnums=1), (None, 0))(params, x)**2).sum(-1)-1
    )**2
Derivation_Term_2 = jacrev(Normlization_Term_2, argnums=0)(params, x)

Problem

The idea is to calculate:

  • $\sum_\alpha J_\alpha^{\gamma}$ this term is easy to realize by the functorch.jvp and torch.autograd.functional.jvp by setting the cotangents as all-one tensor torch.ones(B,I). If we do the summation $\sum_\gamma$ in the wrapped function and pass it to calculate the Jacobian of model's parameter $W$, it run fast and cost small memory.
  • However, when calculate the next term $\sum_\alpha (J_\alpha^{\gamma})^2$ . There is no jvp function here and I have to create the full Jacobian of primal followed with a .sum() function to obtain result. In such a case, we will face OOM problem. My machine is A100-80G.

I suppose it is because we have to access the full Jacobian matrix $J_\alpha^{\gamma}$ in the second case which is too large to store during computation.

The OOM issue is also reported by #636 (comment) and (possibly) solved by the recent update with chunks option in #680 (comment)

My ideas are

  • Can we build a function in native that produce the F(Jacobian)-dot-vector output vector $f(J)\cdot \vec{n}\rightarrow \vec{v}$

    if the $f:x\rightarrow x$ , then it is the functorch.jvp $J\cdot \vec{n}\rightarrow \vec{v}$

    if the $f: x\rightarrow x^2$, the it is the second term in my example. But this time, since it doesn't to access the full Jacobian, it becomes more memory efficient.

  • some usages of Jacobian function would only require

    • Jacobian-dot-vector produce a vector, covered by the functorch.jvp
    • vector-dot-Jacobian produce a vector, covered by the functorch.vjp
    • vecotr-dot-jacobian-dot-vector produce a scalar, need to be realized by the jvp or vjp

    When do gradient calculation on those output, the memory usage to store intermediate tensor is around D of vector x N of parameters. Is that possible to realize a native vecotr-dot-jacobian-dot-vector without access those large intermediate and become memory efficient?


I check the source code in jvp , it directly use the dual mode of pytorch-fwdad and return the jvp term directly from _unpack_dual , so I am afraid this problem may beyond the scope in functorch pipline.

Anyway, I look forward your discussion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant