8000 Merge pull request #134 from psqlpy-python/support_dbapi · psqlpy-python/psqlpy@64b8d3b · GitHub
[go: up one dir, main page]

Skip to content

Commit 64b8d3b

Browse files
authored
Merge pull request #134 from psqlpy-python/support_dbapi
Added inner transaction impl
2 parents 6e0e0c0 + 519c4d4 commit 64b8d3b

File tree

5 files changed

+176
-5
lines changed

5 files changed

+176
-5
lines changed

src/driver/connection.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,16 @@ use crate::{
99
exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError},
1010
format_helpers::quote_ident,
1111
query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult},
12-
runtime::{rustdriver_future, tokio_runtime},
12+
runtime::tokio_runtime,
1313
};
1414

1515
use super::{
1616
common_options::{LoadBalanceHosts, SslMode, TargetSessionAttrs},
17-
connection_pool::{connect_pool, ConnectionPool},
17+
connection_pool::connect_pool,
1818
cursor::Cursor,
1919
inner_connection::PsqlpyConnection,
2020
transaction::Transaction,
2121
transaction_options::{IsolationLevel, ReadVariant, SynchronousCommit},
22-
utils::build_connection_config,
2322
};
2423

2524
/// Make new connection pool.

src/driver/inner_connection.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use bytes::Buf;
2-
use deadpool_postgres::Object;
2+
use deadpool_postgres::{Object, Transaction};
33
use postgres_types::{ToSql, Type};
4-
use pyo3::{Py, PyAny, Python};
4+
use pyo3::{pyclass, Py, PyAny, Python};
55
use std::vec;
66
use tokio_postgres::{Client, CopyInSink, Row, Statement, ToStatement};
77

@@ -18,6 +18,11 @@ pub enum PsqlpyConnection {
1818
SingleConn(Client),
1919
}
2020

21+
// #[pyclass]
22+
// struct Portal {
23+
// trans: Transaction<'static>,
24+
// }
25+
2126
impl PsqlpyConnection {
2227
/// Prepare cached statement.
2328
///
@@ -38,6 +43,25 @@ impl PsqlpyConnection {
3843
}
3944
}
4045

46+
// pub async fn transaction(&mut self) -> Portal {
47+
// match self {
48+
// PsqlpyConnection::PoolConn(pconn, _) => {
49+
// let b = unsafe {
50+
// std::mem::transmute::<Transaction<'_>, Transaction<'static>>(pconn.transaction().await.unwrap())
51+
// };
52+
// Portal {trans: b}
53+
// // let c = b.bind("SELECT 1", &[]).await.unwrap();
54+
// // b.query_portal(&c, 1).await;
55+
// }
56+
// PsqlpyConnection::SingleConn(sconn) => {
57+
// let b = unsafe {
58+
// std::mem::transmute::<Transaction<'_>, Transaction<'static>>(sconn.transaction().await.unwrap())
59+
// };
60+
// Portal {trans: b}
61+
// },
62+
// }
63+
// }
64+
4165
/// Delete prepared statement.
4266
///
4367
/// # Errors

