@@ -697,16 +697,13 @@ def _get_args(function, varargs=False):
697
697
return args
698
698
699
699
700
- def _get_func_name (func , class_name = None ):
700
+ def _get_func_name (func ):
701
701
"""Get function full name
702
702
703
703
Parameters
704
704
----------
705
705
func : callable
706
706
The function object.
707
- class_name : string, optional (default: None)
708
- If ``func`` is a class method and the class name is known specify
709
- class_name for the error message.
710
707
711
708
Returns
712
709
-------
@@ -717,16 +714,16 @@ def _get_func_name(func, class_name=None):
717
714
module = inspect .getmodule (func )
718
715
if module :
719
716
parts .append (module .__name__ )
720
- if class_name is not None :
721
- parts . append ( class_name )
722
- elif hasattr ( func , 'im_class' ) :
723
- parts .append (func . im_class . __name__ )
717
+
718
+ qualname = func . __qualname__
719
+ if qualname != func . __name__ :
720
+ parts .append (qualname [: qualname . find ( '.' )] )
724
721
725
722
parts .append (func .__name__ )
726
723
return '.' .join (parts )
727
724
728
725
729
- def check_docstring_parameters (func , doc = None , ignore = None , class_name = None ):
726
+ def check_docstring_parameters (func , doc = None , ignore = None ):
730
727
"""Helper to check docstring
731
728
732
729
Parameters
@@ -737,9 +734,6 @@ def check_docstring_parameters(func, doc=None, ignore=None, class_name=None):
737
734
Docstring if it is passed manually to the test.
738
735
ignore : None | list
739
736
Parameters to ignore.
740
- class_name : string, optional (default: None)
741
- If ``func`` is a class method and the class name is known specify
742
- class_name for the error message.
743
737
744
738
Returns
745
739
-------
@@ -750,7 +744,7 @@ def check_docstring_parameters(func, doc=None, ignore=None, class_name=None):
750
744
incorrect = []
751
745
ignore = [] if ignore is None else ignore
752
746
753
- func_name = _get_func_name (func , class_name = class_name )
747
+ func_name = _get_func_name (func )
754
748
if (not func_name .startswith ('sklearn.' ) or
755
749
func_name .startswith ('sklearn.externals' )):
756
750
return incorrect
@@ -763,11 +757,13 @@ def check_docstring_parameters(func, doc=None, ignore=None, class_name=None):
763
757
# Dont check estimator_checks module
764
758
if func_name .split ('.' )[2 ] == 'estimator_checks' :
765
759
return incorrect
766
- args = list (filter (lambda x : x not in ignore , _get_args (func )))
760
+ # Get the arguments from the function signature
761
+ param_signature = list (filter (lambda x : x not in ignore , _get_args (func )))
767
762
# drop self
768
- if len (args ) > 0 and args [0 ] == 'self' :
769
- args .remove ('self' )
763
+ if len (param_signature ) > 0 and param_signature [0 ] == 'self' :
764
+
10000
param_signature .remove ('self' )
770
765
766
+ # Analyze function's docstring
771
767
if doc is None :
772
768
with warnings .catch_warnings (record = True ) as w :
773
769
try :
@@ -778,8 +774,9 @@ def check_docstring_parameters(func, doc=None, ignore=None, class_name=None):
778
774
if len (w ):
779
775
raise RuntimeError ('Error for %s:\n %s' % (func_name , w [0 ]))
780
776
781
- param_names = []
777
+ param_docs = []
782
778
for name , type_definition , param_doc in doc ['Parameters' ]:
779
+ # Type hints are empty only if parameter name ended with :
783
780
if not type_definition .strip ():
784
781
if ':' in name and name [:name .index (':' )][- 1 :].strip ():
785
782
incorrect += [func_name +
@@ -790,18 +787,65 @@ def check_docstring_parameters(func, doc=None, ignore=None, class_name=None):
790
787
' Parameter %r has an empty type spec. '
791
788
'Remove the colon' % (name .lstrip ())]
792
789
790
+ # Create a list of parameters to compare with the parameters gotten
791
+ # from the func signature
793
792
if '*' not in name :
794
- param_names .append (name .split (':' )[0 ].strip ('` ' ))
793
+ param_docs .append (name .split (':' )[0 ].strip ('` ' ))
795
794
796
- param_names = list (filter (lambda x : x not in ignore , param_names ))
795
+ # If one of the docstring's parameters had an error then return that
796
+ # incorrect message
797
+ if len (incorrect ) > 0 :
798
+ return incorrect
799
+
800
+ # Remove the parameters that should be ignored from list
801
+ param_docs = list (filter (lambda x : x not in ignore , param_docs ))
802
+
803
+ # The following is derived from pytest, Copyright (c) 2004-2017 Holger
804
+ # Krekel and others, Licensed under MIT License. See
805
+ # https://github.com/pytest-dev/pytest
806
+
807
+ message = []
808
+ for i in range (min (len (param_docs ), len (param_signature ))):
809
+ if param_signature [i ] != param_docs [i ]:
810
+ message += ["There's a parameter name mismatch in function"
811
+ " docstring w.r.t. function signature, at index %s"
812
+ " diff: %r != %r" %
813
+ (i , param_signature [i ], param_docs [i ])]
814
+ break
815
+ if len (param_signature ) > len (param_docs ):
816
+ message += ["Parameters in function docstring have less items w.r.t."
817
+ " function signature, first missing item: %s" %
818
+ param_signature [len (param_docs )]]
819
+
820
+ elif len (param_signature ) < len (param_docs ):
821
+ message += ["Parameters in function docstring have more items w.r.t."
822
+ " function signature, first extra item: %s" %
823
+ param_docs [len (param_signature )]]
824
+
825
+ # If there wasn't any difference in the parameters themselves between
826
+ # docstring and signature including having the same length then return
827
+ # empty list
828
+ if len (message ) == 0 :
829
+ return []
830
+
831
+ import difflib
832
+ import pprint
833
+
834
+ param_docs_formatted = pprint .pformat (param_docs ).splitlines ()
835
+ param_signature_formatted = pprint .pformat (param_signature ).splitlines ()
836
+
837
+ message += ["Full diff:" ]
838
+
839
+ message .extend (
840
+ line .strip () for line in difflib .ndiff (param_signature_formatted ,
841
+ param_docs_formatted )
842
+ )
843
+
844
+ incorrect .extend (message )
845
+
846
+ # Prepend function name
847
+ incorrect = ['In function: ' + func_name ] + incorrect
797
848
798
- if len (param_names ) != len (args ):
799
- bad = str (sorted (list (set (param_names ) ^ set (args ))))
800
- incorrect += [func_name + ' arg mismatch: ' + bad ]
801
- else :
802
- for n1 , n2 in zip (param_names , args ):
803
- if n1 != n2 :
804
- incorrect += [func_name + ' ' + n1 + ' != ' + n2 ]
805
849
return incorrect
806
850
807
851
0 commit comments