tensorflow的学习笔记--滑动平均

付威     2019-04-04   2688   7min  

滑动平均

滑动平均,又叫影子值,记录了每个参数一段时间内过往值的平均,增加了模型的泛化性。

针对所有参数:w和b。(像是给参数加了影子,参数变化,影子缓慢追随),具体的计算公式如下:

影子=衰减率*影子+(1-衰减率)* 参数

影子初值=参数初值

衰减率=

例如:

MOVING_AVERAGE_DECAY为0.99,参数w1为0,轮数global_step为0,w1的滑动平均值为0,参数w1更新为1.则:

w1滑动平均值=min(0.99,1/10)*0+(1-min(0.99,1/10))*1=0.9
轮数global_step为100是,参数w1更新为10,则:

w1滑动平均值=min(0.99,101/110)*0.9+(1-min(0.99,101/110))*10=0.826+0.818=1.644

再次运行:

w1滑动平均值=min(0.99,101/110)*1.644+(1-min(0.99,101/110))*10=2.328

再次运行:

w1平均值=2.956

使用tensorflow表示如下:

ema=tf.train.ExponentialMovingAverage(衰减率MOVING_AVERAGE_DECAY,当前轮数global_step)   

ema_op=ema.apply([])
ema_op=ema.apply(tf.trainable_variables()) # 每运行此据,所有待优化的参数求滑动平均    

with tf.control_dependencies([train_step,ema_op]):
    train_op=tf.no_op(name='train') 

ema.average(查看参数的滑动平均)

我们使用代码来使用模拟上面的计算逻辑:

#coding:utf-8
# 设损失函数 loss=(w+1)^2 令 w初值是常数5,反向传播就是求最优w,即求最小的loss对应的w值  
import tensorflow as tf 


# 1. 定义变量及滑动平均类 
# 定义一个32位浮点变量,初始值为0.0  这个代码就是不断更新w1参数,优化w1参数,滑动平均做了个w1的影子   

w1=tf.Variable(0,dtype=tf.float32)  

# 定义num_updates(NN的迭代轮数),初始值为0,不可被优化。 
global_step=tf.Variable(0,trainable=False) 

# 实例化滑动平均类,给删减率为0.99,当前轮数global_step
MOVING_VERAGE_DECAY=0.99 
ema=tf.train.ExponentialMovingAverage(MOVING_VERAGE_DECAY,global_step)

#ema.applu后的括号里是更新列表,每次运行sess.run(ema_op)时,对更新列表中的元素求滑动平均值  

#在实际应用中会使用tf.trainable_variables()自动将所有的待训练的参数汇总为列表  

# ema_op=ema.apply([w1])  

ema_op=ema.apply(tf.trainable_variables()) 

with tf.Session() as sess:   
    # 初始化  
    init_op =tf.global_variables_initializer()
    sess.run(init_op)
    # 用ema.average(w1)获取w1滑动平均值,(要运行多个节点,作为列表中的元素列出,写在sess)
    #打印出当前的参数w1和w2
    print(sess.run([w1,ema.average(w1)]))

    # 参数w1的值赋值为1   
    sess.run(tf.assign(w1,1))
    sess.run(ema_op)
    print(sess.run([w1,ema.average(w1)]))

    # 更新的step和w1 的值,模拟出100轮迭代后,参数w1变为10
    sess.run(tf.assign(global_step,100))
    sess.run(tf.assign(w1,10))
    sess.run(ema_op)
    print(sess.run([w1,ema.average(w1)]))
    
    for x in range(40):
       # 每次sess.run会更新一次w1的滑动平均值
        sess.run(ema_op)
        print(sess.run([w1,ema.average(w1)]))
   

打印结果:

[0.0, 0.0]
[1.0, 0.9]
[10.0, 1.6445453]
[10.0, 2.3281732]
[10.0, 2.955868]
[10.0, 3.532206]
[10.0, 4.061389]
[10.0, 4.547275]
[10.0, 4.9934072]
[10.0, 5.4030375]
[10.0, 5.7791524]
[10.0, 6.1244946]
[10.0, 6.4415812]
[10.0, 6.7327247]
[10.0, 7.000047]
[10.0, 7.2454977]
[10.0, 7.470866]
[10.0, 7.6777954]
[10.0, 7.867794]
[10.0, 8.042247]
[10.0, 8.202427]
[10.0, 8.349501]
[10.0, 8.484542]
[10.0, 8.608534]
[10.0, 8.722381]
[10.0, 8.826913]
[10.0, 8.922893]
[10.0, 9.01102]
[10.0, 9.091936]
[10.0, 9.166232]
[10.0, 9.234449]
[10.0, 9.297086]
[10.0, 9.354597]
[10.0, 9.407403]
[10.0, 9.455888]
[10.0, 9.500406]
[10.0, 9.541282]
[10.0, 9.578814]
[10.0, 9.613275]
[10.0, 9.644916]
[10.0, 9.673968]
[10.0, 9.700644]
[10.0, 9.725137]

可以看到平均值一直趋近于w1。

(本文完)

作者:付威

博客地址:http://blog.laofu.online

如果觉得对您有帮助,可以下方的RSS订阅,谢谢合作

如有任何知识产权、版权问题或理论错误,还请指正。

本文是付威的网络博客原创,自由转载-非商用-非衍生-保持署名,请遵循:创意共享3.0许可证

交流请加群113249828: 点击加群   或发我邮件 laofu_online@163.com

付威

获得最新的博主文章,请关注上方公众号