src/driver/inner_transaction.rs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
use deadpool_postgres::Transaction as dp_Transaction;
2+
use postgres_types::ToSql;
3+
use tokio_postgres::{Portal, Row, ToStatement, Transaction as tp_Transaction};
4+
5+
use crate::exceptions::rust_errors::PSQLPyResult;
6+
7+
pub enum PsqlpyTransaction {
8+
PoolTrans(dp_Transaction<'static>),
9+
SingleConnTrans(tp_Transaction<'static>),
10+
}
11+
12+
impl PsqlpyTransaction {
13+
async fn commit(self) -> PSQLPyResult<()> {
14+
match self {
15+
PsqlpyTransaction::PoolTrans(p_txid) => Ok(p_txid.commit().await?),
16+
PsqlpyTransaction::SingleConnTrans(s_txid) => Ok(s_txid.commit().await?),
17+
}
18+
}
19+
20+
async fn rollback(self) -> PSQLPyResult<()> {
21+
match self {
22+
PsqlpyTransaction::PoolTrans(p_txid) => Ok(p_txid.rollback().await?),
23+
PsqlpyTransaction::SingleConnTrans(s_txid) => Ok(s_txid.rollback().await?),
24+
}
25+
}
26+
27+
async fn savepoint(&mut self, sp_name: &str) -> PSQLPyResult<()> {
28+
match self {
29+
PsqlpyTransaction::PoolTrans(p_txid) => {
30+
p_txid.savepoint(sp_name).await?;
31+
Ok(())
32+
}
33+
PsqlpyTransaction::SingleConnTrans(s_txid) => {
34+
s_txid.savepoint(sp_name).await?;
35+
Ok(())
36+
}
37+
}
38+
}
39+
40+
async fn release_savepoint(&self, sp_name: &str) -> PSQLPyResult<()> {
41+
match self {
42+
PsqlpyTransaction::PoolTrans(p_txid) => {
43+
p_txid
44+
.batch_execute(format!("RELEASE SAVEPOINT {sp_name}").as_str())
45+
.await?;
46+
Ok(())
47+
}
48+
PsqlpyTransaction::SingleConnTrans(s_txid) => {
49+
s_txid
50+
.batch_execute(format!("RELEASE SAVEPOINT {sp_name}").as_str())
51+
.await?;
52+
Ok(())
53+
}
54+
}
55+
}
56+
57+
async fn rollback_savepoint(&self, sp_name: &str) -> PSQLPyResult<()> {
58+
match self {
59+
PsqlpyTransaction::PoolTrans(p_txid) => {
60+
p_txid
61+
.batch_execute(format!("ROLLBACK TO SAVEPOINT {sp_name}").as_str())
62+
.await?;
63+
Ok(())
64+
}
65+
PsqlpyTransaction::SingleConnTrans(s_txid) => {
66+
s_txid
67+
.batch_execute(format!("ROLLBACK TO SAVEPOINT {sp_name}").as_str())
68+
.await?;
69+
Ok(())
70+
}
71+
}
72+
}
73+
74+
async fn bind<T>(&self, statement: &T, params: &[&(dyn ToSql + Sync)]) -> PSQLPyResult<Portal>
75+
where
76+
T: ?Sized + ToStatement,
77+
{
78+
match self {
79+
PsqlpyTransaction::PoolTrans(p_txid) => Ok(p_txid.bind(statement, params).await?),
80+
PsqlpyTransaction::SingleConnTrans(s_txid) => {
81+
Ok(s_txid.bind(statement, params).await?)
82+
}
83+
}
84+
}
85+
86+
pub async fn query_portal(&self, portal: &Portal, size: i32) -> PSQLPyResult<Vec<Row>> {
87+
match self {
88+
PsqlpyTransaction::PoolTrans(p_txid) => Ok(p_txid.query_portal(portal, size).await?),
89+
PsqlpyTransaction::SingleConnTrans(s_txid) => {
90+
Ok(s_txid.query_portal(portal, size).await?)
91+
}
92+
}
93+
}
94+
}

src/driver/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ pub mod connection_pool;
44
pub mod connection_pool_builder;
55
pub mod cursor;
66
pub mod inner_connection;
7+
pub mod inner_transaction;
78
pub mod listener;
9+
pub mod portal;
810
pub mod transaction;
911
pub mod transaction_options;
1012
pub mod utils;

src/driver/portal.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
use std::sync::Arc;
2+
3+
use pyo3::{pyclass, pymethods};
4+
use tokio_postgres::Portal as tp_Portal;
5+
6+
use crate::{exceptions::rust_errors::PSQLPyResult, query_result::PSQLDriverPyQueryResult};
7+
8+
use super::inner_transaction::PsqlpyTransaction;
9+
10+
#[pyclass]
11+
struct Portal {
12+
transaction: Arc<PsqlpyTransaction>,
13+
inner: tp_Portal,
14+
array_size: i32,
15+
}
16+
17+
impl Portal {
18+
async fn query_portal(&self, size: i32) -> PSQLPyResult<PSQLDriverPyQueryResult> {
19+
let result = self.transaction.query_portal(&self.inner, size).await?;
20+
Ok(PSQLDriverPyQueryResult::new(result))
21+
}
22+
}
23+
24+
#[pymethods]
25+
impl Portal {
26+
#[getter]
27+
fn get_array_size(&self) -> i32 {
28+
self.array_size
29+
}
30+
31+
#[setter]
32+
fn set_array_size(&mut self, value: i32) {
33+
self.array_size = value;
34+
}
35+
36+
async fn fetch_one(&self) -> PSQLPyResult<PSQLDriverPyQueryResult> {
37+
self.query_portal(1).await
38+
}
39+
40+
#[pyo3(signature = (size=None))]
41+
async fn fetch_many(&self, size: Option<i32>) -> PSQLPyResult<PSQLDriverPyQueryResult> {
42+
self.query_portal(size.unwrap_or(self.array_size)).await
43+
}
44+
45+
async fn fetch_all(&self) -> PSQLPyResult<PSQLDriverPyQueryResult> {
46+
self.query_portal(-1).await
47+
}
48+
49+
async fn close(&mut self) {
50+
let _ = Arc::downgrade(&self.transaction);
51+
}
52+
}

0 commit comments

Comments
 (0)
0