Description
scikit-learn provides in the API the ability to pass arbitrary training sample metadata to custom transformers during fit
using **fit_params
and it works as expected within Pipeline
and cross-validation such as GridSearchCV
. But for some common transformation use cases, such as batch effect correction, you need to also pass some test sample metadata during transform
to transform the data using parameters learned from training data in fit
.
For example, one can fit a linear model on the training data using the training sample batch info and learn the coefficient (vector) associated with each batch. Then in the training transform I can subtract the component due to batch effects again using the sample batch info to apply the right coefficient. But when transforming test I need to know which batch each test sample belongs to in order to select the right training coefficient.
Is there a way with the current API to do this? Am I missing some newer scikit-learn feature that can accomplish this or a workaround?