You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

27 lines
1.2 KiB
Python

import tensorflow as tf
from tensorflow.keras.callbacks import CSVLogger
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(24,kernel_size=5,padding='same',activation='relu',
input_shape=(28,28,1)))
model.add(tf.keras.layers.MaxPool2D())
model.add(tf.keras.layers.Conv2D(64,kernel_size=5,padding='same',activation='relu'))
model.add(tf.keras.layers.MaxPool2D(padding='same'))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(256, activation='relu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.1), loss="categorical_crossentropy", metrics=["accuracy"])
csv_logger = CSVLogger('GD_1.log')
history = model.fit(x_train, y_train, validation_data=(x_test, y_test), batch_size = x_train.shape[0], epochs=20, callbacks=[csv_logger])