8000 Add basic support for user-defined mypy plugins (#3517) · python/mypy@fd0a416 · GitHub
[go: up one dir, main page]

Skip to content

Commit fd0a416

Browse files
authored
Add basic support for user-defined mypy plugins (#3517)
Configure them through "plugins=path/plugin.py, ..." in the ini file. The paths are relative to the configuration file. This is an almost minimal implementation and some features are missing: * Plugins installed through pip aren't properly supported. * Plugins within packages aren't properly supported. * Incremental mode doesn't invalidate cache files when plugins change. Also change path normalization in test cases in Windows. Previously we sometimes normalized to Windows paths and sometimes to Linux paths. Now switching to always use Linux paths.
1 parent 7d630b7 commit fd0a416

16 files changed

+253
-18
lines changed

mypy/build.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from mypy.stats import dump_type_stats
4343
from mypy.types import Type
4444
from mypy.version import __version__
45-
from mypy.plugin import DefaultPlugin
45+
from mypy.plugin import Plugin, DefaultPlugin, ChainedPlugin
4646

4747

4848
# We need to know the location of this file to load data, but
@@ -183,7 +183,9 @@ def build(sources: List[BuildSource],
183183
reports=reports,
184184
options=options,
185185
version_id=__version__,
186-
)
186+
plugin=DefaultPlugin(options.python_version))
187+
188+
manager.plugin = load_custom_plugins(manager.plugin, options, manager.errors)
187189

188190
try:
189191
graph = dispatch(sources, manager)
@@ -334,6 +336,67 @@ def import_priority(imp: ImportBase, toplevel_priority: int) -> int:
334336
return toplevel_priority
335337

336338

339+
def load_custom_plugins(default_plugin: Plugin, options: Options, errors: Errors) -> Plugin:
340+
"""Load custom plugins if any are configured.
341+
342+
Return a plugin that chains all custom plugins (if any) and falls
343+
back to default_plugin.
344+
"""
345+
346+
def plugin_error(message: str) -> None:
347+
errors.report(0, 0, message)
348+
errors.raise_error()
349+
350+
custom_plugins = []
351+
for plugin_path in options.plugins:
352+
if options.config_file:
353+
# Plugin paths are relative to the config file location.
354+
plugin_path = os.path.join(os.path.dirname(options.config_file), plugin_path)
355+
errors.set_file(plugin_path, None)
356+
357+
if not os.path.isfile(plugin_path):
358+
plugin_error("Can't find plugin")
359+
plugin_dir = os.path.dirname(plugin_path)
360+
fnam = os.path.basename(plugin_path)
361+
if not fnam.endswith('.py'):
362+
plugin_error("Plugin must have .py extension")
363+
module_name = fnam[:-3]
364+
import importlib
365+
sys.path.insert(0, plugin_dir)
366+
try:
367+
m = importlib.import_module(module_name)
368+
except Exception:
369+
print('Error importing plugin {}\n'.format(plugin_path))
370+
raise # Propagate to display traceback
371+
finally:
372+
assert sys.path[0] == plugin_dir
373+
del sys.path[0]
374+
if not hasattr(m, 'plugin'):
375+
plugin_error('Plugin does not define entry point function "plugin"')
376+
try:
377+
plugin_type = getattr(m, 'plugin')(__version__)
378+
except Exception:
379+
print('Error calling the plugin(version) entry point of {}\n'.format(plugin_path))
380+
raise # Propagate to display traceback
381+
if not isinstance(plugin_type, type):
382+
plugin_error(
383+
'Type object expected as the return value of "plugin" (got {!r})'.format(
384+
plugin_type))
385+
if not issubclass(plugin_type, Plugin):
386+
plugin_error(
387+
'Return value of "plugin" must be a subclass of "mypy.plugin.Plugin"')
388+
try:
389+
custom_plugins.append(plugin_type(options.python_version))
390+
except Exception:
391+
print('Error constructing plugin instance of {}\n'.format(plugin_type.__name__))
392+
raise # Propagate to display traceback
393+
if not custom_plugins:
394+
return default_plugin
395+
else:
396+
# Custom plugins take precendence over built-in plugins.
397+
return ChainedPlugin(options.python_version, custom_plugins + [default_plugin])
398+
399+
337400
# TODO: Get rid of all_types. It's not used except for one log message.
338401
# Maybe we could instead publish a map from module ID to its type_map.
339402
class BuildManager:
@@ -357,6 +420,7 @@ class BuildManager:
357420
missing_modules: Set of modules that could not be imported encountered so far
358421
stale_modules: Set of modules that needed to be rechecked
359422
version_id: The current mypy version (based on commit id when possible)
423+
plugin: Active mypy plugin(s)
360424
"""
361425

362426
def __init__(self, data_dir: str,
@@ -365,7 +429,8 @@ def __init__(self, data_dir: str,
365429
source_set: BuildSourceSet,
366430
reports: Reports,
367431
options: Options,
368-
version_id: str) -> None:
432+
version_id: str,
433+
plugin: Plugin) -> None:
369434
self.start_time = time.time()
370435
self.data_dir = data_dir
371436
self.errors = Errors(options.show_error_context, options.show_column_numbers)
@@ -385,6 +450,7 @@ def __init__(self, data_dir: str,
385450
self.indirection_detector = TypeIndirectionVisitor()
386451
self.stale_modules = set() # type: Set[str]
387452
self.rechecked_modules = set() # type: Set[str]
453+
self.plugin = plugin
388454

389455
def maybe_swap_for_shadow_path(self, path: str) -> str:
390456
if (self.options.shadow_file and
@@ -1549,9 +1615,8 @@ def type_check_first_pass(self) -> None:
15491615
if self.options.semantic_analysis_only:
15501616
return
15511617
with self.wrap_context():
1552-
plugin = DefaultPlugin(self.options.python_version)
15531618
self.type_checker = TypeChecker(manager.errors, manager.modules, self.options,
1554-
self.tree, self.xpath, plugin)
1619+
self.tree, self.xpath, manager.plugin)
15551620
self.type_checker.check_first_pass()
15561621

15571622
def type_check_second_pass(self) -> bool:

mypy/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def add_invertible_flag(flag: str,
379379
parser.parse_args(args, dummy)
380380
config_file = dummy.config_file
381381
if config_file is not None and not os.path.exists(config_file):
382-
parser.error("Cannot file config file '%s'" % config_file)
382+
parser.error("Cannot find config file '%s'" % config_file)
383383

384384
# Parse config file first, so command line can override.
385385
options = Options()
@@ -613,6 +613,7 @@ def get_init_file(dir: str) -> Optional[str]:
613613
# These two are for backwards compatibility
614614
'silent_imports': bool,
615615
'almost_silent': bool,
616+
'plugins': lambda s: [p.strip() for p in s.split(',')],
616617
}
617618

618619
SHARED_CONFIG_FILES = ('setup.cfg',)

mypy/options.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def __init__(self) -> None:
113113
self.debug_cache = False
114114
self.quick_and_dirty = False
115115

116+
# Paths of user plugins
117+
self.plugins = [] # type: List[str]
118+
116119
# Per-module options (raw)
117120
self.per_module_options = {} # type: Dict[Pattern[str], Dict[str, object]]
118121

mypy/plugin.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, List, Tuple, Optional, NamedTuple
1+
from typing import Callable, List, Tuple, Optional, NamedTuple, TypeVar
22

33
from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context
44
from mypy.types import (
@@ -60,7 +60,7 @@
6060

6161

6262
class Plugin:
63-
"""Base class of type checker plugins.
63+
"""Base class of all type checker plugins.
6464
6565
This defines a no-op plugin. Subclasses can override some methods to
6666
provide some actual functionality.
@@ -69,8 +69,6 @@ class Plugin:
6969
results might be cached).
7070
"""
7171

72-
# TODO: Way of chaining multiple plugins
73-
7472
def __init__(self, python_version: Tuple[int, int]) -> None:
7573
self.python_version = python_version
7674

@@ -86,6 +84,46 @@ def get_method_hook(self, fullname: str) -> Optional[MethodHook]:
8684
# TODO: metaclass / class decorator hook
8785

8886

87+
T = TypeVar('T')
88+
89+
90+
class ChainedPlugin(Plugin):
91+
"""A plugin that represents a sequence of chained plugins.
92+
93+
Each lookup method returns the hook for the first plugin that
94+
reports a match.
95+
96+
This class should not be subclassed -- use Plugin as the base class
97+
for all plugins.
98+
"""
99+
100+
# TODO: Support caching of lookup results (through a LRU cache, for example).
101+
102+
def __init__(self, python_version: Tuple[int, int], plugins: List[Plugin]) -> None:
103+
"""Initialize chained plugin.
104+
105+
Assume that the child plugins aren't mutated (results may be cached).
106+
"""
107+
super().__init__(python_version)
108+
self._plugins = plugins
109+
110+
def get_function_hook(self, fullname: str) -> Optional[FunctionHook]:
111+
return self._find_hook(lambda plugin: plugin.get_function_hook(fullname))
112+
113+
def get_method_signature_hook(self, fullname: str) -> Optional[MethodSignatureHook]:
114+
return self._find_hook(lambda plugin: plugin.get_method_signature_hook(fullname))
115+
116+
def get_method_hook(self, fullname: str) -> Optional[MethodHook]:
117+
return self._find_hook(lambda plugin: plugin.get_method_hook(fullname))
118+
119+
def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]:
120+
for plugin in self._plugins:
121+
hook = lookup(plugin)
122+
if hook:
123+
return hook
124+
return None
125+
126+
89127
class DefaultPlugin(Plugin):
90128
"""Type checker plugin that is enabled by default."""
91129

