Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What if I do not have a pre-training model? #15

Open
1999kevin opened this issue Oct 21, 2024 · 2 comments
Open

What if I do not have a pre-training model? #15

1999kevin opened this issue Oct 21, 2024 · 2 comments

Comments

@1999kevin
Copy link

Hi, Zhengyang:

Nice work and thank you so much for sharing the codebase.

I notice that this codebase requires a pertaining model from edm to work as initialization. What if I want to run eCM on my own dataset? Do I need to run the EDM on my own dataset first?

More importantly, does eCM only support initialization with EDM? What if I initialize eCM with other diffusion models? And what about running eCM without an EDM initialization? Does it just converge to consistency training with the continuous-time schedule and better weighting functions?

@Gsunshine
Copy link
Member

Hi @1999kevin ,

If you want to run on your own dataset, you will need either a pretrained diffusion/flow model or do the pretraining yourself, as r=0 leads to it. But it needn't be EDM. You can pick up any forward processes and model parameterization satisfying the boundary condition. For many datasets, EDM and its hyperparameters can be suboptimal.

Please also be advised to use positional embeddings instead of Fourier embeddings since the former usually offers better stability.


If you have a pretrained model using other forward processes, you can simply change the model definition to accommodate your forward process. Here is an example concerning vanilla flow matching,

x_t = (1-t) * x_0 + t * eps.

Note that dx_t/dt = eps - x_0, x_0 = x_t - t * dx_t/dt.

Define your velocity prediction model as v(x_t, t). The flow pretraining can be formulated as

min || v(x_t, x) - dx_t/dt ||^2 = || v(x_t, x) - (eps - x_0) ||^2

Then parameterize your consistency function as f(x_t, t) = x_t - t * v(x_t, x) using the pretrained flow, followed by ECT or ECD shrinking $\Delta t \to \mathrm{d}t$. In this case, || f(x_t, t) - f(x_0, 0) ||^2 = t * || v(x_t, x) - (eps - x_0) ||^2. Combined with w(t) = 1/t, it leads to the standard flow pretraining, too.

If your pretrained model utilizes the cosine schedule or other forward processes, you can derive the model definition similarly.


The key takeaway is that you can safely put over 90% of your training budget into the pretraining stage, which ensures the stability to shrink $\Delta t \to \mathrm{d}t$ as precisely as possible. The scaling works for both pretraining & tuning/distillation.

Thanks,
Zhengyang

@1999kevin
Copy link
Author

Thank you very much for your response!

I wonder whether it is possible to include the pre-training stage in your codebase. In fact, according to your paper, if r=0, we can transform CMs ($$d(f(x_t), f(x_r))$$) to DMs ($$d(f(x_t), x_0)$$). I find that it is possible to achieve this by adjusting your mapping function from t to r.

As in equation (18), we have : $$\frac{r}{t}=1-\frac{1}{q^a} n(t)=1-\frac{1}{q^{\lfloor\text {iters } / d\rfloor}} n(t)$$. In the beginning of training, we have $$\frac{r}{t} = 0$$, and it meets the condition that r=0.

Then, if we want to force the model for DM objective with M iterations, we can simply modify this mapping function to a piecewise function: $$\frac{r}{t}=1-\frac{1}{q^a} n(t)=1-\frac{1}{q^{\lfloor(\text {iters}-M ) / d\rfloor}} n(t)$$ if iters > M; and $$\frac{r}{t} = 0$$ if iters<=M.

Is it current to include the pre-trainning stage in your codebase using such modification?

Additionally, when running your code on cifar10, I find that even though the FID increases and the visualization performance gets better during the training, the loss does not decrease and maintains at about 15.8. If the loss does not decrease, what does the model learn from the data and objective function?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants