8000 Added better error handling for module imports and · shader/python-fire@5311d21 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5311d21

Browse files
jacobaustin123copybara-github
authored andcommitted
Added better error handling for module imports and
PiperOrigin-RevId: 339695282 Change-Id: I28e98d51a9d1ee20aeb69d1cb667980e3c7607cf
1 parent 878b8d8 commit 5311d21

File tree

3 files changed

+137
- 8000 8
lines changed

3 files changed

+137
-8
lines changed

fire/__main__.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,79 @@
1919
"""
2020

2121
import importlib
22+
import os
2223
import sys
2324

2425
import fire
2526

27+
cli_string = """usage: python -m fire [module] [arg] ..."
2628
27-
def main(args):
28-
module_name = args[1]
29+
Python Fire is a library for creating CLIs from absolutely any Python
30+
object or program. To run Python Fire from the command line on an
31+
existing Python file, it can be invoked with "python -m fire [module]"
32+
and passed a Python module using module notation:
33+
34+
"python -m fire packageA.packageB.module"
35+
36+
or with a file path:
37+
38+
"python -m fire packageA/packageB/module.py" """
39+
40+
41+
def import_from_file_path(path):
42+
"""Performs a module import given the filename."""
43+
module_name = os.path.basename(path)
44+
45+
if sys.version_info.major == 3:
46+
from importlib import util # pylint: disable=g-import-not-at-top
47+
spec = util.spec_from_file_location(module_name, path)
48+
49+
if spec is None:
50+
raise IOError('Unable to load module from specified path.')
51+
52+
module = util.module_from_spec(spec)
53+
spec.loader.exec_module(module) # pytype: disable=attribute-error
54+
else:
55+
import imp # pylint: disable=g-import-not-at-top
56+
module = imp.load_source(module_name, path)
57+
58+
return module, module_name
59+
60+
61+
def import_from_module_name(module_name):
2962
module = importlib.import_module(module_name)
63+
return module, module_name
64+
65+
66+
def import_module(module_or_filename):
67+
"""Imports a given module or filename."""
68+
69+
if os.path.exists(module_or_filename):
70+
# importlib.util.spec_from_file_location requires .py
71+
if not module_or_filename.endswith('.py'):
72+
try: # try as module instead
73+
return import_from_module_name(module_or_filename)
74+
except ImportError:
75+
raise ValueError('Fire can only be called on .py files.')
76+
77+
return import_from_file_path(module_or_filename)
78+
79+
if os.path.sep in module_or_filename: # Use / to detect if it was a filename.
80+
raise IOError('Fire was passed a filename which could not be found.')
81+
82+
return import_from_module_name(module_or_filename) # Assume it's a module.
83+
84+
85+
def main(args):
86+
"""Entrypoint for fire when invoked as a module with python -m fire."""
87+
88+
if len(args) < 2:
89+
print(cli_string)
90+
exit(1)
91+
92+
module_or_filename = args[1]
93+
module, module_name = import_module(module_or_filename)
94+
3095
fire.Fire(module, name=module_name, command=args[2:])
3196

3297

fire/main_test.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Test using Fire via `python -m fire`."""
1616

1717
import os
18+
import tempfile
1819