mypy/test/data.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from mypy.myunit import TestCase, SkipTestCaseException
1414

1515

16+
root_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), '..', '..'))
17+
18+
1619
def parse_test_cases(
1720
path: str,
1821
perform: Optional[Callable[['DataDrivenTestCase'], None]],
@@ -62,7 +65,9 @@ def parse_test_cases(
6265
# Record an extra file needed for the test case.
6366
arg = p[i].arg
6467
assert arg is not None
65-
file_entry = (join(base_path, arg), '\n'.join(p[i].data))
68+
contents = '\n'.join(p[i].data)
69+
contents = expand_variables(contents)
70+
file_entry = (join(base_path, arg), contents)
6671
if p[i].id == 'file':
6772
files.append(file_entry)
6873
elif p[i].id == 'outfile':
@@ -119,13 +124,15 @@ def parse_test_cases(
119124
deleted_paths.setdefault(num, set()).add(full)
120125
elif p[i].id == 'out' or p[i].id == 'out1':
121126
tcout = p[i].data
122-
if native_sep and os.path.sep == '\\':
127+
tcout = [expand_variables(line) for line in tcout]
128+
if os.path.sep == '\\':
123129
tcout = [fix_win_path(line) for line in tcout]
124130
ok = True
125131
elif re.match(r'out[0-9]*$', p[i].id):
126132
D96B passnum = int(p[i].id[3:])
127133
assert passnum > 1
128134
output = p[i].data
135+
output = [expand_variables(line) for line in output]
129136
if native_sep and os.path.sep == '\\':
130137
output = [fix_win_path(line) for line in output]
131138
tcout2[passnum] = output
@@ -415,6 +422,10 @@ def expand_includes(a: List[str], base_path: str) -> List[str]:
415422
return res
416423

417424

425+
def expand_variables(s: str) -> str:
426+
return s.replace('<ROOT>', root_dir)
427+
428+
418429
def expand_errors(input: List[str], output: List[str], fnam: str) -> None:
419430
"""Transform comments such as '# E: message' or
420431
'# E:3: message' in input.
@@ -445,16 +456,17 @@ def expand_errors(input: List[str], output: List[str], fnam: str) -> None:
445456

446457

447458
def fix_win_path(line: str) -> str:
448-
r"""Changes paths to Windows paths in error messages.
459+
r"""Changes Windows paths to Linux paths in error messages.
449460
450-
E.g. foo/bar.py -> foo\bar.py.
461+
E.g. foo\bar.py -> foo/bar.py.
451462
"""
463+
line = line.replace(root_dir, root_dir.replace('\\', '/'))
452464
m = re.match(r'^([\S/]+):(\d+:)?(\s+.*)', line)
453465
if not m:
454466
return line
455467
else:
456468
filename, lineno, message = m.groups()
457-
return '{}:{}{}'.format(filename.replace('/', '\\'),
469+
return '{}:{}{}'.format(filename.replace('\\', '/'),
458470
lineno or '', message)
459471

460472

mypy/test/testcheck.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
'check-classvar.test',
7777
'check-enum.test',
7878
'check-incomplete-fixture.test',
79+
'check-custom-plugin.test',
7980
]
8081

8182

@@ -261,7 +262,8 @@ def find_error_paths(self, a: List[str]) -> Set[str]:
261262
for line in a:
262263
m = re.match(r'([^\s:]+):\d+: error:', line)
263264
if m:
264-
p = m.group(1).replace('/', os.path.sep)
265+
# Normalize to Linux paths.
266+
p = m.group(1).replace(os.path.sep, '/')
265267
hits.add(p)
266268
return hits
267269

mypy/test/testcmdline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from mypy.test.config import test_data_prefix, test_temp_dir
1616
from mypy.test.data import fix_cobertura_filename
1717
from mypy.test.data import parse_test_cases, DataDrivenTestCase
18-
from mypy.test.helpers import assert_string_arrays_equal
18+
from mypy.test.helpers import assert_string_arrays_equal, normalize_error_messages
1919
from mypy.version import __version__, base_version
2020

2121
# Path to Python 3 interpreter
@@ -71,10 +71,12 @@ def test_python_evaluation(testcase: DataDrivenTestCase) -> None:
7171
os.path.abspath(test_temp_dir))
7272
if testcase.native_sep and os.path.sep == '\\':
7373
normalized_output = [fix_cobertura_filename(line) for line in normalized_output]
74+
normalized_output = normalize_error_messages(normalized_output)
7475
assert_string_arrays_equal(expected_content.splitlines(), normalized_output,
7576
'Output file {} did not match its expected output'.format(
7677
path))
7778
else:
79+
out = normalize_error_messages(out)
7880
assert_string_arrays_equal(testcase.output, out,
7981
'Invalid output ({}, line {})'.format(
8082
testcase.file, testcase.line))

mypy/test/testgraph.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from mypy.version import __version__
99
from mypy.options import Options
1010
from mypy.report import Reports
11+
from mypy.plugin import Plugin
12+
from mypy import defaults
1113

1214

1315
class GraphSuite(Suite):
@@ -42,6 +44,7 @@ def _make_manager(self) -> BuildManager:
4244
reports=Reports('', {}),
4345
options=Options(),
4446
version_id=__version__,
47+
plugin=Plugin(defaults.PYTHON3_VERSION),
4548
)
4649
return manager
4750

mypy/test/testsemanal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def test_semanal(testcase: DataDrivenTestCase) -> None:
8989
a += str(f).split('\n')
9090
except CompileError as e:
9191
a = e.messages
92+
a = normalize_error_messages(a)
9293
assert_string_arrays_equal(
9394
testcase.output, a,
9495
'Invalid semantic analyzer output ({}, line {})'.format(testcase.file,

mypy/test/testtransform.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from mypy import build
88
from mypy.build import BuildSource
99
from mypy.myunit import Suite
10-
from mypy.test.helpers import assert_string_arrays_equal, testfile_pyversion
10+
from mypy.test.helpers import (
11+
assert_string_arrays_equal, testfile_pyversion, normalize_error_messages
12+
)
1113
from mypy.test.data import parse_test_cases, DataDrivenTestCase
1214
from mypy.test.config import test_data_prefix, test_temp_dir
1315
from mypy.errors import CompileError
@@ -73,6 +75,7 @@ def test_transform(testcase: DataDrivenTestCase) -> None:
7375
a += str(f).split('\n')
7476
except CompileError as e:
7577
a = e.messages
78+
a = normalize_error_messages(a)
7679
assert_string_arrays_equal(
7780
testcase.output, a,
7881
'Invalid semantic analyzer output ({}, line {})'.format(testcase.file,

0 commit comments

Comments
 (0)
0