Shihanmax's blog

< Back

ONNX小记

ONNX([ˈo:nʏks])是一种开源的机器学习模型协议,用于进行平台无关的模型推理,能够在一定程度上提升模型的迁移性和推理效率。本文整理了之前记录的笔记,可以简单了解ONNX及模型转换与推理的方法。

一、ONNX是什么&必要性

ONNX is an open format built to represent machine learning models. ONNX defines a common set of operators - the building blocks of machine learning and deep learning models - and a common file format to enable AI developers to use models with a variety of frameworks, tools, runtimes, and compilers.

Key benefits

  • Interoperability Develop in your preferred framework without worrying about downstream inferencing implications. ONNX enables you to use your preferred framework with your chosen inference engine.
  • Hardware Access ONNX makes it easier to access hardware optimizations. Use ONNX-compatible runtimes and libraries designed to maximize performance across hardware.

二、pytorch模型forward函数定义注意事项

2.1 if分支相关

forward中存在的if分支,会根据实际情况,在生成ONNX时,仅保留一个分支的路径,因此,如无必要,尽量不要在forward作分支选择。

2.2 入口函数相关

实践中,即使不定义forward,使用其他函数名来作为网络结构的定义函数也是可以的,但由于ONNX转换时,无法指定计算图的入口函数,因此它只会默认采用forward作为其入口。在模型定义时,如果后续有转ONNX的需求,则需谨慎定义forward的内容,其内部计算需要限制在我们推理时的使用场景下,即考虑:“当我的模型部署时,forward是否能够接收处理后的内容,并且返回我希望的结果(如分类任务中的logits、解码任务中的下一个token)”,类似计算出loss、beam_search等操作,则应定义在别的forward_xxx函数中;比如,训练、loss计算等op不应杂糅到一个forward中。

2.3 参数、返回值的要求

当前,JIT的输入/输出支持list、tuple、Variable;Dict和string也是支持的,但不推荐使用。因此,我们的forward的输入参数和返回值应避免使用不支持的类型,包括int、float、None等,否则会无法通过内部的格式检查,导致ONNX转换失败;

三、ONNX转换流程

torch模型支持通过torch.onnx.export来输出ONNX格式的模型,示例代码和解释如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
torch.onnx.export(
    model, # 需要export的模型
    dummy_input, # tuple/list格式的输入(forward的输入,bs维可以设置为1)
    "bert_large.onnx", # 输出名
    verbose=True, # 是否在转换完成后打印模型结构信息
    do_constant_folding=True, # constant-folding optimization
    input_names=input_names, # 为输入取别名
    output_names=output_names, # 为输出取别名
    dynamic_axes={ # 设置动态维度(一般为batch_size维设置)
        'input_0': {0: 'batch_size'},
        'input_1': {0: 'batch_size'},
        'output_0': {0: 'batch_size'},
    },
)

四、ONNX的CUDA推理

By default, ONNX Runtime always places input(s) and output(s) on CPU, which is not optimal if the input or output is consumed and produced on a device other than CPU because it introduces data copy between CPU and the device. ONNX Runtime provides a feature, IO Binding, which addresses this issue by enabling users to specify which device to place input(s) and output(s) on. Here are scenarios to use this feature.

ONNX使用CUDA推理的demo:

1
2
3
4
5
6
7
8
9
10
ort_session = ort.InferenceSession("./path/to/your.onnx") # 实例化一个session
io_binding = ort_session.io_binding() # 示例化一个binding对象
input_0, input_1 = inp_
inp_0 = ort.OrtValue.ortvalue_from_numpy(input_0, device_name, device_idx) # 将输入转为OrtValue类型
inp_1 = ort.OrtValue.ortvalue_from_numpy(input_1, device_name, device_idx)
io_binding.bind_ortvalue_input("input_0", inp_0) # 输入绑定
io_binding.bind_ortvalue_input("input_1", inp_1)
io_binding.bind_output('output_0') # 输出绑定
output = ort_session.run_with_iobinding(io_binding) # 使用绑定进行推理
output = io_binding.copy_outputs_to_cpu()[0] # 通过绑定对象将结果拷贝到cpu

五、ONNX在CPU和CUDA上的推理速度与torch对比

本文使用BERT-base-cased在文本分类任务上,分别在CPU和CUDA两种设备上对比了使用ONNX和使用原生torch模型的推理速度。从实验结果来看,在CPU上进行推理时,ONNX相对于torch原生的推理速度快约6~9倍;在CUDA上进行推理时,ONNX相对于torch原生推理约快1.3~1.5倍。

参考

  1. ONNXruntime
  2. demo