Source code for pyathena.sqlalchemy.types

from __future__ import annotations

from datetime import date, datetime
from typing import TYPE_CHECKING, Any

from sqlalchemy import types
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.type_api import TypeEngine

if TYPE_CHECKING:
    from sqlalchemy import Dialect
    from sqlalchemy.sql.type_api import _LiteralProcessorType


def get_double_type() -> type[Any]:
    """Get the appropriate type for DOUBLE based on SQLAlchemy version.

    SQLAlchemy 2.0+ provides a native DOUBLE type, while earlier versions
    only have FLOAT. This function returns the appropriate type based on
    what's available.

    Returns:
        types.DOUBLE for SQLAlchemy 2.0+, types.FLOAT for earlier versions.
    """
    if hasattr(types, "DOUBLE"):
        return types.DOUBLE
    return types.FLOAT


[docs] class AthenaTimestamp(TypeEngine[datetime]): """SQLAlchemy type for Athena TIMESTAMP values. This type handles the conversion of Python datetime objects to Athena's TIMESTAMP literal syntax. When used in queries, datetime values are rendered as ``TIMESTAMP 'YYYY-MM-DD HH:MM:SS.mmm'``. The type supports millisecond precision (3 decimal places) which matches Athena's TIMESTAMP type precision. Example: >>> from sqlalchemy import Column, Table, MetaData >>> from pyathena.sqlalchemy.types import AthenaTimestamp >>> metadata = MetaData() >>> events = Table('events', metadata, ... Column('event_time', AthenaTimestamp) ... ) """ render_literal_cast = True render_bind_cast = True
[docs] @staticmethod def process(value: datetime | Any | None) -> str: if isinstance(value, datetime): return f"""TIMESTAMP '{value.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]}'""" return f"TIMESTAMP '{value!s}'"
[docs] def literal_processor(self, dialect: Dialect) -> _LiteralProcessorType[datetime] | None: return self.process
[docs] class AthenaDate(TypeEngine[date]): """SQLAlchemy type for Athena DATE values. This type handles the conversion of Python date objects to Athena's DATE literal syntax. When used in queries, date values are rendered as ``DATE 'YYYY-MM-DD'``. Example: >>> from sqlalchemy import Column, Table, MetaData >>> from pyathena.sqlalchemy.types import AthenaDate >>> metadata = MetaData() >>> orders = Table('orders', metadata, ... Column('order_date', AthenaDate) ... ) """ render_literal_cast = True render_bind_cast = True
[docs] @staticmethod def process(value: date | datetime | Any) -> str: if isinstance(value, (date, datetime)): f"DATE '{value:%Y-%m-%d}'" return f"DATE '{value!s}'"
[docs] def literal_processor(self, dialect: Dialect) -> _LiteralProcessorType[date] | None: return self.process
[docs] class Tinyint(sqltypes.Integer): """SQLAlchemy type for Athena TINYINT (8-bit signed integer). TINYINT stores values from -128 to 127. This type is useful for columns that contain small integer values to optimize storage. """ __visit_name__ = "tinyint"
class TINYINT(Tinyint): """Uppercase alias for Tinyint type. This provides SQLAlchemy-style uppercase naming convention. """ __visit_name__ = "TINYINT"
[docs] class AthenaStruct(TypeEngine[dict[str, Any]]): """SQLAlchemy type for Athena STRUCT/ROW complex type. STRUCT represents a record with named fields, similar to a database row or a Python dictionary with typed values. Each field has a name and a data type. Args: *fields: Field specifications. Each can be either: - A string (field name, defaults to STRING type) - A tuple of (field_name, field_type) Example: >>> from sqlalchemy import Column, Table, MetaData, types >>> from pyathena.sqlalchemy.types import AthenaStruct >>> metadata = MetaData() >>> users = Table('users', metadata, ... Column('address', AthenaStruct( ... ('street', types.String), ... ('city', types.String), ... ('zip_code', types.Integer) ... )) ... ) See Also: AWS Athena STRUCT Type: https://docs.aws.amazon.com/athena/latest/ug/rows-and-structs.html """ __visit_name__ = "struct"
[docs] def __init__(self, *fields: str | tuple[str, Any]) -> None: self.fields: dict[str, TypeEngine[Any]] = {} for field in fields: if isinstance(field, str): self.fields[field] = sqltypes.String() elif isinstance(field, tuple) and len(field) == 2: field_name, field_type = field if isinstance(field_type, TypeEngine): self.fields[field_name] = field_type else: # Assume it's a SQLAlchemy type class and instantiate it self.fields[field_name] = field_type() else: raise ValueError(f"Invalid field specification: {field}")
def __getitem__(self, key: str) -> TypeEngine[Any]: return self.fields[key] @property def python_type(self) -> type: return dict
class STRUCT(AthenaStruct): """Uppercase alias for AthenaStruct type.""" __visit_name__ = "STRUCT"
[docs] class AthenaMap(TypeEngine[dict[str, Any]]): """SQLAlchemy type for Athena MAP complex type. MAP represents a collection of key-value pairs where all keys have the same type and all values have the same type. Args: key_type: SQLAlchemy type for map keys. Defaults to String. value_type: SQLAlchemy type for map values. Defaults to String. Example: >>> from sqlalchemy import Column, Table, MetaData, types >>> from pyathena.sqlalchemy.types import AthenaMap >>> metadata = MetaData() >>> settings = Table('settings', metadata, ... Column('config', AthenaMap(types.String, types.Integer)) ... ) See Also: AWS Athena MAP Type: https://docs.aws.amazon.com/athena/latest/ug/maps.html """ __visit_name__ = "map"
[docs] def __init__(self, key_type: Any = None, value_type: Any = None) -> None: if key_type is None: self.key_type: TypeEngine[Any] = sqltypes.String() elif isinstance(key_type, TypeEngine): self.key_type = key_type else: # Assume it's a SQLAlchemy type class and instantiate it self.key_type = key_type() if value_type is None: self.value_type: TypeEngine[Any] = sqltypes.String() elif isinstance(value_type, TypeEngine): self.value_type = value_type else: # Assume it's a SQLAlchemy type class and instantiate it self.value_type = value_type()
@property def python_type(self) -> type: return dict
class MAP(AthenaMap): """Uppercase alias for AthenaMap type.""" __visit_name__ = "MAP"
[docs] class AthenaArray(TypeEngine[list[Any]]): """SQLAlchemy type for Athena ARRAY complex type. ARRAY represents an ordered collection of elements of the same type. Args: item_type: SQLAlchemy type for array elements. Defaults to String. Example: >>> from sqlalchemy import Column, Table, MetaData, types >>> from pyathena.sqlalchemy.types import AthenaArray >>> metadata = MetaData() >>> posts = Table('posts', metadata, ... Column('tags', AthenaArray(types.String)) ... ) See Also: AWS Athena ARRAY Type: https://docs.aws.amazon.com/athena/latest/ug/arrays.html """ __visit_name__ = "array"
[docs] def __init__(self, item_type: Any = None) -> None: if item_type is None: self.item_type: TypeEngine[Any] = sqltypes.String() elif isinstance(item_type, TypeEngine): self.item_type = item_type else: # Assume it's a SQLAlchemy type class and instantiate it self.item_type = item_type()
@property def python_type(self) -> type: return list
class ARRAY(AthenaArray): """Uppercase alias for AthenaArray type.""" __visit_name__ = "ARRAY"