Remove num_channels variable, since it is now reduntant
This commit is contained in:
parent
b71ea0c5fe
commit
e6f570109d
@ -44,7 +44,7 @@ dropout_rate = tf.placeholder(tf.float32, name='dropout-prob')
|
||||
|
||||
classifier_network = ClassifierNetworkGraph(input_x=data_inputs, target_placeholder=data_targets,
|
||||
dropout_rate=dropout_rate, batch_size=batch_size,
|
||||
num_channels=train_data.inputs.shape[2], n_classes=train_data.num_classes,
|
||||
n_classes=train_data.num_classes,
|
||||
is_training=training_phase, augment_rotate_flag=rotate_data,
|
||||
strided_dim_reduction=strided_dim_reduction,
|
||||
use_batch_normalization=batch_norm) # initialize our computational graph
|
||||
|
@ -44,8 +44,8 @@ dropout_rate = tf.placeholder(tf.float32, name='dropout-prob')
|
||||
|
||||
classifier_network = ClassifierNetworkGraph(input_x=data_inputs, target_placeholder=data_targets,
|
||||
dropout_rate=dropout_rate, batch_size=batch_size,
|
||||
num_channels=train_data.inputs.shape[2], n_classes=train_data.num_classes,
|
||||
is_training=training_phase, augment_rotate_flag=rotate_data,
|
||||
n_classes=train_data.num_classes, is_training=training_phase,
|
||||
augment_rotate_flag=rotate_data,
|
||||
strided_dim_reduction=strided_dim_reduction,
|
||||
use_batch_normalization=batch_norm) # initialize our computational graph
|
||||
|
||||
|
@ -44,7 +44,7 @@ dropout_rate = tf.placeholder(tf.float32, name='dropout-prob')
|
||||
|
||||
classifier_network = ClassifierNetworkGraph(input_x=data_inputs, target_placeholder=data_targets,
|
||||
dropout_rate=dropout_rate, batch_size=batch_size,
|
||||
num_channels=train_data.inputs.shape[2], n_classes=train_data.num_classes,
|
||||
n_classes=train_data.num_classes,
|
||||
is_training=training_phase, augment_rotate_flag=rotate_data,
|
||||
strided_dim_reduction=strided_dim_reduction,
|
||||
use_batch_normalization=batch_norm) # initialize our computational graph
|
||||
|
@ -44,7 +44,7 @@ dropout_rate = tf.placeholder(tf.float32, name='dropout-prob')
|
||||
classifier_network = ClassifierNetworkGraph(network_name='FCCClassifier',
|
||||
input_x=data_inputs, target_placeholder=data_targets,
|
||||
dropout_rate=dropout_rate, batch_size=batch_size,
|
||||
num_channels=1, n_classes=train_data.num_classes,
|
||||
n_classes=train_data.num_classes,
|
||||
is_training=training_phase, augment_rotate_flag=rotate_data,
|
||||
strided_dim_reduction=strided_dim_reduction,
|
||||
use_batch_normalization=batch_norm) # initialize our computational graph
|
||||
|
@ -44,7 +44,7 @@ dropout_rate = tf.placeholder(tf.float32, name='dropout-prob')
|
||||
classifier_network = ClassifierNetworkGraph(network_name='FCCClassifier',
|
||||
input_x=data_inputs, target_placeholder=data_targets,
|
||||
dropout_rate=dropout_rate, batch_size=batch_size,
|
||||
num_channels=1, n_classes=train_data.num_classes,
|
||||
n_classes=train_data.num_classes,
|
||||
is_training=training_phase, augment_rotate_flag=rotate_data,
|
||||
strided_dim_reduction=strided_dim_reduction,
|
||||
use_batch_normalization=batch_norm) # initialize our computational graph
|
||||
|
@ -6,7 +6,7 @@ from utils.network_summary import count_parameters
|
||||
|
||||
|
||||
class VGGClassifier:
|
||||
def __init__(self, batch_size, layer_stage_sizes, name, num_classes, num_channels=1, batch_norm_use=False,
|
||||
def __init__(self, batch_size, layer_stage_sizes, name, num_classes, batch_norm_use=False,
|
||||
inner_layer_depth=2, strided_dim_reduction=True):
|
||||
|
||||
"""
|
||||
@ -28,7 +28,6 @@ class VGGClassifier:
|
||||
"""
|
||||
self.reuse = False
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.layer_stage_sizes = layer_stage_sizes
|
||||
self.name = name
|
||||
self.num_classes = num_classes
|
||||
@ -89,7 +88,7 @@ class VGGClassifier:
|
||||
|
||||
|
||||
class FCCLayerClassifier:
|
||||
def __init__(self, batch_size, layer_stage_sizes, name, num_classes, num_channels=1, batch_norm_use=False,
|
||||
def __init__(self, batch_size, layer_stage_sizes, name, num_classes, batch_norm_use=False,
|
||||
inner_layer_depth=2, strided_dim_reduction=True):
|
||||
|
||||
"""
|
||||
@ -111,7 +110,6 @@ class FCCLayerClassifier:
|
||||
"""
|
||||
self.reuse = False
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.layer_stage_sizes = layer_stage_sizes
|
||||
self.name = name
|
||||
self.num_classes = num_classes
|
||||
|
@ -1,11 +1,10 @@
|
||||
import tensorflow as tf
|
||||
|
||||
from network_architectures import VGGClassifier, FCCLayerClassifier
|
||||
|
||||
|
||||
class ClassifierNetworkGraph:
|
||||
def __init__(self, input_x, target_placeholder, dropout_rate,
|
||||
batch_size=100, num_channels=1, n_classes=100, is_training=True, augment_rotate_flag=True,
|
||||
batch_size=100, n_classes=100, is_training=True, augment_rotate_flag=True,
|
||||
tensorboard_use=False, use_batch_normalization=False, strided_dim_reduction=True,
|
||||
network_name='VGG_classifier'):
|
||||
|
||||
@ -30,14 +29,12 @@ class ClassifierNetworkGraph:
|
||||
self.batch_size = batch_size
|
||||
if network_name == "VGG_classifier":
|
||||
self.c = VGGClassifier(self.batch_size, name="classifier_neural_network",
|
||||
batch_norm_use=use_batch_normalization, num_channels=num_channels,
|
||||
num_classes=n_classes, layer_stage_sizes=[64, 128, 256],
|
||||
strided_dim_reduction=strided_dim_reduction)
|
||||
batch_norm_use=use_batch_normalization, num_classes=n_classes,
|
||||
layer_stage_sizes=[64, 128, 256], strided_dim_reduction=strided_dim_reduction)
|
||||
elif network_name == "FCCClassifier":
|
||||
self.c = FCCLayerClassifier(self.batch_size, name="classifier_neural_network",
|
||||
batch_norm_use=use_batch_normalization, num_channels=num_channels,
|
||||
num_classes=n_classes, layer_stage_sizes=[64, 128, 256],
|
||||
strided_dim_reduction=strided_dim_reduction)
|
||||
batch_norm_use=use_batch_normalization, num_classes=n_classes,
|
||||
layer_stage_sizes=[64, 128, 256], strided_dim_reduction=strided_dim_reduction)
|
||||
|
||||
self.input_x = input_x
|
||||
self.dropout_rate = dropout_rate
|
||||
|
Loading…
Reference in New Issue
Block a user