-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathData_Augmentation.py
36 lines (27 loc) · 1.05 KB
/
Data_Augmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets
transform_data = transforms.Compose([
transforms.RandomRotation(degrees = 30),
transforms.ColorJitter(brightness = 0.40),
transforms.RandomHorizontalFlip(p = 0.5),
transforms.RandomGrayscale(p = 0.2),
transforms.ToTensor()
])
def data_aug(genre, times, file_path):
dataset = datasets.ImageFolder(file_path, transform = transform_data)
data_loader = torch.utils.data.DataLoader(dataset, batch_size = 1, shuffle = True)
# // Checking the images
# for i in data_loader:
# image, target = i
# plt.imshow(np.transpose(image[0].numpy(), (1, 2, 0)))
# plt.show()
image_num = 1
for _ in range (times):
for batch in data_loader:
img, i = batch
save_image(img, 'Augmented Images/' + genre + '/' + genre + '_image_Aug_' + str(image_num) + '.jpg')
image_num+=1
# data_aug('Hentai', 6)