@@ -1300,24 +1300,30 @@ def fit(self, X, y=None):
13001300 # `_fit` will only raise an error when `self.handle_unknown="error"`
13011301 self ._fit (X , handle_unknown = self .handle_unknown , force_all_finite = "allow-nan" )
13021302
1303- if self .handle_unknown == "use_encoded_value" :
1304- for feature_cats in self .categories_ :
1305- if 0 <= self .unknown_value < len (feature_cats ):
1306- raise ValueError (
1307- "The used value for unknown_value "
1308- f"{ self .unknown_value } is one of the "
1309- "values already used for encoding the "
1310- "seen categories."
1311- )
1303+ cardinalities = [len (categories ) for categories in self .categories_ ]
13121304
13131305 # stores the missing indices per category
13141306 self ._missing_indices = {}
13151307 for cat_idx , categories_for_idx in enumerate (self .categories_ ):
13161308 for i , cat in enumerate (categories_for_idx ):
13171309 if is_scalar_nan (cat ):
13181310 self ._missing_indices [cat_idx ] = i
1311+
1312+ # missing values are not considered part of the cardinality
1313+ # when considering unknown categories or encoded_missing_value
1314+ cardinalities [cat_idx ] -= 1
13191315 continue
13201316
1317+ if self .handle_unknown == "use_encoded_value" :
1318+ for cardinality in cardinalities :
1319+ if 0 <= self .unknown_value < cardinality :
1320+ raise ValueError (
1321+ "The used value for unknown_value "
1322+ f"{ self .unknown_value } is one of the "
1323+ "values already used for encoding the "
1324+ "seen categories."
1325+ )
1326+
13211327 if self ._missing_indices :
13221328 if np .dtype (self .dtype ).kind != "f" and is_scalar_nan (
13231329 self .encoded_missing_value
@@ -1336,9 +1342,9 @@ def fit(self, X, y=None):
13361342 # known category
13371343 invalid_features = [
13381344 cat_idx
1339- for cat_idx , categories_for_idx in enumerate (self . categories_ )
1345+ for cat_idx , cardinality in enumerate (cardinalities )
13401346 if cat_idx in self ._missing_indices
1341- and 0 <= self .encoded_missing_value < len ( categories_for_idx )
1347+ and 0 <= self .encoded_missing_value < cardinality
13421348 ]
13431349
13441350 if invalid_features :
0 commit comments