8000 Add zero-copy flag · data-apis/dataframe-api@5773efc · GitHub
[go: up one dir, main page]

Skip to content

Commit 5773efc

Browse files
committed
Add zero-copy flag
1 parent 1f0286b commit 5773efc

File tree

2 files changed

+45
-23
lines changed

2 files changed

+45
-23
lines changed

protocol/dataframe_protocol.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,16 +335,22 @@ class DataFrame:
335335
``__dataframe__`` method of a public data frame class in a library adhering
336336
to the dataframe interchange protocol specification.
337337
"""
338-
def __dataframe__(self, nan_as_null : bool = False) -> dict:
338+
def __dataframe__(self, nan_as_null : bool = False,
339+
allow_zero_copy : bool = True) -> dict:
339340
"""
340341
Produces a dictionary object following the dataframe protocol spec
341342
342343
``nan_as_null`` is a keyword intended for the consumer to tell the
343344
producer to overwrite null values in the data with ``NaN`` (or ``NaT``).
344345
It is intended for cases where the consumer does not support the bit
345346
mask or byte mask that is the producer's native representation.
347+
348+
``allow_zero_copy`` is a keyword that defines if the given implementation
349+
is going to support striding buffers. It is optional, and the libraries
350+
do not need to implement it.
346351
"""
347352
self._nan_as_null = nan_as_null
353+
self._allow_zero_zopy = allow_zero_copy
348354
return {
349355
"dataframe": self, # DataFrame object adhering to the protocol
350356
"version": 0 # Version number of the protocol

protocol/pandas_implementation.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
ColumnObject = Any
3636

3737

38-
def from_dataframe(df : DataFrameObject) -> pd.DataFrame:
38+
def from_dataframe(df : DataFrameObject,
39+
allow_zero_copy : bool = False) -> pd.DataFrame:
3940
"""
4041
Construct a pandas DataFrame from ``df`` if it supports ``__dataframe__``
4142
"""
@@ -46,7 +47,7 @@ def from_dataframe(df : DataFrameObject) -> pd.DataFrame:
4647
if not hasattr(df, '__dataframe__'):
4748
raise ValueError("`df` does not support __dataframe__")
4849

49-
return _from_dataframe(df.__dataframe__())
50+
return _from_dataframe(df.__dataframe__(allow_zero_copy=allow_zero_copy))
5051

5152

5253
def _from_dataframe(df : DataFrameObject) -> pd.DataFrame:
@@ -160,7 +161,8 @@ def convert_categorical_column(col : ColumnObject) -> pd.Series:
160161
return series
161162

162163

163-
def __dataframe__(cls, nan_as_null : bool = False) -> dict:
164+
def __dataframe__(cls, nan_as_null : bool = False,
165+
allow_zero_copy : bool = False) -> dict:
164166
"""
165167
The public method to attach to pd.DataFrame
166168
@@ -171,8 +173,14 @@ def __dataframe__(cls, nan_as_null : bool = False) -> dict:
171173
producer to overwrite null values in the data with ``NaN`` (or ``NaT``).
172174
This currently has no effect; once support for nullable extension
173175
dtypes is added, this value should be propagated to columns.
176+
177+
``allow_zero_copy`` is a keyword that defines if the given implementation
178+
is going to support striding buffers. It is optional, and the libraries
179+
do not need to implement it. Currently, if the flag is set to ``True`` it
180+
will raise a ``RuntimeError``.
174181
"""
175-
return _PandasDataFrame(cls, nan_as_null=nan_as_null)
182+
return _PandasDataFrame(
183+
cls, nan_as_null=nan_as_null, allow_zero_copy=allow_zero_copy)
176184

177185

178186
# Monkeypatch the Pandas DataFrame class to support the interchange protocol
@@ -187,16 +195,16 @@ class _PandasBuffer:
187195
Data in the buffer is guaranteed to be contiguous in memory.
188196
"""
189197

190-
def __init__(self, x : np.ndarray) -> None:
198+
def __init__(self, x : np.ndarray, allow_zero_copy : bool = False) -> None:
191199
"""
192200
Handle only regular columns (= numpy arrays) for now.
193201
"""
194-
if not x.strides == (x.dtype.itemsize,):
195-
# Array is not contiguous - this is possible to get in Pandas,
196-
# there was some discussion on whether to support it. Som extra
197-
# complexity for libraries that don't support it (e.g. Arrow),
198-
# but would help with numpy-based libraries like Pandas.
199-
raise RuntimeError("Design needs fixing - non-contiguous buffer")
202+
if allow_zero_copy:
203+
# Array is not contiguous and strided buffers do not need to be
204+
# supported. It brings some extra complexity for libraries that
205+
# don't support it (e.g. Arrow).
206+
raise RuntimeError(
207+
"Exports cannot be zero-copy in the case of a non-contiguous buffer")
200208

