"""Defines useful types and utilities for working with bytestrings."""
from __future__ import annotations
import zlib
from abc import abstractmethod, ABCMeta
from collections.abc import Iterable, Sequence
from io import BytesIO
from itertools import chain
from numbers import Number
from typing import final, Any, Final, TypeAlias, TypeVar, TypeGuard, \
SupportsBytes, SupportsIndex, Protocol
__all__ = ['MaybeBytes', 'MaybeBytesT', 'has_bytes', 'WriteStream',
'Writeable', 'BytesFormat']
#: An object that can be converted to a bytestring.
MaybeBytes: TypeAlias = bytes | SupportsBytes
#: A type variable bound to :class:`MaybeBytes`.
MaybeBytesT = TypeVar('MaybeBytesT', bound=MaybeBytes)
_FormatArg: TypeAlias = SupportsIndex | MaybeBytes
[docs]
def has_bytes(value: object) -> TypeGuard[MaybeBytes]:
"""Checks if the *value* is :class:`bytes` or implements the ``__bytes__``
method to be converted to bytes.
Args:
value: The value to check.
"""
return isinstance(value, bytes) or isinstance(value, SupportsBytes)
[docs]
class WriteStream(Protocol):
"""Typing protocol indicating the object implements the :meth:`.write`
method.
See Also:
:class:`~asyncio.StreamWriter`, :class:`~typing.BinaryIO`
"""
[docs]
@abstractmethod
def write(self, data: bytes) -> Any:
"""Defines an abstract method where ``data`` is written to a stream or
buffer.
Args:
data: The data to write.
"""
...
class HashStream(WriteStream):
"""A stream that a :class:`Writeable` can use to generate a
non-cryptographic hash using :func:`zlib.adler32`.
"""
__slots__ = ['_digest']
def __init__(self) -> None:
super().__init__()
self._digest = zlib.adler32(b'')
def write(self, data: bytes) -> None:
self._digest = zlib.adler32(data, self._digest)
def digest(self, data: Writeable | None = None) -> bytes:
"""Return the digest of the data written to the hash stream.
Args:
data: The data to write before computing the digest.
"""
if data is not None:
data.write(self)
return self._digest.to_bytes(4, 'big')
[docs]
class Writeable(metaclass=ABCMeta):
"""Base class for types that can be written to a stream."""
__slots__: Sequence[str] = []
[docs]
@final
def tobytes(self) -> bytes:
"""Convert the writeable object back into a bytestring using the
:meth:`.write` method.
"""
writer = BytesIO()
self.write(writer)
return writer.getvalue()
[docs]
@classmethod
def empty(cls) -> Writeable:
"""Return a :class:`Writeable` for an empty string."""
return _EmptyWriteable()
[docs]
@classmethod
def wrap(cls, data: MaybeBytes) -> Writeable:
"""Wrap the bytes in a :class:`Writeable`.
Args:
data: The object to wrap.
"""
return _WrappedWriteable(data)
[docs]
@classmethod
def concat(cls, data: Iterable[MaybeBytes]) -> Writeable:
"""Wrap the iterable in a :class:`Writeable` that will write each item.
Args:
data: The iterable to wrap.
"""
return _ConcatWriteable(data)
[docs]
def write(self, writer: WriteStream) -> None:
"""Write the object to the stream, with one or more calls to
:meth:`~WriteStream.write`.
Args:
writer: The output stream.
"""
writer.write(bytes(self))
def __bool__(self) -> bool:
return True
def __len__(self) -> int:
return len(bytes(self))
@abstractmethod
def __bytes__(self) -> bytes:
...
def __str__(self) -> str:
return str(bytes(self), 'utf-8', 'replace')
class _EmptyWriteable(Writeable):
__slots__: Sequence[str] = []
def write(self, writer: WriteStream) -> None:
pass
def __bytes__(self) -> bytes:
return b''
def __repr__(self) -> str:
return '<Writeable empty>'
class _WrappedWriteable(Writeable):
__slots__ = ['data']
def __init__(self, data: MaybeBytes) -> None:
self.data = bytes(data)
def __bytes__(self) -> bytes:
return self.data
def __repr__(self) -> str:
return f'<Writeable {self.data!r}>'
class _ConcatWriteable(Writeable):
__slots__ = ['data']
def __init__(self, data: Iterable[MaybeBytes]) -> None:
self.data = [self._wrap(val) for val in data]
@classmethod
def _wrap(cls, val: MaybeBytes) -> Writeable:
if isinstance(val, Writeable):
return val
else:
return _WrappedWriteable(val)
def write(self, writer: WriteStream) -> None:
for item in self.data:
item.write(writer)
def __bytes__(self) -> bytes:
return BytesFormat(b'').join(self.data)
def __str__(self) -> str:
return ''.join(str(d) for d in self.data)
def __repr__(self) -> str:
return f'<Writeable {self.data!r}>'