8000 Wrap tabular data in a new dataclass to simplify ML pipelines · Issue #25126 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Wrap tabular data in a new dataclass to simplify ML pipelines #25126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
zkurtz opened this issue Dec 7, 2022 · 1 comment
Closed

Wrap tabular data in a new dataclass to simplify ML pipelines #25126

zkurtz opened this issue Dec 7, 2022 · 1 comment
Labels
Needs Triage Issue requires triage New Feature

Comments

@zkurtz
Copy link
zkurtz commented Dec 7, 2022

Describe the workflow you want to enable

In my dreams, a new InferenceData class would simplify training and prediction to look more like

import ... as learner

data = InferenceData(
    df=..., # a data frame 
    meta=Meta(
        y_cols=..., # name(s) of output variable
        ... # additional metadata fields
    )
)
train_data, test_data = data.split(train_fraction=0.7, ...)
learner.fit(train_data)
predictions = learner.predict(test_data.x)

Note that this introduces just two data variables train_data and test_data instead of the current standard four (X_train, X_test, y_train, y_test).

In addition, InferenceData could easily be extended to allow the above pipeline to handle related metadata such as feature weights, replacing a step like learner.fit(train_data) by learner.fit(train_data, weights=train_data.row_weights), for example.

Describe your proposed solution

A solution could look something like this:

@dataclass
class Meta:
    """Metadata for a Data class."""
    y_cols: Optional[list[str]] = None
    row_weights_col: Optional[str] = None

    @property
    def y(self) -> list[str]:
        """Output variable column names."""
        if not self.y_cols:
            return []
        return self.y_cols
    
    @property
    def columns(self) -> set[str]:
        """All metadata columns."""
        cols = set(self.y)
        if self.row_weights_col:
            cols.add(row_weights_col)
        return cols


@dataclass
class InferenceData:
    """A data frame container that includes metadata relevant for machine learning and inference."""
   
    df: DataFrame
    meta: Meta = Meta()
        
    def __post_init__(self) -> None:
        """Parameter validation."""
        if (self.y is not None) and (self.n_rows != len(self.y)):
            raise ValueError("Expected y to have the same number of data points as x has.")
        # TODO: also validate that all columns referenced in self.meta exist in self.df etc
    
    @property
    def x(self) -> DataFrame:
        """The data frame of predictor variables, excluding output variables and weights, etc."""
        non_metadata_cols = [col for col in self.df if col not in self.meta.columns]
        return self.df[non_metadata_cols]
    
    @property
    def y(self) -> None | np.ndarray:
        """Output features."""
        if not self.meta.y_cols:
            return None
        if len(self.meta.y_cols) == 1:
            y_col = self.meta.y_cols[0]
            return self.df[y_col].to_numpy()
        return self.df[self.meta.y_cols].to_numpy()

    @property
    def w(self) -> None | np.ndarray:
        """The data frame of predictor variables, excluding output variables and weights, etc."""
        return self.df[row_weights_col]

    def iloc(self, positional_indexes: Iterable) -> "InferenceData":
        """Return the subset of the data in self corresponding to the specificied positional indices."""
        return InferenceData(df=df.iloc[positional_indexes], meta=self.meta)
    
    @property
    def n_rows(self) -> int:
        """Number of rows of data."""
        return self.df.shape[0]

Of course, the usefulness of this solution will depend on wrapping existing machine learning algorithms to accept an InferenceData class as input.

Describe alternatives you've considered, if relevant

No response

Additional context

No response

@zkurtz zkurtz added Needs Triage Issue requires triage New Feature labels Dec 7, 2022
@adrinjalali
Copy link
Member

Duplicate of #13123

@adrinjalali adrinjalali marked this as a duplicate of #13123 Dec 7, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Needs Triage Issue requires triage New Feature
Projects
None yet
Development

No branches or pull requests

2 participants
0