8000 Added INTERVAL type support · psqlpy-python/psqlpy@9af7644 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9af7644

Browse files
committed
Added INTERVAL type support
Signed-off-by: chandr-andr (Kiselev Aleksandr) <chandr@chandr.net>
1 parent b182b3e commit 9af7644

File tree

9 files changed

+141
-7
lines changed

9 files changed

+141
-7
lines changed

Cargo.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,4 @@ openssl = { version = "0.10.64", features = ["vendored"] }
5151
itertools = "0.12.1"
5252
openssl-src = "300.2.2"
5353
openssl-sys = "0.9.102"
54+
pg_interval = { git = "https://github.com/chandr-andr/rust-postgres-interval.git", branch = "psqlpy" }

docs/usage/types/array_types.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ For type safety and better performance we have predefined array types.
3131
| LineArray | LINE ARRAY |
3232
| LsegArray | LSEG ARRAY |
3333
| CircleArray | CIRCLE ARRAY |
34+
| IntervalArray | INTERVAL ARRAY |
3435

3536
### Example:
3637

docs/usage/types/supported_types.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Here you can find all types supported by `PSQLPy`. If PSQLPy isn't `-`, you can
2525
| datetime.time | - | TIME |
2626
| datetime.datetime | - | TIMESTAMP |
2727
| datetime.datetime | - | TIMESTAMPTZ |
28+
| datetime.timedelta | - | INTERVAL |
2829
| UUID | - | UUID |
2930
| dict | - | JSONB |
3031
| dict | PyJSONB | JSONB |

python/psqlpy/_internal/extra_types.pyi

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import typing
2-
from datetime import date, datetime, time
2+
from datetime import date, datetime, time, timedelta
33
from decimal import Decimal
44
from ipaddress import IPv4Address, IPv6Address
55
from uuid import UUID
@@ -753,5 +753,24 @@ class CircleArray:
753753
"""Create new instance of CircleArray.
754754
755755
### Parameters:
756-
- `inner`: inner value, sequence of PyLineSegment values.
756+
- `inner`: inner value, sequence of PyCircle values.
757+
"""
758+
759+
class IntervalArray:
760+
"""Represent INTERVAL ARRAY in PostgreSQL."""
761+
762+
def __init__(
763+
self: Self,
764+
inner: typing.Sequence[
765+
typing.Union[
766+
timedelta,
767+
typing.Sequence[timedelta],
768+
typing.Any,
769+
],
770+
],
771+
) -> None:
772+
"""Create new instance of IntervalArray.
773+
774+
### Parameters:
775+
- `inner`: inner value, sequence of timedelta values.
757776
"""

python/psqlpy/extra_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Int32Array,
1515
Int64Array,
1616
Integer,
17+
IntervalArray,
1718
IpAddressArray,
1819
JSONArray,
1920
JSONBArray,
@@ -96,4 +97,5 @@
9697
"LineArray",
9798
"LsegArray",
9899
"CircleArray",
100+
"IntervalArray",
99101
]

