使用mnist的数据集实现对手写数字识别
感慨下,学了这么久终于有能有点实战的东西了,这篇文章本想写于2019-04-14
,可是担心会对学习的进度产生影响,就一直拖后。所以就再今天(2019-04-24)开始去写这篇实战的文章。
目标
写这篇博客的目的就是为了写一个识别手写程序的方案,首先我们准备了mnist
(mnist数据集)的数据集,和一张手写的图片
{:height=”600px” width=”600px”}
为了测试,增加了一些干扰线:
{:height=”600px” width=”600px”}
我们知道,在mnist的数据集中,是对单个28*28像素图片进行处理,而且是黑底白字的数字图片,所以为了能够使用mnist的数据样本训练,我们也需要对手写的图片处理。
具体的处理方法我们可以分为:
1.二值化
2.去噪声
3.裁剪
4.缩放28*28的图像
5.训练样本
6.识别结果
二值化和去噪声
对于这个图片的二值化可以使用opencv相关处理模块,二值化是直接把图片编程只有黑白两种颜色的图:
{:height=”600px” width=”600px”}
我们可以看到上面有几个 使用均值模糊可以去除干扰线和噪点:
{:height=”600px” width=”600px”}
对模糊后的图片再进行二值化,得到结果如下:
{:height=”600px” width=”600px”}
对应的代码如下:
1 | import cv2 as cv |
反向传播代码:minst_backward.py
# coding:utf-8
import tensorflow as tf
import minst_forward
import os
from tensorflow.examples.tutorials.mnist import input_data
BATCH_SIZE = 200
LEARNING_RATE_BASE = 0.1
LEARNING_RATE_DECAY = 0.99
REGULARIZER = 0.0001
STEPS = 100000
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "minst_model"
def backword(minst):
x = tf.placeholder(tf.float32, (None, minst_forward.INPUT_NODE))
y_ = tf.placeholder(tf.float32, (None, minst_forward.OUTPUT_NODE))
y = minst_forward.forward(x, REGULARIZER)
global_step = tf.Variable(0, trainable=False)
# 使用交叉熵的形式定义损失函数
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cem = tf.reduce_mean(ce)
#正则化防止过拟合
loss = cem + tf.add_n(tf.get_collection("losses"))
# 使用指数衰减的学习率,实现更好的学习率
learning_rate=tf.train.exponential_decay(LEARNING_RATE_BASE,
global_step,
minst.train.num_examples/BATCH_SIZE,
LEARNING_RATE_DECAY,
staircase=True)
# 使用反向传播训练方法,以减小loss值为优化目标
train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step)
# 对所有的参数都使用滑动平均,更准确的定义模型。
ema=tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)
emp_op=ema.apply(tf.trainable_variables())
with tf.control_dependencies([train_step,emp_op]):
train_op=tf.no_op(name="train")
saver=tf.train.Saver()
with tf.Session() as sess:
init_op=tf.global_variables_initializer()
sess.run(init_op)
ckpt=tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
for i in range(STEPS):
xs,ys =minst.train.next_batch(BATCH_SIZE)
_,loss_value,step=sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys})#喂入神经网络数据
if i%1000==0:
print("After %d training step(s),loss on training batch is %g."%(step,loss_value))
saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)
def main():
minst=input_data.read_data_sets("./data",one_hot=True)# 读取mnist的数据
backword(minst)
if __name__=="__main__":
main()
运行backword.py的程序,使用mnist的数据,进行10w的训练,得到数据模型如下:
利用模型获得最大可能的预测值,代码如下:
import imageUtils
import tensorflow as tf
import minst_forward
import minst_backward
import numpy as np
import cv2 as cv
def restore_model(imgArr):
with tf.Graph().as_default() as tg:
x = tf.placeholder(tf.float32, [None, minst_forward.INPUT_NODE])
y = minst_forward.forward(x, None)
preValue = tf.argmax(y, 1) # 得到概率最大的预测值
# 实现滑动平均模型,参数的MOVING_AVERAGE_DECAY用于控制模型的速度
variable_averages = tf.train.ExponentialMovingAverage(minst_backward.MOVING_AVERAGE_DECAY)
variable_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variable_to_restore)
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(minst_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
preValue = sess.run(preValue, feed_dict={x: imgArr})
return preValue
else:
print("No checkpoint file found")
return -1
if __name__ == "__main__":
imgArr = imageUtils.getReadyImage("./pic/01.png")
result = []
i = 0;
for img in imgArr:
try:
im_arr = np.array(img)
nm_arr = im_arr.reshape([1, 784])
nm_arr = nm_arr.astype(np.float32)
img_ready = np.multiply(nm_arr, 1.0 / 255.0)
# testPicArr=pre_pic(testPic)
preValue = restore_model(img_ready)
result.append(str(preValue[0]))
except Exception as ee:
print(ee)
print("".join(result))
识别结果如下:
我们再输入一个正常的手写图片:
输出结果:
识别大部分数据正确,还有有识别错误的现象,因为训练的模型的差异和训练次数过少,所以会导致识别出错的现象。
使用mnist的数据集实现对手写数字识别
http://blog.laofu.online/2019-04-24-tensorflow-mnist-train02/