Skip to content

Commit

Permalink
Refactor spanner to avoid creating large array
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaohanZhangCMU committed Sep 3, 2024
1 parent fac1852 commit 46a4ef7
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions streaming/base/spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
from numpy.typing import NDArray


class Spanner:
"""Given a list of shards, construct a mapping of global index to shard and relative index.
Expand All @@ -25,10 +24,34 @@ def __init__(self, shard_sizes: NDArray[np.int64], span_size: int = 1 << 10) ->
underflow = span_size - overflow if overflow else 0
self.shard_sizes[-1] += underflow

sample_shards = np.repeat(np.arange(len(shard_sizes)), self.shard_sizes)
sample_shards = sample_shards.reshape(-1, span_size)
span_lowest_shards = sample_shards.min(1)
span_highest_shards = sample_shards.max(1)
n_shards = len(shard_sizes)
current_shard = 0
current_position_in_shard = 0

span_lowest_shards = []
span_highest_shards = []

while current_shard < n_shards:
span_min_shard = current_shard
span_max_shard = current_shard

remaining_span_size = span_size
while remaining_span_size > 0 and current_shard < n_shards:
available_in_current_shard = shard_sizes[current_shard] - current_position_in_shard

if remaining_span_size >= available_in_current_shard:
remaining_span_size -= available_in_current_shard
current_shard += 1
current_position_in_shard = 0
else:
current_position_in_shard += remaining_span_size
remaining_span_size = 0

if current_shard < n_shards:
span_max_shard = current_shard

span_lowest_shards.append(span_min_shard)
span_highest_shards.append(span_max_shard)

self.spans = []
for low, high in zip(span_lowest_shards, span_highest_shards):
Expand Down

0 comments on commit 46a4ef7

Please sign in to comment.