Skip to content

Commit

Permalink
Fix irfft bug from torch >= 1.7.0
Browse files Browse the repository at this point in the history
  • Loading branch information
greentfrapp committed Mar 24, 2021
1 parent c5ff13c commit 3191907
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion lucent/optvis/param/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
TORCH_VERSION = torch.__version__


def pixel_image(shape, sd=None):
Expand Down Expand Up @@ -54,7 +55,14 @@ def fft_image(shape, sd=None, decay_power=1):

def inner():
scaled_spectrum_t = scale * spectrum_real_imag_t
image = torch.irfft(scaled_spectrum_t, 2, normalized=True, signal_sizes=(h, w))
if TORCH_VERSION >= "1.7.0":
import torch.fft
if type(scaled_spectrum_t) is not torch.complex64:
scaled_spectrum_t = torch.view_as_complex(scaled_spectrum_t)
image = torch.fft.irfftn(scaled_spectrum_t, s=(h, w), norm='ortho')
else:
import torch
image = torch.irfft(scaled_spectrum_t, 2, normalized=True, signal_sizes=(h, w))
image = image[:batch, :channels, :h, :w]
magic = 4.0 # Magic constant from Lucid library; increasing this seems to reduce saturation
image = image / magic
Expand Down

0 comments on commit 3191907

Please sign in to comment.