class DmaCopyBackend:
"""cuMemcpyBatchAsync copy backend (background thread)."""
def __init__(self) -> None:
self._store_params: BatchMemcpyParams | None = None
self._load_params: BatchMemcpyParams | None = None
self._load_stream: torch.cuda.Stream | None = None
self._store_stream: torch.cuda.Stream | None = None
self._queue: queue.SimpleQueue | None = None
self._thread: threading.Thread | None = None
self._shutdown: bool = False
def init(
self,
gpu_caches: dict[str, torch.Tensor],
cpu_caches: dict[str, torch.Tensor],
device: torch.device,
load_stream: torch.cuda.Stream,
store_stream: torch.cuda.Stream,
) -> None:
self._load_stream = load_stream
self._store_stream = store_stream
self._store_params = build_params(gpu_caches, cpu_caches, store_stream)
self._load_params = build_params(cpu_caches, gpu_caches, load_stream)
self._queue = queue.SimpleQueue()
self._thread = threading.Thread(
target=self._copy_loop,
args=(self._queue, device, load_stream, store_stream),
daemon=True,
)
self._thread.start()
def launch_copy(
self,
src_blocks: list[int],
dst_blocks: list[int],
is_store: bool,
event_idx: int,
events_list: list[tuple[int, torch.Event]],
) -> None:
params = self._store_params if is_store else self._load_params
assert params is not None and self._queue is not None
self._queue.put(
(src_blocks, dst_blocks, params, is_store, event_idx, events_list)
)
def shutdown(self) -> None:
if self._shutdown:
return
self._shutdown = True
if self._queue is not None:
self._queue.put(None)
if self._thread is not None:
self._thread.join(timeout=5.0)
@staticmethod
def _copy_loop(
q: queue.SimpleQueue,
device: torch.device,
load_stream: torch.cuda.Stream,
store_stream: torch.cuda.Stream,
) -> None:
current_platform.set_device(device)
while True:
item = q.get()
if item is None:
return
src_blocks, dst_blocks, params, is_store, event_idx, events_list = item
copy_blocks(src_blocks, dst_blocks, params)
stream = store_stream if is_store else load_stream
event = torch.Event()
event.record(stream)
events_list.append((event_idx, event))