Source code for pyathena.aio.spark.cursor

from __future__ import annotations

import asyncio
import logging
from typing import Any, cast

from pyathena.aio.util import async_retry_api_call
from pyathena.error import DatabaseError, NotSupportedError, OperationalError, ProgrammingError
from pyathena.model import (
    AthenaCalculationExecution,
    AthenaCalculationExecutionStatus,
    AthenaQueryExecution,
)
from pyathena.spark.common import SparkBaseCursor, WithCalculationExecution
from pyathena.util import parse_output_location

_logger = logging.getLogger(__name__)


[docs] class AioSparkCursor(SparkBaseCursor, WithCalculationExecution): """Native asyncio cursor for executing PySpark code on Athena. Overrides post-init I/O methods of ``SparkBaseCursor`` with async equivalents. Session management (``_exists_session``, ``_start_session``, etc.) stays synchronous because ``__init__`` runs inside ``asyncio.to_thread``. Since ``SparkBaseCursor.__init__`` performs I/O (session management), cursor creation must be wrapped in ``asyncio.to_thread``:: cursor = await asyncio.to_thread(conn.cursor) Example: >>> import asyncio >>> async with await pyathena.aio_connect( ... work_group="spark-workgroup", ... cursor_class=AioSparkCursor, ... ) as conn: ... cursor = await asyncio.to_thread(conn.cursor) ... await cursor.execute("spark.sql('SELECT 1').show()") ... print(await cursor.get_std_out()) """
[docs] def __init__( self, session_id: str | None = None, description: str | None = None, engine_configuration: dict[str, Any] | None = None, notebook_version: str | None = None, session_idle_timeout_minutes: int | None = None, **kwargs, ) -> None: super().__init__( session_id=session_id, description=description, engine_configuration=engine_configuration, notebook_version=notebook_version, session_idle_timeout_minutes=session_idle_timeout_minutes, **kwargs, )
@property def calculation_execution(self) -> AthenaCalculationExecution | None: return self._calculation_execution # --- async overrides of SparkBaseCursor I/O methods --- async def _get_calculation_execution_status( # type: ignore[override] self, query_id: str ) -> AthenaCalculationExecutionStatus: request: dict[str, Any] = {"CalculationExecutionId": query_id} try: response = await async_retry_api_call( self._connection.client.get_calculation_execution_status, config=self._retry_config, logger=_logger, **request, ) except Exception as e: _logger.exception("Failed to get calculation execution status.") raise OperationalError(*e.args) from e else: return AthenaCalculationExecutionStatus(response) async def _get_calculation_execution( # type: ignore[override] self, query_id: str ) -> AthenaCalculationExecution: request: dict[str, Any] = {"CalculationExecutionId": query_id} try: response = await async_retry_api_call( self._connection.client.get_calculation_execution, config=self._retry_config, logger=_logger, **request, ) except Exception as e: _logger.exception("Failed to get calculation execution.") raise OperationalError(*e.args) from e else: return AthenaCalculationExecution(response) async def _calculate( # type: ignore[override] self, session_id: str, code_block: str, description: str | None = None, client_request_token: str | None = None, ) -> str: request = self._build_start_calculation_execution_request( session_id=session_id, code_block=code_block, description=description, client_request_token=client_request_token, ) try: response = await async_retry_api_call( self._connection.client.start_calculation_execution, config=self._retry_config, logger=_logger, **request, ) calculation_id = response.get("CalculationExecutionId") except Exception as e: _logger.exception("Failed to execute calculation.") raise DatabaseError(*e.args) from e return cast(str, calculation_id) async def __poll(self, query_id: str) -> AthenaQueryExecution | AthenaCalculationExecution: while True: calculation_status = await self._get_calculation_execution_status(query_id) if calculation_status.state in [ AthenaCalculationExecutionStatus.STATE_COMPLETED, AthenaCalculationExecutionStatus.STATE_FAILED, AthenaCalculationExecutionStatus.STATE_CANCELED, ]: return await self._get_calculation_execution(query_id) await asyncio.sleep(self._poll_interval) async def _poll( # type: ignore[override] self, query_id: str ) -> AthenaQueryExecution | AthenaCalculationExecution: try: query_execution = await self.__poll(query_id) except asyncio.CancelledError: if self._kill_on_interrupt: _logger.warning("Query canceled by user.") await self._cancel(query_id) query_execution = await self.__poll(query_id) else: raise return query_execution async def _cancel(self, query_id: str) -> None: # type: ignore[override] request: dict[str, Any] = {"CalculationExecutionId": query_id} try: await async_retry_api_call( self._connection.client.stop_calculation_execution, config=self._retry_config, logger=_logger, **request, ) except Exception as e: _logger.exception("Failed to cancel calculation.") raise OperationalError(*e.args) from e async def _terminate_session(self) -> None: # type: ignore[override] request: dict[str, Any] = {"SessionId": self._session_id} try: await async_retry_api_call( self._connection.client.terminate_session, config=self._retry_config, logger=_logger, **request, ) except Exception as e: _logger.exception("Failed to terminate session.") raise OperationalError(*e.args) from e async def _read_s3_file_as_text(self, uri) -> str: # type: ignore[override] bucket, key = parse_output_location(uri) response = await async_retry_api_call( self._client.get_object, config=self._retry_config, logger=_logger, Bucket=bucket, Key=key, ) return cast(str, response["Body"].read().decode("utf-8").strip()) # --- public API ---
[docs] async def get_std_out(self) -> str | None: """Get the standard output from the Spark calculation execution. Returns: The standard output as a string, or None if no output is available. """ if not self._calculation_execution or not self._calculation_execution.std_out_s3_uri: return None return await self._read_s3_file_as_text(self._calculation_execution.std_out_s3_uri)
[docs] async def get_std_error(self) -> str | None: """Get the standard error from the Spark calculation execution. Returns: The standard error as a string, or None if no error output is available. """ if not self._calculation_execution or not self._calculation_execution.std_error_s3_uri: return None return await self._read_s3_file_as_text(self._calculation_execution.std_error_s3_uri)
[docs] async def execute( # type: ignore[override] self, operation: str, parameters: dict[str, Any] | list[str] | None = None, session_id: str | None = None, description: str | None = None, client_request_token: str | None = None, work_group: str | None = None, **kwargs, ) -> AioSparkCursor: """Execute PySpark code asynchronously. Args: operation: PySpark code to execute. parameters: Unused, kept for API compatibility. session_id: Spark session ID override. description: Calculation description. client_request_token: Idempotency token. work_group: Unused, kept for API compatibility. **kwargs: Additional parameters. Returns: Self reference for method chaining. """ self._calculation_id = await self._calculate( session_id=session_id if session_id else self._session_id, code_block=operation, description=description, client_request_token=client_request_token, ) self._calculation_execution = cast( AthenaCalculationExecution, await self._poll(self._calculation_id) ) if self._calculation_execution.state != AthenaCalculationExecutionStatus.STATE_COMPLETED: std_error = await self.get_std_error() raise OperationalError(std_error) return self
[docs] async def cancel(self) -> None: """Cancel the currently running calculation. Raises: ProgrammingError: If no calculation is running. """ if not self.calculation_id: raise ProgrammingError("CalculationExecutionId is none or empty.") await self._cancel(self.calculation_id)
[docs] async def close(self) -> None: # type: ignore[override] """Close the cursor by terminating the Spark session.""" await self._terminate_session()
[docs] async def executemany( # type: ignore[override] self, operation: str, seq_of_parameters: list[dict[str, Any] | list[str] | None], **kwargs, ) -> None: raise NotSupportedError
def __aiter__(self): return self async def __anext__(self): raise StopAsyncIteration async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close()