From f719269205eb1d16d300bd51b0119f808f54b763 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Thu, 26 Dec 2019 16:15:32 +0100 Subject: [PATCH 01/11] ENH TF-IGM feature weighting (initial implementation) --- .../plot_tfigm_weighting_text.py | 27 +++++++ sklearn_extra/feature_weighting/__init__.py | 3 + sklearn_extra/feature_weighting/_text.py | 70 +++++++++++++++++++ .../feature_weighting/tests/test_text.py | 15 ++++ 4 files changed, 115 insertions(+) create mode 100644 examples/feature_weighting/plot_tfigm_weighting_text.py create mode 100644 sklearn_extra/feature_weighting/__init__.py create mode 100644 sklearn_extra/feature_weighting/_text.py create mode 100644 sklearn_extra/feature_weighting/tests/test_text.py diff --git a/examples/feature_weighting/plot_tfigm_weighting_text.py b/examples/feature_weighting/plot_tfigm_weighting_text.py new file mode 100644 index 00000000..18a47c8f --- /dev/null +++ b/examples/feature_weighting/plot_tfigm_weighting_text.py @@ -0,0 +1,27 @@ +from sklearn.linear_model import LogisticRegression +from sklearn.preprocessing import Normalizer +from sklearn.pipeline import make_pipeline +from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer +from sklearn.datasets import fetch_20newsgroups +from sklearn.model_selection import cross_val_score +from sklearn.metrics import f1_score + +from sklearn_extra.feature_weighting import TfigmTransformer + + +X, y = fetch_20newsgroups(return_X_y=True) + +for scaler in [TfidfTransformer(), TfigmTransformer(alpha=9)]: + pipe = make_pipeline( + CountVectorizer(min_df=5, stop_words="english"), + scaler, + Normalizer() + ) + X_tr = pipe.fit_transform(X, y) + est = LogisticRegression(random_state=2, solver="liblinear") + scores = cross_val_score( + est, X_tr, y, verbose=1, + scoring=lambda est, X, y: f1_score(y, est.predict(X), average="macro"), + ) + print(f"{scaler.__class__.__name__} F1-macro score: " + f"{scores.mean():.3f}+-{scores.std():.3f}") diff --git a/sklearn_extra/feature_weighting/__init__.py b/sklearn_extra/feature_weighting/__init__.py new file mode 100644 index 00000000..ff7809f1 --- /dev/null +++ b/sklearn_extra/feature_weighting/__init__.py @@ -0,0 +1,3 @@ +from ._text import TfigmTransformer + +__all__ = ["TfigmTransformer"] diff --git a/sklearn_extra/feature_weighting/_text.py b/sklearn_extra/feature_weighting/_text.py new file mode 100644 index 00000000..dbee04d0 --- /dev/null +++ b/sklearn_extra/feature_weighting/_text.py @@ -0,0 +1,70 @@ +import numpy as np +import scipy.sparse as sp + +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.utils.validation import check_array, check_X_y +from sklearn.preprocessing import LabelEncoder + + +class TfigmTransformer(BaseEstimator, TransformerMixin): + """Apply TF-IGM feature weighting + + Parameters + ---------- + alpha : float, default=7 + regularization parameter + + References + ---------- + Chen, Kewen, et al. "Turning from TF-IDF to TF-IGM for term weighting + in text classification." Expert Systems with Applications 66 (2016): + 245-260. + """ + + def __init__(self, alpha=7.0): + self.alpha = alpha + + def _fit(self, X, y): + self._le = LabelEncoder().fit(y) + class_freq = np.zeros((len(self._le.classes_), X.shape[1])) + + X_nz = X != 0 + if sp.issparse(X_nz): + X_nz = X_nz.asformat("csr", copy=False) + + for idx, class_label in enumerate(self._le.classes_): + y_mask = y == class_label + n_samples = y_mask.sum() + class_freq[idx, :] = X_nz[y_mask].sum(axis=0) / n_samples + + self._class_freq = class_freq + self._class_rank = np.argsort(-self._class_freq, axis=0) + f1 = self._class_freq[ + self._class_rank[0, :], np.arange(self._class_freq.shape[1]) + ] + fk = (self._class_freq * (self._class_rank + 1)).sum(axis=0) + self.coef_ = 1 + self.alpha * (f1 / fk) + return self + + def fit(self, X, y): + X, y = check_X_y(X, y, accept_sparse=["csr", "csc"]) + self._fit(X, y) + return self + + def _transform(self, X): + if sp.issparse(X): + X_tr = X @ sp.diags(self.coef_) + else: + X_tr = X * self.coef_[None, :] + return X_tr + + def transform(self, X): + X = check_array(X, accept_sparse=["csr", "csc"]) + X_tr = self._transform(X) + return X_tr + + def fit_transform(self, X, y): + X, y = check_X_y(X, y, accept_sparse=["csr", "csc"]) + self._fit(X, y) + X_tr = self._transform(X) + return X_tr diff --git a/sklearn_extra/feature_weighting/tests/test_text.py b/sklearn_extra/feature_weighting/tests/test_text.py new file mode 100644 index 00000000..c1d6d9c6 --- /dev/null +++ b/sklearn_extra/feature_weighting/tests/test_text.py @@ -0,0 +1,15 @@ +import numpy as np +from numpy.testing import assert_allclose +import scipy.sparse as sp +from sklearn_extra.feature_weighting import TfigmTransformer + + +def test_tfigm_transform(): + X = np.array([[0, 1, 1], [1, 0, 1], [0, 0, 1], [1, 1, 1]]) + X = sp.csr_matrix(X) + y = np.array(["a", "b", "a", "c"]) + + est = TfigmTransformer() + X_tr = est.fit_transform(X, y) + assert X_tr.shape == X.shape + assert_allclose(est.coef_, [4.5, 2.75, 1.777778], rtol=1e-4) From ff8e06515514a822bd3454741843407a62aaa121 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Thu, 26 Dec 2019 17:11:21 +0100 Subject: [PATCH 02/11] Improve example --- .gitignore | 1 + .../plot_tfigm_weighting_text.py | 43 +++++++++++++++---- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index ecdba127..0df2f23e 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,7 @@ nosetests.xml coverage.xml *,cover .hypothesis/ +.swp # Translations *.mo diff --git a/examples/feature_weighting/plot_tfigm_weighting_text.py b/examples/feature_weighting/plot_tfigm_weighting_text.py index 18a47c8f..f90c6a6f 100644 --- a/examples/feature_weighting/plot_tfigm_weighting_text.py +++ b/examples/feature_weighting/plot_tfigm_weighting_text.py @@ -1,9 +1,14 @@ +import pandas as pd +import numpy as np +from tqdm import tqdm + from sklearn.linear_model import LogisticRegression -from sklearn.preprocessing import Normalizer +from sklearn.svm import LinearSVC +from sklearn.preprocessing import Normalizer, FunctionTransformer from sklearn.pipeline import make_pipeline from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer from sklearn.datasets import fetch_20newsgroups -from sklearn.model_selection import cross_val_score +from sklearn.model_selection import cross_validate from sklearn.metrics import f1_score from sklearn_extra.feature_weighting import TfigmTransformer @@ -11,7 +16,16 @@ X, y = fetch_20newsgroups(return_X_y=True) -for scaler in [TfidfTransformer(), TfigmTransformer(alpha=9)]: +#print('classes:', pd.Series(y).value_counts()) +res = [] + +for scaler_label, scaler in tqdm([ + ("identity", FunctionTransformer(lambda x: x)), + ("TF-IDF", TfidfTransformer()), + #("TF-IDF(smooth_idf=True, sublinear_tf=False)", TfidfTransformer()), + #("TF-IDF(smooth_idf=False, sublinear_tf=False)", TfidfTransformer(smooth_idf=False)), + #("TF-IDF(smooth_idf=True, sublinear_tf=True)", TfidfTransformer(sublinear_tf=True)), + ("TF-IGM", TfigmTransformer(alpha=7))]): pipe = make_pipeline( CountVectorizer(min_df=5, stop_words="english"), scaler, @@ -19,9 +33,22 @@ ) X_tr = pipe.fit_transform(X, y) est = LogisticRegression(random_state=2, solver="liblinear") - scores = cross_val_score( - est, X_tr, y, verbose=1, - scoring=lambda est, X, y: f1_score(y, est.predict(X), average="macro"), + #est = LinearSVC() + scoring={ + 'F1-macro': lambda est, X, y: f1_score(y, est.predict(X), average="macro"), + 'balanced_accuracy': "balanced_accuracy" + } + scores = cross_validate( + est, X_tr, y, verbose=0, + n_jobs=6, + scoring=scoring, + return_train_score=True ) - print(f"{scaler.__class__.__name__} F1-macro score: " - f"{scores.mean():.3f}+-{scores.std():.3f}") + res.extend([{'metric': "_".join(key.split('_')[1:]), + 'subset': key.split('_')[0], + "preprocessing": scaler_label, + "score": f"{val.mean():.3f}+-{val.std():.3f}"} + for key, val in scores.items() if not key.endswith('_time')]) +scores = pd.DataFrame(res).set_index(["preprocessing", "metric", 'subset'])['score'].unstack(-1) +scores = scores['test'].unstack(-1).sort_values("F1-macro", ascending=False) +print(scores) From 92486f52e427f4582dd9692fafa6ecfad5b2a78f Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Thu, 26 Dec 2019 17:26:16 +0100 Subject: [PATCH 03/11] Fix bug in TF-IGM --- .gitignore | 2 +- .../plot_tfigm_weighting_text.py | 13 +++++----- sklearn_extra/feature_weighting/_text.py | 24 +++++++++++++------ .../feature_weighting/tests/test_text.py | 2 +- 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index 0df2f23e..b23663b6 100644 --- a/.gitignore +++ b/.gitignore @@ -57,7 +57,7 @@ nosetests.xml coverage.xml *,cover .hypothesis/ -.swp +*.swp # Translations *.mo diff --git a/examples/feature_weighting/plot_tfigm_weighting_text.py b/examples/feature_weighting/plot_tfigm_weighting_text.py index f90c6a6f..9881af6d 100644 --- a/examples/feature_weighting/plot_tfigm_weighting_text.py +++ b/examples/feature_weighting/plot_tfigm_weighting_text.py @@ -21,11 +21,12 @@ for scaler_label, scaler in tqdm([ ("identity", FunctionTransformer(lambda x: x)), - ("TF-IDF", TfidfTransformer()), - #("TF-IDF(smooth_idf=True, sublinear_tf=False)", TfidfTransformer()), - #("TF-IDF(smooth_idf=False, sublinear_tf=False)", TfidfTransformer(smooth_idf=False)), - #("TF-IDF(smooth_idf=True, sublinear_tf=True)", TfidfTransformer(sublinear_tf=True)), - ("TF-IGM", TfigmTransformer(alpha=7))]): + ("TF-IDF(sublinar_tf=False)", TfidfTransformer()), + ("TF-IDF(sublinear_tf=True)", TfidfTransformer(sublinear_tf=True)), + ("TF-IGM(tf_scale=None)", TfigmTransformer(alpha=7)), + ("TF-IGM(tf_scale='sqrt')", TfigmTransformer(alpha=7, tf_scale="sqrt")), + ("TF-IGM(tf_scale='log1p')", TfigmTransformer(alpha=7, tf_scale="log1p")), + ]): pipe = make_pipeline( CountVectorizer(min_df=5, stop_words="english"), scaler, @@ -33,7 +34,7 @@ ) X_tr = pipe.fit_transform(X, y) est = LogisticRegression(random_state=2, solver="liblinear") - #est = LinearSVC() + est = LinearSVC() scoring={ 'F1-macro': lambda est, X, y: f1_score(y, est.predict(X), average="macro"), 'balanced_accuracy': "balanced_accuracy" diff --git a/sklearn_extra/feature_weighting/_text.py b/sklearn_extra/feature_weighting/_text.py index dbee04d0..747a379c 100644 --- a/sklearn_extra/feature_weighting/_text.py +++ b/sklearn_extra/feature_weighting/_text.py @@ -21,12 +21,14 @@ class TfigmTransformer(BaseEstimator, TransformerMixin): 245-260. """ - def __init__(self, alpha=7.0): + def __init__(self, alpha=7.0, tf_scale=None): self.alpha = alpha + self.tf_scale = tf_scale def _fit(self, X, y): self._le = LabelEncoder().fit(y) - class_freq = np.zeros((len(self._le.classes_), X.shape[1])) + n_class = len(self._le.classes_) + class_freq = np.zeros((n_class, X.shape[1])) X_nz = X != 0 if sp.issparse(X_nz): @@ -38,11 +40,10 @@ def _fit(self, X, y): class_freq[idx, :] = X_nz[y_mask].sum(axis=0) / n_samples self._class_freq = class_freq - self._class_rank = np.argsort(-self._class_freq, axis=0) - f1 = self._class_freq[ - self._class_rank[0, :], np.arange(self._class_freq.shape[1]) - ] - fk = (self._class_freq * (self._class_rank + 1)).sum(axis=0) + class_freq_sort = np.sort(self._class_freq, axis=0) + f1 = class_freq_sort[-1, :] + + fk = (class_freq_sort * np.arange(n_class, 0, -1)[:, None]).sum(axis=0) self.coef_ = 1 + self.alpha * (f1 / fk) return self @@ -52,6 +53,15 @@ def fit(self, X, y): return self def _transform(self, X): + if self.tf_scale is None: + pass + elif self.tf_scale == 'sqrt': + X = np.sqrt(X) + elif self.tf_scale == 'log1p': + X = np.log1p(X) + else: + raise ValueError + if sp.issparse(X): X_tr = X @ sp.diags(self.coef_) else: diff --git a/sklearn_extra/feature_weighting/tests/test_text.py b/sklearn_extra/feature_weighting/tests/test_text.py index c1d6d9c6..c088e058 100644 --- a/sklearn_extra/feature_weighting/tests/test_text.py +++ b/sklearn_extra/feature_weighting/tests/test_text.py @@ -12,4 +12,4 @@ def test_tfigm_transform(): est = TfigmTransformer() X_tr = est.fit_transform(X, y) assert X_tr.shape == X.shape - assert_allclose(est.coef_, [4.5, 2.75, 1.777778], rtol=1e-4) + assert_allclose(est.coef_, [3.333333, 4.5, 2.166667], rtol=1e-5) From 04fd530d11fed438d5619d4a71902eaa4f273c45 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Thu, 26 Dec 2019 19:57:13 +0100 Subject: [PATCH 04/11] Improve example --- examples/feature_weighting/plot_tfigm_text.py | 73 +++++++++++++++++++ .../plot_tfigm_weighting_text.py | 55 -------------- sklearn_extra/feature_weighting/_text.py | 5 +- sklearn_extra/tests/test_common.py | 9 ++- 4 files changed, 85 insertions(+), 57 deletions(-) create mode 100644 examples/feature_weighting/plot_tfigm_text.py delete mode 100644 examples/feature_weighting/plot_tfigm_weighting_text.py diff --git a/examples/feature_weighting/plot_tfigm_text.py b/examples/feature_weighting/plot_tfigm_text.py new file mode 100644 index 00000000..c014f6b3 --- /dev/null +++ b/examples/feature_weighting/plot_tfigm_text.py @@ -0,0 +1,73 @@ +import os + +import pandas as pd +import numpy as np + +from sklearn.svm import LinearSVC +from sklearn.preprocessing import Normalizer, FunctionTransformer +from sklearn.pipeline import make_pipeline +from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer +from sklearn.datasets import fetch_20newsgroups +from sklearn.model_selection import cross_validate +from sklearn.metrics import f1_score + +from sklearn_extra.feature_weighting import TfigmTransformer + +if "CI" in os.environ: + # make this example run faster in CI + categories = ["sci.crypt", "comp.graphics", 'comp.sys.mac.hardware'] +else: + categories = None + +docs, y = fetch_20newsgroups(return_X_y=True, categories=categories) + + +vect = CountVectorizer(min_df=5, stop_words="english", ngram_range=(1, 2)) +X = vect.fit_transform(docs) + +res = [] + +for scaler_label, scaler in [ + ("identity", FunctionTransformer(lambda x: x)), + ("TF-IDF(sublinear_tf=False)", TfidfTransformer()), + ("TF-IDF(sublinear_tf=True)", TfidfTransformer(sublinear_tf=True)), + ("TF-IGM(tf_scale=None)", TfigmTransformer(alpha=5)), + ( + "TF-IGM(tf_scale='sqrt')", + TfigmTransformer(alpha=5, tf_scale="sqrt"), + ), + ( + "TF-IGM(tf_scale='log1p')", + TfigmTransformer(alpha=5, tf_scale="log1p"), + ), +]: + pipe = make_pipeline(scaler, Normalizer()) + X_tr = pipe.fit_transform(X, y) + est = LinearSVC() + scoring = { + "F1-macro": lambda est, X, y: f1_score( + y, est.predict(X), average="macro" + ), + "balanced_accuracy": "balanced_accuracy", + } + scores = cross_validate( + est, + X_tr, + y, + scoring=scoring, + ) + for key, val in scores.items(): + if not key.endswith("_time"): + res.append({ + "metric": "_".join(key.split("_")[1:]), + "subset": key.split("_")[0], + "preprocessing": scaler_label, + "score": f"{val.mean():.3f}+-{val.std():.3f}", + }) +scores = ( + pd.DataFrame(res) + .set_index(["preprocessing", "metric", "subset"])["score"] + .unstack(-1) +) +scores = scores["test"].unstack(-1).sort_values("F1-macro", ascending=False) +print(scores) diff --git a/examples/feature_weighting/plot_tfigm_weighting_text.py b/examples/feature_weighting/plot_tfigm_weighting_text.py deleted file mode 100644 index 9881af6d..00000000 --- a/examples/feature_weighting/plot_tfigm_weighting_text.py +++ /dev/null @@ -1,55 +0,0 @@ -import pandas as pd -import numpy as np -from tqdm import tqdm - -from sklearn.linear_model import LogisticRegression -from sklearn.svm import LinearSVC -from sklearn.preprocessing import Normalizer, FunctionTransformer -from sklearn.pipeline import make_pipeline -from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer -from sklearn.datasets import fetch_20newsgroups -from sklearn.model_selection import cross_validate -from sklearn.metrics import f1_score - -from sklearn_extra.feature_weighting import TfigmTransformer - - -X, y = fetch_20newsgroups(return_X_y=True) - -#print('classes:', pd.Series(y).value_counts()) -res = [] - -for scaler_label, scaler in tqdm([ - ("identity", FunctionTransformer(lambda x: x)), - ("TF-IDF(sublinar_tf=False)", TfidfTransformer()), - ("TF-IDF(sublinear_tf=True)", TfidfTransformer(sublinear_tf=True)), - ("TF-IGM(tf_scale=None)", TfigmTransformer(alpha=7)), - ("TF-IGM(tf_scale='sqrt')", TfigmTransformer(alpha=7, tf_scale="sqrt")), - ("TF-IGM(tf_scale='log1p')", TfigmTransformer(alpha=7, tf_scale="log1p")), - ]): - pipe = make_pipeline( - CountVectorizer(min_df=5, stop_words="english"), - scaler, - Normalizer() - ) - X_tr = pipe.fit_transform(X, y) - est = LogisticRegression(random_state=2, solver="liblinear") - est = LinearSVC() - scoring={ - 'F1-macro': lambda est, X, y: f1_score(y, est.predict(X), average="macro"), - 'balanced_accuracy': "balanced_accuracy" - } - scores = cross_validate( - est, X_tr, y, verbose=0, - n_jobs=6, - scoring=scoring, - return_train_score=True - ) - res.extend([{'metric': "_".join(key.split('_')[1:]), - 'subset': key.split('_')[0], - "preprocessing": scaler_label, - "score": f"{val.mean():.3f}+-{val.std():.3f}"} - for key, val in scores.items() if not key.endswith('_time')]) -scores = pd.DataFrame(res).set_index(["preprocessing", "metric", 'subset'])['score'].unstack(-1) -scores = scores['test'].unstack(-1).sort_values("F1-macro", ascending=False) -print(scores) diff --git a/sklearn_extra/feature_weighting/_text.py b/sklearn_extra/feature_weighting/_text.py index 747a379c..d247c387 100644 --- a/sklearn_extra/feature_weighting/_text.py +++ b/sklearn_extra/feature_weighting/_text.py @@ -44,7 +44,10 @@ def _fit(self, X, y): f1 = class_freq_sort[-1, :] fk = (class_freq_sort * np.arange(n_class, 0, -1)[:, None]).sum(axis=0) - self.coef_ = 1 + self.alpha * (f1 / fk) + weight = f1 / fk + # scale weights to [0, 1] + weight = ((1 + n_class)*n_class*weight - 2) / ((1 + n_class)*n_class - 2) + self.coef_ = 1 + self.alpha * weight return self def fit(self, X, y): diff --git a/sklearn_extra/tests/test_common.py b/sklearn_extra/tests/test_common.py index 587b8249..7cd67610 100644 --- a/sklearn_extra/tests/test_common.py +++ b/sklearn_extra/tests/test_common.py @@ -4,8 +4,15 @@ from sklearn_extra.kernel_approximation import Fastfood from sklearn_extra.kernel_methods import EigenProClassifier, EigenProRegressor from sklearn_extra.cluster import KMedoids +from sklearn_extra.feature_weighting import TfigmTransformer -ALL_ESTIMATORS = [Fastfood, KMedoids, EigenProClassifier, EigenProRegressor] +ALL_ESTIMATORS = [ + Fastfood, + KMedoids, + EigenProClassifier, + EigenProRegressor, + TfigmTransformer, +] if hasattr(estimator_checks, "parametrize_with_checks"): # Common tests are only run on scikit-learn 0.22+ From bcff83e0c7160c31251bc51a3f0210c0058d3df5 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Thu, 26 Dec 2019 20:33:23 +0100 Subject: [PATCH 05/11] Improve docstrings --- examples/feature_weighting/plot_tfigm_text.py | 41 +++--- sklearn_extra/feature_weighting/_text.py | 119 ++++++++++++++++-- 2 files changed, 125 insertions(+), 35 deletions(-) diff --git a/examples/feature_weighting/plot_tfigm_text.py b/examples/feature_weighting/plot_tfigm_text.py index c014f6b3..ddfa237d 100644 --- a/examples/feature_weighting/plot_tfigm_text.py +++ b/examples/feature_weighting/plot_tfigm_text.py @@ -15,7 +15,7 @@ if "CI" in os.environ: # make this example run faster in CI - categories = ["sci.crypt", "comp.graphics", 'comp.sys.mac.hardware'] + categories = ["sci.crypt", "comp.graphics", "comp.sys.mac.hardware"] else: categories = None @@ -28,18 +28,12 @@ res = [] for scaler_label, scaler in [ - ("identity", FunctionTransformer(lambda x: x)), - ("TF-IDF(sublinear_tf=False)", TfidfTransformer()), - ("TF-IDF(sublinear_tf=True)", TfidfTransformer(sublinear_tf=True)), - ("TF-IGM(tf_scale=None)", TfigmTransformer(alpha=5)), - ( - "TF-IGM(tf_scale='sqrt')", - TfigmTransformer(alpha=5, tf_scale="sqrt"), - ), - ( - "TF-IGM(tf_scale='log1p')", - TfigmTransformer(alpha=5, tf_scale="log1p"), - ), + ("identity", FunctionTransformer(lambda x: x)), + ("TF-IDF(sublinear_tf=False)", TfidfTransformer()), + ("TF-IDF(sublinear_tf=True)", TfidfTransformer(sublinear_tf=True)), + ("TF-IGM(tf_scale=None)", TfigmTransformer()), + ("TF-IGM(tf_scale='sqrt')", TfigmTransformer(tf_scale="sqrt"),), + ("TF-IGM(tf_scale='log1p')", TfigmTransformer(tf_scale="log1p"),), ]: pipe = make_pipeline(scaler, Normalizer()) X_tr = pipe.fit_transform(X, y) @@ -50,20 +44,17 @@ ), "balanced_accuracy": "balanced_accuracy", } - scores = cross_validate( - est, - X_tr, - y, - scoring=scoring, - ) + scores = cross_validate(est, X_tr, y, scoring=scoring,) for key, val in scores.items(): if not key.endswith("_time"): - res.append({ - "metric": "_".join(key.split("_")[1:]), - "subset": key.split("_")[0], - "preprocessing": scaler_label, - "score": f"{val.mean():.3f}+-{val.std():.3f}", - }) + res.append( + { + "metric": "_".join(key.split("_")[1:]), + "subset": key.split("_")[0], + "preprocessing": scaler_label, + "score": f"{val.mean():.3f}+-{val.std():.3f}", + } + ) scores = ( pd.DataFrame(res) .set_index(["preprocessing", "metric", "subset"])["score"] diff --git a/sklearn_extra/feature_weighting/_text.py b/sklearn_extra/feature_weighting/_text.py index d247c387..fa7aa239 100644 --- a/sklearn_extra/feature_weighting/_text.py +++ b/sklearn_extra/feature_weighting/_text.py @@ -7,12 +7,55 @@ class TfigmTransformer(BaseEstimator, TransformerMixin): - """Apply TF-IGM feature weighting + """TF-IGM feature weighting. + + TF-IGM (Inverse Gravity Momentum) is a supervised + feature weighting scheme for classification tasks that measures + class distinguishing power. + + See User Guide for mode details. Parameters ---------- - alpha : float, default=7 - regularization parameter + alpha : float, default=0.15 + regularization parameter. Known good default values are 0.14 - 0.20. + tf_scale : {"sqrt", "log1p"}, default=None + if not None, scaling applied to term frequency. Possible scaling values are, + - "sqrt": square root scaling + - "log1p": ``log(1 + tf)`` scaling. This option corresponds to + ``sublinear_tf=True`` parameter in + :class:`~sklearn.feature_extraction.text.TfidfTransformer`. + + Attributes + ---------- + igm_ : array of shape (n_features,) + The Inverse Gravity Moment (IGM) weight. + coef_ : array of shape (n_features,) + Regularized IGM weight corresponding to ``alpha + igm_`` + + Examples + -------- + >>> from sklearn.feature_extraction.text import CountVectorizer + >>> from sklearn.pipeline import Pipeline + >>> from sklearn_extra.feature_weighting import TfigmTransformer + >>> corpus = ['this is the first document', + ... 'this document is the second document', + ... 'and this is the third one', + ... 'is this the first document'] + >>> y = ['1', '2', '1', '2'] + >>> pipe = Pipeline([('count', CountVectorizer()), + ... ('tfigm', TfigmTransformer())]).fit(corpus, y) + >>> pipe['count'].transform(corpus).toarray() + array([[0, 1, 1, 1, 0, 0, 1, 0, 1], + [0, 2, 0, 1, 0, 1, 1, 0, 1], + [1, 0, 0, 1, 1, 0, 1, 1, 1], + [0, 1, 1, 1, 0, 0, 1, 0, 1]]) + >>> pipe['tfigm'].igm_ + array([1. , 0.25, 0. , 0. , 1. , 1. , 0. , 1. , 0. ]) + >>> pipe['tfigm'].coef_ + array([1.15, 0.4 , 0.15, 0.15, 1.15, 1.15, 0.15, 1.15, 0.15]) + >>> pipe.transform(corpus).shape + (4, 9) References ---------- @@ -20,12 +63,20 @@ class TfigmTransformer(BaseEstimator, TransformerMixin): in text classification." Expert Systems with Applications 66 (2016): 245-260. """ - - def __init__(self, alpha=7.0, tf_scale=None): + def __init__(self, alpha=0.15, tf_scale=None): self.alpha = alpha self.tf_scale = tf_scale def _fit(self, X, y): + """Learn the igm vector (global term weights) + + Parameters + ---------- + X : {array-like, sparse matrix} of (n_samples, n_features) + a matrix of term/token counts + y : array-like of shape (n_samples,) + target classes + """ self._le = LabelEncoder().fit(y) n_class = len(self._le.classes_) class_freq = np.zeros((n_class, X.shape[1])) @@ -44,23 +95,47 @@ def _fit(self, X, y): f1 = class_freq_sort[-1, :] fk = (class_freq_sort * np.arange(n_class, 0, -1)[:, None]).sum(axis=0) - weight = f1 / fk + # avoid division by zero + igm = np.divide(f1, fk, out=np.zeros_like(f1), where=(fk != 0)) # scale weights to [0, 1] - weight = ((1 + n_class)*n_class*weight - 2) / ((1 + n_class)*n_class - 2) - self.coef_ = 1 + self.alpha * weight + self.igm_ = ((1 + n_class) * n_class * igm - 2) / ( + (1 + n_class) * n_class - 2 + ) + self.coef_ = self.alpha + self.igm_ return self def fit(self, X, y): + """Learn the igm vector (global term weights) + + Parameters + ---------- + X : {array-like, sparse matrix} of (n_samples, n_features) + a matrix of term/token counts + y : array-like of shape (n_samples,) + target classes + """ X, y = check_X_y(X, y, accept_sparse=["csr", "csc"]) self._fit(X, y) return self def _transform(self, X): + """Transform a count matrix to a TF-IGM representation + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + a matrix of term/token counts + + Returns + ------- + vectors : {ndarray, sparse matrix} of shape (n_samples, n_features) + transformed matrix + """ if self.tf_scale is None: pass - elif self.tf_scale == 'sqrt': + elif self.tf_scale == "sqrt": X = np.sqrt(X) - elif self.tf_scale == 'log1p': + elif self.tf_scale == "log1p": X = np.log1p(X) else: raise ValueError @@ -72,11 +147,35 @@ def _transform(self, X): return X_tr def transform(self, X): + """Transform a count matrix to a TF-IGM representation + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + a matrix of term/token counts + + Returns + ------- + vectors : {ndarray, sparse matrix} of shape (n_samples, n_features) + transformed matrix + """ X = check_array(X, accept_sparse=["csr", "csc"]) X_tr = self._transform(X) return X_tr def fit_transform(self, X, y): + """Transform a count matrix to a TF-IGM representation + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + a matrix of term/token counts + + Returns + ------- + vectors : {ndarray, sparse matrix} of shape (n_samples, n_features) + transformed matrix + """ X, y = check_X_y(X, y, accept_sparse=["csr", "csc"]) self._fit(X, y) X_tr = self._transform(X) From 0c08feaa97803e34cc64abf0ede64c8545865664 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Thu, 26 Dec 2019 23:33:27 +0100 Subject: [PATCH 06/11] TST Additional tests --- sklearn_extra/feature_weighting/_text.py | 12 ++- .../feature_weighting/tests/test_text.py | 76 +++++++++++++++++-- 2 files changed, 79 insertions(+), 9 deletions(-) diff --git a/sklearn_extra/feature_weighting/_text.py b/sklearn_extra/feature_weighting/_text.py index fa7aa239..3a0cae5f 100644 --- a/sklearn_extra/feature_weighting/_text.py +++ b/sklearn_extra/feature_weighting/_text.py @@ -63,6 +63,7 @@ class distinguishing power. in text classification." Expert Systems with Applications 66 (2016): 245-260. """ + def __init__(self, alpha=0.15, tf_scale=None): self.alpha = alpha self.tf_scale = tf_scale @@ -97,10 +98,13 @@ def _fit(self, X, y): fk = (class_freq_sort * np.arange(n_class, 0, -1)[:, None]).sum(axis=0) # avoid division by zero igm = np.divide(f1, fk, out=np.zeros_like(f1), where=(fk != 0)) - # scale weights to [0, 1] - self.igm_ = ((1 + n_class) * n_class * igm - 2) / ( - (1 + n_class) * n_class - 2 - ) + if n_class > 1: + # scale weights to [0, 1] + self.igm_ = ((1 + n_class) * n_class * igm - 2) / ( + (1 + n_class) * n_class - 2 + ) + else: + self.igm_ = igm self.coef_ = self.alpha + self.igm_ return self diff --git a/sklearn_extra/feature_weighting/tests/test_text.py b/sklearn_extra/feature_weighting/tests/test_text.py index c088e058..fe60c0d0 100644 --- a/sklearn_extra/feature_weighting/tests/test_text.py +++ b/sklearn_extra/feature_weighting/tests/test_text.py @@ -1,15 +1,81 @@ import numpy as np -from numpy.testing import assert_allclose +from numpy.testing import assert_allclose, assert_array_less import scipy.sparse as sp + +import pytest + from sklearn_extra.feature_weighting import TfigmTransformer +from sklearn.datasets import make_classification -def test_tfigm_transform(): +@pytest.mark.parametrize("array_format", ["dense", "csr", "csc", "coo"]) +def test_tfigm_transform(array_format): X = np.array([[0, 1, 1], [1, 0, 1], [0, 0, 1], [1, 1, 1]]) - X = sp.csr_matrix(X) + if array_format != "dense": + X = sp.csr_matrix(X).asformat(array_format) y = np.array(["a", "b", "a", "c"]) - est = TfigmTransformer() + alpha = 0.2 + est = TfigmTransformer(alpha=alpha) X_tr = est.fit_transform(X, y) + + assert_allclose(est.igm_, [0.20, 0.40, 0.0]) + assert_allclose(est.igm_ + alpha, est.coef_) + assert X_tr.shape == X.shape - assert_allclose(est.coef_, [3.333333, 4.5, 2.166667], rtol=1e-5) + assert sp.issparse(X_tr) is (array_format != "dense") + + if array_format == "dense": + assert_allclose(X * est.coef_[None, :], X_tr) + else: + assert_allclose(X.A * est.coef_[None, :], X_tr.A) + + +def test_tfigm_synthetic(): + X, y = make_classification( + n_samples=100, + n_features=10, + n_informative=5, + n_redundant=0, + random_state=0, + n_classes=5, + shuffle=False, + ) + X = (X > 0).astype(np.float) + + est = TfigmTransformer() + est.fit(X, y) + # informative features have higher IGM weights than noisy ones. + # (athough here we lose a lot of information due to thresholding of X). + assert est.igm_[:5].mean() / est.igm_[5:].mean() > 3 + + +@pytest.mark.parametrize("n_class", [2, 5]) +def test_tfigm_random_distribution(n_class): + rng = np.random.RandomState(0) + n_samples, n_features = 500, 4 + X = rng.randint(2, size=(n_samples, n_features)) + y = rng.randint(n_class, size=(n_samples,)) + + est = TfigmTransformer() + X_tr = est.fit_transform(X, y) + + # all weighs are strictly positive + assert_array_less(0, est.igm_) + # and close to zero, since none of the features are discriminant + assert_array_less(est.igm_, 0.05) + + +def test_tfigm_valid_target(): + X = np.array([[0, 1, 1], [1, 0, 1], [0, 0, 1], [1, 1, 1]]) + y = None + + est = TfigmTransformer() + with pytest.raises(ValueError, match="y cannot be None"): + est.fit(X, y) + + # check asymptotic behaviour for 1 class + y = [1, 1, 1, 1] + est = TfigmTransformer() + est.fit(X, y) + assert_allclose(est.igm_[0], np.ones(3)) From 7a8ce1f7f2980aca0d0b132bec50932276bf2768 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Fri, 27 Dec 2019 00:11:07 +0100 Subject: [PATCH 07/11] Style improvements --- examples/feature_weighting/plot_tfigm_text.py | 10 ++++++---- sklearn_extra/feature_weighting/__init__.py | 2 ++ sklearn_extra/feature_weighting/_text.py | 4 ++++ sklearn_extra/feature_weighting/tests/test_text.py | 2 ++ 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/examples/feature_weighting/plot_tfigm_text.py b/examples/feature_weighting/plot_tfigm_text.py index ddfa237d..4f1f4b2f 100644 --- a/examples/feature_weighting/plot_tfigm_text.py +++ b/examples/feature_weighting/plot_tfigm_text.py @@ -1,4 +1,6 @@ -import os +# License: BSD 3 clause +# +# Authors: Roman Yurchak import pandas as pd import numpy as np @@ -22,13 +24,13 @@ docs, y = fetch_20newsgroups(return_X_y=True, categories=categories) -vect = CountVectorizer(min_df=5, stop_words="english", ngram_range=(1, 2)) +vect = CountVectorizer(min_df=5, stop_words="english", ngram_range=(1, 1)) X = vect.fit_transform(docs) res = [] for scaler_label, scaler in [ - ("identity", FunctionTransformer(lambda x: x)), + ("TF", FunctionTransformer(lambda x: x)), ("TF-IDF(sublinear_tf=False)", TfidfTransformer()), ("TF-IDF(sublinear_tf=True)", TfidfTransformer(sublinear_tf=True)), ("TF-IGM(tf_scale=None)", TfigmTransformer()), @@ -52,7 +54,7 @@ "metric": "_".join(key.split("_")[1:]), "subset": key.split("_")[0], "preprocessing": scaler_label, - "score": f"{val.mean():.3f}+-{val.std():.3f}", + "score": f"{val.mean():.3f}±{val.std():.3f}", } ) scores = ( diff --git a/sklearn_extra/feature_weighting/__init__.py b/sklearn_extra/feature_weighting/__init__.py index ff7809f1..a87491c5 100644 --- a/sklearn_extra/feature_weighting/__init__.py +++ b/sklearn_extra/feature_weighting/__init__.py @@ -1,3 +1,5 @@ +# License: BSD 3 clause + from ._text import TfigmTransformer __all__ = ["TfigmTransformer"] diff --git a/sklearn_extra/feature_weighting/_text.py b/sklearn_extra/feature_weighting/_text.py index 3a0cae5f..cabe8f6e 100644 --- a/sklearn_extra/feature_weighting/_text.py +++ b/sklearn_extra/feature_weighting/_text.py @@ -1,3 +1,7 @@ +# License: BSD 3 clause +# +# Authors: Roman Yurchak + import numpy as np import scipy.sparse as sp diff --git a/sklearn_extra/feature_weighting/tests/test_text.py b/sklearn_extra/feature_weighting/tests/test_text.py index fe60c0d0..fa43517e 100644 --- a/sklearn_extra/feature_weighting/tests/test_text.py +++ b/sklearn_extra/feature_weighting/tests/test_text.py @@ -1,3 +1,5 @@ +# License: BSD 3 clause + import numpy as np from numpy.testing import assert_allclose, assert_array_less import scipy.sparse as sp From 690a0850d47e97c3f505049acbb189c3a3c18f23 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Fri, 27 Dec 2019 00:15:25 +0100 Subject: [PATCH 08/11] flake8 --- examples/feature_weighting/plot_tfigm_text.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/feature_weighting/plot_tfigm_text.py b/examples/feature_weighting/plot_tfigm_text.py index 4f1f4b2f..111da066 100644 --- a/examples/feature_weighting/plot_tfigm_text.py +++ b/examples/feature_weighting/plot_tfigm_text.py @@ -3,7 +3,6 @@ # Authors: Roman Yurchak import pandas as pd -import numpy as np from sklearn.svm import LinearSVC from sklearn.preprocessing import Normalizer, FunctionTransformer From 4f988169b2d2a163f3dd58190d33f7169be7d422 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Fri, 27 Dec 2019 00:45:19 +0100 Subject: [PATCH 09/11] Better parameter validation --- examples/feature_weighting/plot_tfigm_text.py | 2 +- sklearn_extra/feature_weighting/_text.py | 31 +++++++++++++------ .../feature_weighting/tests/test_text.py | 16 +++++++++- 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/examples/feature_weighting/plot_tfigm_text.py b/examples/feature_weighting/plot_tfigm_text.py index 111da066..e32034fa 100644 --- a/examples/feature_weighting/plot_tfigm_text.py +++ b/examples/feature_weighting/plot_tfigm_text.py @@ -53,7 +53,7 @@ "metric": "_".join(key.split("_")[1:]), "subset": key.split("_")[0], "preprocessing": scaler_label, - "score": f"{val.mean():.3f}±{val.std():.3f}", + "score": "{:.3f}±{:.3f}".format(val.mean(), val.std()), } ) scores = ( diff --git a/sklearn_extra/feature_weighting/_text.py b/sklearn_extra/feature_weighting/_text.py index cabe8f6e..029d38e4 100644 --- a/sklearn_extra/feature_weighting/_text.py +++ b/sklearn_extra/feature_weighting/_text.py @@ -82,6 +82,21 @@ def _fit(self, X, y): y : array-like of shape (n_samples,) target classes """ + tf_scale_map = {None: None, "sqrt": np.sqrt, "log1p": np.log1p} + + if self.tf_scale not in tf_scale_map: + raise ValueError( + "tf_scale={} should be one of {}.".format( + self.tf_scale, list(tf_scale_map) + ) + ) + self._tf_scale_func = tf_scale_map[self.tf_scale] + + if not isinstance(self.alpha, float) or self.alpha < 0: + raise ValueError( + "alpha={} must be a positive number.".format(self.alpha) + ) + self._le = LabelEncoder().fit(y) n_class = len(self._le.classes_) class_freq = np.zeros((n_class, X.shape[1])) @@ -103,12 +118,16 @@ def _fit(self, X, y): # avoid division by zero igm = np.divide(f1, fk, out=np.zeros_like(f1), where=(fk != 0)) if n_class > 1: - # scale weights to [0, 1] + # although Chen et al. paper states that it is not mandatory, we + # allways re-scale weights to [0, 1], otherwise with 2 classes + # we would get a minimal IGM value of 1/3. self.igm_ = ((1 + n_class) * n_class * igm - 2) / ( (1 + n_class) * n_class - 2 ) else: self.igm_ = igm + # In the Chen et al. paper the regularization parameter is defined + # as 1/alpha. self.coef_ = self.alpha + self.igm_ return self @@ -139,14 +158,8 @@ def _transform(self, X): vectors : {ndarray, sparse matrix} of shape (n_samples, n_features) transformed matrix """ - if self.tf_scale is None: - pass - elif self.tf_scale == "sqrt": - X = np.sqrt(X) - elif self.tf_scale == "log1p": - X = np.log1p(X) - else: - raise ValueError + if self._tf_scale_func is not None: + X = self._tf_scale_func(X) if sp.issparse(X): X_tr = X @ sp.diags(self.coef_) diff --git a/sklearn_extra/feature_weighting/tests/test_text.py b/sklearn_extra/feature_weighting/tests/test_text.py index fa43517e..412a67c1 100644 --- a/sklearn_extra/feature_weighting/tests/test_text.py +++ b/sklearn_extra/feature_weighting/tests/test_text.py @@ -80,4 +80,18 @@ def test_tfigm_valid_target(): y = [1, 1, 1, 1] est = TfigmTransformer() est.fit(X, y) - assert_allclose(est.igm_[0], np.ones(3)) + assert_allclose(est.igm_, np.ones(3)) + + +def test_tfigm_valid_target(): + X = np.array([[0, 1, 1], [1, 0, 1], [0, 0, 1], [1, 1, 1]]) + y = [1, 1, 2, 2] + + est = TfigmTransformer(alpha=-1) + with pytest.raises(ValueError, match="alpha=-1 must be a positive number"): + est.fit(X, y) + + est = TfigmTransformer(tf_scale="unknown") + msg = r"tf_scale=unknown should be one of \[None, 'sqrt'" + with pytest.raises(ValueError, match=msg): + est.fit(X, y) From ddbf828b8b55f882ad74fa5c2ec15d2b5301a622 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Fri, 27 Dec 2019 01:01:37 +0100 Subject: [PATCH 10/11] FIX for legacy scipy --- sklearn_extra/feature_weighting/_text.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn_extra/feature_weighting/_text.py b/sklearn_extra/feature_weighting/_text.py index 029d38e4..1223577a 100644 --- a/sklearn_extra/feature_weighting/_text.py +++ b/sklearn_extra/feature_weighting/_text.py @@ -102,8 +102,8 @@ def _fit(self, X, y): class_freq = np.zeros((n_class, X.shape[1])) X_nz = X != 0 - if sp.issparse(X_nz): - X_nz = X_nz.asformat("csr", copy=False) + if sp.issparse(X_nz) and X_nz.getformat() != 'csr': + X_nz = X_nz.asformat("csr") for idx, class_label in enumerate(self._le.classes_): y_mask = y == class_label From f4159066508bb7e916324e691dbdd2e735e12ef9 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Fri, 27 Dec 2019 10:50:56 +0100 Subject: [PATCH 11/11] DOC Add to reference API --- doc/api.rst | 11 +++++++++++ examples/feature_weighting/plot_tfigm_text.py | 1 + sklearn_extra/feature_weighting/_text.py | 2 +- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/doc/api.rst b/doc/api.rst index 86b8d333..c2b8215f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -4,6 +4,17 @@ scikit-learn-extra API .. currentmodule:: sklearn_extra + +Feature weighting +================= + +.. autosummary:: + :toctree: generated/ + :template: class.rst + + feature_weighting.TfigmTransformer + + Kernel approximation ==================== diff --git a/examples/feature_weighting/plot_tfigm_text.py b/examples/feature_weighting/plot_tfigm_text.py index e32034fa..5d7c9c44 100644 --- a/examples/feature_weighting/plot_tfigm_text.py +++ b/examples/feature_weighting/plot_tfigm_text.py @@ -1,6 +1,7 @@ # License: BSD 3 clause # # Authors: Roman Yurchak +import os import pandas as pd diff --git a/sklearn_extra/feature_weighting/_text.py b/sklearn_extra/feature_weighting/_text.py index 1223577a..58a019e4 100644 --- a/sklearn_extra/feature_weighting/_text.py +++ b/sklearn_extra/feature_weighting/_text.py @@ -102,7 +102,7 @@ def _fit(self, X, y): class_freq = np.zeros((n_class, X.shape[1])) X_nz = X != 0 - if sp.issparse(X_nz) and X_nz.getformat() != 'csr': + if sp.issparse(X_nz) and X_nz.getformat() != "csr": X_nz = X_nz.asformat("csr") for idx, class_label in enumerate(self._le.classes_):