8000 Add basic support for user-defined mypy plugins by JukkaL · Pull Request #3517 · python/mypy · GitHub
[go: up one dir, main page]

Skip to content

Add basic support for user-defined mypy plugins #3517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 70 additions & 5 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from mypy.stats import dump_type_stats
from mypy.types import Type
from mypy.version import __version__
from mypy.plugin import DefaultPlugin
from mypy.plugin import Plugin, DefaultPlugin, ChainedPlugin


# We need to know the location of this file to load data, but
Expand Down Expand Up @@ -183,7 +183,9 @@ def build(sources: List[BuildSource],
reports=reports,
options=options,
version_id=__version__,
)
plugin=DefaultPlugin(options.python_version))

manager.plugin = load_custom_plugins(manager.plugin, options, manager.errors)

try:
graph = dispatch(sources, manager)
Expand Down Expand Up @@ -333,6 +335,67 @@ def import_priority(imp: ImportBase, toplevel_priority: int) -> int:
return toplevel_priority


def load_custom_plugins(default_plugin: Plugin, options: Options, errors: Errors) -> Plugin:
"""Load custom plugins if any are configured.

Return a plugin that chains all custom plugins (if any) and falls
back to default_plugin.
"""

def plugin_error(message: str) -> None:
errors.report(0, 0, message)
errors.raise_error()

custom_plugins = []
for plugin_path in options.plugins:
if options.config_file:
# Plugin paths are relative to the config file location.
plugin_path = os.path.join(os.path.dirname(options.config_file), plugin_path)
errors.set_file(plugin_path, None)

if not os.path.isfile(plugin_path):
plugin_error("Can't find plugin")
plugin_dir = os.path.dirname(plugin_path)
fnam = os.path.basename(plugin_path)
if not fnam.endswith('.py'):
plugin_error("Plugin must have .py extension")
module_name = fnam[:-3]
import importlib
sys.path.insert(0, plugin_dir)
try:
m = importlib.import_module(module_name)
except Exception:
print('Error importing plugin {}\n'.format(plugin_path))
raise # Propagate to display traceback
finally:
assert sys.path[0] == plugin_dir
del sys.path[0]
if not hasattr(m, 'plugin'):
plugin_error('Plugin does not define entry point function "plugin"')
try:
plugin_type = getattr(m, 'plugin')(__version__)
except Exception:
print('Error calling the plugin(version) entry point of {}\n'.format(plugin_path))
raise # Propagate to display traceback
if not isinstance(plugin_type, type):
plugin_error(
'Type object expected as the return value of "plugin" (got {!r})'.format(
plugin_type))
if not issubclass(plugin_type, Plugin):
plugin_error(
'Return value of "plugin" must be a subclass of "mypy.plugin.Plugin"')
try:
custom_plugins.append(plugin_type(options.python_version))
except Exception:
print('Error constructing plugin instance of {}\n'.format(plugin_type.__name__))
raise # Propagate to display traceback
if not custom_plugins:
return default_plugin
else:
# Custom plugins take precendence over built-in plugins.
return ChainedPlugin(options.python_version, custom_plugins + [default_plugin])


# TODO: Get rid of all_types. It's not used except for one log message.
# Maybe we could instead publish a map from module ID to its type_map.
class BuildManager:
Expand All @@ -356,6 +419,7 @@ class BuildManager:
missing_modules: Set of modules that could not be imported encountered so far
stale_modules: Set of modules that needed to be rechecked
version_id: The current mypy version (based on commit id when possible)
plugin: Active mypy plugin(s)
"""

def __init__(self, data_dir: str,
Expand All @@ -364,7 +428,8 @@ def __init__(self, data_dir: str,
source_set: BuildSourceSet,
reports: Reports,
options: Options,
version_id: str) -> None:
version_id: str,
plugin: Plugin) -> None:
self.start_time = time.time()
self.data_dir = data_dir
self.errors = Errors(options.show_error_context, options.show_column_numbers)
Expand All @@ -384,6 +449,7 @@ def __init__(self, data_dir: str,
self.indirection_detector = TypeIndirectionVisitor()
self.stale_modules = set() # type: Set[str]
self.rechecked_modules = set() # type: Set[str]
self.plugin = plugin

def maybe_swap_for_shadow_path(self, path: str) -> str:
if (self.options.shadow_file and
Expand Down Expand Up @@ -1506,9 +1572,8 @@ def type_check_first_pass(self) -> None:
if self.options.semantic_analysis_only:
return
with self.wrap_context():
plugin = DefaultPlugin(self.options.python_version)
self.type_checker = TypeChecker(manager.errors, manager.modules, self.options,
self.tree, self.xpath, plugin)
self.tree, self.xpath, manager.plugin)
self.type_checker.check_first_pass()

def type_check_second_pass(self) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion mypy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def disallow_any_argument_type(raw_options: str) -> List[str]:
parser.parse_args(args, dummy)
config_file = dummy.config_file
if config_file is not None and not os.path.exists(config_file):
parser.error("Cannot file config file '%s'" % config_file)
parser.error("Cannot find config file '%s'" % config_file)

# Parse config file first, so command line can override.
options = Options()
Expand Down Expand Up @@ -605,6 +605,7 @@ def get_init_file(dir: str) -> Optional[str]:
# These two are for backwards compatibility
'silent_imports': bool,
'almost_silent': bool,
'plugins': lambda s: [p.strip() for p in s.split(',')],
}

SHARED_CONFIG_FILES = ('setup.cfg',)
Expand Down
3 changes: 3 additions & 0 deletions mypy/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def __init__(self) -> None:
self.debug_cache = False
self.quick_and_dirty = False

# Paths of user plugins
self.plugins = [] # type: List[str]

# Per-module options (raw)
self.per_module_options = {} # type: Dict[Pattern[str], Dict[str, object]]

Expand Down
46 changes: 42 additions & 4 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Tuple, Optional, NamedTuple
from typing import Callable, List, Tuple, Optional, NamedTuple, TypeVar

from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context
from mypy.types import (
Expand Down Expand Up @@ -60,7 +60,7 @@


class Plugin:
"""Base class of type checker plugins.
"""Base class of all type checker plugins.

This defines a no-op plugin. Subclasses can override some methods to
provide some actual functionality.
Expand All @@ -69,8 +69,6 @@ class Plugin:
results might be cached).
"""

