【TensorFlow6】持久化



2018年10月29日    Author:Guofei

文章归类: TensorFlow    文章编号: 286

版权声明:本文作者是郭飞。转载随意,但需要标明原文链接,并通知本人
原文链接:https://www.guofei.site/2018/10/29/tf6.html


保存模型

saver=tf.train.Saver()
saver.save(sess,'ckpt/mnist.ckpt')

会在ckpt目录中生成4个文件:checkpoint, munist.ckpt.index, mnist.ckpt.meta, mnist.ckpt.data-00000-of-00001

参数

max_to_keep

saver=tf.train.Saver(max_to_keep=3)
saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

max_to_keep 的作用:

  1. 可以多次调用saver.save来存储sess,以保存最近的 max_to_keep 个 sess
  2. max_to_keep 默认为 5. 为None或0时,每次 saver.save 都会保存
  3. 所以写在迭代中,外加一些代码,可以保存迭代过程中,最好的 max_to_keep 个 sess

恢复模型

import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph('ckpt/mnist.ckpt.meta')
saver.restore(sess,'ckpt/mnist.ckpt')

# 取 tensor
x=tf.get_default_graph().get_tensor_by_name('x:0')
keep_prob=tf.get_default_graph().get_tensor_by_name('keep_prob:0')
y_hat=tf.get_default_graph().get_tensor_by_name('Softmax:0')

# 然后模型就可以运行了
Y_test_predict=sess.run(y_hat,feed_dict={x:X_test,keep_prob:1})

您的支持将鼓励我继续创作!