mlpractical/data_augmentations.py
2024-10-23 01:59:06 +08:00

55 lines
1.4 KiB
Python

from PIL import Image
from numpy import random
from torchvision import transforms
import numpy as np
import torch
class Cutout(object):
"""Randomly mask out one or more patches from an image.
Args:
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
"""
def __init__(self, n_holes, length):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
"""
Args:
img (Tensor): Tensor image of size (C, H, W).
Returns:
Tensor: Image with n_holes of dimension length x length cut out of it.
"""
from_PIL = False
if type(img) == Image.Image:
from_PIL = True
img = transforms.ToTensor()(img)
h = img.size(1)
w = img.size(2)
mask = np.ones((h, w), np.float32)
for n in range(self.n_holes):
y = random.randint(0, h)
x = random.randint(0, w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img = img * mask
if from_PIL:
img = transforms.ToPILImage()(img)
return img