from __future__ import annotations
import logging
from collections import abc
from collections.abc import Callable, Iterable, Iterator
from multiprocessing import cpu_count
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
)
from pyathena import OperationalError
from pyathena.converter import Converter
from pyathena.error import ProgrammingError
from pyathena.model import AthenaQueryExecution
from pyathena.result_set import AthenaResultSet
from pyathena.util import RetryConfig, parse_output_location
if TYPE_CHECKING:
from pandas import DataFrame
from pandas.io.parsers import TextFileReader
from pyathena.connection import Connection
_logger = logging.getLogger(__name__)
def _no_trunc_date(df: DataFrame) -> DataFrame:
return df
[docs]
class PandasDataFrameIterator(abc.Iterator): # type: ignore[type-arg]
"""Iterator for chunked DataFrame results from Athena queries.
This class wraps either a pandas TextFileReader (for chunked reading) or
a single DataFrame, providing a unified iterator interface. It applies
optional date truncation to each DataFrame chunk as it's yielded.
The iterator is used by AthenaPandasResultSet to provide chunked access
to large query results, enabling memory-efficient processing of datasets
that would be too large to load entirely into memory.
Example:
>>> # Iterate over DataFrame chunks
>>> for df_chunk in iterator:
... process(df_chunk)
>>>
>>> # Iterate over individual rows
>>> for idx, row in iterator.iterrows():
... print(row)
Note:
This class is primarily for internal use by AthenaPandasResultSet.
Most users should access results through PandasCursor methods.
"""
[docs]
def __init__(
self,
reader: TextFileReader | DataFrame,
trunc_date: Callable[[DataFrame], DataFrame],
) -> None:
"""Initialize the iterator.
Args:
reader: Either a TextFileReader (for chunked) or a single DataFrame.
trunc_date: Function to apply date truncation to each chunk.
"""
from pandas import DataFrame
if isinstance(reader, DataFrame):
self._reader = iter([reader])
else:
self._reader = reader
self._trunc_date = trunc_date
[docs]
def __next__(self) -> DataFrame:
"""Get the next DataFrame chunk.
Returns:
The next pandas DataFrame chunk with date truncation applied.
Raises:
StopIteration: When no more chunks are available.
"""
try:
df = next(self._reader)
return self._trunc_date(df)
except StopIteration:
self.close()
raise
[docs]
def __iter__(self) -> PandasDataFrameIterator:
"""Return self as iterator."""
return self
[docs]
def __enter__(self) -> PandasDataFrameIterator:
"""Context manager entry."""
return self
[docs]
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""Context manager exit."""
self.close()
[docs]
def close(self) -> None:
"""Close the iterator and release resources."""
from pandas.io.parsers import TextFileReader
if isinstance(self._reader, TextFileReader):
self._reader.close()
[docs]
def iterrows(self) -> Iterator[tuple[int, dict[str, Any]]]:
"""Iterate over rows as (index, row_dict) tuples.
Row indices are continuous across all chunks, starting from 0.
Yields:
Tuple of (row_index, row_dict) for each row across all chunks.
"""
row_num = 0
for df in self:
# Use itertuples for memory efficiency instead of to_dict("records")
# which loads all rows into memory at once
columns = df.columns.tolist()
for row in df.itertuples(index=False):
yield (row_num, dict(zip(columns, row, strict=True)))
row_num += 1
[docs]
def get_chunk(self, size: int | None = None) -> DataFrame:
"""Get a chunk of specified size.
Args:
size: Number of rows to retrieve. If None, returns entire chunk.
Returns:
DataFrame chunk.
"""
from pandas.io.parsers import TextFileReader
if isinstance(self._reader, TextFileReader):
return self._reader.get_chunk(size)
return next(self._reader)
[docs]
def as_pandas(self) -> DataFrame:
"""Collect all chunks into a single DataFrame.
Returns:
Single pandas DataFrame containing all data.
"""
import pandas as pd
dfs: list[DataFrame] = list(self)
if not dfs:
return pd.DataFrame()
if len(dfs) == 1:
return dfs[0]
return pd.concat(dfs, ignore_index=True)
[docs]
class AthenaPandasResultSet(AthenaResultSet):
"""Result set that provides pandas DataFrame results with memory optimization.
This result set handles CSV and Parquet result files from S3, converting them to
pandas DataFrames with configurable chunking for memory-efficient processing.
It automatically optimizes chunk sizes based on file size and provides iterative
processing capabilities for large datasets.
Features:
- Automatic chunk size optimization based on file size
- Support for both CSV and Parquet result formats
- Memory-efficient iterative processing
- Automatic date/time parsing for pandas compatibility
- PyArrow integration for Parquet files
Attributes:
LARGE_FILE_THRESHOLD_BYTES: File size threshold for chunking (50MB).
AUTO_CHUNK_SIZE_LARGE: Default chunk size for large files (100,000 rows).
AUTO_CHUNK_SIZE_MEDIUM: Default chunk size for medium files (50,000 rows).
Example:
>>> # Used automatically by PandasCursor
>>> cursor = connection.cursor(PandasCursor)
>>> cursor.execute("SELECT * FROM large_table")
>>>
>>> # Get full DataFrame
>>> df = cursor.fetchall()
>>>
>>> # Or iterate through chunks for memory efficiency
>>> for chunk_df in cursor:
... process_chunk(chunk_df)
Note:
This class is used internally by PandasCursor and typically not
instantiated directly by users.
"""
# File size thresholds and chunking configuration - Public for user customization
PYARROW_MIN_FILE_SIZE_BYTES: int = 100
LARGE_FILE_THRESHOLD_BYTES: int = 50 * 1024 * 1024 # 50MB
ESTIMATED_BYTES_PER_ROW: int = 100
AUTO_CHUNK_THRESHOLD_LARGE: int = 2_000_000
AUTO_CHUNK_THRESHOLD_MEDIUM: int = 1_000_000
AUTO_CHUNK_SIZE_LARGE: int = 100_000
AUTO_CHUNK_SIZE_MEDIUM: int = 50_000
_PARSE_DATES: ClassVar[list[str]] = [
"date",
"time",
"time with time zone",
"timestamp",
"timestamp with time zone",
]
[docs]
def __init__(
self,
connection: Connection[Any],
converter: Converter,
query_execution: AthenaQueryExecution,
arraysize: int,
retry_config: RetryConfig,
keep_default_na: bool = False,
na_values: Iterable[str] | None = ("",),
quoting: int = 1,
unload: bool = False,
unload_location: str | None = None,
engine: str = "auto",
chunksize: int | None = None,
block_size: int | None = None,
cache_type: str | None = None,
max_workers: int = (cpu_count() or 1) * 5,
auto_optimize_chunksize: bool = False,
result_set_type_hints: dict[str | int, str] | None = None,
**kwargs,
) -> None:
"""Initialize AthenaPandasResultSet with pandas-specific configurations.
Args:
connection: Database connection instance.
converter: Data type converter for Athena types to pandas types.
query_execution: Query execution metadata from Athena.
arraysize: Number of rows to fetch in each batch (not used for pandas processing).
retry_config: Retry configuration for S3 operations.
keep_default_na: pandas option for handling NA values.
na_values: Additional values to recognize as NA.
quoting: CSV quoting behavior.
unload: Whether result uses UNLOAD statement (Parquet format).
unload_location: S3 location for UNLOAD results.
engine: Parsing engine ('auto', 'c', 'python', 'pyarrow').
chunksize: Number of rows per chunk. If specified, takes precedence
over auto_optimize_chunksize.
block_size: S3 read block size.
cache_type: S3 caching strategy.
max_workers: Maximum worker threads for parallel operations.
auto_optimize_chunksize: Enable automatic chunksize determination
for large files when chunksize is None.
result_set_type_hints: Optional dictionary mapping column names to
Athena DDL type signatures for precise type conversion.
**kwargs: Additional arguments passed to pandas.read_csv/read_parquet.
"""
super().__init__(
connection=connection,
converter=converter,
query_execution=query_execution,
arraysize=1, # Fetch one row to retrieve metadata
retry_config=retry_config,
result_set_type_hints=result_set_type_hints,
)
self._rows.clear() # Clear pre_fetch data
self._arraysize = arraysize
self._keep_default_na = keep_default_na
self._na_values = na_values
self._quoting = quoting
self._unload = unload
self._unload_location = unload_location
self._engine = engine
self._chunksize = chunksize
self._block_size = block_size
self._cache_type = cache_type
self._max_workers = max_workers
self._auto_optimize_chunksize = auto_optimize_chunksize
self._data_manifest: list[str] = []
self._kwargs = kwargs
self._fs = self.__s3_file_system()
# Cache time column names for efficient _trunc_date processing
description = self.description if self.description else []
self._time_columns: list[str] = [
d[0] for d in description if d[1] in ("time", "time with time zone")
]
if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location:
df = self._as_pandas()
trunc_date = _no_trunc_date if self.is_unload else self._trunc_date
self._df_iter = PandasDataFrameIterator(df, trunc_date)
elif self.state == AthenaQueryExecution.STATE_SUCCEEDED:
df = self._as_pandas_from_api()
self._df_iter = PandasDataFrameIterator(df, self._trunc_date)
else:
import pandas as pd
self._df_iter = PandasDataFrameIterator(pd.DataFrame(), _no_trunc_date)
self._iterrows = self._df_iter.iterrows()
def _get_parquet_engine(self) -> str:
"""Get the parquet engine to use, handling auto-detection.
Returns:
Name of the parquet engine to use ('pyarrow').
Raises:
ImportError: If pyarrow is not available.
"""
if self._engine == "auto":
return self._get_available_engine(["pyarrow"])
return self._engine
def _get_csv_engine(
self, file_size_bytes: int | None = None, chunksize: int | None = None
) -> str:
"""Determine the appropriate CSV engine based on configuration and compatibility.
Args:
file_size_bytes: Size of the CSV file in bytes. Only used for PyArrow
compatibility checks (minimum file size threshold).
chunksize: Chunksize parameter (overrides self._chunksize if provided).
Returns:
CSV engine name ('pyarrow', 'c', or 'python').
"""
if self._engine == "python":
return "python"
# Use PyArrow only when explicitly requested and all compatibility
# checks pass; otherwise fall through to the C engine default.
if self._engine == "pyarrow":
effective_chunksize = chunksize if chunksize is not None else self._chunksize
is_compatible = (
effective_chunksize is None
and self._quoting == 1
and not self.converters
and (file_size_bytes is None or file_size_bytes >= self.PYARROW_MIN_FILE_SIZE_BYTES)
)
if is_compatible:
try:
return self._get_available_engine(["pyarrow"])
except ImportError:
pass
return "c"
def _get_available_engine(self, engine_candidates: list[str]) -> str:
"""Get the first available engine from a list of candidates.
Args:
engine_candidates: List of engine names to try in order.
Returns:
First available engine name.
Raises:
ImportError: If no engines are available.
"""
import importlib
error_msgs = ""
for engine in engine_candidates:
try:
module = importlib.import_module(engine)
return module.__name__
except ImportError as e: # noqa: PERF203
error_msgs += f"\n - {e!s}"
available_engines = ", ".join(f"'{e}'" for e in engine_candidates)
raise ImportError(
f"Unable to find a usable engine; tried using: {available_engines}."
f"Trying to import the above resulted in these errors:"
f"{error_msgs}"
)
def _auto_determine_chunksize(self, file_size_bytes: int) -> int | None:
"""Determine appropriate chunksize for large files based on file size.
This method provides a simple file-size-based chunksize determination.
Users can customize the thresholds and chunk sizes by modifying the class
attributes (e.g., LARGE_FILE_THRESHOLD_BYTES, AUTO_CHUNK_SIZE_LARGE).
Args:
file_size_bytes: Size of the result file in bytes.
Returns:
Suggested chunksize or None if chunking is not needed.
"""
if file_size_bytes <= self.LARGE_FILE_THRESHOLD_BYTES:
return None
# Simple file size-based estimation
estimated_rows = file_size_bytes // self.ESTIMATED_BYTES_PER_ROW
if estimated_rows > self.AUTO_CHUNK_THRESHOLD_LARGE:
return self.AUTO_CHUNK_SIZE_LARGE
if estimated_rows > self.AUTO_CHUNK_THRESHOLD_MEDIUM:
return self.AUTO_CHUNK_SIZE_MEDIUM
return None
def __s3_file_system(self):
from pyathena.filesystem.s3 import S3FileSystem
return S3FileSystem(
connection=self.connection,
default_block_size=self._block_size,
default_cache_type=self._cache_type,
max_workers=self._max_workers,
)
@property
def dtypes(self) -> dict[str, type[Any]]:
"""Get pandas-compatible data types for result columns.
Returns:
Dictionary mapping column names to their corresponding Python types
based on the converter's type mapping.
"""
description = self.description if self.description else []
return {
d[0]: dtype
for d in description
if (dtype := self._converter.get_dtype(d[1], d[4], d[5])) is not None
}
@property
def converters(
self,
) -> dict[Any | None, Callable[[str | None], Any | None]]:
description = self.description if self.description else []
return {
d[0]: self._converter.get(d[1]) for d in description if d[1] in self._converter.mappings
}
@property
def parse_dates(self) -> list[Any | None]:
description = self.description if self.description else []
return [d[0] for d in description if d[1] in self._PARSE_DATES]
def _trunc_date(self, df: DataFrame) -> DataFrame:
if self._time_columns:
truncated = df.loc[:, self._time_columns].apply(lambda r: r.dt.time)
for time_col in self._time_columns:
df.isetitem(df.columns.get_loc(time_col), truncated[time_col])
return df
[docs]
def fetchone(
self,
) -> tuple[Any | None, ...] | dict[Any, Any | None] | None:
try:
row = next(self._iterrows)
except StopIteration:
return None
else:
self._rownumber = row[0] + 1
description = self.description if self.description else []
return tuple([row[1][d[0]] for d in description])
[docs]
def fetchmany(
self, size: int | None = None
) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]:
if not size or size <= 0:
size = self._arraysize
rows = []
for _ in range(size):
row = self.fetchone()
if row:
rows.append(row)
else:
break
return rows
[docs]
def fetchall(
self,
) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]:
rows = []
while True:
row = self.fetchone()
if row:
rows.append(row)
else:
break
return rows
def _read_csv(self) -> TextFileReader | DataFrame:
import pandas as pd
if not self.output_location:
raise ProgrammingError("OutputLocation is none or empty.")
if not self.output_location.endswith((".csv", ".txt")):
return pd.DataFrame()
if self.substatement_type and self.substatement_type.upper() in (
"UPDATE",
"DELETE",
"MERGE",
"VACUUM_TABLE",
):
return pd.DataFrame()
length = self._get_content_length()
if length == 0:
return pd.DataFrame()
if self.output_location.endswith(".txt"):
sep = "\t"
header = None
description = self.description if self.description else []
names = [d[0] for d in description]
elif self.output_location.endswith(".csv"):
sep = ","
header = 0
names = None
else:
return pd.DataFrame()
# Chunksize determination with user preference priority
effective_chunksize = self._chunksize
# Only auto-optimize if user hasn't specified chunksize AND auto_optimize is enabled
if effective_chunksize is None and self._auto_optimize_chunksize:
effective_chunksize = self._auto_determine_chunksize(length)
if effective_chunksize:
_logger.debug(
"Auto-determined chunksize: %s for file size: %s bytes",
effective_chunksize,
length,
)
csv_engine = self._get_csv_engine(length, effective_chunksize)
read_csv_kwargs = {
"sep": sep,
"header": header,
"names": names,
"dtype": self.dtypes,
"converters": self.converters,
"parse_dates": self.parse_dates,
"skip_blank_lines": False,
"keep_default_na": self._keep_default_na,
"na_values": self._na_values,
"quoting": self._quoting,
"storage_options": {
"connection": self.connection,
"default_block_size": self._block_size,
"default_cache_type": self._cache_type,
"max_workers": self._max_workers,
},
"chunksize": effective_chunksize,
"engine": csv_engine,
}
# Engine-specific compatibility adjustments
if csv_engine == "pyarrow":
# PyArrow doesn't support these pandas-specific options
read_csv_kwargs.pop("quoting", None)
read_csv_kwargs.pop("converters", None)
read_csv_kwargs.update(self._kwargs)
try:
result = pd.read_csv(self.output_location, **read_csv_kwargs)
# Log performance information for large files
if length > self.LARGE_FILE_THRESHOLD_BYTES:
mode = "chunked" if effective_chunksize else "full"
msg = "Reading %s bytes from S3 in %s mode using %s engine"
args: tuple[object, ...] = (length, mode, csv_engine)
if effective_chunksize:
msg += " with chunksize=%s"
args = (*args, effective_chunksize)
_logger.info(msg, *args)
return result
except Exception as e:
_logger.exception("Failed to read %s.", self.output_location)
raise OperationalError(*e.args) from e
def _read_parquet(self, engine) -> DataFrame:
import pandas as pd
self._data_manifest = self._read_data_manifest()
if not self._data_manifest:
return pd.DataFrame()
if not self._unload_location:
self._unload_location = "/".join(self._data_manifest[0].split("/")[:-1]) + "/"
if engine == "pyarrow":
unload_location = self._unload_location
kwargs = {
"use_threads": True,
}
else:
raise ProgrammingError("Engine must be `pyarrow`.")
kwargs.update(self._kwargs)
try:
return pd.read_parquet(
unload_location,
engine=self._engine,
storage_options={
"connection": self.connection,
"default_block_size": self._block_size,
"default_cache_type": self._cache_type,
"max_workers": self._max_workers,
},
**kwargs,
)
except Exception as e:
_logger.exception("Failed to read %s.", self.output_location)
raise OperationalError(*e.args) from e
def _read_parquet_schema(self, engine) -> tuple[dict[str, Any], ...]:
if engine == "pyarrow":
from pyarrow import parquet
from pyathena.arrow.util import to_column_info
if not self._unload_location:
raise ProgrammingError("UnloadLocation is none or empty.")
bucket, key = parse_output_location(self._unload_location)
try:
dataset = parquet.ParquetDataset(f"{bucket}/{key}", filesystem=self._fs)
return to_column_info(dataset.schema)
except Exception as e:
_logger.exception("Failed to read schema %s/%s.", bucket, key)
raise OperationalError(*e.args) from e
else:
raise ProgrammingError("Engine must be `pyarrow`.")
def _as_pandas(self) -> TextFileReader | DataFrame:
if self.is_unload:
engine = self._get_parquet_engine()
df = self._read_parquet(engine)
if df.empty:
self._metadata = ()
else:
self._metadata = self._read_parquet_schema(engine)
else:
df = self._read_csv()
return df
def _as_pandas_from_api(self, converter: Converter | None = None) -> DataFrame:
"""Build a DataFrame from GetQueryResults API.
Used as a fallback when ``output_location`` is not available
(e.g. managed query result storage).
Args:
converter: Type converter for result values. Defaults to
``DefaultTypeConverter`` if not specified.
"""
import pandas as pd
rows = self._fetch_all_rows(converter)
if not rows:
return pd.DataFrame()
description = self.description if self.description else []
columns = [d[0] for d in description]
return pd.DataFrame(self._rows_to_columnar(rows, columns))
[docs]
def as_pandas(self) -> PandasDataFrameIterator | DataFrame:
if self._chunksize is None:
return next(self._df_iter)
return self._df_iter
[docs]
def iter_chunks(self) -> PandasDataFrameIterator:
"""Iterate over result chunks as pandas DataFrames.
This method provides an iterator interface for processing large result sets.
When chunksize is specified, it yields DataFrames in chunks for memory-efficient
processing. When chunksize is not specified, it yields the entire result as a
single DataFrame.
Returns:
PandasDataFrameIterator that yields pandas DataFrames for each chunk
of rows, or the entire DataFrame if chunksize was not specified.
Example:
>>> # With chunking for large datasets
>>> cursor = connection.cursor(PandasCursor, chunksize=50000)
>>> cursor.execute("SELECT * FROM large_table")
>>> for chunk in cursor.iter_chunks():
... process_chunk(chunk) # Each chunk is a pandas DataFrame
>>>
>>> # Without chunking - yields entire result as single chunk
>>> cursor = connection.cursor(PandasCursor)
>>> cursor.execute("SELECT * FROM small_table")
>>> for df in cursor.iter_chunks():
... process(df) # Single DataFrame with all data
"""
return self._df_iter
[docs]
def close(self) -> None:
import pandas as pd
super().close()
self._df_iter = PandasDataFrameIterator(pd.DataFrame(), _no_trunc_date)
self._iterrows = enumerate([])
self._data_manifest = []