8000 feat: add decimal/numeric support · googleapis/python-spanner-django@61e63f7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 61e63f7

Browse files
committed
feat: add decimal/numeric support
1 parent ad8e43e commit 61e63f7

File tree

11 files changed

+418
-41
lines changed

11 files changed

+418
-41
lines changed

django_spanner/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
3434
"CharField": "STRING(%(max_length)s)",
3535
"DateField": "DATE",
3636
"DateTimeField": "TIMESTAMP",
37-
"DecimalField": "FLOAT64",
37+
"DecimalField": "NUMERIC",
3838
"DurationField": "INT64",
3939
"EmailField": "STRING(%(max_length)s)",
4040
"FileField": "STRING(%(max_length)s)",

django_spanner/introspection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
2424
TypeCode.INT64: "IntegerField",
2525
TypeCode.STRING: "CharField",
2626
TypeCode.TIMESTAMP: "DateTimeField",
27+
TypeCode.NUMERIC: "DecimalField",
2728
}
2829

2930
def get_field_type(self, data_type, description):

django_spanner/lookups.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# license that can be found in the LICENSE file or at
55
# https://developers.google.com/open-source/licenses/bsd
66

7-
from django.db.models import DecimalField
87
from django.db.models.lookups import (
98
Contains,
109
EndsWith,
@@ -233,13 +232,8 @@ def cast_param_to_float(self, compiler, connection):
233232
"""
234233
sql, params = self.as_sql(compiler, connection)
235234
if params:
236-
# Cast to DecimaField lookup values to float because
237-
# google.cloud.spanner_v1._helpers._make_value_pb() doesn't serialize
238-
# decimal.Decimal.
239-
if isinstance(self.lhs.output_field, DecimalField):
240-
params[0] = float(params[0])
241235
# Cast remote field lookups that must be integer but come in as string.
242-
elif hasattr(self.lhs.output_field, "get_path_info"):
236+
if hasattr(self.lhs.output_field, "get_path_info"):
243237
for i, field in enumerate(
244238
self.lhs.output_field.get_path_info()[-1].target_fields
245239
):

django_spanner/operations.py

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import re
99
from base64 import b64decode
1010
from datetime import datetime, time
11-
from decimal import Decimal
1211
from uuid import UUID
1312

1413
from django.conf import settings
@@ -190,10 +189,11 @@ def adapt_decimalfield_value(
190189
self, value, max_digits=None, decimal_places=None
191190
):
192191
"""
193-
Convert value from decimal.Decimal into float, for a direct mapping
194-
and correct serialization with RPCs to Cloud Spanner.
192+
Convert value from decimal.Decimal to spanner compatible value.
193+
Since spanner supports Numeric storage of decimal and python spanner
194+
takes care of the conversion so this is a no-op method call.
195195
196-
:type value: :class:`~google.cloud.spanner_v1.types.Numeric`
196+
:type value: :class:`decimal.Decimal`
197197
:param value: A decimal field value.
198198
199199
:type max_digits: int
@@ -203,12 +203,10 @@ def adapt_decimalfield_value(
203203
:param decimal_places: (Optional) The number of decimal places to store
204204
with the number.
205205
206-
:rtype: float
207-
:returns: Formatted value.
206+
:rtype: decimal.Decimal
207+
:returns: decimal value.
208208
"""
209-
if value is None:
210-
return None
211-
return float(value)
209+
return value
212210

213211
def adapt_timefield_value(self, value):
214212
"""
@@ -244,8 +242,6 @@ def get_db_converters(self, expression):
244242
internal_type = expression.output_field.get_internal_type()
245243
if internal_type == "DateTimeField":
246244
converters.append(self.convert_datetimefield_value)
247-
elif internal_type == "DecimalField":
248-
converters.append(self.convert_decimalfield_value)
249245
elif internal_type == "TimeField":
250246
converters.append(self.convert_timefield_value)
251247
elif internal_type == "BinaryField":
@@ -311,26 +307,6 @@ def convert_datetimefield_value(self, value, expression, connection):
311307
else dt
312308
)
313309

314-
def convert_decimalfield_value(self, value, expression, connection):
315-
"""Convert Spanner DecimalField value for Django.
316-
317-
:type value: float
318-
:param value: A decimal field.
319-
320-
:type expression: :class:`django.db.models.expressions.BaseExpression`
321-
:param expression: A query expression.
322-
323-
:type connection: :class:`~google.cloud.cpanner_dbapi.connection.Connection`
324-
:param connection: Reference to a Spanner database connection.
325-
326-
:rtype: :class:`Decimal`
327-
:returns: A converted decimal field.
328-
"""
329-
if value is None:
330-
return value
331-
# Cloud Spanner returns a float.
332-
return Decimal(str(value))
333-
334310
def convert_timefield_value(self, value, expression, connection):
335311
"""Convert Spanner TimeField value for Django.
336312

noxfile.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from __future__ import absolute_import
1111

1212
import os
13+
import pathlib
1314
import shutil
1415

1516
import nox
@@ -25,7 +26,9 @@
2526

2627
DEFAULT_PYTHON_VERSION = "3.8"
2728
SYSTEM_TEST_PYTHON_VERSIONS = ["3.8"]
28-
UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8"]
29+
UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"]
30+
31+
CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute()
2932

3033

3134
@nox.session(python=DEFAULT_PYTHON_VERSION)
@@ -81,7 +84,7 @@ def default(session):
8184
"--cov-report=",
8285
"--cov-fail-under=20",
8386
os.path.join("tests", "unit"),
84-
*session.posargs
87+
*session.posargs,
8588
)
8689

8790

@@ -91,6 +94,56 @@ def unit(session):
9194
default(session)
9295

9396

97+
@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS)
98+
def system(session):
99+
"""Run the system test suite."""
100+
constraints_path = str(
101+
CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt"
102+
)
103+
system_test_path = os.path.join("tests", "system.py")
104+
system_test_folder_path = os.path.join("tests", "system")
105+
106+
# Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true.
107+
if os.environ.get("RUN_SYSTEM_TESTS", "true") == "false":
108+
session.skip("RUN_SYSTEM_TESTS is set to false, skipping")
109+
# Sanity check: Only run tests if the environment variable is set.
110+
if not os.environ.get(
111+
"GOOGLE_APPLICATION_CREDENTIALS", ""
112+
) and not os.environ.get("SPANNER_EMULATOR_HOST", ""):
113+
session.skip(
114+
"Credentials or emulator host must be set via environment variable"
115+
)
116+
117+
system_test_exists = os.path.exists(system_test_path)
118+
system_test_folder_exists = os.path.exists(system_test_folder_path)
119+
# Sanity check: only run tests if found.
120+
if not system_test_exists and not system_test_folder_exists:
121+
session.skip("System tests were not found")
122+
123+
# Use pre-release gRPC for system tests.
124+
session.install("--pre", "grpcio")
125+
126+
# Install all test dependencies, then install this package into the
127+
# virtualenv's dist-packages.
128+
session.install(
129+
"django~=2.2",
130+
"mock",
131+
"pytest",
132+
"google-cloud-testutils",
133+
"-c",
134+
constraints_path,
135+
)
136+
session.install("-e", ".[tracing]", "-c", constraints_path)
137+
138+
# Run py.test against the system tests.
139+
if system_test_exists:
140+
session.run("py.test", "--quiet", system_test_path, *session.posargs)
141+
if system_test_folder_exists:
142+
session.run(
143+
"py.test", "--quiet", system_test_folder_path, *session.posargs
144+
)
145+
146+
94147
@nox.session(python=DEFAULT_PYTHON_VERSION)
95148
def cover(session):
96149
"""Run the final coverage report.

tests/system/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Use of this source code is governed by a BSD-style
4+
# license that can be found in the LICENSE file or at
5+
# https://developers.google.com/open-source/licenses/bsd
6+
7+
import os
8+
import django
9+
from django.conf import settings
10+
11+
# We manually designate which settings we will be using in an environment
12+
# variable. This is similar to what occurs in the `manage.py` file.
13+
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tests.system.settings")
14+
15+
16+
# `pytest` automatically calls this function once when tests are run.
17+
def pytest_configure():
18+
settings.DEBUG = False
19+
django.setup()

tests/system/django_spanner/__init__.py

Whitespace-only changes.

tests/system/django_spanner/models.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Use of this source code is governed by a BSD-style
4+
# license that can be found in the LICENSE file or at
5+
# https://developers.google.com/open-source/licenses/bsd
6+
7+
"""
8+
Different models used by system tests in django-spanner code.
9+
"""
10+
from django.db import models
11+
12+
13+
class Author(models.Model):
14+
first_name = models.CharField(max_length=20)
15+
last_name = models.CharField(max_length=20)
16+
rating = models.DecimalField()
17+
18+
19+
class Number(models.Model):
20+
num = models.DecimalField()
21+
22+
def __str__(self):
23+
return str(self.num)
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Use of this source code is governed by a BSD-style
4+
# license that can be found in the LICENSE file or at
5+
# https://developers.google.com/open-source/licenses/bsd
6+
7+
from .models import Author, Number
8+
from django.test import TransactionTestCase
9+
from django.db import connection, ProgrammingError
10+
from django.db.utils import IntegrityError
11+
from decimal import Decimal
12+
from tests.system.django_spanner.utils import (
13+
setup_instance,
14+
teardown_instance,
15+
setup_database,
16+
teardown_database,
17+
USE_EMULATOR,
18+
)
19+
20+
21+
class TestDecimal(TransactionTestCase):
22+
@classmethod
23+
def setUpClass(cls):
24+
setup_instance()
25+
setup_database()
26+
with connection.schema_editor() as editor:
27+
# Create the tables
28+
editor.create_model(Author)
29+
editor.create_model(Number)
30+
31+
@classmethod
32+
def tearDownClass(cls):
33+
with connection.schema_editor() as editor:
34+
# delete the table
35+
editor.delete_model(Author)
36+
editor.delete_model(Number)
37+
teardown_database()
38+
teardown_instance()
39+
40+
def rating_transform(self, value):
41+
return value["rating"]
42+
43+
def values_transform(self, value):
44+
return value.num
45+
46+
def assertValuesEqual(
47+
self, queryset, expected_values, transformer, ordered=True
48+
):
49+
self.assertQuerysetEqual(
50+
queryset, expected_values, transformer, ordered
51+
)
52+
53+
def test_insert_and_search_decimal_value(self):
54+
"""
55+
Tests model object creation with Author model.
56+
"""
57+
author_kent = Author(
58+
first_name="Arthur", last_name="Kent", rating=Decimal("4.1"),
59+
)
60+
author_kent.save()
61+
qs1 = Author.objects.filter(rating__gte=3).values("rating")
62+
self.assertValuesEqual(
63+
qs1, [Decimal("4.1")], self.rating_transform,
64+
)
65+
# Delete data from Author table.
66+
Author.objects.all().delete()
67+
68+
def test_decimal_filter(self):
69+
"""
70+
Tests decimal filter query.
71+
"""
72+
# Insert data into Number table.
73+
Number.objects.bulk_create(
74+
Number(num=Decimal(i) / Decimal(10)) for i in range(10)
75+
)
76+
qs1 = Number.objects.filter(num__lte=Decimal(2) / Decimal(10))
77+
self.assertValuesEqual(
78+
qs1,
79+
[Decimal(i) / Decimal(10) for i in range(3)],
80+
self.values_transform,
81+
ordered=False,
82+
)
83+
# Delete data from Number table.
84+
Number.objects.all().delete()
85+
86+
def test_decimal_precision_limit(self):
87+
"""
88+
Tests decimal object precission limit.
89+
"""
90+
num_val = Number(num=Decimal(1) / Decimal(3))
91+
if USE_EMULATOR:
92+
msg = "The NUMERIC type supports 38 digits of precision and 9 digits of scale."
93+
with self.assertRaisesRegex(IntegrityError, msg):
94+
num_val.save()
95+
else:
96+
msg = "400 Invalid value for bind parameter a0: Expected NUMERIC."
97+
with self.assertRaisesRegex(ProgrammingError, msg):
98+
num_val.save()
99+
100+
def test_decimal_update(self):
101+
"""
102+
Tests decimal object update.
103+
"""
104+
author_kent = Author(
105+
first_name="Arthur", last_name="Kent", rating=Decimal("4.1"),
106+
)
107+
author_kent.save()
108+
author_kent.rating = Decimal("4.2")
109+
author_kent.save()
110+
qs1 = Author.objects.filter(rating__gte=Decimal("4.2")).values(
111+
"rating"
112+
)
113+
self.assertValuesEqual(
114+
qs1, [Decimal("4.2")], self.rating_transform,
115+
)
116+
# Delete data from Author table.
117+
Author.objects.all().delete()

0 commit comments

Comments
 (0)
0