Shihanmax's blog

< Back

使用Tensotflow同时加载多个模型

使用单个模型时,一种模型的保存和加载的方式如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 输入/输出定义
x = tf.placeholder(dtype, name)
y = tf.placeholder(dtype, name)

# 权重定义
weight = tf.Variable(shape, dtype)

# op定义
output = some_operation(x, weight)
loss = tf.calc_loss(output, y)
train_op = optimizer.minimize(loss, name)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    
    # do some train 训练
    
    saver.save(sess, "./model/model_path")  # 保存模型

针对上述模型,恢复的方式如下:

1
2
3
4
5
6
7
saver = tf.train.Saver()

    sess = tf.Session():
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, "./model/model_path")  # 将模型恢复到sess中
        
    output = sess.run([output], feed_dict=feed_dict)  # 使用恢复的模型进行预测

对单个模型来说,这么做没有问题,但如果我们训练了多个相同结构的模型,我们期待以如下方式恢复它们:

1
2
3
4
5
6
7
8
9
all_sessions = []
for i in range(model_nums):
	saver = tf.train.Saver()

	sess = tf.Session():
	sess.run(tf.global_variables_initializer())
	saver.restore(sess, "./model/model_path")  # 将模型恢复到sess中

	all_sessions.append(sess)

使用上述恢复的session进行预测:

1
2
3
all_result = []
for sess in all_sessions:
	all_result.append(sess.run([output], feed_dict=feed_dict))

但这么做会导致参数错误,预测结果异常,原因是多个模型中的变量会发生冲突,原因是将所有的模型变量都加载到同一个线程的默认图中,解决方法是,针对不同的model使用不同的默认图:

1
2
3
4
5
6
7
8
9
10
class ImportGraph():
    def __init__(self, loc):
        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph)
        with self.graph.as_default():
            saver = tf.train.import_meta_graph("./model/model_path" + '.meta')
            saver.restore(self.sess, "./model/model_path")

    def predict(self, data):
    	return self.sess.run([output], feed_dict=feed_dict)

上述方式是从博客看到的,在我的实验中,并有有成功地将多个模型恢复,我的恢复方式是:

1
2
3
4
5
6
class ImportGraph():
	tf.reset_default_graph()  # The default graph is a property of the current thread. 重置当前线程中的默认图
	self.sess = tf.Session()
    self.sess.run(tf.global_variables_initializer())
    self.saver = tf.train.Saver()
    self.saver.restore(self.sess, "./model/model_path")

重要的地方在于tf.reset_default_graph(),tf官方文档给出的解释是:

1
2
3
4
5
6
tf.reset_default_graph()
Defined in tensorflow/python/framework/ops.py.

Clears the default graph stack and resets the global default graph.

NOTE: The default graph is a property of the current thread. This function applies only to the current thread. Calling this function while a tf.Session or tf.InteractiveSession is active will result in undefined behavior. Using any previously created tf.Operation or tf.Tensor objects after calling this function will result in undefined behavior.