You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
$\sum_\alpha J_\alpha^{\gamma}$ this term is easy to realize by the functorch.jvp and torch.autograd.functional.jvp by setting the cotangents as all-one tensor torch.ones(B,I). If we do the summation $\sum_\gamma$ in the wrapped function and pass it to calculate the Jacobian of model's parameter $W$, it run fast and cost small memory.
However, when calculate the next term $\sum_\alpha (J_\alpha^{\gamma})^2$ . There is no jvp function here and I have to create the full Jacobian of primal followed with a .sum() function to obtain result. In such a case, we will face OOM problem. My machine is A100-80G.
I suppose it is because we have to access the full Jacobian matrix $J_\alpha^{\gamma}$ in the second case which is too large to store during computation.
The OOM issue is also reported by #636 (comment) and (possibly) solved by the recent update with chunks option in #680 (comment)
My ideas are
Can we build a function in native that produce the F(Jacobian)-dot-vector output vector $f(J)\cdot \vec{n}\rightarrow \vec{v}$
if the $f:x\rightarrow x$ , then it is the functorch.jvp$J\cdot \vec{n}\rightarrow \vec{v}$
if the $f: x\rightarrow x^2$, the it is the second term in my example. But this time, since it doesn't to access the full Jacobian, it becomes more memory efficient.
some usages of Jacobian function would only require
Jacobian-dot-vector produce a vector, covered by the functorch.jvp
vector-dot-Jacobian produce a vector, covered by the functorch.vjp
vecotr-dot-jacobian-dot-vector produce a scalar, need to be realized by the jvp or vjp
When do gradient calculation on those output, the memory usage to store intermediate tensor is around D of vector x N of parameters. Is that possible to realize a native vecotr-dot-jacobian-dot-vector without access those large intermediate and become memory efficient?
I check the source code in jvp , it directly use the dual mode of pytorch-fwdad and return the jvp term directly from _unpack_dual , so I am afraid this problem may beyond the scope in functorch pipline.
Anyway, I look forward your discussion.
The text was updated successfully, but these errors were encountered:
Hi,
I'd like to use
functorh
to realize following loss:Question demonstrate
assume the
primal
) isthere exists the Jacobian matrix$(O\times I)$ marked $J_\alpha^\gamma=\frac{\partial y^{\gamma}}{\partial x_\alpha}$
I am want to calculate two term
as well as there gradient of$W$ , $\frac{\partial L1}{\partial W}$ and $\frac{\partial L2}{\partial W}$ for the gradient decent update.
This is easier to realize with the help of
functorch
, I post a toy example belowProblem
The idea is to calculate:
functorch.jvp
andtorch.autograd.functional.jvp
by setting thecotangents
as all-one tensortorch.ones(B,I)
. If we do the summationjvp
function here and I have to create the full Jacobian ofprimal
followed with a.sum()
function to obtain result. In such a case, we will face OOM problem. My machine is A100-80G.I suppose it is because we have to access the full Jacobian matrix$J_\alpha^{\gamma}$ in the second case which is too large to store during computation.
The OOM issue is also reported by #636 (comment) and (possibly) solved by the recent update with
chunks
option in #680 (comment)My ideas are
Can we build a function in native that produce the$f(J)\cdot \vec{n}\rightarrow \vec{v}$
F(Jacobian)-dot-vector
output vectorif the$f:x\rightarrow x$ , then it is the $J\cdot \vec{n}\rightarrow \vec{v}$
functorch.jvp
if the$f: x\rightarrow x^2$ , the it is the second term in my example. But this time, since it doesn't to access the full Jacobian, it becomes more memory efficient.
some usages of Jacobian function would only require
Jacobian-dot-vector
produce a vector, covered by thefunctorch.jvp
vector-dot-Jacobian
produce a vector, covered by thefunctorch.vjp
vecotr-dot-jacobian-dot-vector
produce a scalar, need to be realized by thejvp
orvjp
When do gradient calculation on those output, the memory usage to store intermediate tensor is around
D of vector
xN of parameters
. Is that possible to realize a nativevecotr-dot-jacobian-dot-vector
without access those large intermediate and become memory efficient?I check the source code in
jvp
, it directly use thedual
mode ofpytorch-fwdad
and return thejvp
term directly from_unpack_dual
, so I am afraid this problem may beyond the scope infunctorch
pipline.Anyway, I look forward your discussion.
The text was updated successfully, but these errors were encountered: