TensorFlow 以静态图运行,因此想查看中间结果比较麻烦。本文以强化学习的 ppo 网络为例,结合代码注释提供一个思路。
首先是训练过程中模型的保存:
1 2 3 4 5 6 7 8 9 import tensorflow as tftf.train.write_graph(sess.graph_def, path, filename, as_text) saver = tf.train.Saver({var for var in tf.global_variables()}, max_to_keep=5 ) saver.restore(sess, ckpt.model_checkpoint_path) saver.save(sess, checkpoint_path)
保存的模型应该有三个文件:*.ckpt.index
,*.ckpt.meta
,*.ckpt.data-*
。之所以保存 *.pbtxt
,是因为我们查看模型中间层的时候需要名字,pbtxt
是可以直接查看的模型结构文件,方便我们查看。然后如下调用进行 inference 和显示结果。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 import tensorflow as tfimport numpy as npwith tf.Session() as sess: modelpath = r'../ppo/2/' saver = tf.train.import_meta_graph(modelpath + 'model.ckpt.meta' ) saver.restore(sess, tf.train.latest_checkpoint(modelpath)) graph = tf.get_default_graph() print('Successfully load the pre-trained model!' ) observation_data = np.array(np.load('../ppo/2/observation.npy' )) observation_data = observation.reshape((1 ,197 ,1 )) print(observation_data.shape) in_observation = graph.get_tensor_by_name("ppo/observation:0" ) print(in_observation.shape) out_neglogps = graph.get_tensor_by_name("ppo/neglogps:0" ) out_actions = graph.get_tensor_by_name("ppo/actions:0" ) out_values = graph.get_tensor_by_name("ppo/values:0" ) out_fetches = [out_neglogps, out_actions, out_values] mlp_fc0 = graph.get_tensor_by_name("ppo/model/vf/add:0" ) mid_fetches = [mlp_fc0] fetches = out_fetches + mid_fetches output = sess.run(fetches, feed_dict={in_observation: observation_data}) for out in output: print("out: " , out)