在项目中对某些项目进行tensorrt加速的时候发现会报如下的错误。
上述错误大概是说view输入参数大小是512x7x7的,输出的参数大小确是512x36的,如此造成了输入输出规格不同。但是这在实际的代码中打断点调试根本看不出来,必须深入到tensorrt代码去查。
- 打开tensorrt源码中的view.py文件查看并将输入输出的数据进行shape对比,如下
def convert_view(ctx): input = ctx.method_args[0] print('before input.shape = ',input.shape) input_trt = add_missing_trt_tensors(ctx.network, [input])[0] print('after input.shape = ',input_trt.shape) output = ctx.method_return layer = ctx.network.add_shuffle(input_trt) layer.reshape_dims = tuple(output.shape[1:]) output._trt = layer.get_output(0)
然后运行输出,看调试信息
这就和错误信息对上号了。看来是add_missing_trt_tensors函数改变了原数据的shape
- 打开torch2trt.py文件并定位到目标函数,通过调试发现数据走的是第二个if分支,即是
看样子是t._trt的问题。于是我们在开始的时候添加一个print函数将这两个的shape打印出来
def add_missing_trt_tensors(network, tensors): """Creates missing TensorRT tensors as constants and attaches them to the Torch Tensors""" trt_tensors = [None] * len(tensors) dtype = check_torch_dtype(*tensors) for i, t in enumerate(tensors): print('add_missing i = {}, t.shape = {} t._trt={}'.format(i,t.shape,t._trt.shape)) trt_tensor = None # GET TRT TENSOR (OR CREATE TRT CONSTANT) # get tensor w/ _trt # or... add constant for scalar primitive if isinstance(t, float) or isinstance(t, int): #省略内容 elif hasattr(t, "_trt"): trt_tensor = t._trt # or... add constant for leaf tensor w/o _trt else: #省略内容 assert trt_tensor is not None trt_tensors[i] = trt_tensor return trt_tensors
最后结果如下:
通过结果我们可以看到,前面t和t._trt的shape都是相同的,但是后期就不一样了,而程序里又执行了trt_tensor = t._trt,这就导致了问题的产生。
- 经验证,这是由于输入图像规格太大导致的,将输入数据规格缩小即可避免这个问题,而这个数字如果不正确的话,即使解决了view的问题,也可能会出现下面中所示的问题。
- 我将规格改成196,完美的解决上述两个问题。