Shihanmax's blog
使用Tensotflow同时加载多个模型
使用单个模型时,一种模型的保存和加载的方式如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 输入/输出定义
x = tf.placeholder(dtype, name)
y = tf.placeholder(dtype, name)
# 权重定义
weight = tf.Variable(shape, dtype)
# op定义
output = some_operation(x, weight)
loss = tf.calc_loss(output, y)
train_op = optimizer.minimize(loss, name)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
# do some train 训练
saver.save(sess, "./model/model_path") # 保存模型
针对上述模型,恢复的方式如下:
1
2
3
4
5
6
7
saver = tf.train.Saver()
sess = tf.Session():
sess.run(tf.global_variables_initializer())
saver.restore(sess, "./model/model_path") # 将模型恢复到sess中
output = sess.run([output], feed_dict=feed_dict) # 使用恢复的模型进行预测
对单个模型来说,这么做没有问题,但如果我们训练了多个相同结构的模型,我们期待以如下方式恢复它们:
1
2
3
4
5
6
7
8
9
all_sessions = []
for i in range(model_nums):
saver = tf.train.Saver()
sess = tf.Session():
sess.run(tf.global_variables_initializer())
saver.restore(sess, "./model/model_path") # 将模型恢复到sess中
all_sessions.append(sess)
使用上述恢复的session进行预测:
1
2
3
all_result = []
for sess in all_sessions:
all_result.append(sess.run([output], feed_dict=feed_dict))
但这么做会导致参数错误,预测结果异常,原因是多个模型中的变量会发生冲突,原因是将所有的模型变量都加载到同一个线程的默认图中,解决方法是,针对不同的model使用不同的默认图:
1
2
3
4
5
6
7
8
9
10
class ImportGraph():
def __init__(self, loc):
self.graph = tf.Graph()
self.sess = tf.Session(graph=self.graph)
with self.graph.as_default():
saver = tf.train.import_meta_graph("./model/model_path" + '.meta')
saver.restore(self.sess, "./model/model_path")
def predict(self, data):
return self.sess.run([output], feed_dict=feed_dict)
上述方式是从博客看到的,在我的实验中,并有有成功地将多个模型恢复,我的恢复方式是:
1
2
3
4
5
6
class ImportGraph():
tf.reset_default_graph() # The default graph is a property of the current thread. 重置当前线程中的默认图
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
self.saver = tf.train.Saver()
self.saver.restore(self.sess, "./model/model_path")
重要的地方在于tf.reset_default_graph(),tf官方文档给出的解释是:
1
2
3
4
5
6
tf.reset_default_graph()
Defined in tensorflow/python/framework/ops.py.
Clears the default graph stack and resets the global default graph.
NOTE: The default graph is a property of the current thread. This function applies only to the current thread. Calling this function while a tf.Session or tf.InteractiveSession is active will result in undefined behavior. Using any previously created tf.Operation or tf.Tensor objects after calling this function will result in undefined behavior.