Skip to content

Commit

Permalink
add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 22, 2021
1 parent 2c368d1 commit 976f489
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 1 deletion.
33 changes: 33 additions & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Test

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
python setup.py test
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.25.1',
version = '0.25.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand All @@ -19,6 +19,12 @@
'torch>=1.6',
'torchvision'
],
setup_requires=[
'pytest-runner',
],
tests_require=[
'pytest'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
Expand Down
20 changes: 20 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
from vit_pytorch import ViT

def test():
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img)
assert preds.shape == (1, 1000), 'correct logits outputted'

0 comments on commit 976f489

Please sign in to comment.