Tensorflow的学习笔记--滑动平均

滑动平均

滑动平均,又叫影子值,记录了每个参数一段时间内过往值的平均,增加了模型的泛化性。

针对所有参数:w和b。(像是给参数加了影子,参数变化,影子缓慢追随),具体的计算公式如下:

影子=衰减率*影子+(1-衰减率)* 参数

影子初值=参数初值

衰减率=

例如:

MOVING_AVERAGE_DECAY为0.99,参数w1为0,轮数global_step为0,w1的滑动平均值为0,参数w1更新为1.则:

w1滑动平均值=min(0.99,1/10)*0+(1-min(0.99,1/10))*1=0.9
轮数global_step为100是,参数w1更新为10,则:

w1滑动平均值=min(0.99,101/110)*0.9+(1-min(0.99,101/110))*10=0.826+0.818=1.644

再次运行:

w1滑动平均值=min(0.99,101/110)*1.644+(1-min(0.99,101/110))*10=2.328

再次运行:

w1平均值=2.956

使用tensorflow表示如下:

1
2
3
4
5
6
7
8
9
10
ema=tf.train.ExponentialMovingAverage(衰减率MOVING_AVERAGE_DECAY,当前轮数global_step)   

ema_op=ema.apply([])
ema_op=ema.apply(tf.trainable_variables()) # 每运行此据,所有待优化的参数求滑动平均

with tf.control_dependencies([train_step,ema_op]):
train_op=tf.no_op(name='train')

ema.average(查看参数的滑动平均)

我们使用代码来使用模拟上面的计算逻辑:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#coding:utf-8
# 设损失函数 loss=(w+1)^2 令 w初值是常数5,反向传播就是求最优w,即求最小的loss对应的w值
import tensorflow as tf


# 1. 定义变量及滑动平均类
# 定义一个32位浮点变量,初始值为0.0 这个代码就是不断更新w1参数,优化w1参数,滑动平均做了个w1的影子

w1=tf.Variable(0,dtype=tf.float32)

# 定义num_updates(NN的迭代轮数),初始值为0,不可被优化。
global_step=tf.Variable(0,trainable=False)

# 实例化滑动平均类,给删减率为0.99,当前轮数global_step
MOVING_VERAGE_DECAY=0.99
ema=tf.train.ExponentialMovingAverage(MOVING_VERAGE_DECAY,global_step)

#ema.applu后的括号里是更新列表,每次运行sess.run(ema_op)时,对更新列表中的元素求滑动平均值

#在实际应用中会使用tf.trainable_variables()自动将所有的待训练的参数汇总为列表

# ema_op=ema.apply([w1])

ema_op=ema.apply(tf.trainable_variables())

with tf.Session() as sess:
# 初始化
init_op =tf.global_variables_initializer()
sess.run(init_op)
# 用ema.average(w1)获取w1滑动平均值,(要运行多个节点,作为列表中的元素列出,写在sess)
#打印出当前的参数w1和w2
print(sess.run([w1,ema.average(w1)]))

# 参数w1的值赋值为1
sess.run(tf.assign(w1,1))
sess.run(ema_op)
print(sess.run([w1,ema.average(w1)]))

# 更新的step和w1 的值,模拟出100轮迭代后,参数w1变为10
sess.run(tf.assign(global_step,100))
sess.run(tf.assign(w1,10))
sess.run(ema_op)
print(sess.run([w1,ema.average(w1)]))

for x in range(40):
# 每次sess.run会更新一次w1的滑动平均值
sess.run(ema_op)
print(sess.run([w1,ema.average(w1)]))


打印结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
[0.0, 0.0]
[1.0, 0.9]
[10.0, 1.6445453]
[10.0, 2.3281732]
[10.0, 2.955868]
[10.0, 3.532206]
[10.0, 4.061389]
[10.0, 4.547275]
[10.0, 4.9934072]
[10.0, 5.4030375]
[10.0, 5.7791524]
[10.0, 6.1244946]
[10.0, 6.4415812]
[10.0, 6.7327247]
[10.0, 7.000047]
[10.0, 7.2454977]
[10.0, 7.470866]
[10.0, 7.6777954]
[10.0, 7.867794]
[10.0, 8.042247]
[10.0, 8.202427]
[10.0, 8.349501]
[10.0, 8.484542]
[10.0, 8.608534]
[10.0, 8.722381]
[10.0, 8.826913]
[10.0, 8.922893]
[10.0, 9.01102]
[10.0, 9.091936]
[10.0, 9.166232]
[10.0, 9.234449]
[10.0, 9.297086]
[10.0, 9.354597]
[10.0, 9.407403]
[10.0, 9.455888]
[10.0, 9.500406]
[10.0, 9.541282]
[10.0, 9.578814]
[10.0, 9.613275]
[10.0, 9.644916]
[10.0, 9.673968]
[10.0, 9.700644]
[10.0, 9.725137]

可以看到平均值一直趋近于w1。