0

_rebuild_tensor_v2?pytorch版本间模型兼容性脱坑实践

最近使用Pytorch 0.4.0 进行模型训练,之后使用一个转模型的工具时,报了一个错,就是标题里面的_rebuild_tensor_v2相关的错误。最后发现是本地使用的pytorch的版本是0.3.0,和0.4.0模型上不兼容。各论坛上的解决方案都是说pytorch版本不向后兼容,建议升级pytorch。无奈我这里不方便升级pytorch版本。那么问题就来了,有没有什么不需要修改pytorch源码,或是不升级pytorch,又能让老版本的pytorch读取新版本模型的方案呢?

当然是有的,而且工作量很小。

一、Pytorch模型存储和读取的流程

首先,我们使用pytorch存储模型会使用 torch.save 这个函数,直接将模型的state_dict()保存下来。类似下面的代码:

读取参数的代码也十分简单:

而低版本的pytorch就是在load_state_dict这里报了错。

二、State Dict

我们首先要知道,model.state_dict()的返回值究竟是什么。

这里我直接给出结论:

model.state_dict()的返回值是一个collections.OrderedDict对象,它的键是一个字符串,它的值是Tensor的对象。所以造成兼容性问题的其实是Tensor对象的不兼容。

那么是不是可以将Tensor转化成一个新的非Pytorch内置的数据类型呢?这样就可以避免兼容性问题。

numpy.ndarray就是我们需要的中间态。

三、模型转换

首先,我们需要将state_dict的参数转换成numpy.ndarray保存下来。这里使用高版本的pytorch。

之后,用低版本的pytorch载入这个numpy的state_dict。

四、总结

对于这个问题,还有很多的解决方案,这里是比较简单的一种。

PS. 这是目前为止,写的最快的一篇博客了。。。

 

转载请注明出处,谢谢!

打赏作者
miao

miao

发表评论

电子邮件地址不会被公开。