1920
from fire import __main__
2021
from fire import testutils
@@ -31,11 +32,60 @@ def testNameSetting(self):
3132
def testArgPassing(self):
3233
expected = os.path.join('part1', 'part2', 'part3')
3334
with self.assertOutputMatches('%s\n' % expected):
34-
__main__.main(['__main__.py', 'os.path', 'join', 'part1', 'part2',
35-
'part3'])
35+
__main__.main(
36+
['__main__.py', 'os.path', 'join', 'part1', 'part2', 'part3'])
3637
with self.assertOutputMatches('%s\n' % expected):
37-
__main__.main(['__main__.py', 'os', 'path', '-', 'join', 'part1',
38-
'part2', 'part3'])
38+
__main__.main(
39+
['__main__.py', 'os', 'path', '-', 'join', 'part1', 'part2', 'part3'])
40+
41+
42+
class MainModuleFileTest(testutils.BaseTestCase):
43+
"""Tests to verify correct import behavior for file executables."""
44+
45+
def setUp(self):
46+
super(MainModuleFileTest, self).setUp()
47+
self.file = tempfile.NamedTemporaryFile(suffix='.py')
48+
self.file.write(b'class Foo:\n def double(self, n):\n return 2 * n\n')
49+
self.file.flush()
50+
51+
self.file2 = tempfile.NamedTemporaryFile()
52+
53+
def testFileNameFire(self):
54+
# Confirm that the file is correctly imported and doubles the number.
55+
with self.assertOutputMatches('4'):
56+
__main__.main(
57+
['__main__.py', self.file.name, 'Foo', 'double', '--n', '2'])
58+
59+
def testFileNameFailure(self):
60+
# Confirm that an existing file without a .py suffix raises a ValueError.
61+
with self.assertRaises(ValueError):
62+
__main__.main(
63+
['__main__.py', self.file2.name, 'Foo', 'double', '--n', '2'])
64+
65+
def testFileNameModuleDuplication(self):
66+
# Confirm that a file that masks a module still loads the module.
67+
with self.assertOutputMatches('gettempdir'):
68+
file = self.create_tempfile('tempfile')
69+
70+
with testutils.ChangeDirectory(os.path.dirname(file.full_path)):
71+
__main__.main([
72+
'__main__.py',
73+
'tempfile',
74+
])
75+
76+
def testFileNameModuleFileFailure(self):
77+
# Confirm that an invalid file that masks a non-existent module fails.
78+
with self.assertRaisesWithLiteralMatch(
79+
ValueError, 'Fire can only be called on .py files.'):
80+
file = self.create_tempfile('foobar')
81+
82+
with testutils.ChangeDirectory(os.path.dirname(file.full_path)):
83+
assert os.path.exists('foobar')
84+
85+
__main__.main([
86+
'__main__.py',
87+
'foobar',
88+
])
3989

4090

4191
if __name__ == '__main__':

fire/testutils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import print_function
2020

2121
import contextlib
22+
import os
2223
import re
2324
import sys
2425
import unittest
@@ -44,6 +45,7 @@ def assertOutputMatches(self, stdout='.*', stderr='.*', capture=True):
4445
stdout: (str) regexp to match against stdout (None will check no stdout)
4546
stderr: (str) regexp to match against stderr (None will check no stderr)
4647
capture: (bool, default True) do not bubble up stdout or stderr
48+
4749
Yields:
4850
Yields to the wrapped context.
4951
"""
@@ -80,6 +82,7 @@ def assertRaisesFireExit(self, code, regexp='.*'):
8082
Args:
8183
code: The status code that the FireExit should contain.
8284
regexp: stdout must match this regex.
85+
8386
Yields:
8487
Yields to the wrapped context.
8588
"""
@@ -89,12 +92,23 @@ def assertRaisesFireExit(self, code, regexp='.*'):
8992
yield
9093
except core.FireExit as exc:
9194
if exc.code != code:
92-
raise AssertionError('Incorrect exit code: %r != %r' % (exc.code,
93-
code))
95+
raise AssertionError('Incorrect exit code: %r != %r' %
96+
(exc.code, code))
9497
self.assertIsInstance(exc.trace, trace.FireTrace)
9598
raise
9699

97100

101+
@contextlib.contextmanager
102+
def ChangeDirectory(directory):
103+
cwdir = os.getcwd()
104+
os.chdir(directory)
105+
106+
try:
107+
yield directory
108+
finally:
109+
os.chdir(cwdir)
110+
111+
98112
# pylint: disable=invalid-name
99113
main = unittest.main
100114
skip = unittest.skip

0 commit comments

Comments
 (0)
0