python/tests/test_value_converter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Int32Array,
2828
Int64Array,
2929
Integer,
30+
IntervalArray,
3031
IpAddressArray,
3132
JSONArray,
3233
JSONBArray,
@@ -198,6 +199,11 @@ async def test_as_class(
198199
PyCircle([1, 2.8, 3]),
199200
((1.0, 2.8), 3.0),
200201
),
202+
(
203+
"INTERVAL",
204+
datetime.timedelta(days=100, microseconds=100),
205+
datetime.timedelta(days=100, microseconds=100),
206+
),
201207
(
202208
"VARCHAR ARRAY",
203209
["Some String", "Some String"],
@@ -598,6 +604,11 @@ async def test_as_class(
598604
[((5.0, 1.8), 10.0)],
599605
],
600606
),
607+
(
608+
"INTERVAL ARRAY",
609+
[datetime.timedelta(days=100, microseconds=100), datetime.timedelta(days=100, microseconds=100)],
610+
[datetime.timedelta(days=100, microseconds=100), datetime.timedelta(days=100, microseconds=100)],
611+
),
601612
),
602613
)
603614
async def test_deserialization_simple_into_python(
@@ -1501,6 +1512,13 @@ async def test_empty_array(
15011512
[((5.0, 1.8), 10.0)],
15021513
],
15031514
),
1515+
(
1516+
"INTERVAL ARRAY",
1517+
IntervalArray(
1518+
[[datetime.timedelta(days=100, microseconds=100)], [datetime.timedelta(days=100, microseconds=100)]],
1519+
),
1520+
[[datetime.timedelta(days=100, microseconds=100)], [datetime.timedelta(days=100, microseconds=100)]],
1521+
),
15041522
),
15051523
)
15061524
async def test_array_types(

src/extra_types.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ build_array_type!(PathArray, PythonDTO::PyPathArray);
361361
build_array_type!(LineArray, PythonDTO::PyLineArray);
362362
build_array_type!(LsegArray, PythonDTO::PyLsegArray);
363363
build_array_type!(CircleArray, PythonDTO::PyCircleArray);
364+
build_array_type!(IntervalArray, PythonDTO::PyIntervalArray);
364365

365366
#[allow(clippy::module_name_repetitions)]
366367
#[allow(clippy::missing_errors_doc)]
@@ -410,5 +411,6 @@ pub fn extra_types_module(_py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyRes
410411
pymod.add_class::<LineArray>()?;
411412
pymod.add_class::<LsegArray>()?;
412413
pymod.add_class::<CircleArray>()?;
414+
pymod.add_class::<IntervalArray>()?;
413415
Ok(())
414416
}

src/value_converter.rs

Lines changed: 84 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
22
use geo_types::{coord, Coord, Line as LineSegment, LineString, Point, Rect};
33
use itertools::Itertools;
44
use macaddr::{MacAddr6, MacAddr8};
5+
use pg_interval::Interval;
56
use postgres_types::{Field, FromSql, Kind, ToSql};
67
use rust_decimal::Decimal;
78
use serde_json::{json, Map, Value};
@@ -13,8 +14,8 @@ use postgres_protocol::types;
1314
use pyo3::{
1415
sync::GILOnceCell,
1516
types::{
16-
PyAnyMethods, PyBool, PyBytes, PyDate, PyDateTime, PyDict, PyDictMethods, PyFloat, PyInt,
17-
PyIterator, PyList, PyListMethods, PySequence, PySet, PyString, PyTime, PyTuple, PyType,
17+
PyAnyMethods, PyBool, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyDictMethods, PyFloat,
18+
PyInt, PyList, PyListMethods, PySequence, PySet, PyString, PyTime, PyTuple, PyType,
1819
PyTypeMethods,
1920
},
2021
Bound, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, Python, ToPyObject,
@@ -35,6 +36,7 @@ use crate::{
3536
use postgres_array::{array::Array, Dimension};
3637

3738
static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
39+
static TIMEDELTA_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
3840

3941
pub type QueryParameter = (dyn ToSql + Sync);
4042

@@ -50,6 +52,18 @@ fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
5052
.map(|ty| ty.bind(py))
5153
}
5254

55+
fn get_timedelta_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
56+
TIMEDELTA_CLS
57+
.get_or_try_init(py, || {
58+
let type_object = py
59+
.import_bound("datetime")?
60+
.getattr("timedelta")?
61+
.downcast_into()?;
62+
Ok(type_object.unbind())
63+
})
64+
.map(|ty| ty.bind(py))
65+
}
66+
5367
/// Struct for Uuid.
5468
///
5569
/// We use custom struct because we need to implement external traits
@@ -138,13 +152,43 @@ impl<'a> FromSql<'a> for InnerDecimal {
138152
}
139153
}
140154

