Skip to content

Commit

Permalink
update minimum version for nested tensor of NaViT
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 10, 2024
1 parent 6693d47 commit 0449865
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install pytest
python -m pip install wheel
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
python -m pip install torch==2.5.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
Expand Down
4 changes: 2 additions & 2 deletions vit_pytorch/na_vit_nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch
import packaging.version as pkg_version

if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'):
print('nested tensor NaViT was tested on pytorch 2.4')
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
print('nested tensor NaViT was tested on pytorch 2.5')

from torch import nn, Tensor
import torch.nn.functional as F
Expand Down
4 changes: 2 additions & 2 deletions vit_pytorch/na_vit_nested_tensor_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch
import packaging.version as pkg_version

if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'):
print('nested tensor NaViT was tested on pytorch 2.4')
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
print('nested tensor NaViT was tested on pytorch 2.5')

from torch import nn, Tensor
import torch.nn.functional as F
Expand Down

0 comments on commit 0449865

Please sign in to comment.