From fb1734fbce3e5c9dadff8d8bd838f71b1aa2130b Mon Sep 17 00:00:00 2001 From: ArrayRecord Team Date: Wed, 3 Jun 2026 15:35:23 -0700 Subject: [PATCH] Integrate BoundedReaderPool into PyGrain to support thread-safe, concurrent data loading. PiperOrigin-RevId: 926309038 --- python/array_record_data_source.py | 298 +++++++++++++++++++----- python/array_record_data_source_test.py | 11 +- 2 files changed, 250 insertions(+), 59 deletions(-) diff --git a/python/array_record_data_source.py b/python/array_record_data_source.py index 117d244..8166955 100644 --- a/python/array_record_data_source.py +++ b/python/array_record_data_source.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """array_record_data_source module. Warning: this is an experimental module. The interface might change in the @@ -23,15 +22,20 @@ ``` class RandomAccessDataSource(Protocol, Generic[T]): + def __len__(self) -> int: ... - def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]: + def __getitem__(self, record_key: SupportsIndex) -> T: + ... + + def __getitems__(self, record_keys: Sequence[SupportsIndex]) -> Sequence[T]: ... ``` """ import bisect +import collections from concurrent import futures import dataclasses import hashlib @@ -39,6 +43,7 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]: import os import pathlib import re +import threading import typing from typing import Any, Callable, Iterator, List, Mapping, Protocol, Sequence, SupportsIndex, Tuple, TypeVar, Union @@ -48,6 +53,29 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]: from . import array_record_module +T = TypeVar("T") + + +@typing.runtime_checkable +class FileInstruction(Protocol): + """Protocol with same interface as FileInstruction returned by TFDS. + + ArrayRecordDataSource would accept objects implementing this protocol without + depending on TFDS. + """ + + filename: str + skip: int + take: int + examples_in_shard: int + + +PathLikeOrFileInstruction = Union[epath.PathLike, FileInstruction] +ArrayRecordDataSourcePaths = Union[ + PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction] +] + + # TODO(jolesiak): Decide what to do with these flags, e.g., remove them (could # be appropriate if we decide to use asyncio) or move them somewhere else and # pass the number of threads as an argument. For now, since we experiment, it's @@ -70,8 +98,11 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]: "records)." ), ) - -T = TypeVar("T") +_ARRAY_RECORD_READER_POOL_SIZE = flags.DEFINE_integer( + "array_record_reader_pool_size", + None, + "The default reader pool size per shard in ArrayRecordDataSource.", +) def _run_in_parallel( @@ -96,6 +127,7 @@ def _run_in_parallel( """ if num_workers < 1: raise ValueError("num_workers must be >=1 for parallelism.") + thread_futures = [] with futures.ThreadPoolExecutor(num_workers) as executor: for kwargs in list_of_kwargs_to_function: @@ -125,23 +157,6 @@ def __post_init__(self): object.__setattr__(self, "num_records", self.end - self.start) -@typing.runtime_checkable -class FileInstruction(Protocol): - """Protocol with same interface as FileInstruction returned by TFDS. - - ArrayRecordDataSource would accept objects implementing this protocol without - depending on TFDS. - """ - - filename: str - skip: int - take: int - examples_in_shard: int - - -PathLikeOrFileInstruction = Union[epath.PathLike, FileInstruction] - - def _get_read_instructions( paths: Sequence[PathLikeOrFileInstruction], ) -> Sequence[_ReadInstruction]: @@ -212,8 +227,162 @@ def _check_group_size( ) +class _BoundedReaderPoolBorrowContext: + """Context manager for borrowing a reader safely from a _BoundedReaderPool. + + Ensures that the borrowed reader is always returned to the pool, even if + exceptions are raised within the borrowing thread's critical section. + """ + + def __init__(self, pool: "_BoundedReaderPool"): + self._pool = pool + self._reader = None + + def __enter__(self) -> Any: + self._reader = self._pool.get() + return self._reader + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + if self._reader is not None: + self._pool.put(self._reader) + + +class _BoundedReaderPool: + """A semaphore-throttled thread-safe connection pool for a single shard. + + This pool maintains and recycles expensive, non-thread-safe reader instances + (such as `ArrayRecordReader`) to enable parallel reads without lock + contention. + + This is a private class. Since it is not RAII, directly calling `get()` and + `put()` is subject to a risk of deadlock upon exception handling. Callers + MUST use the context-manager based borrowing pattern instead: + with pool.borrow() as reader: + # Perform read operations + + Concurrency Model (Permit/Ownership Flow): + 1. A thread calls `get()` to acquire a reader. This blocks if the number + of active readers has reached `max_size` (acquires a semaphore permit). + 2. The calling thread is now the exclusive owner of the reader and can + safely perform non-thread-safe read operations on it. + 3. Once reading is complete, the thread MUST call `put(reader)` to return + the reader. This recycles the reader and releases the connection slot + (releases the semaphore permit). + + To guarantee safe lease return, callers are strongly encouraged to use the + context-manager based borrowing pattern: + with pool.borrow() as reader: + # Perform read operations + + WARNING: Failing to return a borrowed reader via `put()` will permanently + leak a semaphore permit, eventually causing all subsequent `get()` calls + to deadlock when the cap is reached. + + Teardown & Lifecycle: + Calling `close_all()` marks the pool as closed and immediately closes all + idle readers. Any outstanding borrowed readers will be closed immediately + upon their return via `put()`, ensuring zero file descriptor leaks during + concurrent shutdown sequences. + """ + + def __init__(self, filename: str, options_string: str, max_size: int = 1): + self._filename = filename + self._options_string = options_string + self._max_size = max_size + self._readers = collections.deque() + # Use BoundedSemaphore to strictly enforce the max_size cap + self._semaphore = threading.BoundedSemaphore(max_size) + self._lock = threading.Lock() + self._group_size_checked = False + self._closed = False + + def get(self) -> Any: + """Acquires a reader from the pool, blocking if the active reader cap is reached. + + If the pool is empty but the cap has not been reached, a new reader is + instantiated. If the pool already has idle readers, one is returned + instantly without blocking. + + Returns: + A reader instance. Callers must use the borrow() context manager. + """ + self._semaphore.acquire() + + # Try to get an existing reader from the deque (lock-free popleft) + try: + return self._readers.popleft() + except IndexError: + pass + + # No idle reader; create a new one under lock + reader = None + try: + with self._lock: + if self._closed: + raise RuntimeError( + f"Cannot get reader from closed pool: {self._filename}" + ) + reader = _create_reader(self._filename, self._options_string) + if not self._group_size_checked: + _check_group_size(self._filename, reader) + self._group_size_checked = True + return reader + except Exception: + if reader and hasattr(reader, "close"): + reader.close() + self._semaphore.release() + raise + + def put(self, reader: Any) -> None: + """Returns a reader to the pool, recycling it for future operations. + + If the pool has been closed in the interim, the reader is closed + immediately. + + Args: + reader: The reader instance previously obtained from `get()`. + """ + with self._lock: + if self._closed: + # If the pool was closed while the reader was borrowed, close it + # immediately. + if reader and hasattr(reader, "close"): + reader.close() + self._semaphore.release() + return + + self._readers.append(reader) + self._semaphore.release() + + def borrow(self) -> _BoundedReaderPoolBorrowContext: + """Returns a context manager to borrow a reader safely. + + Usage: + with pool.borrow() as reader: + # Perform read operations + """ + return _BoundedReaderPoolBorrowContext(self) + + def close_all(self) -> None: + """Closes all pooled readers and prevents future allocations.""" + with self._lock: + self._closed = True + + while True: + try: + reader = self._readers.popleft() + if reader and hasattr(reader, "close"): + reader.close() + except IndexError: + break + + def peek_readers(self) -> List[Any]: + """Returns the list of readers (for testing only).""" + return list(self._readers) + + class ArrayRecordDataSource: - """Datasource for ArrayRecord files.""" + """Datasource for ArrayRecord files using a Lock-Free Connection Pool.""" def __init__( self, @@ -221,6 +390,7 @@ def __init__( PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction] ], reader_options: dict[str, str] | None = None, + reader_pool_size: int | None = None, ): """Creates a new ArrayRecordDataSource object. @@ -242,6 +412,7 @@ def __init__( initialization faster. reader_options: string of comma-separated options to be passed when creating a reader. + reader_pool_size: The maximum number of readers to keep open per shard. """ if isinstance(paths, (str, pathlib.Path, FileInstruction)): paths = [paths] @@ -270,8 +441,18 @@ def __init__( ) self._read_instructions = _get_read_instructions(paths) self._paths = [ri.filename for ri in self._read_instructions] - # We open readers lazily when we need to read from them. - self._readers = [None] * len(self._read_instructions) + self._reader_pool_size = ( + reader_pool_size or _get_flag_value(_ARRAY_RECORD_READER_POOL_SIZE) or 1 + ) + + # Lock-free connection pool per shard + self._shard_pools = [ + _BoundedReaderPool( + ri.filename, self._reader_options_string, self._reader_pool_size + ) + for ri in self._read_instructions + ] + self._num_records = sum( map(lambda x: x.num_records, self._read_instructions) ) @@ -286,10 +467,8 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): logging.debug("__exit__ for ArrayRecordDataSource is called.") - for reader in self._readers: - if reader: - reader.close() - self._readers = [None] * len(self._read_instructions) + for pool in self._shard_pools: + pool.close_all() def __len__(self) -> int: return self._num_records @@ -329,48 +508,44 @@ def _split_keys_per_reader( positions_and_indices[reader_idx] = [(position, idx)] return positions_and_indices - def _ensure_reader_exists(self, reader_idx: int) -> None: - """Threadsafe method to create corresponding reader if it doesn't exist.""" - if self._readers[reader_idx] is not None: - return - filename = self._read_instructions[reader_idx].filename - reader = _create_reader(filename, self._reader_options_string) - _check_group_size(filename, reader) - self._readers[reader_idx] = reader + def _read_record(self, reader: Any, position: int) -> bytes: + """Helper to read a record using the best available method.""" + if hasattr(reader, "read_record"): + return reader.read_record(position) + if hasattr(reader, "read"): + return reader.read([position])[0] + return reader[position] def __getitem__(self, record_key: SupportsIndex) -> bytes: - reader_idx, position = self._reader_idx_and_position(record_key) - self._ensure_reader_exists(reader_idx) - if hasattr(self._readers[reader_idx], "read"): - return self._readers[reader_idx].read([position])[0] - return self._readers[reader_idx][position] + pool_idx, position = self._reader_idx_and_position(record_key) + with self._shard_pools[pool_idx].borrow() as reader: + return self._read_record(reader, position) def __getitems__( self, record_keys: Sequence[SupportsIndex] ) -> Sequence[bytes]: + def read_records( - reader_idx: int, reader_positions_and_indices: Sequence[Tuple[int, int]] + pool_idx: int, reader_positions_and_indices: Sequence[Tuple[int, int]] ) -> Sequence[Tuple[Any, int]]: """Reads records using the given reader keeping track of the indices.""" - # Initialize readers lazily when we need to read from them. - self._ensure_reader_exists(reader_idx) - positions, indices = list(zip(*reader_positions_and_indices)) - if hasattr(self._readers[reader_idx], "read"): - records = self._readers[reader_idx].read(positions) # pytype: disable=attribute-error - else: - records = [self._readers[reader_idx][p] for p in positions] - return list(zip(records, indices)) + with self._shard_pools[pool_idx].borrow() as reader: + records = [] + for position, _ in reader_positions_and_indices: + records.append(self._read_record(reader, position)) + indices = [idx for _, idx in reader_positions_and_indices] + return list(zip(records, indices)) positions_and_indices = self._split_keys_per_reader(record_keys) num_threads = _get_flag_value(_GRAIN_NUM_THREADS_FETCHING_RECORDS) num_workers = min(len(positions_and_indices), num_threads) list_of_kwargs_to_read_records = [] for ( - reader_idx, + pool_idx, reader_positions_and_indices, ) in positions_and_indices.items(): list_of_kwargs_to_read_records.append({ - "reader_idx": reader_idx, + "pool_idx": pool_idx, "reader_positions_and_indices": reader_positions_and_indices, }) records_with_indices: Sequence[Sequence[Tuple[Any, int]]] = ( @@ -390,7 +565,7 @@ def read_records( def __getstate__(self): logging.debug("__getstate__ for ArrayRecordDataSource is called.") state = self.__dict__.copy() - del state["_readers"] + state.pop("_shard_pools", None) return state def __setstate__(self, state): @@ -398,7 +573,14 @@ def __setstate__(self, state): self.__dict__.update(state) # We open readers lazily when we need to read from them. Thus, we don't # need to re-open the same files as before pickling. - self._readers = [None] * len(self._read_instructions) + self._shard_pools = [ + _BoundedReaderPool( + ri.filename, + self._reader_options_string, + getattr(self, "_reader_pool_size", 1), + ) + for ri in self._read_instructions + ] def __repr__(self) -> str: """Storing a hash of paths since paths can be a very long list.""" @@ -407,6 +589,14 @@ def __repr__(self) -> str: h.update(p.encode()) return f"ArrayRecordDataSource(hash_of_paths={h.hexdigest()})" + def _peek_readers(self) -> List[Any]: + """Returns a list of readers (one per shard) or None (for testing only).""" + readers = [] + for pool in self._shard_pools: + pooled_readers = pool.peek_readers() + readers.append(pooled_readers[-1] if pooled_readers else None) + return readers + def _get_flag_value(flag: flags.FlagHolder[int]) -> int: """Retrieves the flag value or the default if run outside of absl.""" diff --git a/python/array_record_data_source_test.py b/python/array_record_data_source_test.py index 8977a27..b4de730 100644 --- a/python/array_record_data_source_test.py +++ b/python/array_record_data_source_test.py @@ -17,6 +17,7 @@ import dataclasses import os import pathlib +import pickle from unittest import mock from absl import flags @@ -109,7 +110,7 @@ def test_array_record_data_source_single_path(self): ) as ar: actual_data = [ar[x] for x in indices_to_read] self.assertEqual(expected_data, actual_data) - self.assertTrue(all(reader is None for reader in ar._readers)) + self.assertTrue(all(reader is None for reader in ar._peek_readers())) def test_array_record_data_source_string_read_instructions(self): indices_to_read = [0, 1, 2, 3, 4] @@ -132,7 +133,7 @@ def test_array_record_data_source_reverse_order(self): ]) as ar: actual_data = [ar[x] for x in indices_to_read] self.assertEqual(expected_data, actual_data) - self.assertTrue(all(reader is None for reader in ar._readers)) + self.assertTrue(all(reader is None for reader in ar._peek_readers())) def test_array_record_data_source_random_order(self): # some random permutation @@ -144,7 +145,7 @@ def test_array_record_data_source_random_order(self): ]) as ar: actual_data = [ar[x] for x in indices_to_read] self.assertEqual(expected_data, actual_data) - self.assertTrue(all(reader is None for reader in ar._readers)) + self.assertTrue(all(reader is None for reader in ar._peek_readers())) def test_array_record_data_source_random_order_batched(self): # some random permutation @@ -156,7 +157,7 @@ def test_array_record_data_source_random_order_batched(self): ]) as ar: actual_data = ar.__getitems__(indices_to_read) self.assertEqual(expected_data, actual_data) - self.assertTrue(all(reader is None for reader in ar._readers)) + self.assertTrue(all(reader is None for reader in ar._peek_readers())) def test_array_record_data_source_file_instructions(self): file_instruction_one = DummyFileInstruction( @@ -187,7 +188,7 @@ def test_array_record_data_source_file_instructions(self): actual_data = [ar[x] for x in indices_to_read] self.assertEqual(expected_data, actual_data) - self.assertTrue(all(reader is None for reader in ar._readers)) + self.assertTrue(all(reader is None for reader in ar._peek_readers())) def test_array_record_source_reader_idx_and_position(self): file_instructions = [