Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ Changes for crate
Unreleased
==========

- Added gzip compression for outgoing request bodies via the ``compress``
parameter (default: ``8192`` bytes).
Pass ``True`` to always compress, ``False`` to disable, or an integer
as a byte threshold. The driver always sends ``Accept-Encoding: gzip,
deflate`` to negotiate compressed responses from the server when
compression is enabled.

- Added named parameter support (``pyformat`` paramstyle). Passing a
:class:`py:dict` as ``parameters`` to ``cursor.execute()`` now accepts
``%(name)s`` placeholders and converts them to positional ``?`` markers
Expand Down
26 changes: 26 additions & 0 deletions docs/connect.rst
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,32 @@ with the rest of your arguments.

However, you can query any schema you like by specifying it in the query.

.. _compression:

Request and response compression
=================================

The ``compress`` parameter controls gzip compression of outgoing request
bodies. The default ``8192`` compresses payloads larger than 8 KB::

>>> connection = client.connect('localhost:4200')
# compress=8192 is the default — payloads > 8 KB are gzip-compressed

To always compress, regardless of payload size::

>>> connection = client.connect('localhost:4200', compress=True)

To disable compression entirely::

>>> connection = client.connect('localhost:4200', compress=False)

To use a custom threshold (bytes)::

>>> connection = client.connect('localhost:4200', compress=4096)

The driver always sends ``Accept-Encoding: gzip, deflate`` so the server
may return compressed responses if compression is enabled.

Next steps
==========

Expand Down
10 changes: 10 additions & 0 deletions src/crate/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
# with Crate these terms will supersede the license and you may use the
# software solely pursuant to the terms of the relevant commercial agreement.

from typing import Union

from verlib2 import Version

from .blob import BlobContainer
Expand Down Expand Up @@ -51,6 +53,7 @@ def __init__(
converter=None,
time_zone=None,
jwt_token=None,
compress: Union[int, bool] = 8192,
):
"""
:param servers:
Expand Down Expand Up @@ -131,6 +134,12 @@ def __init__(
converted from UTC to use the given time zone.
:param jwt_token:
the JWT token to authenticate with the server.
:param compress:
(optional, defaults to ``8192``)
Controls gzip compression of outgoing request bodies.
``False`` disables compression entirely.
``True`` compresses every request regardless of size.
An integer compresses only when the payload exceeds that many bytes.
""" # noqa: E501

self._converter = converter
Expand Down Expand Up @@ -158,6 +167,7 @@ def __init__(
socket_tcp_keepintvl=socket_tcp_keepintvl,
socket_tcp_keepcnt=socket_tcp_keepcnt,
jwt_token=jwt_token,
compress=compress,
)
self.lowest_server_version = self._lowest_server_version()
self._closed = False
Expand Down
22 changes: 19 additions & 3 deletions src/crate/client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import calendar
import datetime as dt
import gzip
import heapq
import io
import logging
Expand Down Expand Up @@ -463,6 +464,7 @@ def __init__(
socket_tcp_keepintvl=None,
socket_tcp_keepcnt=None,
jwt_token=None,
compress: t.Union[int, bool] = 8192,
):
if not servers:
servers = [self.default_server]
Expand All @@ -487,7 +489,7 @@ def __init__(
)

