深度学习解决手写数字的图片识别

小编 2026-06-29 阅读:1209 评论:0
本篇使用TensorFlow框架,利用MNIST手写数字数据集来演示深度学习的入门概念。其训练集共有60000个样本(图片和标签),测试集有10000个样本。手写数字的图片都是尺寸为28*28的二值图:我们先导入必要的库:import t...

本篇使用TensorFlow框架,利用MNIST手写数字数据集来演示深度学习的入门概念。其训练集共有60000个样本(图片和标签),测试集有10000个样本。手写数字的图片都是尺寸为28*28的二值图:

我们先导入必要的库:

import  tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import os

设置全连接神经网络的参数:神经网络的结构为784*500*10 (输入层784节点,1层500个节点的隐藏层,除输出层外每层的激活函数都使用ReLU, 输出层10个节点, 最后使用tf.argmax()函数求出输出层节点中最大的数的索引,范围0~9,该索引值即为手写数字的估计值)

注:上述图片仅做示意,每层节点数,以及隐藏层的层数以代码为准

#模型路径
MODEL_SAVE_PATH ="/model_path/"
MODEL_NAME = "MNIST_model1.ckpt"
INPUT_NODE = 28*28 #图片28*28像素,展平为784=28*28个输入节点
OUTPUT_NODE = 10 #输出特征为10个,对应0~9的量
BATCH_SIZE  =100 # 训练批次的size
LEARNING_RATE_BASE = 0.8 #基础学习率
LEARNING_RATE_DECAY = 0.99 #学习率缩减系数
REGULARIZATION_RATE = 0.0001 # 正则率
TRAINING_STEPS = 30000 #总的训练步数
MOVING_AVERAGE_DECAY = 0.99 #移动平均缩减系数
#神经网络的结构,784*500*10 (输入层784节点,1层500个节点的隐藏层,输出层10个节点)
layer_dimension = [INPUT_NODE,500,OUTPUT_NODE] 
#也可以是多个隐藏层,如 layer_dimension = [INPUT_NODE,50,100,20,UTPUT_NODE] 
n_layers = len(layer_dimension) #神经网络总的层数

前向传播:

def inference(input_x, avg_class, reuse = True): 
    '''Forward propagation'''
    current_layer = input_x
    in_dimension = INPUT_NODE
    with tf.variable_scope("layers", reuse =reuse):
        for i in range(1, n_layers):#循环
            out_dimension = layer_dimension[i] # weight, bias can't be local variables, or will not be updated!!!!!!!!!
            weight = tf.get_variable("weight_"+str(i), [in_dimension, out_dimension], initializer = tf.truncated_normal_initializer(stddev = 0.1))
            tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)(weight))# L2 regularization
            bias = tf.get_variable("bias_"+str(i), [out_dimension], initializer = tf.constant_initializer(0.0))
           
            if avg_class == None:
                if i == n_layers - 1:
                    current_layer = tf.matmul(current_layer, weight) + bias
                else:
                    current_layer = tf.nn.relu(tf.matmul(current_layer, weight) +bias)
            else:
                if i == n_layers - 1:
                    current_layer = tf.matmul(current_layer, avg_class.average(weight)) + avg_class.average(bias)
                else:
                    current_layer = tf.nn.relu(tf.matmul(current_layer, avg_class.average(weight)) + avg_class.average(bias))
            in_dimension = out_dimension
    return current_layer # output

神经网络训练,即反向传播,以梯度下降算法优化各个权重W张量和各个偏置B张量。

