28 lines
853 B
Python
28 lines
853 B
Python
def count_parameters(network_variables, name):
|
|
"""
|
|
This method counts the total number of unique parameters for a list of variable objects
|
|
:param network_variables: A list of tf network variable objects
|
|
:param name: Name of the network
|
|
"""
|
|
total_parameters = 0
|
|
for variable in network_variables:
|
|
# shape is an array of tf.Dimension
|
|
print(variable)
|
|
shape = variable.get_shape()
|
|
variable_parametes = 1
|
|
for dim in shape:
|
|
variable_parametes *= dim.value
|
|
|
|
total_parameters += variable_parametes
|
|
print(name, "has a total of", total_parameters, "parameters")
|
|
|
|
|
|
def view_names_of_variables(variables):
|
|
"""
|
|
View all variable names in a tf variable list
|
|
:param variables: A list of tf variables
|
|
"""
|
|
for variable in variables:
|
|
print(variable)
|
|
|