tensorflow的学习笔记--反向传播

付威     2019-03-31   1910   5min  

反向传播

训练模型参数,在所有的参数上用梯度下降,是NN模型在训练数据上的损失函数最小。

  1. 损失函数(loss): 预测值(y)与已知答案(y’)的差距

  2. 均方误差MSE

    使用tensorflow表示:
    loss=tf.reduce_mean(tf.square(y'-y))

  3. 反向传播训练方法,以减小loss值为优化目标。 有以下几种方法:

train_step=tf.train.GradientDescentOptimizer(learnig_rate).minimize(loss) 
train_step=tf.train.MomentumOptimizer(learnig_rate,momentum).minimize(loss) 
train_step=tf.train.AdamOptimizer(learnig_rate).minimize(loss) 

其中leaning_rate代表学习率,决定每次更新的幅度。

实现代码

实现一个训练模型:

#coding:utf-8
import tensorflow as tf
import numpy as np
BATCH_SIZE=8
seed=23455


#基于seed产生随机数
rng=np.random.RandomState(seed)
#随机数返回32行2列的矩阵,表示32组,体积和重量,作为输入的数据集
X=rng.rand(32,2)  

# 从X这个32x2的矩阵中,取出一行,判断如果和小于1 给Y复制1,如果不小于1 给Y赋值0
# 作为输入数据集的标签(正确答案)  

Y=[[int(x0+x1<1)] for(x0,x1) in X]

print("X:",X)
print("Y:",Y)

# 定义神经王珞丹额输入和输出,定义前向的传播过程

x=tf.placeholder(tf.float32,shape=(None,2))
y_=tf.placeholder(tf.float32,shape=(None,1))

w1=tf.Variable(tf.random_normal([2,3],stddev=1,seed=1))
w2=tf.Variable(tf.random_normal([3,1],stddev=1,seed=1))  

a=tf.matmul(x,w1)
y=tf.matmul(a,w2)

#定义损失函数 
loss=tf.reduce_mean(tf.square(y-y_))
train_step=tf.train.GradientDescentOptimizer(0.001).minimize(loss)
# train_step=tf.train.MomentumOptimizer(learnig_rate,momentum).minimize(loss) 
# train_step=tf.train.AdamOptimizer(learnig_rate).minimize(loss) 

# 生成会话,训练STEPS轮
with tf.Session() as sess: 
    init_op=tf.global_variables_initializer()
    sess.run(init_op)
    # 输出目前(未经训练)的参数值 
    print("w1:",sess.run(w1))
    print("w1:",sess.run(w2))
    print("\n")
    
    #训练模型
    STEPS=3000
    for i in range(STEPS):
        start=(i*BATCH_SIZE)%32
        end=start+BATCH_SIZE
        sess.run(train_step,feed_dict={
            x:X[start:end],y_:Y[start:end]
        })
        if i%500==0:
            total_loss=sess.run(loss,feed_dict={x:X,y_:Y})
            print("After %d training steps,loss on all data is %s"%(i,total_loss))
      
    print("w1:",sess.run(w1))
    print("w1:",sess.run(w2))
    print("\n")

(本文完)

作者:付威

博客地址:http://blog.laofu.online

如果觉得对您有帮助,可以下方的RSS订阅,谢谢合作

如有任何知识产权、版权问题或理论错误,还请指正。

本文是付威的网络博客原创,自由转载-非商用-非衍生-保持署名,请遵循:创意共享3.0许可证

交流请加群113249828: 点击加群   或发我邮件 laofu_online@163.com

付威

获得最新的博主文章,请关注上方公众号