8000 PYTHON-1369 Extend driver vector support to arbitrary subtypes and fi… · datastax/python-driver@c4a808d · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c4a808d

Browse files
authored
PYTHON-1369 Extend driver vector support to arbitrary subtypes and fix handling of variable length types (OSS C* 5.0) (#1217)
1 parent d05e9d3 commit c4a808d

File tree

7 files changed

+504
-72
lines changed

7 files changed

+504
-72
lines changed

cassandra/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -744,9 +744,3 @@ def __init__(self, msg, excs=[]):
744744
if excs:
745745
complete_msg += ("\nThe following exceptions were observed: \n - " + '\n - '.join(str(e) for e in excs))
746746
Exception.__init__(self, complete_msg)
747-
748-
class VectorDeserializationFailure(DriverException):
749-
"""
750-
The driver was unable to deserialize a given vector
751-
"""
752-
pass

cassandra/cqltypes.py

Lines changed: 79 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@
4848
int32_pack, int32_unpack, int64_pack, int64_unpack,
4949
float_pack, float_unpack, double_pack, double_unpack,
5050
varint_pack, varint_unpack, point_be, point_le,
51-
vints_pack, vints_unpack)
52-
from cassandra import util, VectorDeserializationFailure
51+
vints_pack, vints_unpack, uvint_unpack, uvint_pack)
52+
from cassandra import util
5353