def train(mnist):
    '''training'''
    x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
    y_ = tf.placeholder(tf.float32,[None, OUTPUT_NODE], name = 'y-input')
    y = inference(x, None, reuse = False)
    global_step = tf.Variable(0, trainable = False)
    #滑动平均模型
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables()) # moving average applied
    average_y = inference(x, variable_averages, reuse = True)
   
    # loss
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = y, labels = tf.argmax(y_, 1))
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    tf.add_to_collection('losses', cross_entropy_mean)
    loss = tf.add_n(tf.get_collection('losses'))
    #loss = cross_entropy_mean
    #learning rate with decay
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step,mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY, staircase = True)
    #learning_rate = 0.01
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step = global_step)
    train_op = tf.group(train_step, variables_averages_op)
    correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))#tf.cast做类型转换
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        validate_feed  = {x: mnist.validation.images, y_ : mnist.validation.labels}
        test_feed        = {x: mnist.test.images, y_ : mnist.test.labels}
        steps = [] # only for plot
        accs = [] # only for plot
        losses = [] # only for plot
        for i in range(TRAINING_STEPS):
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict = {x : xs, y_: ys})
            if i % 200  == 0:
                validate_acc = sess.run(accuracy, feed_dict = validate_feed)
                steps.append(step); accs.append(validate_acc*100); losses.append(loss_value) # only for plot
                #print("After %d training steps, validation accuracy using average model on training batch is %g%%, loss on training batch is %g"%(step, validate_acc*100,loss_value))
                #saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step =global_step)
               
        test_acc = sess.run(accuracy, feed_dict = test_feed)
        print("After %d training steps, test accuracy using average model is %g%%"%
              (TRAINING_STEPS, test_acc*100))
        
        #仅用于tensorboard
        writer = tf.summary.FileWriter("E://TensorBoard//test",sess.graph)
        
        #保存神经网络模型
        saver = tf.train.Saver()
        saver.save(sess, r"E:Python36my tensorflowckpt filesmode_mnist.ckpt")
    
    #only for plot#不是必要步骤
    from matplotlib import pyplot as plt
    import matplotlib.ticker as mtick
    plt.subplot(211)
    plt.plot(steps, losses,color="red")
    plt.scatter(steps, losses,s=20,color="red")
    plt.xlabel("step"); plt.ylabel("loss(including L2 regularization loss) on training set")
    plt.subplot(212)
    plt.plot(steps, accs,color="green")
    plt.scatter(steps, accs,s=20,color="green")
    yticks = mtick.FormatStrFormatter("%.3f%%")
    plt.gca().yaxis.set_major_formatter(yticks)
    plt.xlabel("step"); plt.ylabel("accuracy on training set")
    plt.show()

真正加载MNIST数据集,并训练模型:

def main(argv = None):
    mnist = input_data.read_data_sets(r"E:Python36my tensorflowMNIST_data",one_hot =True)
    train(mnist)

if __name__ == "__main__":
    tf.app.run()

可以看出5000步以后,模型已大致收敛:

30000步迭代之后,在测试集上的准确率已高达98.5%。

After 30000 training steps, test accuracy using average model is 98.5%

下面我们利用已训练好的模型做预测:

import  tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import matplotlib.pyplot as plt
INPUT_NODE =    28*28
OUTPUT_NODE = 10
layer_dimension = [INPUT_NODE,500,OUTPUT_NODE]
n_layers = len(layer_dimension)
def inference(input_x, avg_class, reuse = True): 
    current_layer = input_x
    in_dimension = INPUT_NODE
    with tf.variable_scope("layers", reuse =reuse):
        for i in range(1, n_layers):
            out_dimension = layer_dimension[i] # weight, bias can't be local variables, or will not be updated!!!!!!!!!
            w = tf.get_variable("weight_"+str(i), [in_dimension, out_dimension])
            b = tf.get_variable("bias_"+str(i), [out_dimension])
            if avg_class == None:
                if i == n_layers - 1:
                    current_layer = tf.matmul(current_layer, w) + b
                else:
                    current_layer = tf.nn.relu(tf.matmul(current_layer, w) +b)
            in_dimension = out_dimension
    return current_layer # output
with tf.variable_scope("layers", reuse =tf.AUTO_REUSE): ####!!!!!!!!
    weight1 =  tf.get_variable("weight_1", [INPUT_NODE, 500])
    bias1= tf.get_variable("bias_1", [500])
    weight2 =  tf.get_variable("weight_2", [500, OUTPUT_NODE])
    bias2= tf.get_variable("bias_2", [OUTPUT_NODE])
