Skip to content
forked from AgentMaker/PaTTA

a test times augmentation toolkit based on paddle2.0.

License

Notifications You must be signed in to change notification settings

hua-xu123/PaTTA

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Patta

Image Test Time Augmentation with Paddle2.0!

           Input
             |           # input batch of images 
        / / /|\ \ \      # apply augmentations (flips, rotation, scale, etc.)
       | | | | | | |     # pass augmented batches through model
       | | | | | | |     # reverse transformations for each batch of masks/labels
        \ \ \ / / /      # merge predictions (mean, max, gmean, etc.)
             |           # output batch of masks/labels
           Output

Table of Contents

  1. Quick Start
  2. Transforms
  3. Aliases
  4. Merge modes
  5. Installation

Quick start (Default Transforms)

Test

We support that you can use the following to test after defining the network.

Segmentation model wrapping [docstring]:
import patta as tta
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
Classification model wrapping [docstring]:
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())
Keypoints model wrapping [docstring]:
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)

Note: the model must return keypoints in the format Tensor([x1, y1, ..., xn, yn])

Predict

We support that you can use the following to test when you have the predictive model: __model____params__.

Load model [docstring]:
import patta as tta
model = tta.load_model(path='model', model_filename='__model__', params_filename='__params__')
Segmentation model wrapping [docstring]:
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
Classification model wrapping [docstring]:
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())
Keypoints model wrapping [docstring]:
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)

Advanced Examples (DIY Transforms)

Custom transform:
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta.Compose(
    [
        tta.HorizontalFlip(),
        tta.Rotate90(angles=[0, 180]),
        tta.Scale(scales=[1, 2, 4]),
        tta.Multiply(factors=[0.9, 1, 1.1]),        
    ]
)

tta_model = tta.SegmentationTTAWrapper(model, transforms)
Custom model (multi-input / multi-output)
# Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)

for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform() 
    
    # augment image
    augmented_image = transformer.augment_image(image)
    
    # pass to model
    model_output = model(augmented_image, another_input_data)
    
    # reverse augmentation for mask and label
    deaug_mask = transformer.deaugment_mask(model_output['mask'])
    deaug_label = transformer.deaugment_label(model_output['label'])
    
    # save results
    labels.append(deaug_mask)
    masks.append(deaug_label)
    
# reduce results as you want, e.g mean/max/min
label = mean(labels)
mask = mean(masks)

Optional Transforms

Transform Parameters Values
HorizontalFlip - -
VerticalFlip - -
Rotate90 angles List[0, 90, 180, 270]
Scale scales
interpolation
List[float]
"nearest"/"linear"
Resize sizes
original_size
interpolation
List[Tuple[int, int]]
Tuple[int,int]
"nearest"/"linear"
Add values List[float]
Multiply factors List[float]
FiveCrops crop_height
crop_width
int
int

Aliases (Combos)

  • flip_transform (horizontal + vertical flips)
  • hflip_transform (horizontal flip)
  • d4_transform (flips + rotation 0, 90, 180, 270)
  • multiscale_transform (scale transform, take scales as input parameter)
  • five_crop_transform (corner crops + center crop)
  • ten_crop_transform (five crops + five crops on horizontal flip)

Merge modes

Installation

PyPI:

# After downloading the whole dir
$ git clone https://github.com/AgentMaker/PaTTA.git
$ pip install PaTTA/

# or

$ pip install git+https://github.com/AgentMaker/PaTTA.git

Run tests

# run test_transforms.py and test_base.py for test
python test/test_transforms.py
python test/test_base.py

About

a test times augmentation toolkit based on paddle2.0.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%