8000 refactor: update py_obj_to_scalar_value to return a Result for better… · kosiew/datafusion-python@b89c695 · GitHub
[go: up one dir, main page]

Skip to content

Commit b89c695

Browse files
committed
refactor: update py_obj_to_scalar_value to return a Result for better error handling
1 parent d546f7a commit b89c695

File tree

3 files changed

+22
-24
lines changed

3 files changed

+22
-24
lines changed

src/config.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ impl PyConfig {
5959

6060
/// Set a configuration option
6161
pub fn set(&mut self, key: &str, value: PyObject, py: Python) -> PyDataFusionResult<()> {
62-
let scalar_value = py_obj_to_scalar_value(py, value);
62+
let scalar_value = py_obj_to_scalar_value(py, value)?;
6363
self.config.set(key, scalar_value.to_string().as_str())?;
6464
Ok(())
6565
}

src/dataframe.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ impl PyDataFrame {
720720
columns: Option<Vec<PyBackedStr>>,
721721
py: Python,
722722
) -> PyDataFusionResult<Self> {
723-
let scalar_value = py_obj_to_scalar_value(py, value);
723+
let scalar_value = py_obj_to_scalar_value(py, value)?;
724724

725725
let cols = match columns {
726726
Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(),

src/utils.rs

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -89,42 +89,40 @@ pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyRe
8989

9090
Ok(())
9191
}
92-
9392
/// Convert a Python object to ScalarValue using PyArrow
9493
///
9594
/// Args:
9695
/// py: Python interpreter
9796
/// obj: Python object to convert
9897
///
9998
/// Returns:
100-
/// ScalarValue representation of the Python object
99+
/// Result containing ScalarValue representation of the Python object
101100
///
102101
/// This function handles basic Python types directly and uses PyArrow
103102
/// for complex types like datetime.
104-
pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> ScalarValue {
103+
pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult<ScalarValue> {
105104
if let Ok(value) = obj.extract::<bool>(py) {
106-
ScalarValue::Boolean(Some(value))
105+
return Ok(ScalarValue::Boolean(Some(value)));
107106
} else if let Ok(value) = obj.extract::<i64>(py) {
108-
ScalarValue::Int64(Some(value))
107+
return Ok(ScalarValue::Int64(Some(value)));
109108
} else if let Ok(value) = obj.extract::<u64>(py) {
110-
ScalarValue::UInt64(Some(value))
109+
return Ok(ScalarValue::UInt64(Some(value)));
111110
} else if let Ok(value) = obj.extract::<f64>(py) {
112-
ScalarValue::Float64(Some(value))
111+
return Ok(ScalarValue::Float64(Some(value)));
113112
} else if let Ok(value) = obj.extract::<String>(py) {
114-
ScalarValue::Utf8(Some(value))
115-
} else {
116-
// For datetime and other complex types, convert via PyArrow
117-
let pa = py.import("pyarrow");
118-
let pa = pa.expect("Failed to import PyArrow");
119-
// Convert Python object to PyArrow scalar
120-
// This handles datetime objects by converting to PyArrow timestamp type
121-
let scalar = pa.call_method1("scalar", (obj,));
122-
let scalar = scalar.expect("Failed to convert Python object to PyArrow scalar");
123-
// Convert PyArrow scalar to PyScalarValue
124-
let py_scalar = PyScalarValue::extract_bound(scalar.as_ref());
125-
// Unwrap the result - this will panic if extraction failed
126-
let py_scalar = py_scalar.expect("Failed to extract PyScalarValue from PyArrow scalar");
127-
// Convert PyScalarValue to ScalarValue
128-
py_scalar.into()
113+
return Ok(ScalarValue::Utf8(Some(value)));
129114
}
115+
116+
// For datetime and other complex types, convert via PyArrow
117+
let pa = py.import("pyarrow")?;
118+
119+
// Convert Python object to PyArrow scalar
120+
let scalar = pa.call_method1("scalar", (obj,))?;
121+
122+
// Convert PyArrow scalar to PyScalarValue
123+
let py_scalar = PyScalarValue::extract_bound(scalar.as_ref())
124+
.map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {}", e)))?;
125+
126+
// Convert PyScalarValue to ScalarValue
127+
Ok(py_scalar.into())
130128
}

0 commit comments

Comments
 (0)
0