201209
# Store the numpy array in which the data resides as a private
202210
# attribute, so we can use it to retrieve the public attributes
@@ -251,7 +259,8 @@ class _PandasColumn:
251259
252260
"""
253261

254-
def __init__(self, column : pd.Series) -> None:
262+
def __init__(self, column : pd.Series,
263+
allow_zero_copy : bool = False) -> None:
255264
"""
256265
Note: doesn't deal with extension arrays yet, just assume a regular
257266
Series/ndarray for now.
@@ -262,6 +271,7 @@ def __init__(self, column : pd.Series) -> None:
262271

263272
# Store the column as a private attribute
264273
self._col = column
274+
self._allow_zero_copy = allow_zero_copy
265275

266276
@property
267277
def size(self) -> int:
@@ -446,11 +456,13 @@ def get_data_buffer(self) -> Tuple[_PandasBuffer, Any]: # Any is for self.dtype
446456
"""
447457
_k = _DtypeKind
448458
if self.dtype[0] in (_k.INT, _k.UINT, _k.FLOAT, _k.BOOL):
449-
buffer = _PandasBuffer(self._col.to_numpy())
459+
buffer = _PandasBuffer(
460+
self._col.to_numpy(), allow_zero_copy=self._allow_zero_copy)
450461
dtype = self.dtype
451462
elif self.dtype[0] == _k.CATEGORICAL:
452463
codes = self._col.values.codes
453-
buffer = _PandasBuffer(codes)
464+
buffer = _PandasBuffer(
465+
codes, allow_zero_copy=self._allow_zero_copy)
454466
dtype = self._dtype_from_pandasdtype(codes.dtype)
455467
else:
456468
raise NotImplementedError(f"Data type {self._col.dtype} not handled yet")
@@ -483,7 +495,8 @@ class _PandasDataFrame:
483495
``pd.DataFrame.__dataframe__`` as objects with the methods and
484496
attributes defined on this class.
485497
"""
486-
def __init__(self, df : pd.DataFrame, nan_as_null : bool = False) -> None:
498+
def __init__(self, df : pd.DataFrame, nan_as_null : bool = False,
499+
allow_zero_copy : bool = False) -> None:
487500
"""
488501
Constructor - an instance of this (private) class is returned from
489502
`pd.DataFrame.__dataframe__`.
@@ -494,6 +507,7 @@ def __init__(self, df : pd.DataFrame, nan_as_null : bool = False) -> None:
494507
# This currently has no effect; once support for nullable extension
495508
# dtypes is added, this value should be propagated to columns.
496509
self._nan_as_null = nan_as_null
510+
self._allow_zero_copy = allow_zero_copy
497511

498512
def num_columns(self) -> int:
499513
return len(self._df.columns)
@@ -508,13 +522,16 @@ def column_names(self) -> Iterable[str]:
508522
return self._df.columns.tolist()
509523

510524
def get_column(self, i: int) -> _PandasColumn:
511-
return _PandasColumn(self._df.iloc[:, i])
525+
return _PandasColumn(
526+
self._df.iloc[:, i], allow_zero_copy=self._allow_zero_copy)
512527

513528
def get_column_by_name(self, name: str) -> _PandasColumn:
514-
return _PandasColumn(self._df[name])
529+
return _PandasColumn(
530+
self._df[name], allow_zero_copy=self._allow_zero_copy)
515531

516532
def get_columns(self) -> Iterable[_PandasColumn]:
517-
return [_PandasColumn(self._df[name]) for name in self._df.columns]
533+
return [_PandasColumn(self._df[name], allow_zero_copy=self._allow_zero_copy)
534+
for name in self._df.columns]
518535

519536
def select_columns(self, indices: Sequence[int]) -> '_PandasDataFrame':
520537
if not isinstance(indices, collections.Sequence):
@@ -552,13 +569,12 @@ def test_mixed_intfloat():
552569

553570

554571
def test_noncontiguous_columns():
555-
# Currently raises: TBD whether it should work or not, see code comment
556-
# where the RuntimeError is raised.
572+
# Currently raises if the flag of allow zero copy is True.
557573
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
558574
df = pd.DataFrame(arr)
559575
assert df[0].to_numpy().strides == (24,)
560-
pytest.raises(RuntimeError, from_dataframe, df)
561-
#df2 = from_dataframe(df)
576+
with pytest.raises(RuntimeError):
577+
df2 = from_dataframe(df, allow_zero_copy=True)
562578

563579

564580
def test_categorical_dtype():

0 commit comments

Comments
 (0)
0