# TODO: Way of chaining multiple plugins

def __init__(self, python_version: Tuple[int, int]) -> None:
self.python_version = python_version

Expand All @@ -86,6 +84,46 @@ def get_method_hook(self, fullname: str) -> Optional[MethodHook]:
# TODO: metaclass / class decorator hook


T = TypeVar('T')


class ChainedPlugin(Plugin):
"""A plugin that represents a sequence of chained plugins.

Each lookup method returns the hook for the first plugin that
reports a match.

This class should not be subclassed -- use Plugin as the base class
for all plugins.
"""

# TODO: Support caching of lookup results (through a LRU cache, for example).

def __init__(self, python_version: Tuple[int, int], plugins: List[Plugin]) -> None:
"""Initialize chained plugin.

Assume that the child plugins aren't mutated (results may be cached).
"""
super().__init__(python_version)
self._plugins = plugins

def get_function_hook(self, fullname: str) -> Optional[FunctionHook]:
return self._find_hook(lambda plugin: plugin.get_function_hook(fullname))

def get_method_signature_hook(self, fullname: str) -> Optional[MethodSignatureHook]:
return self._find_hook(lambda plugin: plugin.get_method_signature_hook(fullname))

def get_method_hook(self, fullname: str) -> Optional[MethodHook]:
return self._find_hook(lambda plugin: plugin.get_method_hook(fullname))

def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result of this should probably be cached based on hook-type and fullname.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created an issue about caching (#3533). This may need some analysis or experimentation to decide a caching strategy (e.g. unlimited cache size vs bounded cache size; maximum size of the cache) so I feel that it's better to do it separately.

for plugin in self._plugins:
hook = lookup(plugin)
if hook:
return hook
return None


class DefaultPlugin(Plugin):
"""Type checker plugin that is enabled by default."""

Expand Down
22 changes: 17 additions & 5 deletions mypy/test/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from mypy.myunit import TestCase, SkipTestCaseException


root_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), '..', '..'))