self._active_servers = servers
self._inactive_servers = []
self._inactive_servers: t.List[t.Tuple[float, str, str]] = []
pool_kw = _pool_kw_args(
verify_ssl_cert,
ca_cert,
Expand All @@ -506,7 +508,7 @@ def __init__(
)
self.ssl_relax_minimum_version = ssl_relax_minimum_version
self.backoff_factor = backoff_factor
self.server_pool = {}
self.server_pool: t.Dict[str, Server] = {}
self._update_server_pool(servers, **pool_kw)
self._pool_kw = pool_kw
self._lock = threading.RLock()
Expand All @@ -516,6 +518,12 @@ def __init__(
self.jwt_token = jwt_token
self.schema = schema

if not isinstance(compress, (bool, int)):
raise TypeError(
f"compress must be bool or int, got {type(compress).__name__!r}"
)
self.compress = compress

self.path = self.SQL_PATH
if error_trace:
self.path += "&error_trace=true"
Expand Down Expand Up @@ -678,8 +686,16 @@ def _json_request(self, method, path, data):
"""
Issue request against the crate HTTP API.
"""
headers = {"Accept-Encoding": "gzip, deflate"}

compress_enabled = self.compress is True or (
not isinstance(self.compress, bool) and len(data) >= self.compress
)
Comment on lines +691 to +693
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
compress_enabled = self.compress is True or (
not isinstance(self.compress, bool) and len(data) >= self.compress
)
compress_enabled = self.compress is True or (
isinstance(self.compress, int) and len(data) >= self.compress
)

Or might even make it stricter and fail if it's neither bool nor int.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added check in __init__ and will fail if it's not int or bool

if compress_enabled:
data = gzip.compress(data, compresslevel=6)
headers["Content-Encoding"] = "gzip"

response = self._request(method, path, data=data)
response = self._request(method, path, data=data, headers=headers)
_raise_for_status(response)
if len(response.data) > 0:
return _json_from_response(response)
Expand Down
81 changes: 81 additions & 0 deletions tests/client/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# with Crate these terms will supersede the license and you may use the
# software solely pursuant to the terms of the relevant commercial agreement.

import gzip
import json
import os
import queue
Expand Down Expand Up @@ -735,3 +736,83 @@ def test_credentials_and_token(serve_http):
assert excinfo.match(
"Either JWT tokens are accepted, or user credentials, but not both"
)

def test_compress_accept_encoding_always_sent():
"""Accept-Encoding is sent even when compression is disabled."""
captured = {}

def capturing(*_, **kwargs):
captured["headers"] = kwargs.get("headers") or {}
return fake_response(200)

with patch(REQUEST_PATH, side_effect=capturing):
Client(servers="localhost:4200", compress=False).sql("SELECT 1")
assert captured["headers"].get("Accept-Encoding") == "gzip, deflate"


def test_compress_false_no_content_encoding():
"""No Content-Encoding header when compress=False."""
captured = {}

def capturing(*_, **kwargs):
captured["headers"] = kwargs.get("headers") or {}
return fake_response(200)

with patch(REQUEST_PATH, side_effect=capturing):
Client(servers="localhost:4200", compress=False).sql("SELECT 1")
assert "Content-Encoding" not in captured["headers"]


def test_compress_true_always_compresses():
"""compress=True compresses regardless of payload size."""
captured = {}

def capturing(*_, **kwargs):
captured["data"] = kwargs.get("data", b"")
captured["headers"] = kwargs.get("headers") or {}
return fake_response(200)

with patch(REQUEST_PATH, side_effect=capturing):
Client(servers="localhost:4200", compress=True).sql("SELECT 1")
assert captured["headers"].get("Content-Encoding") == "gzip"
assert b'"stmt"' in gzip.decompress(captured["data"])


def test_compress_threshold_above():
"""Payload above threshold is compressed."""
captured = {}

def capturing(*_, **kwargs):
captured["headers"] = kwargs.get("headers") or {}
return fake_response(200)

with patch(REQUEST_PATH, side_effect=capturing):
Client(servers="localhost:4200", compress=0).sql("SELECT 1")
assert captured["headers"].get("Content-Encoding") == "gzip"


def test_compress_threshold_below():
"""Payload below threshold is not compressed."""
captured = {}

def capturing(*_, **kwargs):
captured["headers"] = kwargs.get("headers") or {}
return fake_response(200)

with patch(REQUEST_PATH, side_effect=capturing):
Client(servers="localhost:4200", compress=999_999).sql("SELECT 1")
assert "Content-Encoding" not in captured["headers"]


def test_compress_default():
"""Default args: Accept-Encoding sent, small payload not compressed."""
captured = {}

def capturing(*_, **kwargs):
captured["headers"] = kwargs.get("headers") or {}
return fake_response(200)

with patch(REQUEST_PATH, side_effect=capturing):
Client(servers="localhost:4200").sql("SELECT 1")
assert captured["headers"].get("Accept-Encoding") == "gzip, deflate"
assert "Content-Encoding" not in captured["headers"]
Loading