我尝试将PyTorchFOMM模型转换为TorchScript,刚开始用@torch.jit.script
注解一些类就出现错误:OSError: Can't get source for <class 'collections.deque'>. TorchScript requires source access in order to carry out compilation, make sure original .py files are available.
据我所知,在CPython中实现的类因此不能被TorchScript编译器读取。我没有找到任何纯Python实现。我如何克服这个问题?
下面是我尝试注解的类:
import queue
import collections
import threading
import torch
@torch.jit.script
class SyncMaster(object):
"""An abstract `SyncMaster` object.
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""
def __init__(self, master_callback):
"""
Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False
def __getstate__(self):
return {'master_callback': self._master_callback}
def __setstate__(self, state):
self.__init__(state['master_callback'])
def register_slave(self, identifier):
"""
Register an slave device.
Args:
identifier: an identifier, usually is the device id.
Returns: a `SlavePipe` object which can be used to communicate with the master device.
"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)
def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).
Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
Returns: the message to be sent back to the master device.
"""
self._activated = True
intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())
results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'
for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)
for i in range(self.nr_slaves):
assert self._queue.get() is True
return results[0][1]
@property
def nr_slaves(self):
return len(self._registry)
2条答案
按热度按时间mf98qq941#
将TorchScript生成方法从
torch.jit.script
切换到torch.jit.trace
,它可以工作,不需要注解任何东西。或者torch.onnx.export
有时也可以工作。xnifntxz2#
我在尝试在使用torch的Python脚本上使用PyInstaller时遇到了这个问题。我在这个Github线程中执行了第3步,将
modeling_deberta.py
中的标签更改为@torch.jit._script_if_tracing
。(请注意,在Github的答案中,git clone
中有一个错别字,其中的“transormers”而不是“transformers”,并且文件路径略有不同:src/transformers/models/deberta/modeling_deberta.py
。为了安全起见,我也在modeling_deberta_v2.py
中做了。)