From 5578ac472faf3903d4739ba783f3875b77177e57 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 23 Dec 2023 08:11:39 -0800 Subject: [PATCH] address https://github.com/lucidrains/vit-pytorch/issues/292 --- setup.py | 2 +- vit_pytorch/simple_vit_with_fft.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 23429ca..d27968d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.6.4', + version = '1.6.5', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/simple_vit_with_fft.py b/vit_pytorch/simple_vit_with_fft.py index caf1233..5bf6982 100644 --- a/vit_pytorch/simple_vit_with_fft.py +++ b/vit_pytorch/simple_vit_with_fft.py @@ -1,5 +1,5 @@ import torch -from torch.fft import fft +from torch.fft import fft2 from torch import nn from einops import rearrange, reduce, pack, unpack @@ -128,7 +128,7 @@ def forward(self, img): device, dtype = img.device, img.dtype x = self.to_patch_embedding(img) - freqs = torch.view_as_real(fft(img)) + freqs = torch.view_as_real(fft2(img)) f = self.to_freq_embedding(freqs)