From e60aad94d8efb4d43522ebb5ac2531ce3b8aa513 Mon Sep 17 00:00:00 2001 From: "chandr-andr (Kiselev Aleksandr)" Date: Sat, 17 May 2025 02:58:52 +0200 Subject: [PATCH] Added python Portal class and logic --- src/connection/impls.rs | 47 ++++++- src/driver/connection.rs | 69 +++++---- src/driver/portal.rs | 282 +++++++++++++++++++++++++------------ src/lib.rs | 2 + src/statement/statement.rs | 6 +- src/transaction/impls.rs | 35 +++++ src/transaction/mod.rs | 2 + src/transaction/structs.rs | 7 + 8 files changed, 337 insertions(+), 113 deletions(-) create mode 100644 src/transaction/impls.rs create mode 100644 src/transaction/mod.rs create mode 100644 src/transaction/structs.rs diff --git a/src/connection/impls.rs b/src/connection/impls.rs index ee6bab4b..84683edb 100644 --- a/src/connection/impls.rs +++ b/src/connection/impls.rs @@ -1,15 +1,22 @@ +use std::sync::{Arc, RwLock}; + use bytes::Buf; use pyo3::{PyAny, Python}; -use tokio_postgres::{CopyInSink, Row, Statement, ToStatement}; +use tokio_postgres::{CopyInSink, Portal as tp_Portal, Row, Statement, ToStatement}; use crate::{ + driver::portal::Portal, exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, options::{IsolationLevel, ReadVariant}, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, statement::{statement::PsqlpyStatement, statement_builder::StatementBuilder}, + transaction::structs::PSQLPyTransaction, value_converter::to_python::postgres_to_py, }; +use deadpool_postgres::Transaction as dp_Transaction; +use tokio_postgres::Transaction as tp_Transaction; + use super::{ structs::{PSQLPyConnection, PoolConnection, SingleConnection}, traits::{CloseTransaction, Connection, Cursor, StartTransaction, Transaction}, @@ -516,4 +523,42 @@ impl PSQLPyConnection { } } } + + pub async fn transaction(&mut self) -> PSQLPyResult { + match self { + PSQLPyConnection::PoolConn(conn) => { + let transaction = unsafe { + std::mem::transmute::, dp_Transaction<'static>>( + conn.connection.transaction().await?, + ) + }; + Ok(PSQLPyTransaction::PoolTransaction(transaction)) + } + PSQLPyConnection::SingleConnection(conn) => { + let transaction = unsafe { + std::mem::transmute::, tp_Transaction<'static>>( + conn.connection.transaction().await?, + ) + }; + Ok(PSQLPyTransaction::SingleTransaction(transaction)) + } + } + } + + pub async fn portal( + &mut self, + querystring: String, + parameters: Option>, + ) -> PSQLPyResult<(PSQLPyTransaction, tp_Portal)> { + let statement = StatementBuilder::new(querystring, parameters, self, Some(false)) + .build() + .await?; + + let transaction = self.transaction().await?; + let inner_portal = transaction + .portal(statement.raw_query(), &statement.params()) + .await?; + + Ok((transaction, inner_portal)) + } } diff --git a/src/driver/connection.rs b/src/driver/connection.rs index 9635a836..fa480386 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -18,7 +18,9 @@ use crate::{ runtime::tokio_runtime, }; -use super::{connection_pool::connect_pool, cursor::Cursor, transaction::Transaction}; +use super::{ + connection_pool::connect_pool, cursor::Cursor, portal::Portal, transaction::Transaction, +}; /// Make new connection pool. /// @@ -396,17 +398,16 @@ impl Connection { read_variant: Option, deferrable: Option, ) -> PSQLPyResult { - if let Some(db_client) = &self.db_client { - return Ok(Transaction::new( - Some(db_client.clone()), - self.pg_config.clone(), - isolation_level, - read_variant, - deferrable, - )); - } - - Err(RustPSQLDriverError::ConnectionClosedError) + let Some(conn) = &self.db_client else { + return Err(RustPSQLDriverError::ConnectionClosedError); + }; + Ok(Transaction::new( + Some(conn.clone()), + self.pg_config.clone(), + isolation_level, + read_variant, + deferrable, + )) } /// Create new cursor object. @@ -428,19 +429,39 @@ impl Connection { scroll: Option, prepared: Option, ) -> PSQLPyResult { - if let Some(db_client) = &self.db_client { - return Ok(Cursor::new( - db_client.clone(), - self.pg_config.clone(), - querystring, - parameters, - fetch_number.unwrap_or(10), - scroll, - prepared, - )); - } + let Some(conn) = &self.db_client else { + return Err(RustPSQLDriverError::ConnectionClosedError); + }; + + Ok(Cursor::new( + conn.clone(), + self.pg_config.clone(), + querystring, + parameters, + fetch_number.unwrap_or(10), + scroll, + prepared, + )) + } - Err(RustPSQLDriverError::ConnectionClosedError) + #[pyo3(signature = ( + querystring, + parameters=None, + fetch_number=None, + ))] + pub fn portal( + &self, + querystring: String, + parameters: Option>, + fetch_number: Option, + ) -> PSQLPyResult { + println!("{:?}", fetch_number); + Ok(Portal::new( + self.db_client.clone(), + querystring, + parameters, + fetch_number, + )) } #[allow(clippy::needless_pass_by_value)] diff --git a/src/driver/portal.rs b/src/driver/portal.rs index 0c280637..f6b4d755 100644 --- a/src/driver/portal.rs +++ b/src/driver/portal.rs @@ -1,87 +1,195 @@ -// use std::sync::Arc; - -// use pyo3::{pyclass, pymethods, Py, PyObject, Python}; -// use tokio_postgres::Portal as tp_Portal; - -// use crate::{ -// exceptions::rust_errors::PSQLPyResult, query_result::PSQLDriverPyQueryResult, -// runtime::rustdriver_future, -// }; - -// use super::inner_transaction::PsqlpyTransaction; - -// #[pyclass] -// pub struct Portal { -// transaction: Arc, -// inner: tp_Portal, -// array_size: i32, -// } - -// impl Portal { -// pub fn new(transaction: Arc, inner: tp_Portal, array_size: i32) -> Self { -// Self { -// transaction, -// inner, -// array_size, -// } -// } - -// async fn query_portal(&self, size: i32) -> PSQLPyResult { -// let result = self.transaction.query_portal(&self.inner, size).await?; -// Ok(PSQLDriverPyQueryResult::new(result)) -// } -// } - -// #[pymethods] -// impl Portal { -// #[getter] -// fn get_array_size(&self) -> i32 { -// self.array_size -// } - -// #[setter] -// fn set_array_size(&mut self, value: i32) { -// self.array_size = value; -// } - -// fn __aiter__(slf: Py) -> Py { -// slf -// } - -// fn __await__(slf: Py) -> Py { -// slf -// } - -// fn __anext__(&self) -> PSQLPyResult> { -// let transaction = self.transaction.clone(); -// let portal = self.inner.clone(); -// let size = self.array_size.clone(); - -// let py_future = Python::with_gil(move |gil| { -// rustdriver_future(gil, async move { -// let result = transaction.query_portal(&portal, size).await?; - -// Ok(PSQLDriverPyQueryResult::new(result)) -// }) -// }); - -// Ok(Some(py_future?)) -// } - -// async fn fetch_one(&self) -> PSQLPyResult { -// self.query_portal(1).await -// } - -// #[pyo3(signature = (size=None))] -// async fn fetch_many(&self, size: Option) -> PSQLPyResult { -// self.query_portal(size.unwrap_or(self.array_size)).await -// } - -// async fn fetch_all(&self) -> PSQLPyResult { -// self.query_portal(-1).await -// } - -// async fn close(&mut self) { -// let _ = Arc::downgrade(&self.transaction); -// } -// } +use std::sync::Arc; + +use pyo3::{ + exceptions::PyStopAsyncIteration, pyclass, pymethods, Py, PyAny, PyErr, PyObject, Python, +}; +use tokio::sync::RwLock; +use tokio_postgres::Portal as tp_Portal; + +use crate::{ + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, + query_result::PSQLDriverPyQueryResult, + runtime::rustdriver_future, + transaction::structs::PSQLPyTransaction, +}; + +use crate::connection::structs::PSQLPyConnection; + +#[pyclass] +pub struct Portal { + conn: Option>>, + querystring: String, + parameters: Option>, + array_size: i32, + + transaction: Option>, + inner: Option, +} + +impl Portal { + pub fn new( + conn: Option>>, + querystring: String, + parameters: Option>, + array_size: Option, + ) -> Self { + Self { + conn, + transaction: None, + inner: None, + querystring, + parameters, + array_size: array_size.unwrap_or(1), + } + } + + async fn query_portal(&self, size: i32) -> PSQLPyResult { + let Some(transaction) = &self.transaction else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; + let Some(portal) = &self.inner else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; + transaction.query_portal(&portal, size).await + } +} + +impl Drop for Portal { + fn drop(&mut self) { + self.transaction = None; + self.conn = None; + } +} + +#[pymethods] +impl Portal { + #[getter] + fn get_array_size(&self) -> i32 { + self.array_size + } + + #[setter] + fn set_array_size(&mut self, value: i32) { + self.array_size = value; + } + + fn __aiter__(slf: Py) -> Py { + slf + } + + fn __await__(slf: Py) -> Py { + slf + } + + async fn __aenter__<'a>(slf: Py) -> PSQLPyResult> { + let (conn, querystring, parameters) = Python::with_gil(|gil| { + let self_ = slf.borrow(gil); + ( + self_.conn.clone(), + self_.querystring.clone(), + self_.parameters.clone(), + ) + }); + + let Some(conn) = conn else { + return Err(RustPSQLDriverError::CursorClosedError); + }; + let mut write_conn_g = conn.write().await; + + let (txid, inner_portal) = write_conn_g.portal(querystring, parameters).await?; + + Python::with_gil(|gil| { + let mut self_ = slf.borrow_mut(gil); + + self_.transaction = Some(Arc::new(txid)); + self_.inner = Some(inner_portal); + }); + + Ok(slf) + } + + #[allow(clippy::needless_pass_by_value)] + async fn __aexit__<'a>( + &mut self, + _exception_type: Py, + exception: Py, + _traceback: Py, + ) -> PSQLPyResult<()> { + self.close(); + + let (is_exc_none, py_err) = pyo3::Python::with_gil(|gil| { + ( + exception.is_none(gil), + PyErr::from_value(exception.into_bound(gil)), + ) + }); + + if !is_exc_none { + return Err(RustPSQLDriverError::RustPyError(py_err)); + } + Ok(()) + } + + fn __anext__(&self) -> PSQLPyResult> { + let txid = self.transaction.clone(); + let portal = self.inner.clone(); + let size = self.array_size.clone(); + + let py_future = Python::with_gil(move |gil| { + rustdriver_future(gil, async move { + let Some(txid) = &txid else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; + let Some(portal) = &portal else { + return Err(RustPSQLDriverError::TransactionClosedError); + }; + let result = txid.query_portal(&portal, size).await?; + + if result.is_empty() { + return Err(PyStopAsyncIteration::new_err( + "Iteration is over, no more results in portal", + ) + .into()); + }; + + Ok(result) + }) + }); + + Ok(Some(py_future?)) + } + + async fn start(&mut self) -> PSQLPyResult<()> { + let Some(conn) = &self.conn else { + return Err(RustPSQLDriverError::ConnectionClosedError); + }; + let mut write_conn_g = conn.write().await; + + let (txid, inner_portal) = write_conn_g + .portal(self.querystring.clone(), self.parameters.clone()) + .await?; + + self.transaction = Some(Arc::new(txid)); + self.inner = Some(inner_portal); + + Ok(()) + } + + async fn fetch_one(&self) -> PSQLPyResult { + self.query_portal(1).await + } + + #[pyo3(signature = (size=None))] + async fn fetch_many(&self, size: Option) -> PSQLPyResult { + self.query_portal(size.unwrap_or(self.array_size)).await + } + + async fn fetch_all(&self) -> PSQLPyResult { + self.query_portal(-1).await + } + + fn close(&mut self) { + self.transaction = None; + self.conn = None; + } +} diff --git a/src/lib.rs b/src/lib.rs index 0eaac910..3229e675 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub mod query_result; pub mod row_factories; pub mod runtime; pub mod statement; +pub mod transaction; pub mod value_converter; use common::add_module; @@ -35,6 +36,7 @@ fn psqlpy(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> { pymod.add_function(wrap_pyfunction!(driver::connection::connect, pymod)?)?; pymod.add_class::()?; pymod.add_class::()?; + pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; diff --git a/src/statement/statement.rs b/src/statement/statement.rs index addaae89..fd77eb55 100644 --- a/src/statement/statement.rs +++ b/src/statement/statement.rs @@ -32,7 +32,11 @@ impl PsqlpyStatement { pub fn statement_query(&self) -> PSQLPyResult<&Statement> { match &self.prepared_statement { Some(prepared_stmt) => return Ok(prepared_stmt), - None => return Err(RustPSQLDriverError::ConnectionExecuteError("No".into())), + None => { + return Err(RustPSQLDriverError::ConnectionExecuteError( + "No prepared parameters".into(), + )) + } } } diff --git a/src/transaction/impls.rs b/src/transaction/impls.rs new file mode 100644 index 00000000..a2a7c147 --- /dev/null +++ b/src/transaction/impls.rs @@ -0,0 +1,35 @@ +use crate::{exceptions::rust_errors::PSQLPyResult, query_result::PSQLDriverPyQueryResult}; + +use super::structs::PSQLPyTransaction; +use tokio_postgres::{Portal as tp_Portal, ToStatement}; + +impl PSQLPyTransaction { + pub async fn query_portal( + &self, + portal: &tp_Portal, + size: i32, + ) -> PSQLPyResult { + let portal_res = match self { + PSQLPyTransaction::PoolTransaction(txid) => txid.query_portal(portal, size).await?, + PSQLPyTransaction::SingleTransaction(txid) => txid.query_portal(portal, size).await?, + }; + + Ok(PSQLDriverPyQueryResult::new(portal_res)) + } + + pub async fn portal( + &self, + querystring: &T, + params: &[&(dyn postgres_types::ToSql + Sync)], + ) -> PSQLPyResult + where + T: ?Sized + ToStatement, + { + let portal: tp_Portal = match self { + PSQLPyTransaction::PoolTransaction(conn) => conn.bind(querystring, params).await?, + PSQLPyTransaction::SingleTransaction(conn) => conn.bind(querystring, params).await?, + }; + + Ok(portal) + } +} diff --git a/src/transaction/mod.rs b/src/transaction/mod.rs new file mode 100644 index 00000000..4bc01193 --- /dev/null +++ b/src/transaction/mod.rs @@ -0,0 +1,2 @@ +pub mod impls; +pub mod structs; diff --git a/src/transaction/structs.rs b/src/transaction/structs.rs new file mode 100644 index 00000000..0f8946cd --- /dev/null +++ b/src/transaction/structs.rs @@ -0,0 +1,7 @@ +use deadpool_postgres::Transaction as dp_Transaction; +use tokio_postgres::Transaction as tp_Transaction; + +pub enum PSQLPyTransaction { + PoolTransaction(dp_Transaction<'static>), + SingleTransaction(tp_Transaction<'static>), +}