diff --git a/streaming/base/spanner.py b/streaming/base/spanner.py index 221491f72..0e186ad54 100644 --- a/streaming/base/spanner.py +++ b/streaming/base/spanner.py @@ -6,7 +6,6 @@ import numpy as np from numpy.typing import NDArray - class Spanner: """Given a list of shards, construct a mapping of global index to shard and relative index. @@ -25,10 +24,34 @@ def __init__(self, shard_sizes: NDArray[np.int64], span_size: int = 1 << 10) -> underflow = span_size - overflow if overflow else 0 self.shard_sizes[-1] += underflow - sample_shards = np.repeat(np.arange(len(shard_sizes)), self.shard_sizes) - sample_shards = sample_shards.reshape(-1, span_size) - span_lowest_shards = sample_shards.min(1) - span_highest_shards = sample_shards.max(1) + n_shards = len(shard_sizes) + current_shard = 0 + current_position_in_shard = 0 + + span_lowest_shards = [] + span_highest_shards = [] + + while current_shard < n_shards: + span_min_shard = current_shard + span_max_shard = current_shard + + remaining_span_size = span_size + while remaining_span_size > 0 and current_shard < n_shards: + available_in_current_shard = shard_sizes[current_shard] - current_position_in_shard + + if remaining_span_size >= available_in_current_shard: + remaining_span_size -= available_in_current_shard + current_shard += 1 + current_position_in_shard = 0 + else: + current_position_in_shard += remaining_span_size + remaining_span_size = 0 + + if current_shard < n_shards: + span_max_shard = current_shard + + span_lowest_shards.append(span_min_shard) + span_highest_shards.append(span_max_shard) self.spans = [] for low, high in zip(span_lowest_shards, span_highest_shards):