diff --git a/pyscriptjs/src/main.ts b/pyscriptjs/src/main.ts index 96051869e9a..c6ee5c99c9c 100644 --- a/pyscriptjs/src/main.ts +++ b/pyscriptjs/src/main.ts @@ -16,6 +16,7 @@ import { type Stdio, StdioMultiplexer, DEFAULT_STDIO } from './stdio'; import { PyTerminalPlugin } from './plugins/pyterminal'; import { SplashscreenPlugin } from './plugins/splashscreen'; import { ImportmapPlugin } from './plugins/importmap'; +import { AutomaticImportPlugin } from './plugins/automaticImport'; // eslint-disable-next-line // @ts-ignore import pyscript from './python/pyscript.py'; @@ -66,7 +67,12 @@ export class PyScriptApp { constructor() { // initialize the builtin plugins this.plugins = new PluginManager(); - this.plugins.add(new SplashscreenPlugin(), new PyTerminalPlugin(this), new ImportmapPlugin()); + this.plugins.add( + new SplashscreenPlugin(), + new PyTerminalPlugin(this), + new ImportmapPlugin(), + new AutomaticImportPlugin(), + ); this._stdioMultiplexer = new StdioMultiplexer(); this._stdioMultiplexer.addListener(DEFAULT_STDIO); diff --git a/pyscriptjs/src/plugins/automaticImport.ts b/pyscriptjs/src/plugins/automaticImport.ts new file mode 100644 index 00000000000..1f9cad65760 --- /dev/null +++ b/pyscriptjs/src/plugins/automaticImport.ts @@ -0,0 +1,107 @@ +import type { AppConfig } from '../pyconfig'; +import { robustFetch } from '../fetch'; +import { Plugin } from '../plugin'; +import { getLogger } from '../logger'; + +const logger = getLogger('automatic-import-plugin'); + +export class AutomaticImportPlugin extends Plugin { + pyscriptTag = document.querySelector('py-script'); + + /** + * Trim whitespace from the beginning and end of a string. + * The text that we are getting has new lines and possibly + * whitespace at the beginning and end of the line. Since + * we want to match only import statements, we need to trim + * the whitespace on each line, join them together and then + * match each beginning of the line so the regex rule will + * only match what we perceive as import statements. + * + * NOTE: This is not a perfect solution and I'm a bit worried + * about the performance of this. I'm sure there is a better + * way to do this. + * + * @param text - The text to trim + */ + _trimWhiteSpace(text: string): string { + const lines = text.split('\n'); + return lines.map(line => line.trim()).join('\n'); + } + /** + * + * Use regex magic to capture import statements and add the + * dependency(s) to the packages list in the config. + * + * @param text - The text to search for import statements + * @param config - The config object to add the dependencies to + * + */ + _addImportToPackagesList(text: string, config: AppConfig) { + // Regex encantation to capture the imported dependency into a + // named group called "dependency". The rule will match any text + // that contains 'import' or 'from' at the beginning of the line + // or '\nimport', '\nfrom' which is the case when we invoke _trimWhiteSpace + const importRegexRule = /^(?:\\n)?(?:import|from)\s+(?[a-zA-Z0-9_]+)?/gm; + + text = this._trimWhiteSpace(text); + + const matches = text.matchAll(importRegexRule); + // Regex matches full match and groups, let's just push the group. + for (const match of matches) { + const dependency = match.groups.dependency; + if (dependency) { + logger.info(`Found import statement for ${dependency}, adding to packages list.`); + config.packages.push(dependency); + } + } + } + + /** + * + * In this initial lifecycle hook, we will look for any imports + * in the tag inner HTML and add them to the packages list. + * + * We are skipping looking into the src attribute to not delay the + * preliminary initialization phase. + * + */ + configure(config: AppConfig) { + if (config.autoImports ?? true) { + // config.packages should already be a list, but + // let's be defensive just in case. + if (!config.packages) { + config.packages = []; + } + + this._addImportToPackagesList(this.pyscriptTag.innerHTML, config); + } + } + + /** + * In this lifecycle hook, we will to see if the user has specified a + * src attribute in the tag and fetch the script if so. Then + * we will look for any imports in the fetched script and add them to the + * packages list. + * + * NOTE: If we are fetching the file from a URL and not from the local file + * system, this will delay the download of the python interpreter. Perhaps + * we should throw an error if the src attribute is a URL? + * + */ + beforeLaunch(config: AppConfig) { + if (config.autoImports ?? true) { + const srcAttribute = this.pyscriptTag.getAttribute('src'); + if (srcAttribute) { + logger.info(`Found src attribute in tag, fetching ${srcAttribute}...`); + robustFetch(srcAttribute) + .then(response => response.text()) + .then(text => { + this._addImportToPackagesList(text, config); + }) + .catch(error => { + logger.error(`Failed to fetch ${srcAttribute}: ${(error as Error).message}`); + }); + } + } + } +} diff --git a/pyscriptjs/tests/integration/test_assets/automatic_import.py b/pyscriptjs/tests/integration/test_assets/automatic_import.py new file mode 100644 index 00000000000..f9962cb9ebf --- /dev/null +++ b/pyscriptjs/tests/integration/test_assets/automatic_import.py @@ -0,0 +1,5 @@ +import numpy as np +import pandas as pd + +print(np.__version__) +print(pd.__version__) diff --git a/pyscriptjs/tests/integration/test_plugins.py b/pyscriptjs/tests/integration/test_plugins.py index 3299ca2d49b..623fcd2ba37 100644 --- a/pyscriptjs/tests/integration/test_plugins.py +++ b/pyscriptjs/tests/integration/test_plugins.py @@ -1,3 +1,5 @@ +import re + from .support import PyScriptTest # Source code of a simple plugin that creates a Custom Element for testing purposes @@ -186,3 +188,108 @@ def test_no_plugin_attribute_error(self): ) # EXPECT an error for the missing attribute assert error_msg in self.console.error.lines + + +class TestCorePlugins(PyScriptTest): + def test_core_plugin_automatic_import(self): + """Test that packages get installed automatically""" + self.pyscript_run( + """ + + import numpy as np + print(np.__version__) + + """ + ) + + py_terminal = self.page.locator("py-terminal") + assert re.match(r"\d+.\d+.\d+", py_terminal.inner_text()) + + def test_core_plugin_automatic_import_from(self): + """Test that packages get installed automatically""" + self.pyscript_run( + """ + + from numpy import __version__ as np_version + print(np_version) + + """ + ) + + py_terminal = self.page.locator("py-terminal") + assert re.match(r"\d+.\d+.\d+", py_terminal.inner_text()) + + def test_core_plugin_automatic_multiple_imports(self): + """Test that packages get installed automatically""" + + self.pyscript_run( + """ + + import numpy as np + from pandas import __version__ as pandas_version + print(np.__version__) + print(pandas_version) + + """ + ) + + py_terminal = self.page.locator("py-terminal") + assert re.match(r"\d+.\d+.\d+\n\d+.\d+.\d+", py_terminal.inner_text()) + + def test_core_plugin_automatic_import_python_file(self): + """Test that packages get installed automatically""" + py_file = ( + "import numpy as np\n" + "from pandas import __version__ as pandas_version\n" + "\n" + "print(np.__version__)\n" + "print(pandas_version)\n" + ) + + self.writefile("automatic_import.py", py_file) + self.pyscript_run( + """ + + + """ + ) + + py_terminal = self.page.locator("py-terminal") + assert re.match(r"\d+.\d+.\d+\n\d+.\d+.\d+", py_terminal.inner_text()) + + def test_core_plugin_automatic_import_disabled(self): + """Test that packages get installed automatically""" + self.pyscript_run( + """ + + autoImports = false + + + from numpy import __version__ as np_version + print(np._version) + + """, + wait_for_pyscript=False, + ) + + alert_banner = self.page.locator(".py-error") + assert ( + "ModuleNotFoundError: No module named 'numpy'" in alert_banner.inner_text() + ) + + def test_core_plugin_automatic_import_doesnt_match_text(self): + """Test that packages get installed automatically""" + self.pyscript_run( + """ + + from numpy import __version__ as np_version + print(np_version) + + def test_import_regex_logic(): + "from docstrings?" + print("import my packages") + + """ + ) + py_terminal = self.page.locator("py-terminal") + assert re.match(r"\d+.\d+.\d+", py_terminal.inner_text())