def parse_test_cases(
path: str,
perform: Optional[Callable[['DataDrivenTestCase'], None]],
Expand Down Expand Up @@ -62,7 +65,9 @@ def parse_test_cases(
# Record an extra file needed for the test case.
arg = p[i].arg
assert arg is not None
file_entry = (join(base_path, arg), '\n'.join(p[i].data))
contents = '\n'.join(p[i].data)
contents = expand_variables(contents)
file_entry = (join(base_path, arg), contents)
if p[i].id == 'file':
files.append(file_entry)
elif p[i].id == 'outfile':
Expand Down Expand Up @@ -119,13 +124,15 @@ def parse_test_cases(
deleted_paths.setdefault(num, set()).add(full)
elif p[i].id == 'out' or p[i].id == 'out1':
tcout = p[i].data
if native_sep and os.path.sep == '\\':
tcout = [expand_variables(line) for line in tcout]
if os.path.sep == '\\':
tcout = [fix_win_path(line) for line in tcout]
ok = True
elif re.match(r'out[0-9]*$', p[i].id):
passnum = int(p[i].id[3:])
assert passnum > 1
output = p[i].data
output = [expand_variables(line) for line in output]
if native_sep and os.path.sep == '\\':
output = [fix_win_path(line) for line in output]
tcout2[passnum] = output
Expand Down Expand Up @@ -415,6 +422,10 @@ def expand_includes(a: List[str], base_path: str) -> List[str]:
return res


def expand_variables(s: str) -> str:
return s.replace('<ROOT>', root_dir)


def expand_errors(input: List[str], output: List[str], fnam: str) -> None:
"""Transform comments such as '# E: message' or
'# E:3: message' in input.
Expand Down Expand Up @@ -445,16 +456,17 @@ def expand_errors(input: List[str], output: List[str], fnam: str) -> None:


def fix_win_path(line: str) -> str:
r"""Changes paths to Windows paths in error messages.
r"""Changes Windows paths to Linux paths in error messages.

E.g. foo/bar.py -> foo\bar.py.
E.g. foo\bar.py -> foo/bar.py.
"""
line = line.replace(root_dir, root_dir.replace('\\', '/'))
m = re.match(r'^([\S/]+):(\d+:)?(\s+.*)', line)
if not m:
return line
else:
filename, lineno, message = m.groups()
return '{}:{}{}'.format(filename.replace('/', '\\'),
return '{}:{}{}'.format(filename.replace('\\', '/'),
lineno or '', message)


Expand Down
4 changes: 3 additions & 1 deletion mypy/test/testcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
'check-classvar.test',
'check-enum.test',
'check-incomplete-fixture.test',
'check-custom-plugin.test',
]


Expand Down Expand Up @@ -261,7 +262,8 @@ def find_error_paths(self, a: List[str]) -> Set[str]:
for line in a:
m = re.match(r'([^\s:]+):\d+: error:', line)
if m:
p = m.group(1).replace('/', os.path.sep)
# Normalize to Linux paths.
p = m.group(1).replace(os.path.sep, '/')
hits.add(p)
return hits

Expand Down
4 changes: 3 additions & 1 deletion mypy/test/testcmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mypy.test.config import test_data_prefix, test_temp_dir
from mypy.test.data import fix_cobertura_filename
from mypy.test.data import parse_test_cases, DataDrivenTestCase
from mypy.test.helpers import assert_string_arrays_equal
from mypy.test.helpers import assert_string_arrays_equal, normalize_error_messages
from mypy.version import __version__, base_version

# Path to Python 3 interpreter
Expand Down Expand Up @@ -71,10 +71,12 @@ def test_python_evaluation(testcase: DataDrivenTestCase) -> None:
os.path.abspath(test_temp_dir))
if testcase.native_sep and os.path.sep == '\\':
normalized_output = [fix_cobertura_filename(line) for line in normalized_output]
normalized_output = normalize_error_messages(normalized_output)
assert_string_arrays_equal(expected_content.splitlines(), normalized_output,
'Output file {} did not match its expected output'.format(
path))
else:
out = normalize_error_messages(out)
assert_string_arrays_equal(testcase.output, out,
'Invalid output ({}, line {})'.format(
testcase.file, testcase.line))
Expand Down
3 changes: 3 additions & 0 deletions mypy/test/testgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from mypy.version import __version__
from mypy.options import Options
from mypy.report import Reports
from mypy.plugin import Plugin
from mypy import defaults


class GraphSuite(Suite):
Expand Down Expand Up @@ -42,6 +44,7 @@ def _make_manager(self) -> BuildManager:
reports=Reports('', {}),
options=Options(),
version_id=__version__,
plugin=Plugin(defaults.PYTHON3_VERSION),
)
return manager

Expand Down
1 change: 1 addition & 0 deletions mypy/test/testsemanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_semanal(testcase: DataDrivenTestCase) -> None:
a += str(f).split('\n')
except CompileError as e:
a = e.messages
a = normalize_error_messages(a)
assert_string_arrays_equal(
testcase.output, a,
'Invalid semantic analyzer output ({}, line {})'.format(testcase.file,
Expand Down
5 changes: 4 additions & 1 deletion mypy/test/testtransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from mypy import build
from mypy.build import BuildSource
from mypy.myunit import Suite
from mypy.test.helpers import assert_string_arrays_equal, testfile_pyversion
from mypy.test.helpers import (
assert_string_arrays_equal, testfile_pyversion, normalize_error_messages
)
from mypy.test.data import parse_test_cases, DataDrivenTestCase
from mypy.test.config import test_data_prefix, test_temp_dir
from mypy.errors import CompileError
Expand Down Expand Up @@ -73,6 +75,7 @@ def test_transform(testcase: DataDrivenTestCase) -> None:
a += str(f).split('\n')
except CompileError as e:
a = e.messages
a = normalize_error_messages(a)
assert_string_arrays_equal(
testcase.output, a,
'Invalid semantic analyzer output ({}, line {})'.format(testcase.file,
Expand Down
Loading
0