5454
_little_endian_flag = 1 # we always serialize LE
5555
import ipaddress
@@ -392,6 +392,9 @@ def cass_parameterized_type(cls, full=False):
392392
"""
393393
return cls.cass_parameterized_type_with(cls.subtypes, full=full)
394394

395+
@classmethod
396+
def serial_size(cls):
397+
return None
395398

396399
# it's initially named with a _ to avoid registering it as a real type, but
397400
# client programs may want to use the name still for isinstance(), etc
@@ -457,10 +460,12 @@ def serialize(uuid, protocol_version):
457460
except AttributeError:
458461
raise TypeError("Got a non-UUID object for a UUID value")
459462

463+
@classmethod
464+
def serial_size(cls):
465+
return 16
460466

461467
class BooleanType(_CassandraType):
462468
typename = 'boolean'
463-
serial_size = 1
464469

465470
@staticmethod
466471
def deserialize(byts, protocol_version):
@@ -470,6 +475,10 @@ def deserialize(byts, protocol_version):
470475
def serialize(truth, protocol_version):
471476
return int8_pack(truth)
472477

478+
@classmethod
479+
def serial_size(cls):
480+
return 1
481+
473482
class ByteType(_CassandraType):
474483
typename = 'tinyint'
475484

@@ -500,7 +509,6 @@ def serialize(var, protocol_version):
500509

501510
class FloatType(_CassandraType):
502511
typename = 'float'
503-
serial_size = 4
504512

505513
@staticmethod
506514
def deserialize(byts, protocol_version):
@@ -510,10 +518,12 @@ def deserialize(byts, protocol_version):
510518
def serialize(byts, protocol_version):
511519
return float_pack(byts)
512520

521+
@classmethod
522+
def serial_size(cls):
523+
return 4
513524

514525
class DoubleType(_CassandraType):
515526
typename = 'double'
516-
serial_size = 8
517527

518528
@staticmethod
519529
def deserialize(byts, protocol_version):
@@ -523,10 +533,12 @@ def deserialize(byts, protocol_version):
523533
def serialize(byts, protocol_version):
524534
return double_pack(byts)
525535

536+
@classmethod
537+
def serial_size(cls):
538+
return 8
526539

527540
class LongType(_CassandraType):
528541
typename = 'bigint'
529-
serial_size = 8
530542

531543
@staticmethod
532544
def deserialize(byts, protocol_version):
@@ -536,10 +548,12 @@ def deserialize(byts, protocol_version):
536548
def serialize(byts, protocol_version):
537549
return int64_pack(byts)
538550

551+
@classmethod
552+
def serial_size(cls):
553+
return 8
539554

540555
class Int32Type(_CassandraType):
541556
typename = 'int'
542-
serial_size = 4
543557

544558
@staticmethod
545559
def deserialize(byts, protocol_version):
@@ -549,6 +563,9 @@ def deserialize(byts, protocol_version):
549563
def serialize(byts, protocol_version):
550564
return int32_pack(byts)
551565

566+
@classmethod
567+
def serial_size(cls):
568+
return 4
552569

553570
class IntegerType(_CassandraType):
554571
typename = 'varint'
@@ -645,14 +662,16 @@ def serialize(v, protocol_version):
645662

646663
return int64_pack(int(timestamp))
647664

665+
@classmethod
666+
def serial_size(cls):
667+
return 8
648668

649669
class TimestampType(DateType):
650670
pass
651671

652672

653673
class TimeUUIDType(DateType):
654674
typename = 'timeuuid'
655-
serial_size = 16
656675

657676
def my_timestamp(self):
658677
return util.unix_time_from_uuid1(self.val)
@@ -668,6 +687,9 @@ def serialize(timeuuid, protocol_version):
668687
except AttributeError:
669688
raise TypeError("Got a non-UUID object for a UUID value")
670689

690+
@classmethod
691+
def serial_size(cls):
692+
return 16
671693

672694
class SimpleDateType(_CassandraType):
673695
typename = 'date'
@@ -699,7 +721,6 @@ def serialize(val, protocol_version):
699721

700722
class ShortType(_CassandraType):
701723
typename = 'smallint'
702-
serial_size = 2
703724

704725
@staticmethod
705726
def deserialize(byts, protocol_version):
@@ -709,10 +730,14 @@ def deserialize(byts, protocol_version):
709730
def serialize(byts, protocol_version):
710731
return int16_pack(byts)
711732

712-
713733
class TimeType(_CassandraType):
714734
typename = 'time'
715-
serial_size = 8
735+
# Time should be a fixed size 8 byte type but Cassandra 5.0 code marks it as
736+
# variable size... and we have to match what the server expects since the server
737+
# uses that specification to encode data of that type.
738+
#@classmethod
739+
#def serial_size(cls):
740+
# return 8
716741

717742
@staticmethod
718743
def deserialize(byts, protocol_version):
@@ -1409,6 +1434,11 @@ class VectorType(_CassandraType):
14091434
vector_size = 0
14101435
subtype = None
14111436

1437+
@classmethod
1438+
def serial_size(cls):
1439+
serialized_size = cls.subtype.serial_size()
1440+
return cls.vector_size * serialized_size if serialized_size is not None else None
1441+
14121442
@classmethod
14131443
def apply_parameters(cls, params, names):
14141444
assert len(params) == 2
@@ -1418,19 +1448,50 @@ def apply_parameters(cls, params, names):
14181448

14191449
@classmethod
14201450
def deserialize(cls, byts, protocol_version):
1421-
serialized_size = getattr(cls.subtype, "serial_size", None)
1422-
if not serialized_size:
1423-
raise VectorDeserializationFailure("Cannot determine serialized size for vector with subtype %s" % cls.subtype.__name__)
1424-
indexes = (serialized_size * x for x in range(0, cls.vector_size))
1425-
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes]
1451+
serialized_size = cls.subtype.serial_size()
1452+
if serialized_size is not None:
1453+
expected_byte_size = serialized_size * cls.vector_size
1454+
if len(byts) != expected_byte_size:
1455+
raise ValueError(
1456+
"Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead"\
1457+
.format(cls.subtype.typename, cls.vector_size, expected_byte_size, len(byts)))
1458+
indexes = (serialized_size * x for x in range(0, cls.vector_size))
1459+
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes]
1460+
1461+
idx = 0
1462+
rv = []
1463+
while (len(rv) < cls.vector_size):
1464+
try:
1465+
size, bytes_read = uvint_unpack(byts[idx:])
1466+
idx += bytes_read
1467+
rv.append(cls.subtype.deserialize(byts[idx:idx + size], protocol_version))
1468+
idx += size
1469+
except:
1470+
raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"\
1471+
.format(len(rv)))
1472+
1473+
# If we have any additional data in the serialized vector treat that as an error as well
1474+
if idx < len(byts):
1475+
raise ValueError("Additional bytes remaining after vector deserialization completed")
1476+
return rv
14261477

14271478
@classmethod
14281479
def serialize(cls, v, protocol_version):
1480+
v_length = len(v)
1481+
if cls.vector_size != v_length:
1482+
raise ValueError(
1483+
"Expected sequence of size {0} for vector of type {1} and dimension {0}, observed sequence of length {2}"\
1484+
.format(cls.vector_size, cls.subtype.typename, v_length))
1485+
1486+
serialized_size = cls.subtype.serial_size()
14291487
buf = io.BytesIO()
14301488
for item in v:
1431-
buf.write(cls.subtype.serialize(item, protocol_version))
1489+
item_bytes = cls.subtype.serialize(item, protocol_version)
1490+
if serialized_size is None:
1491+
buf.write(uvint_pack(len(item_bytes)))
1492+
buf.write(item_bytes)
14321493
return buf.getvalue()
14331494

14341495
@classmethod
14351496
def cql_parameterized_type(cls):
1436-
return "%s<%s, %s>" % (cls.typename, cls.subtype.typename, cls.vector_size)
1497+
return "%s<%s, %s>" % (cls.typename, cls.subtype.cql_parameterized_type(), cls.vector_size)

cassandra/encoder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
log = logging.getLogger(__name__)
2222

2323
from binascii import hexlify
24+
from decimal import Decimal
2425
import calendar
2526
import datetime
2627
import math
@@ -59,6 +60,7 @@ class Encoder(object):
5960
def __init__(self):
6061
self.mapping = {
6162
float: self.cql_encode_float,
63+
Decimal: self.cql_encode_decimal,
6264
bytearray: self.cql_encode_bytes,
6365
str: self.cql_encode_str,
6466
int: self.cql_encode_object,
@@ -217,3 +219,6 @@ def cql_encode_ipaddress(self, val):
217219
is suitable for ``inet`` type columns.
218220
"""
219221
return "'%s'" % val.compressed
222+
223+
def cql_encode_decimal(self, val):
224+
return self.cql_encode_float(float(val))

