55 lines
1.4 KiB
Python
55 lines
1.4 KiB
Python
|
from PIL import Image
|
||
|
from numpy import random
|
||
|
from torchvision import transforms
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
|
||
|
class Cutout(object):
|
||
|
"""Randomly mask out one or more patches from an image.
|
||
|
Args:
|
||
|
n_holes (int): Number of patches to cut out of each image.
|
||
|
length (int): The length (in pixels) of each square patch.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, n_holes, length):
|
||
|
self.n_holes = n_holes
|
||
|
self.length = length
|
||
|
|
||
|
def __call__(self, img):
|
||
|
"""
|
||
|
Args:
|
||
|
img (Tensor): Tensor image of size (C, H, W).
|
||
|
Returns:
|
||
|
Tensor: Image with n_holes of dimension length x length cut out of it.
|
||
|
"""
|
||
|
|
||
|
from_PIL = False
|
||
|
|
||
|
if type(img) == Image.Image:
|
||
|
from_PIL = True
|
||
|
img = transforms.ToTensor()(img)
|
||
|
|
||
|
h = img.size(1)
|
||
|
w = img.size(2)
|
||
|
|
||
|
mask = np.ones((h, w), np.float32)
|
||
|
|
||
|
for n in range(self.n_holes):
|
||
|
y = random.randint(0, h)
|
||
|
x = random.randint(0, w)
|
||
|
|
||
|
y1 = np.clip(y - self.length // 2, 0, h)
|
||
|
y2 = np.clip(y + self.length // 2, 0, h)
|
||
|
x1 = np.clip(x - self.length // 2, 0, w)
|
||
|
x2 = np.clip(x + self.length // 2, 0, w)
|
||
|
|
||
|
mask[y1: y2, x1: x2] = 0.
|
||
|
|
||
|
mask = torch.from_numpy(mask)
|
||
|
mask = mask.expand_as(img)
|
||
|
img = img * mask
|
||
|
|
||
|
if from_PIL:
|
||
|
img = transforms.ToPILImage()(img)
|
||
|
|
||
|
return img
|