Source code for pyathena.polars.result_set

from __future__ import annotations

import logging
from collections import abc
from collections.abc import Callable, Iterator
from multiprocessing import cpu_count
from typing import (
    TYPE_CHECKING,
    Any,
    cast,
)

from pyathena import OperationalError
from pyathena.converter import Converter
from pyathena.error import ProgrammingError
from pyathena.model import AthenaQueryExecution
from pyathena.polars.util import to_column_info
from pyathena.result_set import AthenaResultSet
from pyathena.util import RetryConfig

if TYPE_CHECKING:
    import polars as pl
    from pyarrow import Table

    from pyathena.connection import Connection

_logger = logging.getLogger(__name__)


def _identity(x: Any) -> Any:
    """Identity function for use as default converter."""
    return x


[docs] class PolarsDataFrameIterator(abc.Iterator): # type: ignore[type-arg] """Iterator for chunked DataFrame results from Athena queries. This class wraps either a Polars DataFrame iterator (for chunked reading) or a single DataFrame, providing a unified iterator interface. It applies optional type conversion to each DataFrame chunk as it's yielded. The iterator is used by AthenaPolarsResultSet to provide chunked access to large query results, enabling memory-efficient processing of datasets that would be too large to load entirely into memory. Example: >>> # Iterate over DataFrame chunks >>> for df_chunk in iterator: ... process(df_chunk) >>> >>> # Iterate over individual rows >>> for idx, row in iterator.iterrows(): ... print(row) Note: This class is primarily for internal use by AthenaPolarsResultSet. Most users should access results through PolarsCursor methods. """
[docs] def __init__( self, reader: Iterator[pl.DataFrame] | pl.DataFrame, converters: dict[str, Callable[[str | None], Any | None]], column_names: list[str], ) -> None: """Initialize the iterator. Args: reader: Either a DataFrame iterator (for chunked) or a single DataFrame. converters: Dictionary mapping column names to converter functions. column_names: List of column names in order. """ import polars as pl if isinstance(reader, pl.DataFrame): self._reader: Iterator[pl.DataFrame] = iter([reader]) else: self._reader = reader self._converters = converters self._column_names = column_names
[docs] def __next__(self) -> pl.DataFrame: """Get the next DataFrame chunk. Returns: The next Polars DataFrame chunk. Raises: StopIteration: When no more chunks are available. """ try: return next(self._reader) except StopIteration: self.close() raise
[docs] def __iter__(self) -> PolarsDataFrameIterator: """Return self as iterator.""" return self
[docs] def __enter__(self) -> PolarsDataFrameIterator: """Context manager entry.""" return self
[docs] def __exit__(self, exc_type, exc_value, traceback) -> None: """Context manager exit.""" self.close()
[docs] def close(self) -> None: """Close the iterator and release resources.""" from types import GeneratorType if isinstance(self._reader, GeneratorType): self._reader.close()
[docs] def iterrows(self) -> Iterator[tuple[int, dict[str, Any]]]: """Iterate over rows as (index, row_dict) tuples. Yields: Tuple of (row_index, row_dict) for each row across all chunks. """ row_num = 0 for df in self: for row_dict in df.iter_rows(named=True): # Apply converters (use module-level _identity to avoid creating lambdas) processed_row = { col: self._converters.get(col, _identity)(row_dict.get(col)) for col in self._column_names } yield (row_num, processed_row) row_num += 1
[docs] def as_polars(self) -> pl.DataFrame: """Collect all chunks into a single DataFrame. Returns: Single Polars DataFrame containing all data. """ import polars as pl dfs = cast(list["pl.DataFrame"], list(self)) if not dfs: return pl.DataFrame() if len(dfs) == 1: return dfs[0] return pl.concat(dfs)
[docs] class AthenaPolarsResultSet(AthenaResultSet): """Result set that provides Polars DataFrame results with optional Arrow interoperability. This result set handles CSV and Parquet result files from S3, converting them to Polars DataFrames using Polars' native reading capabilities. It does not require PyArrow for basic functionality, but can optionally provide Arrow Table access when PyArrow is installed. Features: - Native Polars CSV and Parquet reading (no PyArrow required) - Efficient columnar data processing with Polars - Optional Arrow interoperability when PyArrow is available - Support for both CSV and Parquet result formats - Chunked iteration for memory-efficient processing of large datasets - Optimized memory usage through columnar format Example: >>> # Used automatically by PolarsCursor >>> cursor = connection.cursor(PolarsCursor) >>> cursor.execute("SELECT * FROM large_table") >>> >>> # Get Polars DataFrame >>> df = cursor.as_polars() >>> >>> # Work with Polars >>> print(f"DataFrame has {df.height} rows and {df.width} columns") >>> filtered = df.filter(pl.col("value") > 100) >>> >>> # Optional: Get Arrow Table (requires pyarrow) >>> table = cursor.as_arrow() >>> >>> # Memory-efficient chunked iteration >>> cursor = connection.cursor(PolarsCursor, chunksize=50000) >>> cursor.execute("SELECT * FROM huge_table") >>> for chunk in cursor.iter_chunks(): ... process_chunk(chunk) Note: This class is used internally by PolarsCursor and typically not instantiated directly by users. Requires polars to be installed. PyArrow is optional and only needed for as_arrow() functionality. """
[docs] def __init__( self, connection: Connection[Any], converter: Converter, query_execution: AthenaQueryExecution, arraysize: int, retry_config: RetryConfig, unload: bool = False, unload_location: str | None = None, block_size: int | None = None, cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, chunksize: int | None = None, result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> None: """Initialize the Polars result set. Args: connection: The Athena connection object. converter: Type converter for Athena data types. query_execution: Query execution metadata. arraysize: Number of rows to fetch per batch. retry_config: Configuration for retry behavior. unload: Whether this is an UNLOAD query result. unload_location: S3 location for UNLOAD results. block_size: Block size for S3 file reading. cache_type: Cache type for S3 file system. max_workers: Maximum number of worker threads. chunksize: Number of rows per chunk for memory-efficient processing. If specified, data is loaded lazily in chunks for all data access methods including fetchone(), fetchmany(), and iter_chunks(). result_set_type_hints: Optional dictionary mapping column names to Athena DDL type signatures for precise type conversion. **kwargs: Additional arguments passed to Polars read functions. """ super().__init__( connection=connection, converter=converter, query_execution=query_execution, arraysize=1, # Fetch one row to retrieve metadata retry_config=retry_config, result_set_type_hints=result_set_type_hints, ) self._rows.clear() # Clear pre_fetch data self._arraysize = arraysize self._unload = unload self._unload_location = unload_location self._block_size = block_size self._cache_type = cache_type self._max_workers = max_workers self._chunksize = chunksize self._kwargs = kwargs # Build DataFrame iterator (handles both chunked and non-chunked cases) # Note: _create_dataframe_iterator() calls _as_polars() which may update # _metadata for unload queries, so we must cache column names AFTER this. if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location: self._df_iter = self._create_dataframe_iterator() elif self.state == AthenaQueryExecution.STATE_SUCCEEDED: df = self._as_polars_from_api() self._df_iter = PolarsDataFrameIterator(df, self.converters, self._get_column_names()) else: import polars as pl self._df_iter = PolarsDataFrameIterator( pl.DataFrame(), self.converters, self._get_column_names() ) # Cache column names for efficient access in fetchone() # Must be after _create_dataframe_iterator() which updates _metadata for unload self._column_names_cache: list[str] = self._get_column_names() self._iterrows = self._df_iter.iterrows()
@property def _csv_storage_options(self) -> dict[str, Any]: """Get storage options for Polars CSV reading via fsspec. Polars read_csv uses fsspec for cloud storage access, which works with PyAthena's registered S3FileSystem. Returns: Dictionary with fsspec-compatible options for S3 access. """ return { "connection": self.connection, "default_block_size": self._block_size, "default_cache_type": self._cache_type, "max_workers": self._max_workers, } @property def _parquet_storage_options(self) -> dict[str, Any]: """Get storage options for Polars Parquet reading via native object_store. Polars read_parquet uses Rust's native object_store crate, which requires AWS credentials to be passed directly rather than through fsspec. Returns: Dictionary with AWS credentials and region for S3 access. """ credentials = self.connection.session.get_credentials() options: dict[str, Any] = {} if credentials: frozen_credentials = credentials.get_frozen_credentials() options["aws_access_key_id"] = frozen_credentials.access_key options["aws_secret_access_key"] = frozen_credentials.secret_key if frozen_credentials.token: options["aws_session_token"] = frozen_credentials.token if self.connection.region_name: options["aws_region"] = self.connection.region_name return options @property def dtypes(self) -> dict[str, Any]: """Get Polars-compatible data types for result columns.""" description = self.description if self.description else [] return { d[0]: dtype for d in description if (dtype := self._converter.get_dtype(d[1], d[4], d[5])) is not None } @property def converters(self) -> dict[str, Callable[[str | None], Any | None]]: """Get converter functions for each column. Returns: Dictionary mapping column names to their converter functions. """ description = self.description if self.description else [] return {d[0]: self._converter.get(d[1]) for d in description} def _get_column_names(self) -> list[str]: """Get column names from description. Returns: List of column names. """ description = self.description if self.description else [] return [d[0] for d in description] def _create_dataframe_iterator(self) -> PolarsDataFrameIterator: """Create a DataFrame iterator for the result set. Returns: PolarsDataFrameIterator that handles both chunked and non-chunked cases. """ if self._chunksize is not None: # Chunked mode: create lazy iterator reader: Iterator[pl.DataFrame] | pl.DataFrame = ( self._iter_parquet_chunks() if self.is_unload else self._iter_csv_chunks() ) else: # Non-chunked mode: load entire DataFrame reader = self._as_polars() return PolarsDataFrameIterator(reader, self.converters, self._get_column_names())
[docs] def fetchone( self, ) -> tuple[Any | None, ...] | dict[Any, Any | None] | None: """Fetch the next row of the query result. Returns: A single row as a tuple, or None if no more rows are available. """ try: row = next(self._iterrows) except StopIteration: return None else: self._rownumber = row[0] + 1 return tuple([row[1][col] for col in self._column_names_cache])
[docs] def fetchmany( self, size: int | None = None ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch the next set of rows of the query result. Args: size: Number of rows to fetch. Defaults to arraysize. Returns: A list of rows as tuples. """ if not size or size <= 0: size = self._arraysize rows = [] for _ in range(size): row = self.fetchone() if row: rows.append(row) else: break return rows
[docs] def fetchall( self, ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch all remaining rows of the query result. Returns: A list of all remaining rows as tuples. """ rows = [] while True: row = self.fetchone() if row: rows.append(row) else: break return rows
def _is_csv_readable(self) -> bool: """Check if CSV output is available and can be read. Returns: True if CSV data is available to read, False otherwise. Raises: ProgrammingError: If output location is not set. """ if not self.output_location: raise ProgrammingError("OutputLocation is none or empty.") if not self.output_location.endswith((".csv", ".txt")): return False if self.substatement_type and self.substatement_type.upper() in ( "UPDATE", "DELETE", "MERGE", "VACUUM_TABLE", ): return False length = self._get_content_length() return length != 0 def _prepare_parquet_location(self) -> bool: """Prepare unload location for Parquet reading. Returns: True if Parquet data is available to read, False otherwise. """ manifests = self._read_data_manifest() if not manifests: return False if not self._unload_location: self._unload_location = "/".join(manifests[0].split("/")[:-1]) + "/" return True def _read_csv(self) -> pl.DataFrame: """Read query results from CSV file in S3. Returns: Polars DataFrame containing the CSV data. Raises: ProgrammingError: If output location is not set. OperationalError: If reading the CSV file fails. """ import polars as pl if not self._is_csv_readable(): return pl.DataFrame() if self.output_location is None: raise ProgrammingError("output_location is not available.") separator, has_header, new_columns = self._get_csv_params() try: df = pl.read_csv( self.output_location, separator=separator, has_header=has_header, schema_overrides=self.dtypes, storage_options=self._csv_storage_options, **self._kwargs, ) if new_columns: df.columns = new_columns return df except Exception as e: _logger.exception("Failed to read %s.", self.output_location) raise OperationalError(*e.args) from e def _read_parquet(self) -> pl.DataFrame: """Read query results from Parquet files in S3. Returns: Polars DataFrame containing the Parquet data. Raises: OperationalError: If reading the Parquet files fails. """ import polars as pl if not self._prepare_parquet_location(): return pl.DataFrame() if self._unload_location is None: raise ProgrammingError("unload_location is not available.") try: return pl.read_parquet( self._unload_location, storage_options=self._parquet_storage_options, **self._kwargs, ) except Exception as e: _logger.exception("Failed to read %s.", self._unload_location) raise OperationalError(*e.args) from e def _read_parquet_schema(self) -> tuple[dict[str, Any], ...]: """Read schema from Parquet files for metadata.""" import polars as pl if not self._unload_location: raise ProgrammingError("UnloadLocation is none or empty.") try: # Use scan_parquet to get schema without reading all data lazy_df = pl.scan_parquet( self._unload_location, storage_options=self._parquet_storage_options, ) schema = lazy_df.collect_schema() return to_column_info(schema) except Exception as e: _logger.exception("Failed to read schema from %s.", self._unload_location) raise OperationalError(*e.args) from e def _as_polars(self) -> pl.DataFrame: """Load query results as a Polars DataFrame. Reads from Parquet for UNLOAD queries, otherwise from CSV. Returns: Polars DataFrame containing the query results. """ if self.is_unload: df = self._read_parquet() if df.is_empty(): self._metadata = () else: self._metadata = self._read_parquet_schema() else: df = self._read_csv() return df def _as_polars_from_api(self, converter: Converter | None = None) -> pl.DataFrame: """Build a Polars DataFrame from GetQueryResults API. Used as a fallback when ``output_location`` is not available (e.g. managed query result storage). Args: converter: Type converter for result values. Defaults to ``DefaultTypeConverter`` if not specified. """ import polars as pl rows = self._fetch_all_rows(converter) if not rows: return pl.DataFrame() description = self.description if self.description else [] columns = [d[0] for d in description] return pl.DataFrame(self._rows_to_columnar(rows, columns))
[docs] def as_polars(self) -> pl.DataFrame: """Return query results as a Polars DataFrame. Returns the query results as a Polars DataFrame. This is the primary method for accessing results with PolarsCursor. Note: When chunksize is set, calling this method will collect all chunks into a single DataFrame, loading all data into memory. Use iter_chunks() for memory-efficient processing of large datasets. Returns: Polars DataFrame containing all query results. Example: >>> cursor = connection.cursor(PolarsCursor) >>> cursor.execute("SELECT * FROM my_table") >>> df = cursor.as_polars() >>> print(f"DataFrame has {df.height} rows") >>> filtered = df.filter(pl.col("value") > 100) """ return self._df_iter.as_polars()
[docs] def as_arrow(self) -> Table: """Return query results as an Apache Arrow Table. Converts the Polars DataFrame to an Apache Arrow Table for interoperability with other Arrow-compatible tools and libraries. Returns: Apache Arrow Table containing all query results. Raises: ImportError: If pyarrow is not installed. Example: >>> cursor = connection.cursor(PolarsCursor) >>> cursor.execute("SELECT * FROM my_table") >>> table = cursor.as_arrow() >>> # Use with other Arrow-compatible libraries """ try: return self._df_iter.as_polars().to_arrow() except ImportError as e: raise ImportError( "pyarrow is required for as_arrow(). Install it with: pip install pyarrow" ) from e
def _get_csv_params(self) -> tuple[str, bool, list[str] | None]: """Get CSV parsing parameters based on file type. Returns: Tuple of (separator, has_header, new_columns). """ if self.output_location and self.output_location.endswith(".txt"): separator = "\t" has_header = False new_columns: list[str] | None = self._get_column_names() else: separator = "," has_header = True new_columns = None return separator, has_header, new_columns def _iter_csv_chunks(self) -> Iterator[pl.DataFrame]: """Iterate over CSV data in chunks using lazy evaluation. Yields: Polars DataFrame for each chunk. Raises: ProgrammingError: If output location is not set. OperationalError: If reading the CSV file fails. """ import polars as pl if not self._is_csv_readable(): return if self.output_location is None: raise ProgrammingError("output_location is not available.") separator, has_header, new_columns = self._get_csv_params() try: # scan_csv uses Rust's native object_store (like scan_parquet), # not fsspec, so we use the same storage options as Parquet lazy_df = pl.scan_csv( self.output_location, separator=separator, has_header=has_header, schema_overrides=self.dtypes, storage_options=self._parquet_storage_options, **self._kwargs, ) for batch in lazy_df.collect_batches(chunk_size=self._chunksize): if new_columns: batch.columns = new_columns yield batch except Exception as e: _logger.exception("Failed to read %s.", self.output_location) raise OperationalError(*e.args) from e def _iter_parquet_chunks(self) -> Iterator[pl.DataFrame]: """Iterate over Parquet data in chunks using lazy evaluation. Yields: Polars DataFrame for each chunk. Raises: OperationalError: If reading the Parquet files fails. """ import polars as pl if not self._prepare_parquet_location(): return if self._unload_location is None: raise ProgrammingError("unload_location is not available.") try: lazy_df = pl.scan_parquet( self._unload_location, storage_options=self._parquet_storage_options, **self._kwargs, ) yield from lazy_df.collect_batches(chunk_size=self._chunksize) except Exception as e: _logger.exception("Failed to read %s.", self._unload_location) raise OperationalError(*e.args) from e
[docs] def iter_chunks(self) -> PolarsDataFrameIterator: """Iterate over result chunks as Polars DataFrames. This method provides an iterator interface for processing large result sets. When chunksize is specified, it yields DataFrames in chunks using lazy evaluation for memory-efficient processing. When chunksize is not specified, it yields the entire result as a single DataFrame. Returns: PolarsDataFrameIterator that yields Polars DataFrames for each chunk of rows, or the entire DataFrame if chunksize was not specified. Example: >>> # With chunking for large datasets >>> cursor = connection.cursor(PolarsCursor, chunksize=50000) >>> cursor.execute("SELECT * FROM large_table") >>> for chunk in cursor.iter_chunks(): ... process_chunk(chunk) # Each chunk is a Polars DataFrame >>> >>> # Without chunking - yields entire result as single chunk >>> cursor = connection.cursor(PolarsCursor) >>> cursor.execute("SELECT * FROM small_table") >>> for df in cursor.iter_chunks(): ... process(df) # Single DataFrame with all data """ return self._df_iter
[docs] def close(self) -> None: """Close the result set and release resources.""" import polars as pl super().close() self._df_iter = PolarsDataFrameIterator(pl.DataFrame(), {}, []) self._iterrows = iter([])