cassandra/marshal.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ def vints_unpack(term): # noqa
111111

112112
return tuple(values)
113113

114-
115114
def vints_pack(values):
116115
revbytes = bytearray()
117116
values = [int(v) for v in values[::-1]]
@@ -143,3 +142,48 @@ def vints_pack(values):
143142

144143
revbytes.reverse()
145144
return bytes(revbytes)
145+
146+
def uvint_unpack(bytes):
147+
first_byte = bytes[0]
148+
149+
if (first_byte & 128) == 0:
150+
return (first_byte,1)
151+
152+
num_extra_bytes = 8 - (~first_byte & 0xff).bit_length()
153+
rv = first_byte & (0xff >> num_extra_bytes)
154+
for idx in range(1,num_extra_bytes + 1):
155+
new_byte = bytes[idx]
156+
rv <<= 8
157+
rv |= new_byte & 0xff
158+
159+
return (rv, num_extra_bytes + 1)
160+
161+
def uvint_pack(val):
162+
rv = bytearray()
163+
if val < 128:
164+
rv.append(val)
165+
else:
166+
v = val
167+
num_extra_bytes = 0
168+
num_bits = v.bit_length()
169+
# We need to reserve (num_extra_bytes+1) bits in the first byte
170+
# ie. with 1 extra byte, the first byte needs to be something like '10XXXXXX' # 2 bits reserved
171+
# ie. with 8 extra bytes, the first byte needs to be '11111111' # 8 bits reserved
172+
reserved_bits = num_extra_bytes + 1
173+
while num_bits > (8-(reserved_bits)):
174+
num_extra_bytes += 1
175+
num_bits -= 8
176+
reserved_bits = min(num_extra_bytes + 1, 8)
177+
rv.append(v & 0xff)
178+
v >>= 8
179+
180+
if num_extra_bytes > 8:
181+
raise ValueError('Value %d is too big and cannot be encoded as vint' % val)
182+
183+
# We can now store the last bits in the first byte
184+
n = 8 - num_extra_bytes
185+
v |= (0xff >> n << n)
186+
rv.append(abs(v))
187+
188+
rv.reverse()
189+
return bytes(rv)

tests/integration/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,10 @@ def _id_and_mark(f):
330330
greaterthanorequalcass36 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.6'), 'Cassandra version 3.6 or greater required')
331331
greaterthanorequalcass3_10 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.10'), 'Cassandra version 3.10 or greater required')
332332
greaterthanorequalcass3_11 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.11'), 'Cassandra version 3.11 or greater required')
333-
greaterthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION >= Version('4.0-a'), 'Cassandra version 4.0 or greater required')
334-
lessthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION <= Version('4.0-a'), 'Cassandra version less or equal to 4.0 required')
335-
lessthancass40 = unittest.skipUnless(CASSANDRA_VERSION < Version('4.0-a'), 'Cassandra version less than 4.0 required')
333+
greaterthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION >= Version('4.0'), 'Cassandra version 4.0 or greater required')
334+
greaterthanorequalcass50 = unittest.skipUnless(CASSANDRA_VERSION >= Version('5.0-beta'), 'Cassandra version 5.0 or greater required')
335+
lessthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION <= Version('4.0'), 'Cassandra version less or equal to 4.0 required')
336+
lessthancass40 = unittest.skipUnless(CASSANDRA_VERSION < Version('4.0'), 'Cassandra version less than 4.0 required')
336337
lessthancass30 = unittest.skipUnless(CASSANDRA_VERSION < Version('3.0'), 'Cassandra version less then 3.0 required')
337338
greaterthanorequaldse68 = unittest.skipUnless(DSE_VERSION and DSE_VERSION >= Version('6.8'), "DSE 6.8 or greater required for this test")
338339
greaterthanorequaldse67 = unittest.skipUnless(DSE_VERSION and DSE_VERSION >= Version('6.7'), "DSE 6.7 or greater required for this test")

0 commit comments

Comments
 (0)
0