#saver = tf.train.import_meta_graph(r"E:Python36my tensorflowckpt filesmode_mnist.ckpt.meta")
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, r"E:Python36my tensorflowckpt filesmode_mnist.ckpt")
    #graph = tf.get_default_graph()
    #for v in tf.global_variables():
        #print(v.name)
        #print(v.eval())
    image_rawdata = tf.gfile.FastGFile(r"E:Python36MNIST picture	est54.jpg","rb").read()
    img_data = tf.image.decode_jpeg(image_rawdata)
    if img_data.dtype != tf.float32:
        img_data = tf.image.convert_image_dtype(img_data, dtype = tf.float32)
    image_data_shaped2 =  tf.reshape(img_data,(1,INPUT_NODE))# tftensor
    input_x =  image_data_shaped2
    y = inference(input_x, avg_class=None, reuse = True)
    print("y: ", sess.run(y))
    print( "predict number: ", sess.run(tf.argmax(y, 1)))#output label
 
    image_data = img_data.eval() # return a numpy array
    image_data_shaped1 = image_data.reshape(image_data.shape[0],image_data.shape[1])#numpy array
    #print(image_data_shaped1)
    plt.imshow(image_data_shaped1,cmap='gray')
    plt.show()

注意,要保持神经网络的结构和训练时的一致。我们加载已训练好的模型,用它来预测测试集中的第54张图:

输出结果是:

predict number:  [6]

正确!

如果想要预测我们自己拍的照片,记得须先将照片转化为28*28的二值图, 用openCV实现起来很简单,不再赘述。

版权声明

本文仅代表作者观点,不代表百度立场。
本文系作者授权百度百家发表,未经许可,不得转载。

热门文章
  • 机房智能化温湿度解决方式之POE供电以太网温湿度传感器

    机房智能化温湿度解决方式之POE供电以太网温湿度传感器
    机房智能化温湿度解决方式之POE供电以太网温湿度传感器 北京盈创力和电子科技有限公司 智能型TCP网口温湿度记录仪 北京IP网络温湿度记录仪厂家,北京盈创力和 北京智能型TCP网口温湿度记录仪IP网络温湿度记录仪是一种新型的基于TCP/IP协议双绞线以太网标准温湿度采集模块,利用它可以实现现场温度值、相对湿度值的采集,同时利用其自身的RJ45通信接口可以方便地和机房监控主机或交换机集线器进行联网。 工作于-40℃~85℃工业级带...
  • Sequential Monte Carlo Methods (SMC) 序列蒙特卡洛/粒子滤波/Bootstrap Filtering

    Sequential Monte Carlo Methods (SMC) 序列蒙特卡洛/粒子滤波/Bootstrap Filtering
    Problem Statement 我们考虑一个具有马尔可夫性质、非线性、非高斯的状态空间模型(State Space Model):对于一个时间序列上的观测结果{yt,t∈N}\\{ y_t , t \\in N \\}{yt​,t∈N},我们认为每个观测结果yty_tyt​的生成依赖于一个无法直接观察的隐变量xt∈{xt,t∈N}x_t \\in \\{x_t , t \\in N \\}xt​∈{xt​,t∈N},即:p(...
  • HTTP状态保持的原理

    HTTP状态保持的原理
    a)在用户登录之后,浏览器返回响应的时候会在响应中添加上cookieb)浏览器接收到cookie之后会自动保存c)当用户再次请求同一服务器中的其他网页的时候,浏览器会自动带上之前保存的cookied)服务接收到请求之后可以请 request 对象中取到cookie 判断当前用户是否登录  Http是无状态的,就是连接时数据互通,关闭后...
  • Hive 系统函数及示例

    Hive 系统函数及示例
    查看所有系统函数 show functions; 函数分类 内置函数【系统函数】 数学函数: floor、round、ceil、cos、log2等 字符串函数: length、reverse、trim、lower、get_json_object、repeat等 收集函数: size 转换函数: cast 日期函数: year、month、datediff、date、date_add等 条件函数: coalesce、case…w...
  • CSRF的原理和防范措施

    CSRF的原理和防范措施
    a)攻击原理:i.用户C访问正常网站A时进行登录,浏览器保存A的cookieii.用户C再访问攻击网站B,网站B上有某个隐藏的链接或者图片标签会自动请求网站A的URL地址,例如表单提交,传指定的参数iii.而攻击网站B在访问网站A的时候,浏览器会自动带上网站A的cookieiv.所以网站A在接收到请求之后可判断当前用户是登录状态,所以...
标签列表