diff --git a/Makefile b/Makefile index 1be335e3..3c14085c 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ SCALA_VERSION?=2.13 KAFKA_VERSION?=2.8.1 DOCKER_IMAGE=aiolibs/kafka:$(SCALA_VERSION)_$(KAFKA_VERSION) DIFF_BRANCH=origin/master -FORMATTED_AREAS=aiokafka/util.py aiokafka/structs.py +FORMATTED_AREAS=aiokafka/util.py aiokafka/structs.py aiokafka/codec.py tests/test_codec.py .PHONY: setup setup: diff --git a/aiokafka/codec.py b/aiokafka/codec.py index 27f06e4b..2ff2bce9 100644 --- a/aiokafka/codec.py +++ b/aiokafka/codec.py @@ -1,7 +1,11 @@ +from __future__ import annotations + import gzip import io import struct +from typing_extensions import Buffer + _XERIAL_V1_HEADER = (-126, b"S", b"N", b"A", b"P", b"P", b"Y", 0, 1, 1) _XERIAL_V1_FORMAT = "bccccccBii" ZSTD_MAX_OUTPUT_SIZE = 1024 * 1024 @@ -12,23 +16,23 @@ cramjam = None -def has_gzip(): +def has_gzip() -> bool: return True -def has_snappy(): +def has_snappy() -> bool: return cramjam is not None -def has_zstd(): +def has_zstd() -> bool: return cramjam is not None -def has_lz4(): +def has_lz4() -> bool: return cramjam is not None -def gzip_encode(payload, compresslevel=None): +def gzip_encode(payload: Buffer, compresslevel: int | None = None) -> bytes: if not compresslevel: compresslevel = 9 @@ -45,7 +49,7 @@ def gzip_encode(payload, compresslevel=None): return buf.getvalue() -def gzip_decode(payload): +def gzip_decode(payload: Buffer) -> bytes: buf = io.BytesIO(payload) # Gzip context manager introduced in python 2.7 @@ -57,7 +61,9 @@ def gzip_decode(payload): gzipper.close() -def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32 * 1024): +def snappy_encode( + payload: Buffer, xerial_compatible: bool = True, xerial_blocksize: int = 32 * 1024 +) -> bytes: """Encodes the given data with snappy compression. If xerial_compatible is set then the stream is encoded in a fashion @@ -93,12 +99,9 @@ def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32 * 1024): for fmt, dat in zip(_XERIAL_V1_FORMAT, _XERIAL_V1_HEADER): out.write(struct.pack("!" + fmt, dat)) - # Chunk through buffers to avoid creating intermediate slice copies - def chunker(payload, i, size): - return memoryview(payload)[i : size + i] - + payload = memoryview(payload) for chunk in ( - chunker(payload, i, xerial_blocksize) + payload[i : i + xerial_blocksize] for i in range(0, len(payload), xerial_blocksize) ): block = cramjam.snappy.compress_raw(chunk) @@ -109,7 +112,7 @@ def chunker(payload, i, size): return out.getvalue() -def _detect_xerial_stream(payload): +def _detect_xerial_stream(payload: Buffer) -> bool: """Detects if the data given might have been encoded with the blocking mode of the xerial snappy library. @@ -131,20 +134,21 @@ def _detect_xerial_stream(payload): 1. """ + payload = memoryview(payload) if len(payload) > 16: - header = struct.unpack("!" + _XERIAL_V1_FORMAT, memoryview(payload)[:16]) + header = struct.unpack("!" + _XERIAL_V1_FORMAT, payload[:16]) return header == _XERIAL_V1_HEADER return False -def snappy_decode(payload): +def snappy_decode(payload: Buffer) -> bytes: if not has_snappy(): raise NotImplementedError("Snappy codec is not available") if _detect_xerial_stream(payload): # TODO ? Should become a fileobj ? out = io.BytesIO() - byt = payload[16:] + byt = memoryview(payload)[16:] length = len(byt) cursor = 0 @@ -162,7 +166,7 @@ def snappy_decode(payload): return bytes(cramjam.snappy.decompress_raw(payload)) -def lz4_encode(payload, level=9): +def lz4_encode(payload: Buffer, level: int = 9) -> bytes: # level=9 is used by default by broker itself # https://cwiki.apache.org/confluence/display/KAFKA/KIP-390%3A+Support+Compression+Level if not has_lz4(): @@ -177,14 +181,14 @@ def lz4_encode(payload, level=9): return bytes(compressor.finish()) -def lz4_decode(payload): +def lz4_decode(payload: Buffer) -> bytes: if not has_lz4(): raise NotImplementedError("LZ4 codec is not available") return bytes(cramjam.lz4.decompress(payload)) -def zstd_encode(payload, level=None): +def zstd_encode(payload: Buffer, level: int | None = None) -> bytes: if not has_zstd(): raise NotImplementedError("Zstd codec is not available") @@ -196,7 +200,7 @@ def zstd_encode(payload, level=None): return bytes(cramjam.zstd.compress(payload, level=level)) -def zstd_decode(payload): +def zstd_decode(payload: Buffer) -> bytes: if not has_zstd(): raise NotImplementedError("Zstd codec is not available") diff --git a/pyproject.toml b/pyproject.toml index 1fe6562c..f91db219 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dynamic = ["version"] dependencies = [ "async-timeout", "packaging", + "typing_extensions >=4.6.0", ] [project.optional-dependencies] diff --git a/tests/test_codec.py b/tests/test_codec.py index 0420ec03..ed72ab68 100644 --- a/tests/test_codec.py +++ b/tests/test_codec.py @@ -20,7 +20,7 @@ from ._testutil import random_string -def test_gzip(): +def test_gzip() -> None: for i in range(1000): b1 = random_string(100) b2 = gzip_decode(gzip_encode(b1)) @@ -28,7 +28,7 @@ def test_gzip(): @pytest.mark.skipif(not has_snappy(), reason="Snappy not available") -def test_snappy(): +def test_snappy() -> None: for i in range(1000): b1 = random_string(100) b2 = snappy_decode(snappy_encode(b1)) @@ -36,7 +36,7 @@ def test_snappy(): @pytest.mark.skipif(not has_snappy(), reason="Snappy not available") -def test_snappy_detect_xerial(): +def test_snappy_detect_xerial() -> None: _detect_xerial_stream = codecs._detect_xerial_stream header = b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01Some extra bytes" @@ -55,7 +55,7 @@ def test_snappy_detect_xerial(): @pytest.mark.skipif(not has_snappy(), reason="Snappy not available") -def test_snappy_decode_xerial(): +def test_snappy_decode_xerial() -> None: header = b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01" random_snappy = snappy_encode(b"SNAPPY" * 50, xerial_compatible=False) block_len = len(random_snappy) @@ -73,7 +73,7 @@ def test_snappy_decode_xerial(): @pytest.mark.skipif(not has_snappy(), reason="Snappy not available") -def test_snappy_encode_xerial(): +def test_snappy_encode_xerial() -> None: to_ensure = ( b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01" b"\x00\x00\x00\x18\xac\x02\x14SNAPPY\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00" @@ -88,7 +88,7 @@ def test_snappy_encode_xerial(): @pytest.mark.skipif(not has_lz4(), reason="LZ4 not available") -def test_lz4(): +def test_lz4() -> None: for i in range(1000): b1 = random_string(100) b2 = lz4_decode(lz4_encode(b1)) @@ -97,7 +97,7 @@ def test_lz4(): @pytest.mark.skipif(not has_lz4(), reason="LZ4 not available") -def test_lz4_incremental(): +def test_lz4_incremental() -> None: for i in range(1000): # lz4 max single block size is 4MB # make sure we test with multiple-blocks @@ -108,7 +108,7 @@ def test_lz4_incremental(): @pytest.mark.skipif(not has_zstd(), reason="Zstd not available") -def test_zstd(): +def test_zstd() -> None: for _ in range(1000): b1 = random_string(100) b2 = zstd_decode(zstd_encode(b1))