摘要
本文为系列博客tensorflow模型部署系列的一部分,用于为模型部署提供最开始的输入————标准化的模型文件。相关源码见链接
引言
本文为系列博客tensorflow模型部署系列的一部分,用于为模型部署提供最开始的输入————标准化的模型文件。相关示例代码放在gdyshi的github上
主题
可保存的模型格式有多种,本文仅针对 tensorflow
的默认格式ckpt
、 keras
的默认格式h5
和tensorflow
例程常用格式pb
进行说明
tensorflow的ckpt格式
ckpt
是checkpoint
的简称,是tensorflow
官方使用的模型文件。保存模型时会同时生成4个文件,这些文件实质上是代码,有被注入的风险,详见SECURITY
-
checkpoint
checkpoint
是文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model -
*.meta
meta文件是pb(protocol buffer)格式文件,保存的是图结构。包含变量、op、集合等
*.index
-
*.data-*
ckpt文件和index文件是二进制文件,保存了所有的weights、biases、gradients等数据。
ckpt文件的保存语句为
saver = tf.train.Saver()
saver.save(sess, './saved_tf/tf_model')
ckpt文件的恢复语句为
new_saver = tf.train.import_meta_graph('./saved_tf/tf_model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint(''./saved_tf'))
keras的h5格式
h5格式是keras框架所使用的保存格式,使用HDF5编码。
h5文件的保存语句为
# 保存模型和权重
model.save('./saved_keras/save.h5')
# 仅保存模型
model.save_weights('./saved_keras/save_weights.h5')
h5文件的恢复语句为
# 恢复模型和权重
model = keras.models.load_model( filepath )
# 恢复权重
model.load_weights('my_model_weights.h5',by_name=True)
pb格式
前面所述在ckpt文件变量数据和图是各自独立的文件存储的。这种解耦形式存在的方法对以后的迁移学习以及对程序进行微小的改动提供了极大的便利性。但是对于已经训练好,需要部署的模型来说,把整个模型保存为一个文件则更方便。tensorflow例程常见的是pb文件的形式。pb文件实际上是一个较广范围的概念,泛指以protocol buffer
格式存储的文件,我写的tensorflow模型部署系列中所说的pb文件是狭义的概念,指protocol buffer
格式存储的模型图和模型参数文件,这一系列博客也将以pb格式为主要格式来进行部署
pb文件可以保存整个图表(元+数据),并将所有的变量固化为常量
tensorflow模型转换为pb格式文件的语句为
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
graph_io.write_graph(frozen_graph, output_path, pb_model_name, as_text=False)
示例代码
-
模型训练&保存代码
-
模型转换为pb文件
-
模型恢复&推理代码