You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
27 lines
1.2 KiB
Python
27 lines
1.2 KiB
Python
import tensorflow as tf
|
|
from tensorflow.keras.callbacks import CSVLogger
|
|
mnist = tf.keras.datasets.mnist
|
|
|
|
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
|
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
|
|
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
|
|
x_train, x_test = x_train / 255.0, x_test / 255.0
|
|
|
|
y_train = tf.keras.utils.to_categorical(y_train)
|
|
y_test = tf.keras.utils.to_categorical(y_test)
|
|
|
|
model = tf.keras.models.Sequential()
|
|
model.add(tf.keras.layers.Conv2D(24,kernel_size=5,padding='same',activation='relu',
|
|
input_shape=(28,28,1)))
|
|
model.add(tf.keras.layers.MaxPool2D())
|
|
model.add(tf.keras.layers.Conv2D(64,kernel_size=5,padding='same',activation='relu'))
|
|
model.add(tf.keras.layers.MaxPool2D(padding='same'))
|
|
model.add(tf.keras.layers.Flatten())
|
|
model.add(tf.keras.layers.Dense(256, activation='relu'))
|
|
model.add(tf.keras.layers.Dense(10, activation='softmax'))
|
|
model.compile(optimizer=tf.keras.optimizers.SGD(), loss="categorical_crossentropy", metrics=["accuracy"])
|
|
|
|
|
|
csv_logger = CSVLogger('SGD_01_b32.log')
|
|
history = model.fit(x_train, y_train, validation_data=(x_test, y_test), batch_size = 32, epochs=20, callbacks=[csv_logger])
|