155+
struct InnerInterval(Interval);
156+
157+
impl ToPyObject for InnerInterval {
158+
fn to_object(&self, py: Python<'_>) -> PyObject {
159+
let td_cls = get_timedelta_cls(py).expect("failed to load datetime.timedelta");
160+
let pydict = PyDict::new_bound(py);
161+
let months = self.0.months * 30;
162+
let _ = pydict.set_item("days", self.0.days + months);
163+
let _ = pydict.set_item("microseconds", self.0.microseconds);
164+
let ret = td_cls
165+
.call((), Some(&pydict))
166+
.expect("failed to call datetime.timedelta(days=<>, microseconds=<>)");
167+
ret.to_object(py)
168+
}
169+
}
170+
171+
impl<'a> FromSql<'a> for InnerInterval {
172+
fn from_sql(
173+
ty: &Type,
174+
raw: &'a [u8],
175+
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
176+
Ok(InnerInterval(<Interval as FromSql>::from_sql(ty, raw)?))
177+
}
178+
179+
fn accepts(_ty: &Type) -> bool {
180+
true
181+
}
182+
}
183+
141184
/// Additional type for types come from Python.
142185
///
143186
/// It's necessary because we need to pass this
144187
/// enum into `to_sql` method of `ToSql` trait from
145188
/// `postgres` crate.
146189
#[derive(Debug, Clone, PartialEq)]
147190
pub enum PythonDTO {
191+
// Primitive
148192
PyNone,
149193
PyBytes(Vec<u8>),
150194
PyBool(bool),
@@ -164,6 +208,7 @@ pub enum PythonDTO {
164208
PyTime(NaiveTime),
165209
PyDateTime(NaiveDateTime),
166210
PyDateTimeTz(DateTime<FixedOffset>),
211+
PyInterval(Interval),
167212
PyIpAddress(IpAddr),
168213
PyList(Vec<PythonDTO>),
169214
PyArray(Array<PythonDTO>),
@@ -180,6 +225,7 @@ pub enum PythonDTO {
180225
PyLine(Line),
181226
PyLineSegment(LineSegment),
182227
PyCircle(Circle),
228+
// Arrays
183229
PyBoolArray(Array<PythonDTO>),
184230
PyUuidArray(Array<PythonDTO>),
185231
PyVarCharArray(Array<PythonDTO>),
@@ -206,6 +252,7 @@ pub enum PythonDTO {
206252
PyLineArray(Array<PythonDTO>),
207253
PyLsegArray(Array<PythonDTO>),
208254
PyCircleArray(Array<PythonDTO>),
255+
PyIntervalArray(Array<PythonDTO>),
209256
}
210257

211258
impl ToPyObject for PythonDTO {
@@ -267,6 +314,7 @@ impl PythonDTO {
267314
PythonDTO::PyLine(_) => Ok(tokio_postgres::types::Type::LINE_ARRAY),
268315
PythonDTO::PyLineSegment(_) => Ok(tokio_postgres::types::Type::LSEG_ARRAY),
269316
PythonDTO::PyCircle(_) => Ok(tokio_postgres::types::Type::CIRCLE_ARRAY),
317+
PythonDTO::PyInterval(_) => Ok(tokio_postgres::types::Type::INTERVAL_ARRAY),
270318
_ => Err(RustPSQLDriverError::PyToRustValueConversionError(
271319
"Can't process array type, your type doesn't have support yet".into(),
272320
)),
@@ -385,6 +433,9 @@ impl ToSql for PythonDTO {
385433
PythonDTO::PyDateTimeTz(pydatetime_tz) => {
386434
<&DateTime<FixedOffset> as ToSql>::to_sql(&pydatetime_tz, ty, out)?;
387435
}
436+
PythonDTO::PyInterval(pyinterval) => {
437+
<&Interval as ToSql>::to_sql(&pyinterval, ty, out)?;
438+
}
388439
PythonDTO::PyIpAddress(pyidaddress) => {
389440
<&IpAddr as ToSql>::to_sql(&pyidaddress, ty, out)?;
390441
}
@@ -525,6 +576,9 @@ impl ToSql for PythonDTO {
525576
PythonDTO::PyCircleArray(array) => {
526577
array.to_sql(&Type::CIRCLE_ARRAY, out)?;
527578
}
579+
PythonDTO::PyIntervalArray(array) => {
580+
array.to_sql(&Type::INTERVAL_ARRAY, out)?;
581+
}
528582
}
529583

530584
if return_is_null_true {
@@ -787,6 +841,16 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
787841
return Ok(PythonDTO::PyTime(parameter.extract::<NaiveTime>()?));
788842
}
789843

844+
if parameter.is_instance_of::<PyDelta>() {
845+
let duration = parameter.extract::<chrono::Duration>()?;
846+
if let Some(interval) = Interval::from_duration(duration) {
847+
return Ok(PythonDTO::PyInterval(interval));
848+
}
849+
return Err(RustPSQLDriverError::PyToRustValueConversionError(format!(
850+
"Cannot convert timedelta from Python to inner Rust type.",
851+
)));
852+
}
853+
790854
if parameter.is_instance_of::<PyList>() | parameter.is_instance_of::<PyTuple>() {
791855
return Ok(PythonDTO::PyArray(py_sequence_into_postgres_array(
792856
parameter,
@@ -1052,6 +1116,12 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
10521116
._convert_to_python_dto();
10531117
}
10541118

1119+
if parameter.is_instance_of::<extra_types::IntervalArray>() {
1120+
return parameter
1121+
.extract::<extra_types::IntervalArray>()?
1122+
._convert_to_python_dto();
1123+
}
1124+
10551125
if let Ok(id_address) = parameter.extract::<IpAddr>() {
10561126
return Ok(PythonDTO::PyIpAddress(id_address));
10571127
}
@@ -1065,9 +1135,6 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
10651135
}
10661136
}
10671137

1068-
let a = parameter.downcast::<PyIterator>();
1069-
println!("{:?}", a.iter());
1070-
10711138
Err(RustPSQLDriverError::PyToRustValueConversionError(format!(
10721139
"Can not covert you type {parameter} into inner one",
10731140
)))
@@ -1387,6 +1454,13 @@ fn postgres_bytes_to_py(
13871454
None => Ok(py.None().to_object(py)),
13881455
}
13891456
}
1457+
Type::INTERVAL => {
1458+
let interval = _composite_field_postgres_to_py::<Option<Interval>>(type_, buf, is_simple)?;
1459+
if let Some(interval) = interval {
1460+
return Ok(InnerInterval(interval).to_object(py));
1461+
}
1462+
return Ok(py.None())
1463+
}
13901464
// ---------- Array Text Types ----------
13911465
Type::BOOL_ARRAY => Ok(postgres_array_to_py(py, _composite_field_postgres_to_py::<Option<Array<bool>>>(
13921466
type_, buf, is_simple,
@@ -1505,6 +1579,11 @@ fn postgres_bytes_to_py(
15051579

15061580
Ok(postgres_array_to_py(py, circle_array_).to_object(py))
15071581
}
1582+
Type::INTERVAL_ARRAY => {
1583+
let interval_array_ = _composite_field_postgres_to_py::<Option<Array<InnerInterval>>>(type_, buf, is_simple)?;
1584+
1585+
Ok(postgres_array_to_py(py, interval_array_).to_object(py))
1586+
}
15081587
_ => Err(RustPSQLDriverError::RustToPyValueConversionError(
15091588
format!("Cannot convert {type_} into Python type, please look at the custom_decoders functionality.")
15101589
)),

0 commit comments

Comments
 (0)
0