Skip to content

Commit

Permalink
Add [Compressor|Decompressor].process for API compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
sethmlarson committed Nov 28, 2020
1 parent 9f42c41 commit 31f57cf
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 14 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ dist/
.eggs/
.tox/
.coverage
.hypothesis/
.hypothesis/
*.so
30 changes: 17 additions & 13 deletions src/brotlicffi/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ._brotlicffi import ffi, lib


class Error(Exception):
class error(Exception):
"""
Raised whenever an error is encountered with compressing or decompressing
data using brotlicffi.
Expand All @@ -15,11 +15,11 @@ class Error(Exception):
pass


#: An alias of :class:`Error <brotlicffi.Error>` that
#: exists for compatibility with the original C brotli module.
#: An alias of :class:`error <brotli.error>` that
#: exists for compatibility with the original CFFI brotli module.
#:
#: .. versionadded: 0.5.1
error = Error
#: .. versionadded: 0.8.0
Error = error


class BrotliEncoderMode(enum.IntEnum):
Expand Down Expand Up @@ -159,15 +159,15 @@ def _validate_mode(val):
try:
val = BrotliEncoderMode(val)
except ValueError:
raise Error("%s is not a valid encoder mode" % val)
raise error("%s is not a valid encoder mode" % val)


def _validate_quality(val):
"""
Validate that the quality setting is valid.
"""
if not (0 <= val <= 11):
raise Error(
raise error(
"%d is not a valid quality, must be between 0 and 11" % val
)

Expand All @@ -177,15 +177,15 @@ def _validate_lgwin(val):
Validate that the lgwin setting is valid.
"""
if not (10 <= val <= 24):
raise Error("%d is not a valid lgwin, must be between 10 and 24" % val)
raise error("%d is not a valid lgwin, must be between 10 and 24" % val)


def _validate_lgblock(val):
"""
Validate that the lgblock setting is valid.
"""
if (val != 0) and not (16 <= val <= 24):
raise Error(
raise error(
"%d is not a valid lgblock, must be either 0 or between 16 and 24"
% val
)
Expand Down Expand Up @@ -214,7 +214,7 @@ def _set_parameter(encoder, parameter, parameter_name, val):
# function returns a value we can live in hope that the brotli folks will
# enforce their own constraints.
if rc != lib.BROTLI_TRUE: # pragma: no cover
raise Error(
raise error(
"Error setting parameter %s: %d" % (parameter_name, val)
)

Expand Down Expand Up @@ -309,7 +309,7 @@ def _compress(self, data, operation):
ffi.NULL
)
if rc != lib.BROTLI_TRUE: # pragma: no cover
raise Error("Error encountered compressing data.")
raise error("Error encountered compressing data.")

assert not input_size[0]

Expand All @@ -327,6 +327,8 @@ def compress(self, data):
"""
return self._compress(data, lib.BROTLI_OPERATION_PROCESS)

process = compress

def flush(self):
"""
Flush the compressor. This will emit the remaining output data, but
Expand Down Expand Up @@ -414,7 +416,7 @@ def decompress(self, data):
if rc == lib.BROTLI_DECODER_RESULT_ERROR:
error_code = lib.BrotliDecoderGetErrorCode(self._decoder)
error_message = lib.BrotliDecoderErrorString(error_code)
raise Error(
raise error(
"Decompression error: %s" % ffi.string(error_message)
)

Expand All @@ -433,6 +435,8 @@ def decompress(self, data):

return b''.join(chunks)

process = decompress

def flush(self):
"""
Complete the decompression, return whatever data is remaining to be
Expand Down Expand Up @@ -460,7 +464,7 @@ def finish(self):
lib.BrotliDecoderHasMoreOutput(self._decoder) == lib.BROTLI_FALSE
)
if not self.is_finished():
raise Error("Decompression error: incomplete compressed stream.")
raise error("Decompression error: incomplete compressed stream.")

return b''

Expand Down
15 changes: 15 additions & 0 deletions test/test_simple_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,21 @@ def test_compressed_data_with_dictionaries(s, dictionary):
assert uncompressed == s


@given(binary())
def test_process_alias(s):
c1 = brotlicffi.Compressor()
c2 = brotlicffi.Compressor()
d1 = brotlicffi.Decompressor()
d2 = brotlicffi.Decompressor()
s1 = c1.compress(s) + c1.finish()
s2 = c2.process(s) + c2.finish()
assert (
(d1.decompress(s1) + d1.finish())
== (d2.process(s2) + d2.finish())
== s
)


@pytest.mark.parametrize(
"params",
[
Expand Down

0 comments on commit 31f57cf

Please sign in to comment.