网创优客建站品牌官网
为成都网站建设公司企业提供高品质网站建设
热线:028-86922220
成都专业网站建设公司

定制建站费用3500元

符合中小企业对网站设计、功能常规化式的企业展示型网站建设

成都品牌网站建设

品牌网站建设费用6000元

本套餐主要针对企业品牌型网站、中高端设计、前端互动体验...

成都商城网站建设

商城网站建设费用8000元

商城网站建设因基本功能的需求不同费用上面也有很大的差别...

成都微信网站建设

手机微信网站建站3000元

手机微信网站开发、微信官网、微信商城网站...

建站知识

当前位置:首页 > 建站知识

Tensorflow之Saver的用法详解-创新互联

Saver的用法

成都创新互联-专业网站定制、快速模板网站建设、高性价比常宁网站开发、企业建站全套包干低至880元,成熟完善的模板库,直接使用。一站式常宁网站制作公司更省心,省钱,快速模板网站建设找我们,业务覆盖常宁地区。费用合理售后完善,十年实体公司更值得信赖。

1. Saver的背景介绍

我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。


Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。


只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。


为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。


2. Saver的实例

下面以一个例子来讲述如何使用Saver类 


import tensorflow as tf 
import numpy as np  
x = tf.placeholder(tf.float32, shape=[None, 1]) 
y = 4 * x + 4  
w = tf.Variable(tf.random_normal([1], -1, 1)) 
b = tf.Variable(tf.zeros([1])) 
y_predict = w * x + b 
loss = tf.reduce_mean(tf.square(y - y_predict)) 
optimizer = tf.train.GradientDescentOptimizer(0.5) 
train = optimizer.minimize(loss)  
isTrain = False 
train_steps = 100 
checkpoint_steps = 50 
checkpoint_dir = ''  
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b 
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))  
with tf.Session() as sess: 
  sess.run(tf.initialize_all_variables()) 
  if isTrain: 
    for i in xrange(train_steps): 
      sess.run(train, feed_dict={x: x_data}) 
      if (i + 1) % checkpoint_steps == 0: 
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) 
  else: 
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
    if ckpt and ckpt.model_checkpoint_path: 
      saver.restore(sess, ckpt.model_checkpoint_path) 
    else: 
      pass 
    print(sess.run(w)) 
    print(sess.run(b)) 

当前文章:Tensorflow之Saver的用法详解-创新互联
分享路径:http://bjjierui.cn/article/psspo.html

其他资讯