from __future__ import annotations
import asyncio
import logging
from multiprocessing import cpu_count
from typing import TYPE_CHECKING, Any, cast
from fsspec.asyn import AsyncFileSystem
from fsspec.callbacks import _DEFAULT_CALLBACK
from pyathena.filesystem.s3 import S3File, S3FileSystem
from pyathena.filesystem.s3_executor import S3AioExecutor
from pyathena.filesystem.s3_object import S3Object
if TYPE_CHECKING:
from datetime import datetime
from pyathena.connection import Connection
_logger = logging.getLogger(__name__)
[docs]
class AioS3FileSystem(AsyncFileSystem):
"""An async filesystem interface for Amazon S3 using fsspec's AsyncFileSystem.
This class wraps ``S3FileSystem`` to provide native asyncio support. Instead of
using ``ThreadPoolExecutor`` for parallel operations, it uses ``asyncio.gather``
with ``asyncio.to_thread`` for natural integration with the asyncio event loop.
The implementation uses composition: an internal ``S3FileSystem`` instance handles
all boto3 calls, while this class delegates to it via ``asyncio.to_thread()``.
This avoids diamond inheritance issues and keeps all boto3 logic in one place.
File handles created by ``_open`` use ``S3AioExecutor`` so that parallel
operations (range reads, multipart uploads) are dispatched via the event loop
instead of spawning additional threads.
Attributes:
_sync_fs: The internal synchronous S3FileSystem instance.
Example:
>>> from pyathena.filesystem.s3_async import AioS3FileSystem
>>> fs = AioS3FileSystem(asynchronous=True)
>>>
>>> # Use in async context
>>> files = await fs._ls('s3://my-bucket/data/')
>>>
>>> # Sync wrappers also available (auto-generated by fsspec)
>>> files = fs.ls('s3://my-bucket/data/')
"""
# https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html
DELETE_OBJECTS_MAX_KEYS: int = 1000
protocol = ("s3", "s3a")
mirror_sync_methods = True
async_impl = True
_extra_tokenize_attributes = ("default_block_size",)
[docs]
def __init__(
self,
connection: Connection[Any] | None = None,
default_block_size: int | None = None,
default_cache_type: str | None = None,
max_workers: int = (cpu_count() or 1) * 5,
s3_additional_kwargs: dict[str, Any] | None = None,
asynchronous: bool = False,
loop: Any | None = None,
batch_size: int | None = None,
**kwargs,
) -> None:
super().__init__(
asynchronous=asynchronous,
loop=loop,
batch_size=batch_size,
**kwargs,
)
self._sync_fs = S3FileSystem(
connection=connection,
default_block_size=default_block_size,
default_cache_type=default_cache_type,
max_workers=max_workers,
s3_additional_kwargs=s3_additional_kwargs,
**kwargs,
)
# Share dircache for cache coherence between async and sync instances
self.dircache = self._sync_fs.dircache
[docs]
@staticmethod
def parse_path(path: str) -> tuple[str, str | None, str | None]:
return S3FileSystem.parse_path(path)
async def _info(self, path: str, **kwargs) -> S3Object:
return await asyncio.to_thread(self._sync_fs.info, path, **kwargs)
async def _ls(self, path: str, detail: bool = False, **kwargs) -> list[S3Object] | list[str]:
return await asyncio.to_thread(self._sync_fs.ls, path, detail=detail, **kwargs)
async def _cat_file(
self, path: str, start: int | None = None, end: int | None = None, **kwargs
) -> bytes:
return await asyncio.to_thread(self._sync_fs.cat_file, path, start=start, end=end, **kwargs)
async def _exists(self, path: str, **kwargs) -> bool:
return await asyncio.to_thread(self._sync_fs.exists, path, **kwargs)
async def _rm_file(self, path: str, **kwargs) -> None:
await asyncio.to_thread(self._sync_fs.rm_file, path, **kwargs)
async def _pipe_file(self, path: str, value: bytes, **kwargs) -> None:
await asyncio.to_thread(self._sync_fs.pipe_file, path, value, **kwargs)
async def _put_file(self, lpath: str, rpath: str, callback=_DEFAULT_CALLBACK, **kwargs) -> None:
await asyncio.to_thread(self._sync_fs.put_file, lpath, rpath, callback=callback, **kwargs)
async def _get_file(self, rpath: str, lpath: str, callback=_DEFAULT_CALLBACK, **kwargs) -> None:
await asyncio.to_thread(self._sync_fs.get_file, rpath, lpath, callback=callback, **kwargs)
async def _mkdir(self, path: str, create_parents: bool = True, **kwargs) -> None:
await asyncio.to_thread(self._sync_fs.mkdir, path, create_parents=create_parents, **kwargs)
async def _makedirs(self, path: str, exist_ok: bool = False) -> None:
await asyncio.to_thread(self._sync_fs.makedirs, path, exist_ok=exist_ok)
async def _rm(self, path: str | list[str], recursive: bool = False, **kwargs) -> None:
"""Remove files or directories using async parallel batch deletion.
For multiple paths, chunks into batches of 1000 (S3 API limit) and uses
``asyncio.gather`` with ``asyncio.to_thread`` instead of ThreadPoolExecutor.
"""
if isinstance(path, str):
path = [path]
bucket, _, _ = self.parse_path(path[0])
expand_paths: list[str] = []
for p in path:
expanded = await asyncio.to_thread(self._sync_fs.expand_path, p, recursive=recursive)
expand_paths.extend(expanded)
if not expand_paths:
return
quiet = kwargs.pop("Quiet", True)
delete_objects: list[dict[str, Any]] = []
for p in expand_paths:
_, key, version_id = self.parse_path(p)
if key:
object_: dict[str, Any] = {"Key": key}
if version_id:
object_["VersionId"] = version_id
delete_objects.append(object_)
if not delete_objects:
return
chunks = [
delete_objects[i : i + self.DELETE_OBJECTS_MAX_KEYS]
for i in range(0, len(delete_objects), self.DELETE_OBJECTS_MAX_KEYS)
]
async def _delete_chunk(chunk: list[dict[str, Any]]) -> None:
request = {
"Bucket": bucket,
"Delete": {
"Objects": chunk,
"Quiet": quiet,
},
}
await asyncio.to_thread(
self._sync_fs._call, self._sync_fs._client.delete_objects, **request
)
await asyncio.gather(*[_delete_chunk(chunk) for chunk in chunks])
for p in expand_paths:
self._sync_fs.invalidate_cache(p)
async def _cp_file(self, path1: str, path2: str, **kwargs) -> None:
"""Copy an S3 object, using async parallel multipart upload for large files."""
kwargs.pop("onerror", None)
bucket1, key1, version_id1 = self.parse_path(path1)
bucket2, key2, version_id2 = self.parse_path(path2)
if version_id2:
raise ValueError("Cannot copy to a versioned file.")
if not key1 or not key2:
raise ValueError("Cannot copy buckets.")
info1 = await self._info(path1)
size1 = info1.get("size", 0)
if size1 <= S3FileSystem.MULTIPART_UPLOAD_MAX_PART_SIZE:
await asyncio.to_thread(
self._sync_fs._copy_object,
bucket1=bucket1,
key1=key1,
version_id1=version_id1,
bucket2=bucket2,
key2=key2,
**kwargs,
)
else:
await self._copy_object_with_multipart_upload(
bucket1=bucket1,
key1=key1,
version_id1=version_id1,
size1=size1,
bucket2=bucket2,
key2=key2,
**kwargs,
)
self._sync_fs.invalidate_cache(path2)
async def _copy_object_with_multipart_upload(
self,
bucket1: str,
key1: str,
size1: int,
bucket2: str,
key2: str,
block_size: int | None = None,
version_id1: str | None = None,
**kwargs,
) -> None:
block_size = block_size if block_size else S3FileSystem.MULTIPART_UPLOAD_MAX_PART_SIZE
if (
block_size < S3FileSystem.MULTIPART_UPLOAD_MIN_PART_SIZE
or block_size > S3FileSystem.MULTIPART_UPLOAD_MAX_PART_SIZE
):
raise ValueError("Block size must be greater than 5MiB and less than 5GiB.")
copy_source: dict[str, Any] = {
"Bucket": bucket1,
"Key": key1,
}
if version_id1:
copy_source["VersionId"] = version_id1
ranges = S3File._get_ranges(
0,
size1,
self._sync_fs.max_workers,
block_size,
)
multipart_upload = await asyncio.to_thread(
self._sync_fs._create_multipart_upload,
bucket=bucket2,
key=key2,
**kwargs,
)
async def _upload_part(i: int, range_: tuple[int, int]) -> dict[str, Any]:
result = await asyncio.to_thread(
self._sync_fs._upload_part_copy,
bucket=bucket2,
key=key2,
copy_source=copy_source,
upload_id=cast(str, multipart_upload.upload_id),
part_number=i + 1,
copy_source_ranges=range_,
)
return {
"ETag": result.etag,
"PartNumber": result.part_number,
}
parts = await asyncio.gather(*[_upload_part(i, r) for i, r in enumerate(ranges)])
parts_list = sorted(parts, key=lambda x: x["PartNumber"])
await asyncio.to_thread(
self._sync_fs._complete_multipart_upload,
bucket=bucket2,
key=key2,
upload_id=cast(str, multipart_upload.upload_id),
parts=parts_list,
)
async def _find(
self,
path: str,
maxdepth: int | None = None,
withdirs: bool = False,
**kwargs,
) -> dict[str, S3Object] | list[str]:
detail = kwargs.pop("detail", False)
files = await asyncio.to_thread(
self._sync_fs._find, path, maxdepth=maxdepth, withdirs=withdirs, **kwargs
)
if detail:
return {f.name: f for f in files}
return [f.name for f in files]
def _open(
self,
path: str,
mode: str = "rb",
block_size: int | None = None,
cache_type: str | None = None,
autocommit: bool = True,
cache_options: dict[Any, Any] | None = None,
**kwargs,
) -> AioS3File:
if block_size is None:
block_size = self._sync_fs.default_block_size
if cache_type is None:
cache_type = self._sync_fs.default_cache_type
max_workers = kwargs.pop("max_worker", self._sync_fs.max_workers)
s3_additional_kwargs = kwargs.pop("s3_additional_kwargs", {})
s3_additional_kwargs.update(self._sync_fs.s3_additional_kwargs)
return AioS3File(
self._sync_fs,
path,
mode,
version_id=None,
max_workers=max_workers,
executor=S3AioExecutor(loop=self._loop),
block_size=block_size,
cache_type=cache_type,
autocommit=autocommit,
cache_options=cache_options,
s3_additional_kwargs=s3_additional_kwargs,
**kwargs,
)
[docs]
def sign(self, path: str, expiration: int = 3600, **kwargs) -> str:
return cast(str, self._sync_fs.sign(path, expiration=expiration, **kwargs))
[docs]
def checksum(self, path: str, **kwargs) -> int:
return cast(int, self._sync_fs.checksum(path, **kwargs))
[docs]
def created(self, path: str) -> datetime:
return self._sync_fs.created(path)
[docs]
def modified(self, path: str) -> datetime:
return self._sync_fs.modified(path)
[docs]
def invalidate_cache(self, path: str | None = None) -> None:
self._sync_fs.invalidate_cache(path)
async def _touch(self, path: str, truncate: bool = True, **kwargs) -> None:
await asyncio.to_thread(self._sync_fs.touch, path, truncate=truncate, **kwargs)
[docs]
class AioS3File(S3File):
"""Async-aware S3 file handle using ``S3AioExecutor``.
Functionally identical to ``S3File``; exists as a distinct type for
``isinstance`` checks and to document the async execution model.
All parallel operations (range reads, multipart uploads) are dispatched
through the ``S3Executor`` interface — the ``S3AioExecutor``
provided by ``AioS3FileSystem`` uses the event loop instead of threads.
"""