tensorflow的学习笔记--反向传播

反向传播

训练模型参数,在所有的参数上用梯度下降,是NN模型在训练数据上的损失函数最小。

  1. 损失函数(loss): 预测值(y)与已知答案(y’)的差距

  2. 均方误差MSE

    使用tensorflow表示:
    loss=tf.reduce_mean(tf.square(y'-y))

  3. 反向传播训练方法,以减小loss值为优化目标。 有以下几种方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
train_step=tf.train.GradientDescentOptimizer(learnig_rate).minimize(loss) 
train_step=tf.train.MomentumOptimizer(learnig_rate,momentum).minimize(loss)
train_step=tf.train.AdamOptimizer(learnig_rate).minimize(loss)

```
其中`leaning_rate`代表学习率,决定每次更新的幅度。

### 实现代码

实现一个训练模型:

``` python
#coding:utf-8
import tensorflow as tf
import numpy as np
BATCH_SIZE=8
seed=23455


#基于seed产生随机数
rng=np.random.RandomState(seed)
#随机数返回32行2列的矩阵,表示32组,体积和重量,作为输入的数据集
X=rng.rand(32,2)

# 从X这个32x2的矩阵中,取出一行,判断如果和小于1 给Y复制1,如果不小于1 给Y赋值0
# 作为输入数据集的标签(正确答案)

Y=[[int(x0+x1<1)] for(x0,x1) in X]

print("X:",X)
print("Y:",Y)

# 定义神经王珞丹额输入和输出,定义前向的传播过程

x=tf.placeholder(tf.float32,shape=(None,2))
y_=tf.placeholder(tf.float32,shape=(None,1))

w1=tf.Variable(tf.random_normal([2,3],stddev=1,seed=1))
w2=tf.Variable(tf.random_normal([3,1],stddev=1,seed=1))

a=tf.matmul(x,w1)
y=tf.matmul(a,w2)

#定义损失函数
loss=tf.reduce_mean(tf.square(y-y_))
train_step=tf.train.GradientDescentOptimizer(0.001).minimize(loss)
# train_step=tf.train.MomentumOptimizer(learnig_rate,momentum).minimize(loss)
# train_step=tf.train.AdamOptimizer(learnig_rate).minimize(loss)

# 生成会话,训练STEPS轮
with tf.Session() as sess:
init_op=tf.global_variables_initializer()
sess.run(init_op)
# 输出目前(未经训练)的参数值
print("w1:",sess.run(w1))
print("w1:",sess.run(w2))
print("\n")

#训练模型
STEPS=3000
for i in range(STEPS):
start=(i*BATCH_SIZE)%32
end=start+BATCH_SIZE
sess.run(train_step,feed_dict={
x:X[start:end],y_:Y[start:end]
})
if i%500==0:
total_loss=sess.run(loss,feed_dict={x:X,y_:Y})
print("After %d training steps,loss on all data is %s"%(i,total_loss))

print("w1:",sess.run(w1))
print("w1:",sess.run(w2))
print("\n")

作者

付威

发布于

2019-03-31

更新于

2020-08-10

许可协议

评论