mlpractical/train_evaluate_emnist_classification_system.py

96 lines
4.7 KiB
Python
Raw Permalink Normal View History

2024-10-22 19:59:06 +02:00
import numpy as np
import data_providers as data_providers
from arg_extractor import get_args
from data_augmentations import Cutout
from experiment_builder import ExperimentBuilder
from model_architectures import ConvolutionalNetwork
args, device = get_args() # get arguments from command line
rng = np.random.RandomState(seed=args.seed) # set the seeds for the experiment
from torchvision import transforms
import torch
torch.manual_seed(seed=args.seed) # sets pytorch's seed
if args.dataset_name == 'emnist':
train_data = data_providers.EMNISTDataProvider('train', batch_size=args.batch_size,
rng=rng,
flatten=False) # initialize our rngs using the argument set seed
val_data = data_providers.EMNISTDataProvider('valid', batch_size=args.batch_size,
rng=rng,
flatten=False) # initialize our rngs using the argument set seed
test_data = data_providers.EMNISTDataProvider('test', batch_size=args.batch_size,
rng=rng,
flatten=False) # initialize our rngs using the argument set seed
num_output_classes = train_data.num_classes
elif args.dataset_name == 'cifar10':
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
Cutout(n_holes=1, length=14),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = data_providers.CIFAR10(root='data', set_name='train', download=True, transform=transform_train)
train_data = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)
valset = data_providers.CIFAR10(root='data', set_name='val', download=True, transform=transform_test)
val_data = torch.utils.data.DataLoader(valset, batch_size=100, shuffle=False, num_workers=2)
testset = data_providers.CIFAR10(root='data', set_name='test', download=True, transform=transform_test)
test_data = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
num_output_classes = 10
elif args.dataset_name == 'cifar100':
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
Cutout(n_holes=1, length=14),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = data_providers.CIFAR100(root='data', set_name='train', download=True, transform=transform_train)
train_data = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)
valset = data_providers.CIFAR100(root='data', set_name='val', download=True, transform=transform_test)
val_data = torch.utils.data.DataLoader(valset, batch_size=100, shuffle=False, num_workers=2)
testset = data_providers.CIFAR100(root='data', set_name='test', download=True, transform=transform_test)
test_data = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
num_output_classes = 100
custom_conv_net = ConvolutionalNetwork( # initialize our network object, in this case a ConvNet
input_shape=(args.batch_size, args.image_num_channels, args.image_height, args.image_height),
dim_reduction_type=args.dim_reduction_type, num_filters=args.num_filters, num_layers=args.num_layers,
use_bias=False,
num_output_classes=num_output_classes)
conv_experiment = ExperimentBuilder(network_model=custom_conv_net, use_gpu=args.use_gpu,
experiment_name=args.experiment_name,
num_epochs=args.num_epochs,
weight_decay_coefficient=args.weight_decay_coefficient,
continue_from_epoch=args.continue_from_epoch,
train_data=train_data, val_data=val_data,
test_data=test_data) # build an experiment object
experiment_metrics, test_metrics = conv_experiment.run_experiment() # run experiment and return experiment metrics