diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index ba7105e1ad6039..8a0575b136ef03 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -285,19 +285,29 @@ def discover(self, start_dir, pattern='test*.py', top_level_dir=None): sys.path.insert(0, top_level_dir) self._top_level_dir = top_level_dir - is_not_importable = False is_namespace = False tests = [] if os.path.isdir(os.path.abspath(start_dir)): - start_dir = os.path.abspath(start_dir) - if start_dir != top_level_dir: - is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py')) + if os.path.abspath(start_dir) != top_level_dir: + try: + # Convert to dot import + dot_import = '.'.join( + os.path.relpath(start_dir, top_level_dir).split(os.sep) + ) + __import__(dot_import) + except ImportError: + raise ImportError('Start directory is not importable: %r' % start_dir) + + if not os.path.isfile(os.path.join(start_dir, '__init__.py')): + is_namespace = True + + tests = list(self._find_tests(start_dir, pattern, is_namespace)) else: # support for discovery from dotted module names try: __import__(start_dir) except ImportError: - is_not_importable = True + raise ImportError('Start directory is not importable: %r' % start_dir) else: the_module = sys.modules[start_dir] top_part = start_dir.split('.')[0] @@ -341,12 +351,9 @@ def discover(self, start_dir, pattern='test*.py', top_level_dir=None): sys.path.remove(top_level_dir) else: sys.path.remove(top_level_dir) + if not is_namespace: + tests = list(self._find_tests(start_dir, pattern)) - if is_not_importable: - raise ImportError('Start directory is not importable: %r' % start_dir) - - if not is_namespace: - tests = list(self._find_tests(start_dir, pattern)) return self.suiteClass(tests) def _get_directory_containing_module(self, module_name):