[问题] TensorFlow网络参数储存问题

楼主: Paudse (SICO)   2018-04-23 19:44:55
我想将NN的参数储存下次继续学习
但发现储存时似乎发生问题
叫出的参数每次都一样
我的程式结构如下
还请强者指教
谢谢
class DQ:
def __init__():
self.sess = tf.Session()
saver = tf.train.Saver()
self.sess.run(tf.global_variables_initializer())
with tf.Session() as sess:
if os.path.isfile("save_net.ckpt.index"):
saver.restore(sess, "save_net.ckpt")
print('File exists, loading previous data!')
else:
# save_path = self.saver.save(self.sess, "save_net.ckpt")
print('File does not exist, starting fresh')
def _build_net(self):
省略
def learn(self,save_step):
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
save_path = saver.save(sess, "save_net.ckpt")
print('save parameters')
作者: jameszhan (123)   2018-04-24 08:25:00
用with的话 你要跟训练在同一个with里吧不然你存参数时候的session应该没东西最后的with那行等于一个新的session 你初始化参数后就直接存 中间应该要有训练过程还是你只是想问为何初始化后存的参数会一样?
楼主: Paudse (SICO)   2018-04-24 10:38:00
恩恩 对阿 初始化参数都会一样 是为什么呢 谢谢甚至我把之前处存的ckpt档都删了 跑出来的参数还是一样我朋友后来说他会存很多个ckpt 可以设定几个epoch存一次要restore最后一个ckpt才是最接近训练最后的结果
作者: jameszhan (123)   2018-04-24 13:13:00
当然啊 你可以看一下tensorflow的文件saver(max=n) 可以设定要保留几个档案
楼主: Paudse (SICO)   2018-04-24 13:31:00
我现在用model_file=tf.train.latest_checkpoint('ckpt/')saver.restore(sess,model_file)但还是都从最一开始的开始训练 不知道是怎么回事另外也已经改成saver=tf.train.Saver(max_to_keep=1)
作者: goldflower (金色小黄花)   2018-04-24 13:34:00
=1代表只存一个吧https://stackoverflow.com/questions/48324072/照这个做应该就好了
作者: jameszhan (123)   2018-04-24 13:48:00
参数初始化的部分可以看这个truncated_normal_initiali从最一开始的训练或许是你本来就只有一开始才有存?直接去github看别人完整的code比较快 看人家怎么用的
作者: chchan1111 (123)   2018-04-24 13:54:00
对了 你这个code是不是怪怪的 你一开始就有实体化session了 为何后面还要with tf.se....
楼主: Paudse (SICO)   2018-04-24 13:54:00
感谢各位的建议 我后来发现 我原本把放在saver.restore
作者: chchan1111 (123)   2018-04-24 13:55:00
直接self.sess.run就可以了 不然你等于又实体化一个session
楼主: Paudse (SICO)   2018-04-24 13:55:00
一个if判断句里面检查有没有之前存的ckpt 但一值失败我后来把saver.restore拿出那个if结构外就可以了虽然不太懂为何会有这个问题 不过现在OK了 感谢大家!!大大们说的没错 我后来把with拿掉了

Links booklink

Contact Us: admin [ a t ] ucptt.com