@@ -1300,24 +1300,30 @@ def fit(self, X, y=None):
1300
1300
# `_fit` will only raise an error when `self.handle_unknown="error"`
1301
1301
self ._fit (X , handle_unknown = self .handle_unknown , force_all_finite = "allow-nan" )
1302
1302
1303
- if self .handle_unknown == "use_encoded_value" :
1304
- for feature_cats in self .categories_ :
1305
- if 0 <= self .unknown_value < len (feature_cats ):
1306
- raise ValueError (
1307
- "The used value for unknown_value "
1308
- f"{ self .unknown_value } is one of the "
1309
- "values already used for encoding the "
1310
- "seen categories."
1311
- )
1303
+ cardinalities = [len (categories ) for categories in self .categories_ ]
1312
1304
1313
1305
# stores the missing indices per category
1314
1306
self ._missing_indices = {}
1315
1307
for cat_idx , categories_for_idx in enumerate (self .categories_ ):
1316
1308
for i , cat in enumerate (categories_for_idx ):
1317
1309
if is_scalar_nan (cat ):
1318
1310
self ._missing_indices [cat_idx ] = i
1311
+
1312
+ # missing values are not considered part of the cardinality
1313
+ # when considering unknown categories or encoded_missing_value
1314
+ cardinalities [cat_idx ] -= 1
1319
1315
continue
1320
1316
1317
+ if self .handle_unknown == "use_encoded_value" :
1318
+ for cardinality in cardinalities :
1319
+ if 0 <= self .unknown_value < cardinality :
1320
+ raise ValueError (
1321
+ "The used value for unknown_value "
1322
+ f"{ self .unknown_value } is one of the "
1323
+ "values already used for encoding the "
1324
+ "seen categories."
1325
+ )
1326
+
1321
1327
if self ._missing_indices :
1322
1328
if np .dtype (self .dtype ).kind != "f" and is_scalar_nan (
1323
1329
self .encoded_missing_value
@@ -1336,9 +1342,9 @@ def fit(self, X, y=None):
1336
1342
# known category
1337
1343
invalid_features = [
1338
1344
cat_idx
1339
- for cat_idx , categories_for_idx in enumerate (self . categories_ )
1345
+ for cat_idx , cardinality in enumerate (cardinalities )
1340
1346
if cat_idx in self ._missing_indices
1341
- and 0 <= self .encoded_missing_value < len ( categories_for_idx )
1347
+ and 0 <= self .encoded_missing_value < cardinality
1342
1348
]
1343
1349
1344
1350
if invalid_features :
0 commit comments