[Keras深度学习浅尝]实战二·RNN实现Fashion MNIST 数据集分类
与我们上篇博文[Keras深度学习浅尝]实战一结构相同,修改的地方有,定义网络与模型训练两部分,可以对比着来看。通过使用RNN结构,预测准确率略有提升,可以通过修改超参数以获得更优结果。
代码部分
# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
# Helper libraries
import os
os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"
import numpy as np
import matplotlib.pyplot as plt
EAGER = True
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
print(train_images.shape,train_labels.shape)
train_images = train_images.reshape([-1,28,28]) / 255.0
test_images = test_images.reshape([-1,28,28]) / 255.0
model = keras.Sequential([
#(-1,28,28)->(-1,100)
keras. s.SimpleRNN(
# for batch_input_shape, if using tensorflow as the backend, we have to put None for the batch_size.
# Otherwise, model.evaluate() will get error.
input_shape=(28, 28), # Or: input_dim=INPUT_SIZE, input_length=TIME_STEPS,
units=256,
unroll=True),
keras. s.Dropout(rate=0.2),
#(-1,256)->(-1,10)
keras. s.Dense(10, activation=tf.nn.softmax)
])
print(model.summary())
lr = 0.001
epochs = 5
model.compile(optimizer=tf.train.AdamOptimizer(lr),
loss=\'sparse_categorical_crossentropy\',
metrics=[\'accuracy\'])
model.fit(train_images, train_labels, epochs=epochs,validation_data=[test_images[:1000],test_labels[:1000]])
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(np.argmax(model.predict(test_images[:10]),1),test_labels[:10])
输出结果
_________________________________________________________________
(type) Output Shape Param #
=================================================================
simple_rnn (SimpleRNN) (None, 256) 72960
_________________________________________________________________
dropout (Dropout) (None, 256) 0
_________________________________________________________________
dense (Dense) (None, 10) 2570
=================================================================
Total params: 75,530
Trainable params: 75,530
Non-trainable params: 0
_________________________________________________________________
None
Train on 60000 samples, validate on 1000 samples
Epoch 1/5
60000/60000 [==============================] - 56s 927us/step - loss: 0.7429 - acc: 0.7307 - val_loss: 0.6208 - val_acc: 0.7750
Epoch 2/5
60000/60000 [==============================] - 46s 759us/step - loss: 0.5935 - acc: 0.7876 - val_loss: 0.5550 - val_acc: 0.8060
Epoch 3/5
60000/60000 [==============================] - 50s 828us/step - loss: 0.5558 - acc: 0.8004 - val_loss: 0.4969 - val_acc: 0.8220
Epoch 4/5
60000/60000 [==============================] - 53s 886us/step - loss: 0.5267 - acc: 0.8100 - val_loss: 0.5298 - val_acc: 0.8080
Epoch 5/5
60000/60000 [==============================] - 62s 1ms/step - loss: 0.5243 - acc: 0.8115 - val_loss: 0.4916 - val_acc: 0.8180
10000/10000 [==============================] - 4s 435us/step
[9 2 1 1 6 1 6 6 5 7] [9 2 1 1 6 1 4 6 5 7]
yansongdeMacBook-Pro:TFAPP yss$
继续阅读与本文标签相同的文章
下一篇 :
微信小程序总结篇
-
天猫精灵发布“智慧屏”新品 将与平头哥共同定制语音芯片
2026-05-18栏目: 教程
-
【新手小白攻略】企业如何选择阿里云服务器配置
2026-05-18栏目: 教程
-
IntelliJ IDEA 2019 控制台中文乱码问题
2026-05-18栏目: 教程
-
超干货!奇点云《数智商业论坛》金句频出,引爆首日云栖大会
2026-05-18栏目: 教程
-
阿里董事局主席张勇:大数据和算力就是数字经济时代的石油和发动机
2026-05-18栏目: 教程
