gdyshi | 孤独隐士 MindMap, Machine Learning, Medical Equipment

tensorflow模型部署系列————预训练模型导出(附代码)


摘要

本文为系列博客tensorflow模型部署系列的一部分,用于为模型部署提供最开始的输入————标准化的模型文件。相关源码见链接


引言

本文为系列博客tensorflow模型部署系列的一部分,用于为模型部署提供最开始的输入————标准化的模型文件。相关示例代码放在gdyshi的github

主题

可保存的模型格式有多种,本文仅针对 tensorflow 的默认格式ckptkeras 的默认格式h5tensorflow例程常用格式pb进行说明

tensorflow的ckpt格式

ckptcheckpoint的简称,是tensorflow官方使用的模型文件。保存模型时会同时生成4个文件,这些文件实质上是代码,有被注入的风险,详见SECURITY

  • checkpoint

    checkpoint是文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model

  • *.meta

    meta文件是pb(protocol buffer)格式文件,保存的是图结构。包含变量、op、集合等

  • *.index
  • *.data-*

    ckpt文件和index文件是二进制文件,保存了所有的weights、biases、gradients等数据。

ckpt文件的保存语句为

saver = tf.train.Saver()
saver.save(sess, './saved_tf/tf_model')

ckpt文件的恢复语句为

new_saver = tf.train.import_meta_graph('./saved_tf/tf_model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint(''./saved_tf'))

keras的h5格式

h5格式是keras框架所使用的保存格式,使用HDF5编码。

h5文件的保存语句为

# 保存模型和权重
model.save('./saved_keras/save.h5')
# 仅保存模型
model.save_weights('./saved_keras/save_weights.h5')

h5文件的恢复语句为

# 恢复模型和权重
model = keras.models.load_model( filepath )
# 恢复权重
model.load_weights('my_model_weights.h5',by_name=True)

pb格式

前面所述在ckpt文件变量数据和图是各自独立的文件存储的。这种解耦形式存在的方法对以后的迁移学习以及对程序进行微小的改动提供了极大的便利性。但是对于已经训练好,需要部署的模型来说,把整个模型保存为一个文件则更方便。tensorflow例程常见的是pb文件的形式。pb文件实际上是一个较广范围的概念,泛指以protocol buffer格式存储的文件,我写的tensorflow模型部署系列中所说的pb文件是狭义的概念,指protocol buffer格式存储的模型图和模型参数文件,这一系列博客也将以pb格式为主要格式来进行部署

pb文件可以保存整个图表(元+数据),并将所有的变量固化为常量

tensorflow模型转换为pb格式文件的语句为

frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                              output_names, freeze_var_names)
graph_io.write_graph(frozen_graph, output_path, pb_model_name, as_text=False)

示例代码

附录

参考



Similar Posts

Content