ONNX小记
ONNX([ˈo:nʏks])是一种开源的机器学习模型协议,用于进行平台无关的模型推理,能够在一定程度上提升模型的迁移性和推理效率。本文整理了之前记录的笔记,可以简单了解ONNX及模型转换与推理的方法。
一、ONNX是什么&必要性¶
ONNX is an open format built to represent machine learning models. ONNX defines a common set of operators - the building blocks of machine learning and deep learning models - and a common file format to enable AI developers to use models with a variety of frameworks, tools, runtimes, and compilers.
Key benefits
- Interoperability
Develop in your preferred framework without worrying about downstream inferencing implications. ONNX enables you to use your preferred framework with your chosen inference engine. - Hardware Access
ONNX makes it easier to access hardware optimizations. Use ONNX-compatible runtimes and libraries designed to maximize performance across hardware.
二、pytorch模型forward函数定义注意事项¶
2.1 if分支相关¶
forward中存在的if分支,会根据实际情况,在生成ONNX时,仅保留一个分支的路径,因此,如无必要,尽量不要在forward作分支选择。
2.2 入口函数相关¶
实践中,即使不定义forward,使用其他函数名来作为网络结构的定义函数也是可以的,但由于ONNX转换时,无法指定计算图的入口函数,因此它只会默认采用forward作为其入口。在模型定义时,如果后续有转ONNX的需求,则需谨慎定义forward的内容,其内部计算需要限制在我们推理时的使用场景下,即考虑:“当我的模型部署时,forward是否能够接收处理后的内容,并且返回我希望的结果(如分类任务中的logits、解码任务中的下一个token)”,类似计算出loss、beam_search等操作,则应定义在别的forward_xxx函数中;比如,训练、loss计算等op不应杂糅到一个forward中。
2.3 参数、返回值的要求¶
当前,JIT的输入/输出支持list、tuple、Variable;Dict和string也是支持的,但不推荐使用。因此,我们的forward的输入参数和返回值应避免使用不支持的类型,包括int、float、None等,否则会无法通过内部的格式检查,导致ONNX转换失败;
三、ONNX转换流程¶
torch模型支持通过torch.onnx.export来输出ONNX格式的模型,示例代码和解释如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
|