Source code for pyathena.aio.result_set

from __future__ import annotations

import logging
from typing import (
    TYPE_CHECKING,
    Any,
    cast,
)

from pyathena.aio.util import async_retry_api_call
from pyathena.converter import Converter
from pyathena.error import OperationalError, ProgrammingError
from pyathena.model import AthenaQueryExecution
from pyathena.result_set import AthenaDictResultSet, AthenaResultSet
from pyathena.util import RetryConfig

if TYPE_CHECKING:
    from pyathena.connection import Connection

_logger = logging.getLogger(__name__)


[docs] class AthenaAioResultSet(AthenaResultSet): """Async result set that provides async fetch methods. Skips the synchronous ``_pre_fetch`` by passing ``_pre_fetch=False`` to the parent ``__init__`` and provides an ``async create()`` classmethod factory instead. """
[docs] def __init__( self, connection: Connection[Any], converter: Converter, query_execution: AthenaQueryExecution, arraysize: int, retry_config: RetryConfig, result_set_type_hints: dict[str | int, str] | None = None, ) -> None: super().__init__( connection=connection, converter=converter, query_execution=query_execution, arraysize=arraysize, retry_config=retry_config, _pre_fetch=False, result_set_type_hints=result_set_type_hints, )
[docs] @classmethod async def create( cls, connection: Connection[Any], converter: Converter, query_execution: AthenaQueryExecution, arraysize: int, retry_config: RetryConfig, result_set_type_hints: dict[str | int, str] | None = None, ) -> AthenaAioResultSet: """Async factory method. Creates an ``AthenaAioResultSet`` and awaits the initial data fetch. Args: connection: The database connection. converter: Type converter for result values. query_execution: Query execution metadata. arraysize: Number of rows to fetch per request. retry_config: Retry configuration for API calls. result_set_type_hints: Optional dictionary mapping column names to Athena DDL type signatures for precise type conversion. Returns: A fully initialized ``AthenaAioResultSet``. """ result_set = cls( connection, converter, query_execution, arraysize, retry_config, result_set_type_hints=result_set_type_hints, ) if result_set.state == AthenaQueryExecution.STATE_SUCCEEDED: await result_set._async_pre_fetch() return result_set
async def __async_get_query_results( self, max_results: int, next_token: str | None = None ) -> dict[str, Any]: if not self.query_id: raise ProgrammingError("QueryExecutionId is none or empty.") if self.state != AthenaQueryExecution.STATE_SUCCEEDED: raise ProgrammingError("QueryExecutionState is not SUCCEEDED.") if self.is_closed: raise ProgrammingError("AthenaAioResultSet is closed.") request: dict[str, Any] = { "QueryExecutionId": self.query_id, "MaxResults": max_results, } if next_token: request["NextToken"] = next_token try: response = await async_retry_api_call( self.connection.client.get_query_results, config=self._retry_config, logger=_logger, **request, ) except Exception as e: _logger.exception("Failed to fetch result set.") raise OperationalError(*e.args) from e else: return cast(dict[str, Any], response) async def __async_fetch(self, next_token: str | None = None) -> dict[str, Any]: return await self.__async_get_query_results(self._arraysize, next_token) async def _async_fetch(self) -> None: if not self._next_token: raise ProgrammingError("NextToken is none or empty.") response = await self.__async_fetch(self._next_token) rows, self._next_token = self._parse_result_rows(response) self._process_rows(rows) async def _async_pre_fetch(self) -> None: response = await self.__async_fetch() self._process_metadata(response) self._process_update_count(response) rows, self._next_token = self._parse_result_rows(response) offset = 1 if rows and self._is_first_row_column_labels(rows) else 0 self._process_rows(rows, offset)
[docs] async def fetchone( # type: ignore[override] self, ) -> tuple[Any | None, ...] | dict[Any, Any | None] | None: """Fetch the next row of the result set. Automatically fetches the next page from Athena when the current page is exhausted and more pages are available. Returns: A tuple representing the next row, or None if no more rows. """ if not self._rows and self._next_token: await self._async_fetch() if not self._rows: return None if self._rownumber is None: self._rownumber = 0 self._rownumber += 1 return self._rows.popleft()
[docs] async def fetchmany( # type: ignore[override] self, size: int | None = None ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch multiple rows from the result set. Args: size: Maximum number of rows to fetch. If None, uses arraysize. Returns: List of row tuples. May contain fewer rows than requested if fewer are available. """ if not size or size <= 0: size = self._arraysize rows = [] for _ in range(size): row = await self.fetchone() if row: rows.append(row) else: break return rows
[docs] async def fetchall( # type: ignore[override] self, ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch all remaining rows from the result set. Returns: List of all remaining row tuples. """ rows = [] while True: row = await self.fetchone() if row: rows.append(row) else: break return rows
def __aiter__(self): return self async def __anext__(self): row = await self.fetchone() if row is None: raise StopAsyncIteration return row
[docs] class AthenaAioDictResultSet(AthenaDictResultSet, AthenaAioResultSet): """Async result set that returns rows as dictionaries. Inherits ``_get_rows`` from ``AthenaDictResultSet`` and async fetch methods from ``AthenaAioResultSet`` via multiple inheritance. """