diff --git a/CHANGES.rst b/CHANGES.rst index 1f42b690..84dd336b 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -5,6 +5,13 @@ Changes for crate Unreleased ========== +- Added gzip compression for outgoing request bodies via the ``compress`` + parameter (default: ``8192`` bytes). + Pass ``True`` to always compress, ``False`` to disable, or an integer + as a byte threshold. The driver always sends ``Accept-Encoding: gzip, + deflate`` to negotiate compressed responses from the server when + compression is enabled. + - Added named parameter support (``pyformat`` paramstyle). Passing a :class:`py:dict` as ``parameters`` to ``cursor.execute()`` now accepts ``%(name)s`` placeholders and converts them to positional ``?`` markers diff --git a/docs/connect.rst b/docs/connect.rst index fca3a667..afc8f59c 100644 --- a/docs/connect.rst +++ b/docs/connect.rst @@ -266,6 +266,32 @@ with the rest of your arguments. However, you can query any schema you like by specifying it in the query. +.. _compression: + +Request and response compression +================================= + +The ``compress`` parameter controls gzip compression of outgoing request +bodies. The default ``8192`` compresses payloads larger than 8 KB:: + + >>> connection = client.connect('localhost:4200') + # compress=8192 is the default — payloads > 8 KB are gzip-compressed + +To always compress, regardless of payload size:: + + >>> connection = client.connect('localhost:4200', compress=True) + +To disable compression entirely:: + + >>> connection = client.connect('localhost:4200', compress=False) + +To use a custom threshold (bytes):: + + >>> connection = client.connect('localhost:4200', compress=4096) + +The driver always sends ``Accept-Encoding: gzip, deflate`` so the server +may return compressed responses if compression is enabled. + Next steps ========== diff --git a/src/crate/client/connection.py b/src/crate/client/connection.py index c9fa1340..f722b848 100644 --- a/src/crate/client/connection.py +++ b/src/crate/client/connection.py @@ -19,6 +19,8 @@ # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. +from typing import Union + from verlib2 import Version from .blob import BlobContainer @@ -51,6 +53,7 @@ def __init__( converter=None, time_zone=None, jwt_token=None, + compress: Union[int, bool] = 8192, ): """ :param servers: @@ -131,6 +134,12 @@ def __init__( converted from UTC to use the given time zone. :param jwt_token: the JWT token to authenticate with the server. + :param compress: + (optional, defaults to ``8192``) + Controls gzip compression of outgoing request bodies. + ``False`` disables compression entirely. + ``True`` compresses every request regardless of size. + An integer compresses only when the payload exceeds that many bytes. """ # noqa: E501 self._converter = converter @@ -158,6 +167,7 @@ def __init__( socket_tcp_keepintvl=socket_tcp_keepintvl, socket_tcp_keepcnt=socket_tcp_keepcnt, jwt_token=jwt_token, + compress=compress, ) self.lowest_server_version = self._lowest_server_version() self._closed = False diff --git a/src/crate/client/http.py b/src/crate/client/http.py index 2026cdbb..139330ff 100644 --- a/src/crate/client/http.py +++ b/src/crate/client/http.py @@ -22,6 +22,7 @@ import calendar import datetime as dt +import gzip import heapq import io import logging @@ -463,6 +464,7 @@ def __init__( socket_tcp_keepintvl=None, socket_tcp_keepcnt=None, jwt_token=None, + compress: t.Union[int, bool] = 8192, ): if not servers: servers = [self.default_server] @@ -487,7 +489,7 @@ def __init__( ) self._active_servers = servers - self._inactive_servers = [] + self._inactive_servers: t.List[t.Tuple[float, str, str]] = [] pool_kw = _pool_kw_args( verify_ssl_cert, ca_cert, @@ -506,7 +508,7 @@ def __init__( ) self.ssl_relax_minimum_version = ssl_relax_minimum_version self.backoff_factor = backoff_factor - self.server_pool = {} + self.server_pool: t.Dict[str, Server] = {} self._update_server_pool(servers, **pool_kw) self._pool_kw = pool_kw self._lock = threading.RLock() @@ -516,6 +518,12 @@ def __init__( self.jwt_token = jwt_token self.schema = schema + if not isinstance(compress, (bool, int)): + raise TypeError( + f"compress must be bool or int, got {type(compress).__name__!r}" + ) + self.compress = compress + self.path = self.SQL_PATH if error_trace: self.path += "&error_trace=true" @@ -678,8 +686,16 @@ def _json_request(self, method, path, data): """ Issue request against the crate HTTP API. """ + headers = {"Accept-Encoding": "gzip, deflate"} + + compress_enabled = self.compress is True or ( + not isinstance(self.compress, bool) and len(data) >= self.compress + ) + if compress_enabled: + data = gzip.compress(data, compresslevel=6) + headers["Content-Encoding"] = "gzip" - response = self._request(method, path, data=data) + response = self._request(method, path, data=data, headers=headers) _raise_for_status(response) if len(response.data) > 0: return _json_from_response(response) diff --git a/tests/client/test_http.py b/tests/client/test_http.py index e3c49cb1..32350ab9 100644 --- a/tests/client/test_http.py +++ b/tests/client/test_http.py @@ -19,6 +19,7 @@ # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. +import gzip import json import os import queue @@ -735,3 +736,83 @@ def test_credentials_and_token(serve_http): assert excinfo.match( "Either JWT tokens are accepted, or user credentials, but not both" ) + +def test_compress_accept_encoding_always_sent(): + """Accept-Encoding is sent even when compression is disabled.""" + captured = {} + + def capturing(*_, **kwargs): + captured["headers"] = kwargs.get("headers") or {} + return fake_response(200) + + with patch(REQUEST_PATH, side_effect=capturing): + Client(servers="localhost:4200", compress=False).sql("SELECT 1") + assert captured["headers"].get("Accept-Encoding") == "gzip, deflate" + + +def test_compress_false_no_content_encoding(): + """No Content-Encoding header when compress=False.""" + captured = {} + + def capturing(*_, **kwargs): + captured["headers"] = kwargs.get("headers") or {} + return fake_response(200) + + with patch(REQUEST_PATH, side_effect=capturing): + Client(servers="localhost:4200", compress=False).sql("SELECT 1") + assert "Content-Encoding" not in captured["headers"] + + +def test_compress_true_always_compresses(): + """compress=True compresses regardless of payload size.""" + captured = {} + + def capturing(*_, **kwargs): + captured["data"] = kwargs.get("data", b"") + captured["headers"] = kwargs.get("headers") or {} + return fake_response(200) + + with patch(REQUEST_PATH, side_effect=capturing): + Client(servers="localhost:4200", compress=True).sql("SELECT 1") + assert captured["headers"].get("Content-Encoding") == "gzip" + assert b'"stmt"' in gzip.decompress(captured["data"]) + + +def test_compress_threshold_above(): + """Payload above threshold is compressed.""" + captured = {} + + def capturing(*_, **kwargs): + captured["headers"] = kwargs.get("headers") or {} + return fake_response(200) + + with patch(REQUEST_PATH, side_effect=capturing): + Client(servers="localhost:4200", compress=0).sql("SELECT 1") + assert captured["headers"].get("Content-Encoding") == "gzip" + + +def test_compress_threshold_below(): + """Payload below threshold is not compressed.""" + captured = {} + + def capturing(*_, **kwargs): + captured["headers"] = kwargs.get("headers") or {} + return fake_response(200) + + with patch(REQUEST_PATH, side_effect=capturing): + Client(servers="localhost:4200", compress=999_999).sql("SELECT 1") + assert "Content-Encoding" not in captured["headers"] + + +def test_compress_default(): + """Default args: Accept-Encoding sent, small payload not compressed.""" + captured = {} + + def capturing(*_, **kwargs): + captured["headers"] = kwargs.get("headers") or {} + return fake_response(200) + + with patch(REQUEST_PATH, side_effect=capturing): + Client(servers="localhost:4200").sql("SELECT 1") + assert captured["headers"].get("Accept-Encoding") == "gzip, deflate" + assert "Content-Encoding" not in captured["headers"]