mlpractical/network_builder.py

178 lines
9.0 KiB
Python

import tensorflow as tf
from network_architectures import VGGClassifier, FCCLayerClassifier
class ClassifierNetworkGraph:
def __init__(self, input_x, target_placeholder, dropout_rate,
batch_size=100, n_classes=100, is_training=True, augment_rotate_flag=True,
tensorboard_use=False, use_batch_normalization=False, strided_dim_reduction=True,
network_name='VGG_classifier'):
"""
Initializes a Classifier Network Graph that can build models, train, compute losses and save summary statistics
and images
:param input_x: A placeholder that will feed the input images, usually of size [batch_size, height, width,
channels]
:param target_placeholder: A target placeholder of size [batch_size,]. The classes should be in index form
i.e. not one hot encoding, that will be done automatically by tf
:param dropout_rate: A placeholder of size [None] that holds a single float that defines the amount of dropout
to apply to the network. i.e. for 0.1 drop 0.1 of neurons
:param batch_size: The batch size
:param num_channels: Number of channels
:param n_classes: Number of classes we will be classifying
:param is_training: A placeholder that will indicate whether we are training or not
:param augment_rotate_flag: A placeholder indicating whether to apply rotations augmentations to our input data
:param tensorboard_use: Whether to use tensorboard in this experiment
:param use_batch_normalization: Whether to use batch normalization between layers
:param strided_dim_reduction: Whether to use strided dim reduction instead of max pooling
"""
self.batch_size = batch_size
if network_name == "VGG_classifier":
self.c = VGGClassifier(self.batch_size, name="classifier_neural_network",
batch_norm_use=use_batch_normalization, num_classes=n_classes,
layer_stage_sizes=[64, 128, 256], strided_dim_reduction=strided_dim_reduction)
elif network_name == "FCCClassifier":
self.c = FCCLayerClassifier(self.batch_size, name="classifier_neural_network",
batch_norm_use=use_batch_normalization, num_classes=n_classes,
layer_stage_sizes=[64, 128, 256], strided_dim_reduction=strided_dim_reduction)
self.input_x = input_x
self.dropout_rate = dropout_rate
self.targets = target_placeholder
self.training_phase = is_training
self.n_classes = n_classes
self.iterations_trained = 0
self.augment_rotate = augment_rotate_flag
self.is_tensorboard = tensorboard_use
self.strided_dim_reduction = strided_dim_reduction
self.use_batch_normalization = use_batch_normalization
def loss(self):
"""build models, calculates losses, saves summary statistcs and images.
Returns:
dict of losses.
"""
with tf.name_scope("losses"):
image_inputs = self.data_augment_batch(self.input_x) # conditionally apply augmentaions
true_outputs = self.targets
# produce predictions and get layer features to save for visual inspection
preds, layer_features = self.c(image_input=image_inputs, training=self.training_phase,
dropout_rate=self.dropout_rate)
# compute loss and accuracy
correct_prediction = tf.equal(tf.argmax(preds, 1), tf.cast(true_outputs, tf.int64))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
crossentropy_loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(labels=true_outputs, logits=preds))
# add loss and accuracy to collections
tf.add_to_collection('crossentropy_losses', crossentropy_loss)
tf.add_to_collection('accuracy', accuracy)
# save summaries for the losses, accuracy and image summaries for input images, augmented images
# and the layer features
if len(self.input_x.get_shape().as_list()) == 4:
self.save_features(name="VGG_features", features=layer_features)
tf.summary.image('image', [tf.concat(tf.unstack(self.input_x, axis=0), axis=0)])
tf.summary.image('augmented_image', [tf.concat(tf.unstack(image_inputs, axis=0), axis=0)])
tf.summary.scalar('crossentropy_losses', crossentropy_loss)
tf.summary.scalar('accuracy', accuracy)
return {"crossentropy_losses": tf.add_n(tf.get_collection('crossentropy_losses'),
name='total_classification_loss'),
"accuracy": tf.add_n(tf.get_collection('accuracy'), name='total_accuracy')}
def save_features(self, name, features, num_rows_in_grid=4):
"""
Saves layer features in a grid to be used in tensorboard
:param name: Features name
:param features: A list of feature tensors
"""
for i in range(len(features)):
shape_in = features[i].get_shape().as_list()
channels = shape_in[3]
y_channels = num_rows_in_grid
x_channels = int(channels / y_channels)
activations_features = tf.reshape(features[i], shape=(shape_in[0], shape_in[1], shape_in[2],
y_channels, x_channels))
activations_features = tf.unstack(activations_features, axis=4)
activations_features = tf.concat(activations_features, axis=2)
activations_features = tf.unstack(activations_features, axis=3)
activations_features = tf.concat(activations_features, axis=1)
activations_features = tf.expand_dims(activations_features, axis=3)
tf.summary.image('{}_{}'.format(name, i), activations_features)
def rotate_image(self, image):
"""
Rotates a single image
:param image: An image to rotate
:return: A rotated or a non rotated image depending on the result of the flip
"""
no_rotation_flip = tf.unstack(
tf.random_uniform([1], minval=1, maxval=100, dtype=tf.int32, seed=None,
name=None)) # get a random number between 1 and 100
flip_boolean = tf.less_equal(no_rotation_flip[0], 50)
# if that number is less than or equal to 50 then set to true
random_variable = tf.unstack(tf.random_uniform([1], minval=1, maxval=3, dtype=tf.int32, seed=None, name=None))
# get a random variable between 1 and 3 for how many degrees the rotation will be i.e. k=1 means 1*90,
# k=2 2*90 etc.
image = tf.cond(flip_boolean, lambda: tf.image.rot90(image, k=random_variable[0]),
lambda: image) # if flip_boolean is true the rotate if not then do not rotate
return image
def rotate_batch(self, batch_images):
"""
Rotate a batch of images
:param batch_images: A batch of images
:return: A rotated batch of images (some images will not be rotated if their rotation flip ends up False)
"""
shapes = map(int, list(batch_images.get_shape()))
if len(list(batch_images.get_shape())) < 4:
return batch_images
batch_size, x, y, c = shapes
with tf.name_scope('augment'):
batch_images_unpacked = tf.unstack(batch_images)
new_images = []
for image in batch_images_unpacked:
new_images.append(self.rotate_image(image))
new_images = tf.stack(new_images)
new_images = tf.reshape(new_images, (batch_size, x, y, c))
return new_images
def data_augment_batch(self, batch_images):
"""
Augments data with a variety of augmentations, in the current state only does rotations.
:param batch_images: A batch of images to augment
:return: Augmented data
"""
batch_images = tf.cond(self.augment_rotate, lambda: self.rotate_batch(batch_images), lambda: batch_images)
return batch_images
def train(self, losses, learning_rate=1e-3, beta1=0.9):
"""
Args:
losses dict.
Returns:
train op.
"""
c_opt = tf.train.AdamOptimizer(beta1=beta1, learning_rate=learning_rate)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Needed for correct batch norm usage
with tf.control_dependencies(update_ops):
c_error_opt_op = c_opt.minimize(losses["crossentropy_losses"], var_list=self.c.variables,
colocate_gradients_with_ops=True)
return c_error_opt_op
def init_train(self):
"""
Builds graph ops and returns them
:return: Summary, losses and training ops
"""
losses_ops = self.loss()
c_error_opt_op = self.train(losses_ops)
summary_op = tf.summary.merge_all()
return summary_op, losses_ops, c_error_opt_op