48
48
int32_pack , int32_unpack , int64_pack , int64_unpack ,
49
49
float_pack , float_unpack , double_pack , double_unpack ,
50
50
varint_pack , varint_unpack , point_be , point_le ,
51
- vints_pack , vints_unpack )
52
- from cassandra import util , VectorDeserializationFailure
51
+ vints_pack , vints_unpack , uvint_unpack , uvint_pack )
52
+ from cassandra import util
53
53
54
54
_little_endian_flag = 1 # we always serialize LE
55
55
import ipaddress
@@ -392,6 +392,9 @@ def cass_parameterized_type(cls, full=False):
392
392
"""
393
393
return cls .cass_parameterized_type_with (cls .subtypes , full = full )
394
394
395
+ @classmethod
396
+ def serial_size (cls ):
397
+ return None
395
398
396
399
# it's initially named with a _ to avoid registering it as a real type, but
397
400
# client programs may want to use the name still for isinstance(), etc
@@ -457,10 +460,12 @@ def serialize(uuid, protocol_version):
457
460
except AttributeError :
458
461
raise TypeError ("Got a non-UUID object for a UUID value" )
459
462
463
+ @classmethod
464
+ def serial_size (cls ):
465
+ return 16
460
466
461
467
class BooleanType (_CassandraType ):
462
468
typename = 'boolean'
463
- serial_size = 1
464
469
465
470
@staticmethod
466
471
def deserialize (byts , protocol_version ):
@@ -470,6 +475,10 @@ def deserialize(byts, protocol_version):
470
475
def serialize (truth , protocol_version ):
471
476
return int8_pack (truth )
472
477
478
+ @classmethod
479
+ def serial_size (cls ):
480
+ return 1
481
+
473
482
class ByteType (_CassandraType ):
474
483
typename = 'tinyint'
475
484
@@ -500,7 +509,6 @@ def serialize(var, protocol_version):
500
509
501
510
class FloatType (_CassandraType ):
502
511
typename = 'float'
503
- serial_size = 4
504
512
505
513
@staticmethod
506
514
def deserialize (byts , protocol_version ):
@@ -510,10 +518,12 @@ def deserialize(byts, protocol_version):
510
518
def serialize (byts , protocol_version ):
511
519
return float_pack (byts )
512
520
521
+ @classmethod
522
+ def serial_size (cls ):
523
+ return 4
513
524
514
525
class DoubleType (_CassandraType ):
515
526
typename = 'double'
516
- serial_size = 8
517
527
518
528
@staticmethod
519
529
def deserialize (byts , protocol_version ):
@@ -523,10 +533,12 @@ def deserialize(byts, protocol_version):
523
533
def serialize (byts , protocol_version ):
524
534
return double_pack (byts )
525
535
536
+ @classmethod
537
+ def serial_size (cls ):
538
+ return 8
526
539
527
540
class LongType (_CassandraType ):
528
541
typename = 'bigint'
529
- serial_size = 8
530
542
531
543
@staticmethod
532
544
def deserialize (byts , protocol_version ):
@@ -536,10 +548,12 @@ def deserialize(byts, protocol_version):
536
548
def serialize (byts , protocol_version ):
537
549
return int64_pack (byts )
538
550
551
+ @classmethod
552
+ def serial_size (cls ):
553
+ return 8
539
554
540
555
class Int32Type (_CassandraType ):
541
556
typename = 'int'
542
- serial_size = 4
543
557
544
558
@staticmethod
545
559
def deserialize (byts , protocol_version ):
@@ -549,6 +563,9 @@ def deserialize(byts, protocol_version):
549
563
def serialize (byts , protocol_version ):
550
564
return int32_pack (byts )
551
565
566
+ @classmethod
567
+ def serial_size (cls ):
568
+ return 4
552
569
553
570
class IntegerType (_CassandraType ):
554
571
typename = 'varint'
@@ -645,14 +662,16 @@ def serialize(v, protocol_version):
645
662
646
663
return int64_pack (int (timestamp ))
647
664
665
+ @classmethod
666
+ def serial_size (cls ):
667
+ return 8
648
668
649
669
class TimestampType (DateType ):
650
670
pass
651
671
652
672
653
673
class TimeUUIDType (DateType ):
654
674
typename = 'timeuuid'
655
- serial_size = 16
656
675
657
676
def my_timestamp (self ):
658
677
return util .unix_time_from_uuid1 (self .val )
@@ -668,6 +687,9 @@ def serialize(timeuuid, protocol_version):
668
687
except AttributeError :
669
688
raise TypeError ("Got a non-UUID object for a UUID value" )
670
689
690
+ @classmethod
691
+ def serial_size (cls ):
692
+ return 16
671
693
672
694
class SimpleDateType (_CassandraType ):
673
695
typename = 'date'
@@ -699,7 +721,6 @@ def serialize(val, protocol_version):
699
721
700
722
class ShortType (_CassandraType ):
701
723
typename = 'smallint'
702
- serial_size = 2
703
724
704
725
@staticmethod
705
726
def deserialize (byts , protocol_version ):
@@ -709,10 +730,14 @@ def deserialize(byts, protocol_version):
709
730
def serialize (byts , protocol_version ):
710
731
return int16_pack (byts )
711
732
712
-
713
733
class TimeType (_CassandraType ):
714
734
typename = 'time'
715
- serial_size = 8
735
+ # Time should be a fixed size 8 byte type but Cassandra 5.0 code marks it as
736
+ # variable size... and we have to match what the server expects since the server
737
+ # uses that specification to encode data of that type.
738
+ #@classmethod
739
+ #def serial_size(cls):
740
+ # return 8
716
741
717
742
@staticmethod
718
743
def deserialize (byts , protocol_version ):
@@ -1409,6 +1434,11 @@ class VectorType(_CassandraType):
1409
1434
vector_size = 0
1410
1435
subtype = None
1411
1436
1437
+ @classmethod
1438
+ def serial_size (cls ):
1439
+ serialized_size = cls .subtype .serial_size ()
1440
+ return cls .vector_size * serialized_size if serialized_size is not None else None
1441
+
1412
1442
@classmethod
1413
1443
def apply_parameters (cls , params , names ):
1414
1444
assert len (params ) == 2
@@ -1418,19 +1448,50 @@ def apply_parameters(cls, params, names):
1418
1448
1419
1449
@classmethod
1420
1450
def deserialize (cls , byts , protocol_version ):
1421
- serialized_size = getattr (cls .subtype , "serial_size" , None )
1422
- if not serialized_size :
1423
- raise VectorDeserializationFailure ("Cannot determine serialized size for vector with subtype %s" % cls .subtype .__name__ )
1424
- indexes = (serialized_size * x for x in range (0 , cls .vector_size ))
1425
- return [cls .subtype .deserialize (byts [idx :idx + serialized_size ], protocol_version ) for idx in indexes ]
1451
+ serialized_size = cls .subtype .serial_size ()
1452
+ if serialized_size is not None :
1453
+ expected_byte_size = serialized_size * cls .vector_size
1454
+ if len (byts ) != expected_byte_size :
1455
+ raise ValueError (
1456
+ "Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead" \
1457
+ .format (cls .subtype .typename , cls .vector_size , expected_byte_size , len (byts )))
1458
+ indexes = (serialized_size * x for x in range (0 , cls .vector_size ))
1459
+ return [cls .subtype .deserialize (byts [idx :idx + serialized_size ], protocol_version ) for idx in indexes ]
1460
+
1461
+ idx = 0
1462
+ rv = []
1463
+ while (len (rv ) < cls .vector_size ):
1464
+ try :
1465
+ size , bytes_read = uvint_unpack (byts [idx :])
1466
+ idx += bytes_read
1467
+ rv .append (cls .subtype .deserialize (byts [idx :idx + size ], protocol_version ))
1468
+ idx += size
1469
+ except :
1470
+ raise ValueError ("Error reading additional data during vector deserialization after successfully adding {} elements" \
1471
+ .format (len (rv )))
1472
+
1473
+ # If we have any additional data in the serialized vector treat that as an error as well
1474
+ if idx < len (byts ):
1475
+ raise ValueError ("Additional bytes remaining after vector deserialization completed" )
1476
+ return rv
1426
1477
1427
1478
@classmethod
1428
1479
def serialize (cls , v , protocol_version ):
1480
+ v_length = len (v )
1481
+ if cls .vector_size != v_length :
1482
+ raise ValueError (
1483
+ "Expected sequence of size {0} for vector of type {1} and dimension {0}, observed sequence of length {2}" \
1484
+ .format (cls .vector_size , cls .subtype .typename , v_length ))
1485
+
1486
+ serialized_size = cls .subtype .serial_size ()
1429
1487
buf = io .BytesIO ()
1430
1488
for item in v :
1431
- buf .write (cls .subtype .serialize (item , protocol_version ))
1489
+ item_bytes = cls .subtype .serialize (item , protocol_version )
1490
+ if serialized_size is None :
1491
+ buf .write (uvint_pack (len (item_bytes )))
1492
+ buf .write (item_bytes )
1432
1493
return buf .getvalue ()
1433
1494
1434
1495
@classmethod
1435
1496
def cql_parameterized_type (cls ):
1436
- return "%s<%s, %s>" % (cls .typename , cls .subtype .typename , cls .vector_size )
1497
+ return "%s<%s, %s>" % (cls .typename , cls .subtype .cql_parameterized_type () , cls .vector_size )
0 commit comments