diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a301583..48b9a4e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -48,6 +48,8 @@ jobs: git clone --recurse-submodules https://github.com/adafruit/CircuitPython_Community_Bundle.git cd CircuitPython_Community_Bundle circuitpython-build-bundles --filename_prefix test-bundle --library_location libraries --library_depth 2 + - name: Munge tests + run: pytest tests - name: Build Python package run: | pip install --upgrade setuptools wheel twine readme_renderer testresources diff --git a/.gitignore b/.gitignore index 1c50e6e..08aafad 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ version.py .env/* .DS_Store .idea/* +testcases/*.out diff --git a/circuitpython_build_tools/build.py b/circuitpython_build_tools/build.py index ab28b72..271dc6e 100644 --- a/circuitpython_build_tools/build.py +++ b/circuitpython_build_tools/build.py @@ -56,6 +56,8 @@ def git_filter_arg(): else: return [] +from .munge import munge + # pyproject.toml `py_modules` values that are incorrect. These should all have PRs filed! # and should be removed when the fixed version is incorporated in its respective bundle. @@ -182,16 +184,6 @@ def mpy_cross(version, quiet=False): shutil.copy(mpy_built, mpy_cross_filename) return mpy_cross_filename -def _munge_to_temp(original_path, temp_file, library_version): - with open(original_path, "r", encoding="utf-8") as original_file: - for line in original_file: - line = line.strip("\n") - if line.startswith("__version__"): - line = line.replace("0.0.0-auto.0", library_version) - line = line.replace("0.0.0+auto.0", library_version) - print(line, file=temp_file) - temp_file.flush() - def get_package_info(library_path, package_folder_prefix): lib_path = pathlib.Path(library_path) parent_idx = len(lib_path.parts) @@ -301,25 +293,22 @@ def library(library_path, output_directory, package_folder_prefix, full_path = os.path.join(library_path, filename) output_file = output_directory / filename.relative_to(library_path) if filename.suffix == ".py": - with tempfile.NamedTemporaryFile(delete=False, mode="w+") as temp_file: - temp_file_name = temp_file.name - try: - _munge_to_temp(full_path, temp_file, library_version) - temp_file.close() - if mpy_cross and os.stat(temp_file.name).st_size != 0: - output_file = output_file.with_suffix(".mpy") - mpy_success = subprocess.call([ - mpy_cross, - "-o", output_file, - "-s", str(filename.relative_to(library_path)), - temp_file.name - ]) - if mpy_success != 0: - raise RuntimeError("mpy-cross failed on", full_path) - else: - shutil.copyfile(temp_file_name, output_file) - finally: - os.remove(temp_file_name) + content = munge(full_path, library_version) + if mpy_cross and content: + # TODO: Once 8.x bundles are no longer built, switch to + # sending mpy-cross the code on stdin instead of via + # temporary file (supports the "-" input argument) + with tempfile.NamedTemporaryFile(delete=False, mode="w+") as temp_file: + temp_file.write(content) + temp_file.flush() + subprocess.check_output([ + mpy_cross, + "-o", output_file.with_suffix(".mpy"), + "-s", str(filename.relative_to(library_path)), + temp_file.name + ], input=content.encode('utf-8')) + else: + output_file.write_text(content, encoding="utf-8") else: shutil.copyfile(full_path, output_file) diff --git a/circuitpython_build_tools/munge.py b/circuitpython_build_tools/munge.py new file mode 100644 index 0000000..4026efc --- /dev/null +++ b/circuitpython_build_tools/munge.py @@ -0,0 +1,117 @@ +# The MIT License (MIT) +# +# Copyright (c) 2024 Jeff Epler for Adafruit Industries +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +# Filter program removes some code patterns introduced by type checking, +# to move towards zero overhead static typing in circuitpython libraries +# +# Recognized: +# from __future__ import ... -- eliminated +# try: import typing -- eliminated, but first except: preserved +# try: from typing import ... -- eliminated, but first except: preserved +# if STATIC_TYPING: -- transformed to 'if 0:' +# if sys.implementation_name... -- transformed to unconditional if +# __version__ = ... -- set to library version string +# +# mpy-cross does constant propagation and dead branch elimination of +# 'if 0:' and 'if 1:' +# +# Depends on the file being black-formatted! + +import pathlib +import sys +import ast + +VERBOSE = 0 + +# The canonical spelling of this test... +sys_implementation_is_circuitpython = ast.unparse(ast.parse('sys.implementation.name == "circuitpython"')) +sys_implementation_not_circuitpython = ast.unparse(ast.parse('not sys.implementation.name == "circuitpython"')) +sys_implementation_not_circuitpython2 = ast.unparse(ast.parse('sys.implementation.name != "circuitpython"')) + +def munge(src: pathlib.Path|str, version_str: str) -> str: + path = pathlib.Path(src) + replacements = {} + + def replace(line, new): + if VERBOSE: + replacements[line] = f"{new:<40s} ### {lines[line]}" + else: + replacements[line] = new + + def blank_range(node): + for i in range(node.lineno, node.end_lineno+1): + replace(i, "") + + def unblank_range(node): + for i in range(node.lineno, node.end_lineno+1): + replacements.pop(i, None) + + def imports_from_typing(node): + if isinstance(node, ast.Import) and node.names[0].name == 'typing': + return True + if isinstance(node, ast.ImportFrom) and node.module == 'typing': + return True + return False + + def process_statement(node): + # filter out 'from future import...' + if isinstance(node, ast.ImportFrom): + if node.module == '__future__': + blank_range(node) + # filter out 'try: import typing...' + # but preserve the first 'except:' or 'except ImportError' + elif isinstance(node, ast.Try): + b = node.body[0] + if imports_from_typing(node.body[0]): + blank_range(node) + for h in node.handlers: + if h.type is None or ast.unparse(h.type) == 'ImportError' or ast.unparse(h.type) == 'Exception': + unblank_range(h) + replace(h.lineno, 'if 1:') + break + return + elif isinstance(node, ast.If): + node_test = ast.unparse(node.test) + # return the statements in the 'if' branch of 'if sys.implementation...: ...' + if node_test == sys_implementation_is_circuitpython: + replace(node.lineno, 'if 1:') + # return the statements in the 'else' branch of 'if sys.implementation...: ...' + elif node_test == sys_implementation_not_circuitpython or node_test == sys_implementation_not_circuitpython2: + replace(node.lineno, 'if 0:') + # return the statements in the else branch of 'if TYPE_CHECKING: ...' + elif node_test == 'TYPE_CHECKING': + replace(node.lineno, 'if 0:') + elif isinstance(node, ast.Assign) and isinstance(node.targets[0], ast.Name) and node.targets[0].id == '__version__': + replace(node.lineno, f"__version__ = \"{version_str}\"") + + content = pathlib.Path(path).read_text(encoding="utf-8") + # Insert a blank line 0 because ast line numbers are 1-based + lines = [''] + content.rstrip().split('\n') + a = ast.parse(content, path.name) + + for node in a.body: process_statement(node) + + result = [] + for i in range(1, len(lines)): + result.append(replacements.get(i, lines[i])) + + return "\n".join(result) + "\n" diff --git a/circuitpython_build_tools/scripts/munge.py b/circuitpython_build_tools/scripts/munge.py new file mode 100644 index 0000000..1ab59d8 --- /dev/null +++ b/circuitpython_build_tools/scripts/munge.py @@ -0,0 +1,27 @@ +import pathlib +from difflib import unified_diff +import click +from ..munge import munge + + +@click.command +@click.option("--diff/--no-diff", "show_diff", default=False) +@click.option("--munged-version", default="munged-version") +@click.argument("input", type=click.Path(exists=True)) +@click.argument("output", type=click.File("w", encoding="utf-8"), default="-") +def main(show_diff, munged_version, input, output): + input_path = pathlib.Path(input) + munged = munge(input, munged_version) + if show_diff: + old_lines = input_path.read_text(encoding="utf-8").splitlines(keepends=True) + new_lines = munged.splitlines(keepends=True) + output.writelines( + unified_diff( + old_lines, + new_lines, + fromfile=input, + tofile=str(input_path.with_suffix(".munged.py")), + ) + ) + else: + output.write(munged) diff --git a/requirements.txt b/requirements.txt index b11b4c5..8aac8a1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ Click +pytest requests semver -wheel tomli; python_version < "3.11" +wheel platformdirs diff --git a/setup.py b/setup.py index 4358300..9a3d4e2 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ entry_points=''' [console_scripts] circuitpython-build-bundles=circuitpython_build_tools.scripts.build_bundles:build_bundles + circuitpython-munge=circuitpython_build_tools.scripts.munge:main circuitpython-mpy-cross=circuitpython_build_tools.scripts.circuitpython_mpy_cross:main ''' ) diff --git a/testcases/test1.exp b/testcases/test1.exp new file mode 100644 index 0000000..02d7b82 --- /dev/null +++ b/testcases/test1.exp @@ -0,0 +1,33 @@ + + + + +if 1: + pass + + + +if 1: + pass + + + + +if 1: + pass + + + +if 1: + pass + +__version__ = "1.2.3" + +if 1: + print("is circuitpython") + +if 0: + print("not circuitpython (1)") + +if 0: + print("not circuitpython (2)") diff --git a/testcases/test1.py b/testcases/test1.py new file mode 100644 index 0000000..60f4e0f --- /dev/null +++ b/testcases/test1.py @@ -0,0 +1,33 @@ +from __future__ import annotation + +try: + from typing import TYPE_CHECKING +except ImportError: + pass + +try: + from typing import TYPE_CHECKING as T +except ImportError: + pass + + +try: + import typing +except: + pass + +try: + import typing as T +except: + pass + +__version__ = "0.0.0-auto" + +if sys.implementation.name == "circuitpython": + print("is circuitpython") + +if sys.implementation.name != "circuitpython": + print("not circuitpython (1)") + +if not sys.implementation.name == "circuitpython": + print("not circuitpython (2)") diff --git a/tests/test_munge.py b/tests/test_munge.py new file mode 100644 index 0000000..48e95f2 --- /dev/null +++ b/tests/test_munge.py @@ -0,0 +1,22 @@ +import sys, pathlib +import pytest + +top = pathlib.Path(__file__).parent.parent +sys.path.insert(0, str(top)) + +from circuitpython_build_tools.munge import munge + +@pytest.mark.parametrize("test_path", top.glob("testcases/*.py")) +def test_munge(test_path): + result_path = test_path.with_suffix(".out") + result_path.unlink(missing_ok = True) + + result_content = munge(test_path, "1.2.3") + result_path.write_text(result_content, encoding="utf-8") + + expected_path = test_path.with_suffix(".exp") + expected_content = expected_path.read_text(encoding="utf-8") + + assert result_content == expected_content + + result_path.unlink()