from PIL import Image from numpy import random from torchvision import transforms import numpy as np import torch class Cutout(object): """Randomly mask out one or more patches from an image. Args: n_holes (int): Number of patches to cut out of each image. length (int): The length (in pixels) of each square patch. """ def __init__(self, n_holes, length): self.n_holes = n_holes self.length = length def __call__(self, img): """ Args: img (Tensor): Tensor image of size (C, H, W). Returns: Tensor: Image with n_holes of dimension length x length cut out of it. """ from_PIL = False if type(img) == Image.Image: from_PIL = True img = transforms.ToTensor()(img) h = img.size(1) w = img.size(2) mask = np.ones((h, w), np.float32) for n in range(self.n_holes): y = random.randint(0, h) x = random.randint(0, w) y1 = np.clip(y - self.length // 2, 0, h) y2 = np.clip(y + self.length // 2, 0, h) x1 = np.clip(x - self.length // 2, 0, w) x2 = np.clip(x + self.length // 2, 0, w) mask[y1: y2, x1: x2] = 0. mask = torch.from_numpy(mask) mask = mask.expand_as(img) img = img * mask if from_PIL: img = transforms.ToPILImage()(img) return img