Skip to content

Commit

Permalink
parallel_loader: Fix TPU memory leak when calling __iter__. (#8039)
Browse files Browse the repository at this point in the history
  • Loading branch information
dudulightricks committed Sep 20, 2024
1 parent d0ea5cc commit d79a37c
Showing 1 changed file with 43 additions and 25 deletions.
68 changes: 43 additions & 25 deletions torch_xla/distributed/parallel_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,14 @@ def __init__(self,
self._done = False
self._queues = dict()
self._input_sharding = input_sharding
self._threads = []
for device in self._devices:
self._queues[device] = PerDeviceQueue(device, loader_prefetch_size,
device_prefetch_size)
thread = threading.Thread(target=self._loader_worker)
thread.daemon = True
thread.start()
self._threads.append(thread)
for dqueue in self._queues.values():
for i in range(host_to_device_transfer_threads):
thread = threading.Thread(
Expand All @@ -111,6 +113,7 @@ def __init__(self,
))
thread.daemon = True
thread.start()
self._threads.append(thread)

def per_device_loader(self, device):
"""Retrieves the loader iterator object for the given device.
Expand Down Expand Up @@ -139,6 +142,9 @@ def close(self):
dqueue.queue.close()
dqueue.loader_queue.close()

for thread in self._threads:
thread.join()

@property
def batches_per_execution(self):
return self._batches_per_execution
Expand All @@ -147,18 +153,21 @@ def _loader_worker(self):
queues = list(self._queues.values())
data_iter = enumerate(self._loader)
batch = []
while not self._done:
try:
_, data = next(data_iter)
except StopIteration:
break
batch.append(data)
if len(batch) == len(self._devices):
for queue_no, device_batch in enumerate(batch):
queues[queue_no].loader_queue.put(device_batch)
batch = []
for dqueue in queues:
dqueue.loader_queue.close_write()

try:
while not self._done:
try:
_, data = next(data_iter)
except StopIteration:
break
batch.append(data)
if len(batch) == len(self._devices):
for queue_no, device_batch in enumerate(batch):
queues[queue_no].loader_queue.put(device_batch)
batch = []
finally:
for dqueue in queues:
dqueue.loader_queue.close_write()

def _get_batch(self, dqueue):
batch = []
Expand All @@ -171,16 +180,21 @@ def _get_batch(self, dqueue):

def _worker(self, dqueue, host_to_device_transfer_threads):
device = torch.device(dqueue.device)
while True:
batch = self._get_batch(dqueue)
if not batch:
break
batch = xm.send_cpu_data_to_device(batch, device, self._input_sharding)
for data in batch:
dqueue.queue.put(data)
close_queue_count = next(dqueue.close_queue_count)
if close_queue_count == host_to_device_transfer_threads - 1:
dqueue.queue.close_write()

try:
while True:
batch = self._get_batch(dqueue)
if not batch:
break
with torch.no_grad():
batch = xm.send_cpu_data_to_device(batch, device,
self._input_sharding)
for data in batch:
dqueue.queue.put(data)
finally:
close_queue_count = next(dqueue.close_queue_count)
if close_queue_count == host_to_device_transfer_threads - 1:
dqueue.queue.close_write()


class MpDeviceLoader(object):
Expand All @@ -206,11 +220,15 @@ def __init__(self, loader, device, **kwargs):
self._loader = loader
self._device = device
self._parallel_loader_kwargs = kwargs
self._parallel_loader = None

def __iter__(self):
parallel_loader = ParallelLoader(self._loader, [self._device],
**self._parallel_loader_kwargs)
return parallel_loader.per_device_loader(self._device)
if self._parallel_loader is not None:
self._parallel_loader.close()
self._parallel_loader = None
self._parallel_loader = ParallelLoader(self._loader, [self._device],
**self._parallel_loader_kwargs)
return self._parallel_loader.per_device_loader(self._device)

def __len__(self):
return len(self._loader)

0 comments on commit d79a37c

Please sign in to comment.