保存模型
saver=tf.train.Saver()
saver.save(sess,'ckpt/mnist.ckpt')
会在ckpt目录中生成4个文件:checkpoint, munist.ckpt.index, mnist.ckpt.meta, mnist.ckpt.data-00000-of-00001
参数
max_to_keep
saver=tf.train.Saver(max_to_keep=3)
saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
max_to_keep 的作用:
- 可以多次调用
saver.save
来存储sess,以保存最近的max_to_keep
个 sess - max_to_keep 默认为 5. 为None或0时,每次
saver.save
都会保存 - 所以写在迭代中,外加一些代码,可以保存迭代过程中,最好的
max_to_keep
个 sess
恢复模型
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph('ckpt/mnist.ckpt.meta')
saver.restore(sess,'ckpt/mnist.ckpt')
# 取 tensor
x=tf.get_default_graph().get_tensor_by_name('x:0')
keep_prob=tf.get_default_graph().get_tensor_by_name('keep_prob:0')
y_hat=tf.get_default_graph().get_tensor_by_name('Softmax:0')
# 然后模型就可以运行了
Y_test_predict=sess.run(y_hat,feed_dict={x:X_test,keep_prob:1})