8000 RustPython/stdlib/src/sqlite.rs at main · MichaReiser/RustPython · GitHub
[go: up one dir, main page]

Skip to content

Latest commit

 

History

History
2998 lines (2659 loc) · 101 KB

File metadata and controls

2998 lines (2659 loc) · 101 KB
// spell-checker:ignore libsqlite3 threadsafety PYSQLITE decltypes colnames collseq cantinit dirtywal
// spell-checker:ignore corruptfs narg setinputsizes setoutputsize lastrowid arraysize executemany
// spell-checker:ignore blobopen executescript iterdump getlimit setlimit errorcode errorname
// spell-checker:ignore rowid rowcount fetchone fetchmany fetchall errcode errname vtable pagecount
// spell-checker:ignore autocommit libversion toobig errmsg nomem threadsafe longlong vdbe reindex
// spell-checker:ignore savepoint cantopen ioerr nolfs nomem notadb notfound fullpath notempdir vtab
// spell-checker:ignore checkreservedlock noent fstat rdlock shmlock shmmap shmopen shmsize sharedcache
// spell-checker:ignore cantlock commithook foreignkey notnull primarykey gettemppath autoindex convpath
// spell-checker:ignore dbmoved vnode nbytes
use rustpython_vm::{builtins::PyModule, AsObject, PyRef, VirtualMachine};
// pub(crate) use _sqlite::make_module;
pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef<PyModule> {
// TODO: sqlite version check
let module = _sqlite::make_module(vm);
_sqlite::setup_module(module.as_object(), vm);
module
}
#[pymodule]
mod _sqlite {
use libsqlite3_sys::{
sqlite3, sqlite3_aggregate_context, sqlite3_backup_finish, sqlite3_backup_init,
sqlite3_backup_pagecount, sqlite3_backup_remaining, sqlite3_backup_step, sqlite3_bind_bl B94A ob,
sqlite3_bind_double, sqlite3_bind_int64, sqlite3_bind_null, sqlite3_bind_parameter_count,
sqlite3_bind_parameter_name, sqlite3_bind_text, sqlite3_blob, sqlite3_blob_bytes,
sqlite3_blob_close, sqlite3_blob_open, sqlite3_blob_read, sqlite3_blob_write,
sqlite3_busy_timeout, sqlite3_changes, sqlite3_close_v2, sqlite3_column_blob,
sqlite3_column_bytes, sqlite3_column_count, sqlite3_column_decltype, sqlite3_column_double,
sqlite3_column_int64, sqlite3_column_name, sqlite3_column_text, sqlite3_column_type,
sqlite3_complete, sqlite3_context, sqlite3_context_db_handle, sqlite3_create_collation_v2,
sqlite3_create_function_v2, sqlite3_create_window_function, sqlite3_data_count,
sqlite3_db_handle, sqlite3_errcode, sqlite3_errmsg, sqlite3_exec, sqlite3_expanded_sql,
sqlite3_extended_errcode, sqlite3_finalize, sqlite3_get_autocommit, sqlite3_interrupt,
sqlite3_last_insert_rowid, sqlite3_libversion, sqlite3_limit, sqlite3_open_v2,
sqlite3_prepare_v2, sqlite3_progress_handler, sqlite3_reset, sqlite3_result_blob,
sqlite3_result_double, sqlite3_result_error, sqlite3_result_error_nomem,
sqlite3_result_error_toobig, sqlite3_result_int64, sqlite3_result_null,
sqlite3_result_text, sqlite3_set_authorizer, sqlite3_sleep, sqlite3_step, sqlite3_stmt,
sqlite3_stmt_busy, sqlite3_stmt_readonly, sqlite3_threadsafe, sqlite3_total_changes,
sqlite3_trace_v2, sqlite3_user_data, sqlite3_value, sqlite3_value_blob,
sqlite3_value_bytes, sqlite3_value_double, sqlite3_value_int64, sqlite3_value_text,
sqlite3_value_type, SQLITE_BLOB, SQLITE_DETERMINISTIC, SQLITE_FLOAT, SQLITE_INTEGER,
SQLITE_NULL, SQLITE_OPEN_CREATE, SQLITE_OPEN_READWRITE, SQLITE_OPEN_URI, SQLITE_TEXT,
SQLITE_TRACE_STMT, SQLITE_TRANSIENT, SQLITE_UTF8,
};
use rustpython_common::{
atomic::{Ordering, PyAtomic, Radium},
hash::PyHash,
lock::{PyMappedMutexGuard, PyMutex, PyMutexGuard},
static_cell,
};
use rustpython_vm::{
atomic_func,
builtins::{
PyBaseException, PyBaseExceptionRef, PyByteArray, PyBytes, PyDict, PyDictRef, PyFloat,
PyInt, PyIntRef, PySlice, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef,
},
convert::IntoObject,
function::{ArgCallable, ArgIterable, FsPath, FuncArgs, OptionalArg, PyComparisonValue},
protocol::{PyBuffer, PyIterReturn, PyMappingMethods, PySequence, PySequenceMethods},
sliceable::{SaturatedSliceIter, SliceableSequenceOp},
types::{
AsMapping, AsSequence, Callable, Comparable, Constructor, Hashable, IterNext, Iterable,
PyComparisonOp, SelfIter,
},
utils::ToCString,
AsObject, Py, PyAtomicRef, PyObject, PyObjectRef, PyPayload, PyRef, PyResult,
TryFromBorrowedObject, VirtualMachine,
__exports::paste,
object::{Traverse, TraverseFn},
};
use std::{
ffi::{c_int, c_longlong, c_uint, c_void, CStr},
fmt::Debug,
ops::Deref,
ptr::{addr_of_mut, null, null_mut},
thread::ThreadId,
};
macro_rules! exceptions {
($(($x:ident, $base:expr)),*) => {
paste::paste! {
static_cell! {
$(
static [<$x:snake:upper>]: PyTypeRef;
)*
}
$(
#[allow(dead_code)]
fn [<new_ $x:snake>](vm: &VirtualMachine, msg: String) -> PyBaseExceptionRef {
vm.new_exception_msg([<$x:snake _type>]().to_owned(), msg)
}
fn [<$x:snake _type>]() -> &'static Py<PyType> {
[<$x:snake:upper>].get().expect("exception type not initialize")
}
)*
fn setup_module_exceptions(module: &PyObject, vm: &VirtualMachine) {
$(
let exception = [<$x:snake:upper>].get_or_init(
|| vm.ctx.new_exception_type("_sqlite3", stringify!($x), Some(vec![$base(vm).to_owned()])));
module.set_attr(stringify!($x), exception.clone().into_object(), vm).unwrap();
)*
}
}
};
}
exceptions!(
(Warning, |vm: &VirtualMachine| vm
.ctx
.exceptions
.exception_type),
(Error, |vm: &VirtualMachine| vm
.ctx
.exceptions
.exception_type),
(InterfaceError, |_| error_type()),
(DatabaseError, |_| error_type()),
(DataError, |_| database_error_type()),
(OperationalError, |_| database_error_type()),
(IntegrityError, |_| database_error_type()),
(InternalError, |_| database_error_type()),
(ProgrammingError, |_| database_error_type()),
(NotSupportedError, |_| database_error_type())
);
#[pyattr]
fn sqlite_version(vm: &VirtualMachine) -> String {
let s = unsafe { sqlite3_libversion() };
ptr_to_str(s, vm).unwrap().to_owned()
}
#[pyattr]
fn threadsafety(_: &VirtualMachine) -> c_int {
let mode = unsafe { sqlite3_threadsafe() };
match mode {
0 => 0,
1 => 3,
2 => 1,
_ => panic!("Unable to interpret SQLite threadsafety mode"),
}
}
#[pyattr(name = "_deprecated_version")]
const PYSQLITE_VERSION: &str = "2.6.0";
#[pyattr]
const PARSE_DECLTYPES: c_int = 1;
#[pyattr]
const PARSE_COLNAMES: c_int = 2;
#[pyattr]
use libsqlite3_sys::{
SQLITE_ALTER_TABLE, SQLITE_ANALYZE, SQLITE_ATTACH, SQLITE_CREATE_INDEX,
SQLITE_CREATE_TABLE, SQLITE_CREATE_TEMP_INDEX, SQLITE_CREATE_TEMP_TABLE,
SQLITE_CREATE_TEMP_TRIGGER, SQLITE_CREATE_TEMP_VIEW, SQLITE_CREATE_TRIGGER,
SQLITE_CREATE_VIEW, SQLITE_CREATE_VTABLE, SQLITE_DELETE, SQLITE_DENY, SQLITE_DETACH,
SQLITE_DROP_INDEX, SQLITE_DROP_TABLE, SQLITE_DROP_TEMP_INDEX, SQLITE_DROP_TEMP_TABLE,
SQLITE_DROP_TEMP_TRIGGER, SQLITE_DROP_TEMP_VIEW, SQLITE_DROP_TRIGGER, SQLITE_DROP_VIEW,
SQLITE_DROP_VTABLE, SQLITE_FUNCTION, SQLITE_IGNORE, SQLITE_INSERT, SQLITE_LIMIT_ATTACHED,
SQLITE_LIMIT_COLUMN, SQLITE_LIMIT_COMPOUND_SELECT, SQLITE_LIMIT_EXPR_DEPTH,
SQLITE_LIMIT_FUNCTION_ARG, SQLITE_LIMIT_LENGTH, SQLITE_LIMIT_LIKE_PATTERN_LENGTH,
SQLITE_LIMIT_SQL_LENGTH, SQLITE_LIMIT_TRIGGER_DEPTH, SQLITE_LIMIT_VARIABLE_NUMBER,
SQLITE_LIMIT_VDBE_OP, SQLITE_LIMIT_WORKER_THREADS, SQLITE_PRAGMA, SQLITE_READ,
SQLITE_RECURSIVE, SQLITE_REINDEX, SQLITE_SAVEPOINT, SQLITE_SELECT, SQLITE_TRANSACTION,
SQLITE_UPDATE,
};
macro_rules! error_codes {
($($x:ident),*) => {
$(
#[allow(unused_imports)]
use libsqlite3_sys::$x;
)*
static ERROR_CODES: &[(&str, c_int)] = &[
$(
(stringify!($x), libsqlite3_sys::$x),
)*
];
};
}
error_codes!(
SQLITE_ABORT,
SQLITE_AUTH,
SQLITE_BUSY,
SQLITE_CANTOPEN,
SQLITE_CONSTRAINT,
SQLITE_CORRUPT,
SQLITE_DONE,
SQLITE_EMPTY,
SQLITE_ERROR,
SQLITE_FORMAT,
SQLITE_FULL,
SQLITE_INTERNAL,
SQLITE_INTERRUPT,
SQLITE_IOERR,
SQLITE_LOCKED,
SQLITE_MISMATCH,
SQLITE_MISUSE,
SQLITE_NOLFS,
SQLITE_NOMEM,
< 5276 /div>
SQLITE_NOTADB,
SQLITE_NOTFOUND,
SQLITE_OK,
SQLITE_PERM,
SQLITE_PROTOCOL,
SQLITE_RANGE,
SQLITE_READONLY,
SQLITE_ROW,
SQLITE_SCHEMA,
SQLITE_TOOBIG,
SQLITE_NOTICE,
SQLITE_WARNING,
SQLITE_ABORT_ROLLBACK,
SQLITE_BUSY_RECOVERY,
SQLITE_CANTOPEN_FULLPATH,
SQLITE_CANTOPEN_ISDIR,
SQLITE_CANTOPEN_NOTEMPDIR,
SQLITE_CORRUPT_VTAB,
SQLITE_IOERR_ACCESS,
SQLITE_IOERR_BLOCKED,
SQLITE_IOERR_CHECKRESERVEDLOCK,
SQLITE_IOERR_CLOSE,
SQLITE_IOERR_DELETE,
SQLITE_IOERR_DELETE_NOENT,
SQLITE_IOERR_DIR_CLOSE,
SQLITE_IOERR_DIR_FSYNC,
SQLITE_IOERR_FSTAT,
SQLITE_IOERR_FSYNC,
SQLITE_IOERR_LOCK,
SQLITE_IOERR_NOMEM,
SQLITE_IOERR_RDLOCK,
SQLITE_IOERR_READ,
SQLITE_IOERR_SEEK,
SQLITE_IOERR_SHMLOCK,
SQLITE_IOERR_SHMMAP,
SQLITE_IOERR_SHMOPEN,
SQLITE_IOERR_SHMSIZE,
SQLITE_IOERR_SHORT_READ,
SQLITE_IOERR_TRUNCATE,
SQLITE_IOERR_UNLOCK,
SQLITE_IOERR_WRITE,
SQLITE_LOCKED_SHAREDCACHE,
SQLITE_READONLY_CANTLOCK,
SQLITE_READONLY_RECOVERY,
SQLITE_CONSTRAINT_CHECK,
SQLITE_CONSTRAINT_COMMITHOOK,
SQLITE_CONSTRAINT_FOREIGNKEY,
SQLITE_CONSTRAINT_FUNCTION,
SQLITE_CONSTRAINT_NOTNULL,
SQLITE_CONSTRAINT_PRIMARYKEY,
SQLITE_CONSTRAINT_TRIGGER,
SQLITE_CONSTRAINT_UNIQUE,
SQLITE_CONSTRAINT_VTAB,
SQLITE_READONLY_ROLLBACK,
SQLITE_IOERR_MMAP,
SQLITE_NOTICE_RECOVER_ROLLBACK,
SQLITE_NOTICE_RECOVER_WAL,
SQLITE_BUSY_SNAPSHOT,
SQLITE_IOERR_GETTEMPPATH,
SQLITE_WARNING_AUTOINDEX,
SQLITE_CANTOPEN_CONVPATH,
SQLITE_IOERR_CONVPATH,
SQLITE_CONSTRAINT_ROWID,
SQLITE_READONLY_DBMOVED,
SQLITE_AUTH_USER,
SQLITE_OK_LOAD_PERMANENTLY,
SQLITE_IOERR_VNODE,
SQLITE_IOERR_AUTH,
SQLITE_IOERR_BEGIN_ATOMIC,
SQLITE_IOERR_COMMIT_ATOMIC,
SQLITE_IOERR_ROLLBACK_ATOMIC,
SQLITE_ERROR_MISSING_COLLSEQ,
SQLITE_ERROR_RETRY,
SQLITE_READONLY_CANTINIT,
SQLITE_READONLY_DIRECTORY,
SQLITE_CORRUPT_SEQUENCE,
SQLITE_LOCKED_VTAB,
SQLITE_CANTOPEN_DIRTYWAL,
SQLITE_ERROR_SNAPSHOT,
SQLITE_CANTOPEN_SYMLINK,
SQLITE_CONSTRAINT_PINNED,
SQLITE_OK_SYMLINK,
SQLITE_BUSY_TIMEOUT,
SQLITE_CORRUPT_INDEX,
SQLITE_IOERR_DATA,
SQLITE_IOERR_CORRUPTFS
);
#[derive(FromArgs)]
struct ConnectArgs {
#[pyarg(any)]
database: FsPath,
#[pyarg(any, default = "5.0")]
timeout: f64,
#[pyarg(any, default = "0")]
detect_types: c_int,
#[pyarg(any, default = "Some(vm.ctx.empty_str.to_owned())")]
isolation_level: Option<PyStrRef>,
#[pyarg(any, default = "true")]
check_same_thread: bool,
#[pyarg(any, default = "Connection::class(&vm.ctx).to_owned()")]
factory: PyTypeRef,
// TODO: cache statements
#[allow(dead_code)]
#[pyarg(any, default = "0")]
cached_statements: c_int,
#[pyarg(any, default = "false")]
uri: bool,
}
unsafe impl Traverse for ConnectArgs {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.isolation_level.traverse(tracer_fn);
self.factory.traverse(tracer_fn);
}
}
#[derive(FromArgs)]
struct BackupArgs {
#[pyarg(any)]
target: PyRef<Connection>,
#[pyarg(named, default = "-1")]
pages: c_int,
#[pyarg(named, optional)]
progress: Option<ArgCallable>,
#[pyarg(named, optional)]
name: Option<PyStrRef>,
#[pyarg(named, default = "0.250")]
sleep: f64,
}
unsafe impl Traverse for BackupArgs {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.progress.traverse(tracer_fn);
self.name.traverse(tracer_fn);
}
}
#[derive(FromArgs)]
struct CreateFunctionArgs {
#[pyarg(any)]
name: PyStrRef,
#[pyarg(any)]
narg: c_int,
#[pyarg(any)]
func: PyObjectRef,
#[pyarg(named, default)]
deterministic: bool,
}
#[derive(FromArgs)]
struct CreateAggregateArgs {
#[pyarg(any)]
name: PyStrRef,
#[pyarg(positional)]
narg: c_int,
#[pyarg(positional)]
aggregate_class: PyObjectRef,
}
#[derive(FromArgs)]
struct BlobOpenArgs {
#[pyarg(positional)]
table: PyStrRef,
#[pyarg(positional)]
column: PyStrRef,
#[pyarg(positional)]
row: i64,
#[pyarg(named, default)]
readonly: bool,
#[pyarg(named, default = "vm.ctx.new_str(stringify!(main))")]
name: PyStrRef,
}
struct CallbackData {
obj: *const PyObject,
vm: *const VirtualMachine,
}
impl CallbackData {
fn new(obj: PyObjectRef, vm: &VirtualMachine) -> Option<Self> {
(!vm.is_none(&obj)).then_some(Self {
obj: obj.into_raw(),
vm,
})
}
fn retrieve(&self) -> (&PyObject, &VirtualMachine) {
unsafe { (&*self.obj, &*self.vm) }
}
unsafe extern "C" fn destructor(data: *mut c_void) {
drop(Box::from_raw(data.cast::<Self>()));
}
unsafe extern "C" fn func_callback(
context: *mut sqlite3_context,
argc: c_int,
argv: *mut *mut sqlite3_value,
) {
let context = SqliteContext::from(context);
let (func, vm) = (*context.user_data::<Self>()).retrieve();
let args = std::slice::from_raw_parts(argv, argc as usize);
let f = || -> PyResult<()> {
let db = context.db_handle();
let args = args
.iter()
.cloned()
.map(|val| value_to_object(val, db, vm))
.collect::<PyResult<Vec<PyObjectRef>>>()?;
let val = func.call(args, vm)?;
context.result_from_object(&val, vm)
};
if let Err(exc) = f() {
context.result_exception(vm, exc, "user-defined function raised exception\0")
}
}
unsafe extern "C" fn step_callback(
context: *mut sqlite3_context,
argc: c_int,
argv: *mut *mut sqlite3_value,
) {
let context = SqliteContext::from(context);
let (cls, vm) = (*context.user_data::<Self>()).retrieve();
let args = std::slice::from_raw_parts(argv, argc as usize);
let instance = context.aggregate_context::<*const PyObject>();
if (*instance).is_null() {
match cls.call((), vm) {
Ok(obj) => *instance = obj.into_raw(),
Err(exc) => {
return context.result_exception(
vm,
exc,
"user-defined aggregate's '__init__' method raised error\0",
)
}
}
}
let instance = &**instance;
Self::call_method_with_args(context, instance, "step", args, vm);
}
unsafe extern "C" fn finalize_callback(context: *mut sqlite3_context) {
let context = SqliteContext::from(context);
let (_, vm) = (*context.user_data::<Self>()).retrieve();
let instance = context.aggregate_context::<*const PyObject>();
let Some(instance) = (*instance).as_ref() else { return; };
Self::callback_result_from_method(context, instance, "finalize", vm);
}
unsafe extern "C" fn collation_callback(
data: *mut c_void,
a_len: c_int,
a_ptr: *const c_void,
b_len: c_int,
b_ptr: *const c_void,
) -> c_int {
let (callable, vm) = (*data.cast::<Self>()).retrieve();
let f = || -> PyResult<c_int> {
let text1 = ptr_to_string(a_ptr.cast(), a_len, null_mut(), vm)?;
5276 let text1 = vm.ctx.new_str(text1);
let text2 = ptr_to_string(b_ptr.cast(), b_len, null_mut(), vm)?;
let text2 = vm.ctx.new_str(text2);
let val = callable.call((text1, text2), vm)?;
let Some(val) = val.to_number().index(vm) else {
return Ok(0);
};
let val = match val?.as_bigint().sign() {
num_bigint::Sign::Plus => 1,
num_bigint::Sign::Minus => -1,
num_bigint::Sign::NoSign => 0,
};
Ok(val)
};
f().unwrap_or(0)
}
unsafe extern "C" fn value_callback(context: *mut sqlite3_context) {
let context = SqliteContext::from(context);
let (_, vm) = (*context.user_data::<Self>()).retrieve();
let instance = context.aggregate_context::<*const PyObject>();
let instance = &**instance;
Self::callback_result_from_method(context, instance, "value", vm);
}
unsafe extern "C" fn inverse_callback(
context: *mut sqlite3_context,
argc: c_int,
argv: *mut *mut sqlite3_value,
) {
let context = SqliteContext::from(context);
let (_, vm) = (*context.user_data::<Self>()).retrieve();
let args = std::slice::from_raw_parts(argv, argc as usize);
let instance = context.aggregate_context::<*const PyObject>();
let instance = &**instance;
Self::call_method_with_args(context, instance, "inverse", args, vm);
}
unsafe extern "C" fn authorizer_callback(
data: *mut c_void,
action: c_int,
arg1: *const libc::c_char,
arg2: *const libc::c_char,
db_name: *const libc::c_char,
access: *const libc::c_char,
) -> c_int {
let (callable, vm) = (*data.cast::<Self>()).retrieve();
let f = || -> PyResult<c_int> {
let arg1 = ptr_to_str(arg1, vm)?;
let arg2 = ptr_to_str(arg2, vm)?;
let db_name = ptr_to_str(db_name, vm)?;
let access = ptr_to_str(access, vm)?;
let val = callable.call((action, arg1, arg2, db_name, access), vm)?;
let Some(val) = val.payload::<PyInt>() else {
return Ok(SQLITE_DENY);
};
val.try_to_primitive::<c_int>(vm)
};
f().unwrap_or(SQLITE_DENY)
}
unsafe extern "C" fn trace_callback(
_typ: c_uint,
data: *mut c_void,
stmt: *mut c_void,
sql: *mut c_void,
) -> c_int {
let (callable, vm) = (*data.cast::<Self>()).retrieve();
let expanded = sqlite3_expanded_sql(stmt.cast());
let f = || -> PyResult<()> {
let stmt = ptr_to_str(expanded, vm).or_else(|_| ptr_to_str(sql.cast(), vm))?;
callable.call((stmt,), vm)?;
Ok(())
};
let _ = f();
0
}
unsafe extern "C" fn progress_callback(data: *mut c_void) -> c_int {
let (callable, vm) = (*data.cast::<Self>()).retrieve();
if let Ok(val) = callable.call((), vm) {
if let Ok(val) = val.is_true(vm) {
return val as c_int;
}
}
-1
}
fn callback_result_from_method(
context: SqliteContext,
instance: &PyObject,
name: &str,
vm: &VirtualMachine,
) {
let f = || -> PyResult<()> {
let val = vm.call_method(instance, name, ())?;
context.result_from_object(&val, vm)
};
if let Err(exc) = f() {
if exc.fast_isinstance(vm.ctx.exceptions.attribute_error) {
context.result_exception(
vm,
exc,
&format!("user-defined aggregate's '{name}' method not defined\0"),
)
} else {
context.result_exception(
vm,
exc,
&format!("user-defined aggregate's '{name}' method raised error\0"),
)
}
}
}
fn call_method_with_args(
context: SqliteContext,
instance: &PyObject,
name: &str,
args: &[*mut sqlite3_value],
vm: &VirtualMachine,
) {
let f = || -> PyResult<()> {
let db = context.db_handle();
let args = args
.iter()
.cloned()
.map(|val| value_to_object(val, db, vm))
.collect::<PyResult<Vec<PyObjectRef>>>()?;
vm.call_method(instance, name, args).map(drop)
};
if let Err(exc) = f() {
if exc.fast_isinstance(vm.ctx.exceptions.attribute_error) {
context.result_exception(
vm,
exc,
&format!("user-defined aggregate's '{name}' method not defined\0"),
)
} else {
context.result_exception(
vm,
exc,
&format!("user-defined aggregate's '{name}' method raised error\0"),
)
}
}
}
}
impl Drop for CallbackData {
fn drop(&mut self) {
unsafe { PyObjectRef::from_raw(self.obj) };
}
}
#[pyfunction]
fn connect(args: ConnectArgs, vm: &VirtualMachine) -> PyResult {
Connection::py_new(args.factory.clone(), args, vm)
}
#[pyfunction]
fn complete_statement(statement: PyStrRef, vm: &VirtualMachine) -> PyResult<bool> {
let s = statement.to_cstring(vm)?;
let ret = unsafe { sqlite3_complete(s.as_ptr()) };
Ok(ret == 1)
}
#[pyfunction]
fn enable_callback_tracebacks(flag: bool) {
enable_traceback().store(flag, Ordering::Relaxed);
}
#[pyfunction]
fn register_adapter(typ: PyTypeRef, adapter: ArgCallable, vm: &VirtualMachine) -> PyResult<()> {
if typ.is(PyInt::class(&vm.ctx))
|| typ.is(PyFloat::class(&vm.ctx))
|| typ.is(PyStr::class(&vm.ctx))
|| typ.is(PyByteArray::class(&vm.ctx))
{
let _ = BASE_TYPE_ADAPTED.set(());
}
let protocol = PrepareProtocol::class(&vm.ctx).to_owned();
let key = vm.ctx.new_tuple(vec![typ.into(), protocol.into()]);
adapters().set_item(key.as_object(), adapter.into(), vm)
}
#[pyfunction]
fn register_converter(
typename: PyStrRef,
converter: ArgCallable,
vm: &VirtualMachine,
) -> PyResult<()> {
let name = typename.as_str().to_uppercase();
converters().set_item(&name, converter.into(), vm)
}
fn _adapt<F>(obj: &PyObject, proto: PyTypeRef, alt: F, vm: &VirtualMachine) -> PyResult
where
F: FnOnce(&PyObject) -> PyResult,
{
let proto = proto.into_object();
let key = vm
.ctx
.new_tuple(vec![obj.class().to_owned().into(), proto.clone()]);
if let Some(adapter) = adapters().get_item_opt(key.as_object(), vm)? {
return adapter.call((obj,), vm);
}
if let Ok(adapter) = proto.get_attr("__adapt__", vm) {
match adapter.call((obj,), vm) {
Ok(val) => return Ok(val),
Err(exc) => {
if !exc.fast_isinstance(vm.ctx.exceptions.type_error) {
return Err(exc);
}
}
}
}
if let Ok(adapter) = obj.get_attr("__conform__", vm) {
match adapter.call((proto,), vm) {
Ok(val) => return Ok(val),
Err(exc) => {
if !exc.fast_isinstance(vm.ctx.exceptions.type_error) {
return Err(exc);
}
}
}
}
alt(obj)
}
#[pyfunction]
fn adapt(
obj: PyObjectRef,
proto: OptionalArg<Option<PyTypeRef>>,
alt: OptionalArg<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult {
// TODO: None proto
let proto = proto
.flatten()
.unwrap_or_else(|| PrepareProtocol::class(&vm.ctx).to_owned());
_adapt(
&obj,
proto,
|_| {
if let OptionalArg::Present(alt) = alt {
Ok(alt)
} else {
Err(new_programming_error(vm, "can't adapt".to_owned()))
}
},
vm,
)
}
fn need_adapt(obj: &PyObject, vm: &VirtualMachine) -> bool {
if BASE_TYPE_ADAPTED.get().is_some() {
true
} else {
let cls = obj.class();
!(cls.is(vm.ctx.types.int_type)
|| cls.is(vm.ctx.types.float_type)
|| cls.is(vm.ctx.types.str_type)
|| cls.is(vm.ctx.types.bytearray_type))
}
}
static_cell! {
static CONVERTERS: PyDictRef;
static ADAPTERS: PyDictRef;
static BASE_TYPE_ADAPTED: ();
static USER_FUNCTION_EXCEPTION: PyAtomicRef<Option<PyBaseException>>;
static ENABLE_TRACEBACK: PyAtomic<bool>;
}
fn converters() -> &'static Py<PyDict> {
CONVERTERS.get().expect("converters not initialize")
}
fn adapters() -> &'static Py<PyDict> {
ADAPTERS.get().expect("adapters not initialize")
}
fn user_function_exception() -> &'static PyAtomicRef<Option<PyBaseException>> {
USER_FUNCTION_EXCEPTION
.get()
.expect("user function exception not initialize")
}
fn enable_traceback() -> &'static PyAtomic<bool> {
ENABLE_TRACEBACK
.get()
.expect("enable traceback not initialize")
}
pub(super) fn setup_module(module: &PyObject, vm: &VirtualMachine) {
for (name, code) in ERROR_CODES {
let name = vm.ctx.intern_str(*name);
let code = vm.new_pyobj(*code);
module.set_attr(name, code, vm).unwrap();
}
setup_module_exceptions(module, vm);
let _ = CONVERTERS.set(vm.ctx.new_dict());
let _ = ADAPTERS.set(vm.ctx.new_dict());
let _ = USER_FUNCTION_EXCEPTION.set(PyAtomicRef::from(None));
let _ = ENABLE_TRACEBACK.set(Radium::new(false));
module
.set_attr("converters", converters().to_owned(), vm)
.unwrap();
module
.set_attr("adapters", adapters().to_owned(), vm)
.unwrap();
}
#[pyattr]
#[pyclass(name)]
#[derive(PyPayload)]
struct Connection {
db: PyMutex<Option<Sqlite>>,
detect_types: c_int,
isolation_level: PyAtomicRef<Option<PyStr>>,
check_same_thread: bool,
thread_ident: ThreadId,
row_factory: PyAtomicRef<Option<PyObject>>,
text_factory: PyAtomicRef<PyObject>,
}
impl Debug for Connection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "Sqlite3 Connection")
}
}
impl Constructor for Connection {
type Args = ConnectArgs;
fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult {
Ok(Self::new(args, vm)?.into_ref_with_type(vm, cls)?.into())
}
}
impl Callable for Connection {
type Args = (PyStrRef,);
fn call(zelf: &Py<Self>, args: Self::Args, vm: &VirtualMachine) -> PyResult {
if let Some(stmt) = Statement::new(zelf, &args.0, vm)? {
Ok(stmt.into_ref(&vm.ctx).into())
} else {
Ok(vm.ctx.none())
}
}
}
#[pyclass(with(Constructor, Callable), flags(BASETYPE))]
impl Connection {
fn new(args: ConnectArgs, vm: &VirtualMachine) -> PyResult<Self> {
let path = args.database.to_cstring(vm)?;
let db = Sqlite::from(SqliteRaw::open(path.as_ptr(), args.uri, vm)?);
let timeout = (args.timeout * 1000.0) as c_int;
db.busy_timeout(timeout);
if let Some(isolation_level) = &args.isolation_level {
begin_statement_ptr_from_isolation_level(isolation_level, vm)?;
}
let text_factory = PyStr::class(&vm.ctx).to_owned().into_object();
Ok(Self {
db: PyMutex::new(Some(db)),
detect_types: args.detect_types,
isolation_level: PyAtomicRef::from(args.isolation_level),
check_same_thread: args.check_same_thread,
thread_ident: std::thread::current().id(),
row_factory: PyAtomicRef::from(None),
text_factory: PyAtomicRef::from(text_factory),
})
}
fn db_lock(&self, vm: &VirtualMachine) -> PyResult<PyMappedMutexGuard<Sqlite>> {
self.check_thread(vm)?;
self._db_lock(vm)
}
fn _db_lock(&self, vm: &VirtualMachine) -> PyResult<PyMappedMutexGuard<Sqlite>> {
let guard = self.db.lock();
if guard.is_some() {
Ok(PyMutexGuard::map(guard, |x| unsafe {
x.as_mut().unwrap_unchecked()
}))
} else {
Err(new_programming_error(
vm,
"Cannot operate on a closed database.".to_owned(),
))
}
}
#[pymethod]
fn cursor(
zelf: PyRef<Self>,
factory: OptionalArg<ArgCallable>,
vm: &VirtualMachine,
) -> PyResult<PyRef<Cursor>> {
zelf.db_lock(vm).map(drop)?;
let cursor = if let OptionalArg::Present(factory) = factory {
let cursor = factory.invoke((zelf.clone(),), vm)?;
let cursor = cursor.downcast::<Cursor>().map_err(|x| {
vm.new_type_error(format!("factory must return a cursor, not {}", x.class()))
})?;
unsafe { cursor.row_factory.swap(zelf.row_factory.to_owned()) };
cursor
} else {
let row_factory = zelf.row_factory.to_owned();
Cursor::new(zelf, row_factory, vm).into_ref(&vm.ctx)
};
Ok(cursor)
}
#[pymethod]
fn blobopen(
zelf: PyRef<Self>,
args: BlobOpenArgs,
vm: &VirtualMachine,
) -> PyResult<PyRef<Blob>> {
let table = args.table.to_cstring(vm)?;
let column = args.column.to_cstring(vm)?;
let name = args.name.to_cstring(vm)?;
let db = zelf.db_lock(vm)?;
let mut blob = null_mut();
let ret = unsafe {
sqlite3_blob_open(
db.db,
name.as_ptr(),
table.as_ptr(),
column.as_ptr(),
args.row,
4E22 (!args.readonly) as c_int,
&mut blob,
)
};
db.check(ret, vm)?;
drop(db);
let blob = SqliteBlob { blob };
let blob = Blob {
connection: zelf,
inner: PyMutex::new(Some(BlobInner { blob, offset: 0 })),
};
Ok(blob.into_ref(&vm.ctx))
}
#[pymethod]
fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
self.check_thread(vm)?;
self.db.lock().take();
Ok(())
}
#[pymethod]
fn commit(&self, vm: &VirtualMachine) -> PyResult<()> {
self.db_lock(vm)?.implicit_commit(vm)
}
#[pymethod]
fn rollback(&self, vm: &VirtualMachine) -> PyResult<()> {
let db = self.db_lock(vm)?;
if !db.is_autocommit() {
db._exec(b"ROLLBACK\0", vm)
} else {
Ok(())
}
}
#[pymethod]
fn execute(
zelf: PyRef<Self>,
sql: PyStrRef,
parameters: OptionalArg<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<PyRef<Cursor>> {
let row_factory = zelf.row_factory.to_owned();
let cursor = Cursor::new(zelf, row_factory, vm).into_ref(&vm.ctx);
Cursor::execute(cursor, sql, parameters, vm)
}
#[pymethod]
fn executemany(
zelf: PyRef<Self>,
sql: PyStrRef,
seq_of_params: ArgIterable,
vm: &VirtualMachine,
) -> PyResult<PyRef<Cursor>> {
let row_factory = zelf.row_factory.to_owned();
let cursor = Cursor::new(zelf, row_factory, vm).into_ref(&vm.ctx);
Cursor::executemany(cursor, sql, seq_of_params, vm)
}
#[pymethod]
fn executescript(
zelf: PyRef<Self>,
script: PyStrRef,
vm: &VirtualMachine,
) -> PyResult<PyRef<Cursor>> {
let row_factory = zelf.row_factory.to_owned();
Cursor::executescript(
Cursor::new(zelf, row_factory, vm).into_ref(&vm.ctx),
script,
vm,
)
}
0