Source code for pyathena.arrow.result_set

from __future__ import annotations

import logging
from collections.abc import Callable
from typing import (
    TYPE_CHECKING,
    Any,
    ClassVar,
)

from pyathena import OperationalError
from pyathena.arrow.util import to_column_info
from pyathena.converter import Converter
from pyathena.error import ProgrammingError
from pyathena.model import AthenaQueryExecution
from pyathena.result_set import AthenaResultSet
from pyathena.util import RetryConfig, parse_output_location

if TYPE_CHECKING:
    import polars as pl
    from pyarrow import Table

    from pyathena.connection import Connection

_logger = logging.getLogger(__name__)


[docs] class AthenaArrowResultSet(AthenaResultSet): """Result set that provides Apache Arrow Table results with columnar optimization. This result set handles CSV and Parquet result files from S3, converting them to Apache Arrow Tables which provide efficient columnar data processing and memory usage. It's optimized for analytical workloads and large dataset operations. Features: - Efficient columnar data processing with Apache Arrow - Support for both CSV and Parquet result formats - Optimized memory usage for large datasets - Advanced timestamp parsing with multiple format support - Zero-copy operations where possible Attributes: DEFAULT_BLOCK_SIZE: Default block size for Arrow operations (128MB). Example: >>> # Used automatically by ArrowCursor >>> cursor = connection.cursor(ArrowCursor) >>> cursor.execute("SELECT * FROM large_table") >>> >>> # Get Arrow Table >>> table = cursor.fetchall() >>> >>> # Convert to pandas if needed >>> df = table.to_pandas() >>> >>> # Or work with Arrow directly >>> print(f"Table has {table.num_rows} rows and {table.num_columns} columns") Note: This class is used internally by ArrowCursor and typically not instantiated directly by users. Requires pyarrow to be installed. """ DEFAULT_BLOCK_SIZE = 1024 * 1024 * 128 _timestamp_parsers: ClassVar[list[str]] = [ "%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%d %H:%M:%S %Z", "%Y-%m-%d %H:%M:%S %z", "%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S.%f %Z", "%Y-%m-%d %H:%M:%S.%f %z", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%S %Z", "%Y-%m-%dT%H:%M:%S %z", "%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S.%f %Z", "%Y-%m-%dT%H:%M:%S.%f %z", ]
[docs] def __init__( self, connection: Connection[Any], converter: Converter, query_execution: AthenaQueryExecution, arraysize: int, retry_config: RetryConfig, block_size: int | None = None, unload: bool = False, unload_location: str | None = None, connect_timeout: float | None = None, request_timeout: float | None = None, result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> None: super().__init__( connection=connection, converter=converter, query_execution=query_execution, arraysize=1, # Fetch one row to retrieve metadata retry_config=retry_config, result_set_type_hints=result_set_type_hints, ) self._rows.clear() # Clear pre_fetch data self._arraysize = arraysize self._block_size = block_size if block_size else self.DEFAULT_BLOCK_SIZE self._unload = unload self._unload_location = unload_location self._connect_timeout = connect_timeout self._request_timeout = request_timeout self._kwargs = kwargs self._fs = self.__s3_file_system() if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location: self._table = self._as_arrow() elif self.state == AthenaQueryExecution.STATE_SUCCEEDED: self._table = self._as_arrow_from_api() else: import pyarrow as pa self._table = pa.Table.from_pydict({}) self._batches = iter(self._table.to_batches(arraysize))
def __s3_file_system(self): from pyarrow import fs connection = self.connection # Build timeout parameters dict timeout_kwargs = {} if self._connect_timeout is not None: timeout_kwargs["connect_timeout"] = self._connect_timeout if self._request_timeout is not None: timeout_kwargs["request_timeout"] = self._request_timeout if connection._kwargs.get("role_arn"): external_id = connection._kwargs.get("external_id") fs = fs.S3FileSystem( role_arn=connection._kwargs["role_arn"], session_name=connection._kwargs["role_session_name"], external_id="" if external_id is None else external_id, load_frequency=connection._kwargs["duration_seconds"], region=connection.region_name, **timeout_kwargs, ) elif connection.profile_name: profile = connection.session._session.full_config["profiles"][connection.profile_name] fs = fs.S3FileSystem( access_key=profile.get("aws_access_key_id", None), secret_key=profile.get("aws_secret_access_key", None), session_token=profile.get("aws_session_token", None), region=connection.region_name, **timeout_kwargs, ) else: # Try explicit credentials first explicit_access_key = connection._kwargs.get("aws_access_key_id") explicit_secret_key = connection._kwargs.get("aws_secret_access_key") if explicit_access_key and explicit_secret_key: # Use explicitly provided credentials fs = fs.S3FileSystem( access_key=explicit_access_key, secret_key=explicit_secret_key, session_token=connection._kwargs.get("aws_session_token"), region=connection.region_name, **timeout_kwargs, ) else: # Fall back to dynamic credentials from boto3 session # This handles EC2 instance profiles, temporary credentials, etc. try: credentials = connection.session._session.get_credentials() if credentials: fs = fs.S3FileSystem( access_key=credentials.access_key, secret_key=credentials.secret_key, session_token=credentials.token, region=connection.region_name, **timeout_kwargs, ) else: # Fall back to default (no explicit credentials) fs = fs.S3FileSystem(region=connection.region_name, **timeout_kwargs) except Exception: # Fall back to default if credential retrieval fails fs = fs.S3FileSystem(region=connection.region_name, **timeout_kwargs) return fs @property def timestamp_parsers(self) -> list[str]: from pyarrow.csv import ISO8601 return [ISO8601, *self._timestamp_parsers] @property def column_types(self) -> dict[str, type[Any]]: description = self.description if self.description else [] return { d[0]: dtype for d in description if (dtype := self._converter.get_dtype(d[1], d[4], d[5])) is not None } @property def converters(self) -> dict[str, Callable[[str | None], Any | None]]: description = self.description if self.description else [] return {d[0]: self._converter.get(d[1]) for d in description} def _fetch(self) -> None: try: rows = next(self._batches) except StopIteration: return else: dict_rows = rows.to_pydict() column_names = dict_rows.keys() processed_rows = [ tuple(self.converters[k](v) for k, v in zip(column_names, row, strict=False)) for row in zip(*dict_rows.values(), strict=False) ] self._rows.extend(processed_rows)
[docs] def fetchone( self, ) -> tuple[Any | None, ...] | dict[Any, Any | None] | None: if not self._rows: self._fetch() if not self._rows: return None if self._rownumber is None: self._rownumber = 0 self._rownumber += 1 return self._rows.popleft()
[docs] def fetchmany( self, size: int | None = None ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: if not size or size <= 0: size = self._arraysize rows = [] for _ in range(size): row = self.fetchone() if row: rows.append(row) else: break return rows
[docs] def fetchall( self, ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: rows = [] while True: row = self.fetchone() if row: rows.append(row) else: break return rows
def _read_csv(self) -> Table: import pyarrow as pa from pyarrow import csv if not self.output_location: raise ProgrammingError("OutputLocation is none or empty.") if not self.output_location.endswith((".csv", ".txt")): return pa.Table.from_pydict({}) if self.substatement_type and self.substatement_type.upper() in ( "UPDATE", "DELETE", "MERGE", "VACUUM_TABLE", ): return pa.Table.from_pydict({}) length = self._get_content_length() if length and self.output_location.endswith(".txt"): description = self.description if self.description else [] column_names = [d[0] for d in description] read_opts = csv.ReadOptions( skip_rows=0, column_names=column_names, block_size=self._block_size, use_threads=True, ) parse_opts = csv.ParseOptions( delimiter="\t", quote_char=False, double_quote=False, escape_char=False, ) elif length and self.output_location.endswith(".csv"): read_opts = csv.ReadOptions(skip_rows=0, block_size=self._block_size, use_threads=True) parse_opts = csv.ParseOptions( delimiter=",", quote_char='"', double_quote=True, escape_char=False, ) else: return pa.Table.from_pydict({}) bucket, key = parse_output_location(self.output_location) try: return csv.read_csv( self._fs.open_input_stream(f"{bucket}/{key}"), read_options=read_opts, parse_options=parse_opts, convert_options=csv.ConvertOptions( quoted_strings_can_be_null=False, timestamp_parsers=self.timestamp_parsers, column_types=self.column_types, ), ) except Exception as e: _logger.exception("Failed to read %s/%s.", bucket, key) raise OperationalError(*e.args) from e def _read_parquet(self) -> Table: import pyarrow as pa from pyarrow import parquet manifests = self._read_data_manifest() if not manifests: return pa.Table.from_pydict({}) if not self._unload_location: self._unload_location = "/".join(manifests[0].split("/")[:-1]) + "/" bucket, key = parse_output_location(self._unload_location) try: dataset = parquet.ParquetDataset(f"{bucket}/{key}", filesystem=self._fs) return dataset.read(use_threads=True) except Exception as e: _logger.exception("Failed to read %s/%s.", bucket, key) raise OperationalError(*e.args) from e def _as_arrow(self) -> Table: if self.is_unload: table = self._read_parquet() self._metadata = to_column_info(table.schema) else: table = self._read_csv() return table def _as_arrow_from_api(self, converter: Converter | None = None) -> Table: """Build an Arrow Table from GetQueryResults API. Used as a fallback when ``output_location`` is not available (e.g. managed query result storage). Args: converter: Type converter for result values. Defaults to ``DefaultTypeConverter`` if not specified. """ import pyarrow as pa rows = self._fetch_all_rows(converter) if not rows: return pa.Table.from_pydict({}) description = self.description if self.description else [] columns = [d[0] for d in description] return pa.table(self._rows_to_columnar(rows, columns))
[docs] def as_arrow(self) -> Table: return self._table
[docs] def as_polars(self) -> pl.DataFrame: """Return query results as a Polars DataFrame. Converts the Apache Arrow Table to a Polars DataFrame for interoperability with the Polars data processing library. Returns: Polars DataFrame containing all query results. Raises: ImportError: If polars is not installed. Example: >>> cursor = connection.cursor(ArrowCursor) >>> cursor.execute("SELECT * FROM my_table") >>> df = cursor.as_polars() >>> # Use with Polars operations """ try: import polars as pl return pl.from_arrow(self._table) # type: ignore[return-value] except ImportError as e: raise ImportError( "polars is required for as_polars(). Install it with: pip install polars" ) from e
[docs] def close(self) -> None: import pyarrow as pa super().close() self._table = pa.Table.from_pydict({}) self._batches = []