diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..ba1b7f19 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,31 @@ +# Basic dependabot.yml file with minimum configuration for two package managers + +version: 2 +updates: + # Enable version updates for python + - package-ecosystem: "pip" + directory: ".github/scripts/" + schedule: + interval: "monthly" + labels: ["dependabot"] + pull-request-branch-name: + separator: "-" + open-pull-requests-limit: 5 + reviewers: + - "dbieber" + + # Enable version updates for GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" + groups: + gh-actions: + patterns: + - "*" # Check all dependencies + labels: ["dependabot"] + pull-request-branch-name: + separator: "-" + open-pull-requests-limit: 5 + reviewers: + - "dbieber" diff --git a/.github/scripts/build.sh b/.github/scripts/build.sh new file mode 100755 index 00000000..111257ae --- /dev/null +++ b/.github/scripts/build.sh @@ -0,0 +1,32 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env bash + +# Exit when any command fails. +set -e + +PYTHON_VERSION=${PYTHON_VERSION:-3.7} + +pip install -U -r .github/scripts/requirements.txt +python setup.py develop +python -m pytest # Run the tests without IPython. +pip install ipython +python -m pytest # Now run the tests with IPython. +pylint fire --ignore=test_components_py3.py,parser_fuzz_test.py,console +if [[ ${PYTHON_VERSION} == 3.7 ]]; then + # Run type-checking. + pip install pytype; + pytype -x fire/test_components_py3.py; +fi diff --git a/.github/scripts/requirements.txt b/.github/scripts/requirements.txt new file mode 100644 index 00000000..613c4da0 --- /dev/null +++ b/.github/scripts/requirements.txt @@ -0,0 +1,9 @@ +setuptools <=78.1.0 +pip +pylint <3.3.7 +pytest <=8.3.3 +pytest-pylint <=1.1.2 +pytest-runner <7.0.0 +termcolor <2.6.0 +hypothesis <6.133.0 +levenshtein <=0.26.1 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 00000000..75a687f3 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,38 @@ +name: Python Fire + +on: + push: + branches: ["master"] + pull_request: + branches: ["master"] + +defaults: + run: + shell: bash + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: ["macos-latest", "ubuntu-latest"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13.0-rc.2"] + include: + - {os: "ubuntu-22.04", python-version: "3.7"} + + steps: + # Checkout the repo. + - name: Checkout Python Fire repository + uses: actions/checkout@v4 + + # Set up Python environment. + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + # Build Python Fire using the build.sh script. + - name: Run build script + run: ./.github/scripts/build.sh + env: + PYTHON_VERSION: ${{ matrix.python-version }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..a2166684 --- /dev/null +++ b/.gitignore @@ -0,0 +1,103 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# PyCharm IDE +.idea/ + +# Type-checking +.pytype/ diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index bfefd2a8..00000000 --- a/.travis.yml +++ /dev/null @@ -1,11 +0,0 @@ -language: python -python: - - "2.7" - - "3.4" - - "3.5" - - "3.6" -before_install: - - pip install --upgrade setuptools pip -install: - - python setup.py develop -script: nosetests --ignore-files=parser_fuzz_test.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0786fdf4..b5d67c96 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -3,6 +3,14 @@ We'd love to accept your patches and contributions to this project. There are just a few small guidelines you need to follow. +First, read these guidelines. +Before you begin making changes, state your intent to do so in an Issue. +Then, fork the project. Make changes in your copy of the repository. +Then open a pull request once your changes are ready. +If this is your first contribution, sign the Contributor License Agreement. +A discussion about your change will follow, and if accepted your contribution +will be incorporated into the Python Fire codebase. + ## Contributor License Agreement Contributions to this project must be accompanied by a Contributor License @@ -17,8 +25,35 @@ again. ## Code reviews -All submissions, including submissions by project members, require review. We -use GitHub pull requests for this purpose. Consult [GitHub Help] for more -information on using pull requests. +All submissions, including submissions by project members, require review. +For changes introduced by non-Googlers, we use GitHub pull requests for this +purpose. Consult [GitHub Help] for more information on using pull requests. [GitHub Help]: https://help.github.com/articles/about-pull-requests/ + +## Code style + +In general, Python Fire follows the guidelines in the +[Google Python Style Guide]. + +In addition, the project follows a convention of: +- Maximum line length: 80 characters +- Indentation: 2 spaces (4 for line continuation) +- PascalCase for function and method names. +- Single quotes around strings, three double quotes around docstrings. + +[Google Python Style Guide]: http://google.github.io/styleguide/pyguide.html + +## Testing + +Python Fire uses [GitHub Actions](https://github.com/google/python-fire/actions) to run tests on each pull request. You can run +these tests yourself as well. To do this, first install the test dependencies +listed in setup.py (e.g. pytest, mock, termcolor, and hypothesis). +Then run the tests by running `pytest` in the root directory of the repository. + +## Linting + +Please run lint on your pull requests to make accepting the requests easier. +To do this, run `pylint fire` in the root directory of the repository. +Note that even if lint is passing, additional style changes to your submission +may be made during merging. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..1aba38f6 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include LICENSE diff --git a/README.md b/README.md index 8f53debd..1482d56d 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,28 @@ -# Python Fire -_Python Fire is a library for creating command line interfaces (CLIs) from -absolutely any Python object._ - -- Python Fire is a simple way to create a CLI in Python. [[1]](doc/benefits.md#simple-cli) -- Python Fire is a helpful tool for developing and debugging Python code. [[2]](doc/benefits.md#debugging) -- Python Fire helps with exploring existing code or turning other people's code -into a CLI. [[3]](doc/benefits.md#exploring) -- Python Fire makes transitioning between Bash and Python easier. [[4]](doc/benefits.md#bash) -- Python Fire makes using a Python REPL easier by setting up the REPL with the -modules and variables you'll need already imported and created. [[5]](doc/benefits.md#repl) +# Python Fire [![PyPI](https://img.shields.io/pypi/pyversions/fire.svg?style=plastic)](https://github.com/google/python-fire) + +_Python Fire is a library for automatically generating command line interfaces +(CLIs) from absolutely any Python object._ + +- Python Fire is a simple way to create a CLI in Python. + [[1]](docs/benefits.md#simple-cli) +- Python Fire is a helpful tool for developing and debugging Python code. + [[2]](docs/benefits.md#debugging) +- Python Fire helps with exploring existing code or turning other people's + code into a CLI. [[3]](docs/benefits.md#exploring) +- Python Fire makes transitioning between Bash and Python easier. + [[4]](docs/benefits.md#bash) +- Python Fire makes using a Python REPL easier by setting up the REPL with the + modules and variables you'll need already imported and created. + [[5]](docs/benefits.md#repl) ## Installation -`pip install fire` +To install Python Fire with pip, run: `pip install fire` + +To install Python Fire with conda, run: `conda install fire -c conda-forge` + +To install Python Fire from source, first clone the repository and then run: +`python setup.py install` ## Basic Usage @@ -20,7 +30,27 @@ You can call `Fire` on any Python object:
functions, classes, modules, objects, dictionaries, lists, tuples, etc. They all work! -Here's a simple example. +Here's an example of calling Fire on a function. + +```python +import fire + +def hello(name="World"): + return "Hello %s!" % name + +if __name__ == '__main__': + fire.Fire(hello) +``` + +Then, from the command line, you can run: + +```bash +python hello.py # Hello World! +python hello.py --name=David # Hello David! +python hello.py --help # Shows usage information. +``` + +Here's an example of calling Fire on a class. ```python import fire @@ -43,13 +73,17 @@ python calculator.py double --number=15 # 30 ``` To learn how Fire behaves on functions, objects, dicts, lists, etc, and to learn -about Fire's other features, see the [Using a Fire CLI page](doc/using-cli.md). +about Fire's other features, see the [Using a Fire CLI page](docs/using-cli.md). +For additional examples, see [The Python Fire Guide](docs/guide.md). ## Why is it called Fire? When you call `Fire`, it fires off (executes) your command. +## Where can I learn more? + +Please see [The Python Fire Guide](docs/guide.md). ## Reference @@ -63,16 +97,21 @@ When you call `Fire`, it fires off (executes) your command. | Call | `fire.Fire()` | Turns the current module into a Fire CLI. | Call | `fire.Fire(component)` | Turns `component` into a Fire CLI. -| Using a CLI | Command | Notes -| :------------- | :------------------------- | :--------- -| [Help](doc/using-cli.md#help-flag) | `command -- --help` | -| [REPL](doc/using-cli.md#interactive-flag) | `command -- --interactive` | Enters interactive mode. -| [Separator](doc/using-cli.md#separator-flag) | `command -- --separator=X` | This sets the separator to `X`. The default separator is `-`. -| [Completion](doc/using-cli.md#completion-flag) | `command -- --completion` | Generate a completion script for the CLI. -| [Trace](doc/using-cli.md#trace-flag) | `command -- --trace` | Gets a Fire trace for the command. -| [Verbose](doc/using-cli.md#verbose-flag) | `command -- --verbose` | -_Note that flags are separated from the Fire command by an isolated `--` arg._ +| Using a CLI | Command | Notes +| :---------------------------------------------- | :-------------------------------------- | :---- +| [Help](docs/using-cli.md#help-flag) | `command --help` or `command -- --help` | +| [REPL](docs/using-cli.md#interactive-flag) | `command -- --interactive` | Enters interactive mode. +| [Separator](docs/using-cli.md#separator-flag) | `command -- --separator=X` | Sets the separator to `X`. The default separator is `-`. +| [Completion](docs/using-cli.md#completion-flag) | `command -- --completion [shell]` | Generates a completion script for the CLI. +| [Trace](docs/using-cli.md#trace-flag) | `command -- --trace` | Gets a Fire trace for the command. +| [Verbose](docs/using-cli.md#verbose-flag) | `command -- --verbose` | + +_Note that these flags are separated from the Fire command by an isolated `--`._ + +## License +Licensed under the +[Apache 2.0](https://github.com/google/python-fire/blob/master/LICENSE) License. ## Disclaimer diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 00000000..aae92cd6 --- /dev/null +++ b/docs/api.md @@ -0,0 +1,46 @@ +## Python Fire Quick Reference + +| Setup | Command | Notes +| ------- | ------------------- | ---------- +| install | `pip install fire` | Installs fire from pypi + +| Creating a CLI | Command | Notes +| ---------------| ---------------------- | ---------- +| import | `import fire` | +| Call | `fire.Fire()` | Turns the current module into a Fire CLI. +| Call | `fire.Fire(component)` | Turns `component` into a Fire CLI. + +| Using a CLI | Command | Notes | +| ------------------------------------------ | ----------------- | -------------- | +| [Help](using-cli.md#help-flag) | `command --help` | Show the help screen. | +| [REPL](using-cli.md#interactive-flag) | `command -- --interactive` | Enters interactive mode. | +| [Separator](using-cli.md#separator-flag) | `command -- --separator=X` | This sets the separator to `X`. The default separator is `-`. | +| [Completion](using-cli.md#completion-flag) | `command -- --completion [shell]` | Generate a completion script for the CLI. | +| [Trace](using-cli.md#trace-flag) | `command -- --trace` | Gets a Fire trace for the command. | +| [Verbose](using-cli.md#verbose-flag) | `command -- --verbose` | | + +_Note that flags are separated from the Fire command by an isolated `--` arg. +Help is an exception; the isolated `--` is optional for getting help._ + +## Arguments for Calling fire.Fire() + +| Argument | Usage | Notes | +| --------- | ------------------------- | ------------------------------------ | +| component | `fire.Fire(component)` | If omitted, defaults to a dict of all locals and globals. | +| command | `fire.Fire(command='hello --name=5')` | Either a string or a list of arguments. If a string is provided, it is split to determine the arguments. If a list or tuple is provided, they are the arguments. If `command` is omitted, then `sys.argv[1:]` (the arguments from the command line) are used by default. | +| name | `fire.Fire(name='tool')` | The name of the CLI, ideally the name users will enter to run the CLI. This name will be used in the CLI's help screens. If the argument is omitted, it will be inferred automatically.| +| serialize | `fire.Fire(serialize=custom_serializer)` | If omitted, simple types are serialized via their builtin str method, and any objects that define a custom `__str__` method are serialized with that. If specified, all objects are serialized to text via the provided method. | + +## Using a Fire CLI without modifying any code + +You can use Python Fire on a module without modifying the code of the module. +The syntax for this is: + +`python -m fire ` + +or + +`python -m fire ` + +For example, `python -m fire calendar -h` will treat the built in `calendar` +module as a CLI and provide its help. diff --git a/doc/benefits.md b/docs/benefits.md similarity index 81% rename from doc/benefits.md rename to docs/benefits.md index 1ba95513..ac09f0be 100644 --- a/doc/benefits.md +++ b/docs/benefits.md @@ -1,13 +1,14 @@ # Benefits of Python Fire -## Python Fire is a simple way to create a CLI in Python. + +## Create CLIs in Python It's dead simple. Simply write the functionality you want exposed at the command line as a function / module / class, and then call Fire. With this addition of a single-line call to Fire, your CLI is ready to go. - -## Python Fire is a helpful tool for developing and debugging Python code. + +## Develop and debug Python code When you're writing a Python library, you probably want to try it out as you go. You could write a main method to check the functionality you're interested in, @@ -23,8 +24,8 @@ a main method. And if you use the `--interactive` flag to enter an IPython REPL then you don't need to load the imports or create your variables; they'll already be ready for use as soon as you start the REPL. - -## Python Fire helps with exploring existing code or turning other people's code into a CLI. + +## Explore existing code; turn other people's code into a CLI You can take an existing module, maybe even one that you don't have access to the source code for, and call `Fire` on it. This lets you easily see what @@ -40,8 +41,8 @@ The auto-generated help strings that Fire provides when you run a Fire CLI allow you to see all the functionality these modules provide in a concise manner. - -## Python Fire makes transitioning between Bash and Python easier. + +## Transition between Bash and Python Using Fire lets you call Python directly from Bash. So you can mix your Python functions with the unix tools you know and love, like `grep`, `xargs`, `wc`, @@ -51,8 +52,8 @@ Additionally since writing CLIs in Python requires only a single call to Fire, it is now easy to write even one-off scripts that would previously have been in Bash, in Python. - -## Python Fire makes using a Python REPL easier by setting up the REPL with the modules and variables you'll need already imported and created. + +## Explore code in a Python REPL When you use the `--interactive` flag to enter an IPython REPL, it starts with variables and modules already defined for you. You don't need to waste time diff --git a/docs/guide.md b/docs/guide.md new file mode 100644 index 00000000..444a76ff --- /dev/null +++ b/docs/guide.md @@ -0,0 +1,780 @@ +## The Python Fire Guide + +### Introduction + +Welcome to the Python Fire guide! Python Fire is a Python library that will turn +any Python component into a command line interface with just a single call to +`Fire`. + +Let's get started! + +### Installation + +To install Python Fire from pypi, run: + +`pip install fire` + +Alternatively, to install Python Fire from source, clone the source and run: + +`python setup.py install` + +### Hello World + +##### Version 1: `fire.Fire()` + +The easiest way to use Fire is to take any Python program, and then simply call +`fire.Fire()` at the end of the program. This will expose the full contents of +the program to the command line. + +```python +import fire + +def hello(name): + return f'Hello {name}!' + +if __name__ == '__main__': + fire.Fire() +``` + +Here's how we can run our program from the command line: + +```bash +$ python example.py hello World +Hello World! +``` + +##### Version 2: `fire.Fire()` + +Let's modify our program slightly to only expose the `hello` function to the +command line. + +```python +import fire + +def hello(name): + return f'Hello {name}!' + +if __name__ == '__main__': + fire.Fire(hello) +``` + +Here's how we can run this from the command line: + +```bash +$ python example.py World +Hello World! +``` + +Notice we no longer have to specify to run the `hello` function, because we +called `fire.Fire(hello)`. + +##### Version 3: Using a main + +We can alternatively write this program like this: + +```python +import fire + +def hello(name): + return f'Hello {name}!' + +def main(): + fire.Fire(hello) + +if __name__ == '__main__': + main() +``` + +Or if we're using +[entry points](https://setuptools.readthedocs.io/en/latest/pkg_resources.html#entry-points), +then simply this: + +```python +import fire + +def hello(name): + return f'Hello {name}!' + +def main(): + fire.Fire(hello) +``` + +##### Version 4: Fire Without Code Changes + +If you have a file `example.py` that doesn't even import fire: + +```python +def hello(name): + return f'Hello {name}!' +``` + +Then you can use it with Fire like this: + +```bash +$ python -m fire example hello --name=World +Hello World! +``` + +You can also specify the filepath of example.py rather than its module path, +like so: + +```bash +$ python -m fire example.py hello --name=World +Hello World! +``` + +### Exposing Multiple Commands + +In the previous example, we exposed a single function to the command line. Now +we'll look at ways of exposing multiple functions to the command line. + +##### Version 1: `fire.Fire()` + +The simplest way to expose multiple commands is to write multiple functions, and +then call Fire. + +```python +import fire + +def add(x, y): + return x + y + +def multiply(x, y): + return x * y + +if __name__ == '__main__': + fire.Fire() +``` + +We can use this like so: + +```bash +$ python example.py add 10 20 +30 +$ python example.py multiply 10 20 +200 +``` + +You'll notice that Fire correctly parsed `10` and `20` as numbers, rather than +as strings. Read more about [argument parsing here](#argument-parsing). + +##### Version 2: `fire.Fire()` + +In version 1 we exposed all the program's functionality to the command line. By +using a dict, we can selectively expose functions to the command line. + +```python +import fire + +def add(x, y): + return x + y + +def multiply(x, y): + return x * y + +if __name__ == '__main__': + fire.Fire({ + 'add': add, + 'multiply': multiply, + }) +``` + +We can use this in the same way as before: + +```bash +$ python example.py add 10 20 +30 +$ python example.py multiply 10 20 +200 +``` + +##### Version 3: `fire.Fire()` + +Fire also works on objects, as in this variant. This is a good way to expose +multiple commands. + +```python +import fire + +class Calculator(object): + + def add(self, x, y): + return x + y + + def multiply(self, x, y): + return x * y + +if __name__ == '__main__': + calculator = Calculator() + fire.Fire(calculator) +``` + +We can use this in the same way as before: + +```bash +$ python example.py add 10 20 +30 +$ python example.py multiply 10 20 +200 +``` + + +##### Version 4: `fire.Fire()` + +Fire also works on classes. This is another good way to expose multiple +commands. + +```python +import fire + +class Calculator(object): + + def add(self, x, y): + return x + y + + def multiply(self, x, y): + return x * y + +if __name__ == '__main__': + fire.Fire(Calculator) +``` + +We can use this in the same way as before: + +```bash +$ python example.py add 10 20 +30 +$ python example.py multiply 10 20 +200 +``` + +Why might you prefer a class over an object? One reason is that you can pass +arguments for constructing the class too, as in this broken calculator example. + +```python +import fire + +class BrokenCalculator(object): + + def __init__(self, offset=1): + self._offset = offset + + def add(self, x, y): + return x + y + self._offset + + def multiply(self, x, y): + return x * y + self._offset + +if __name__ == '__main__': + fire.Fire(BrokenCalculator) +``` + +When you use a broken calculator, you get wrong answers: + +```bash +$ python example.py add 10 20 +31 +$ python example.py multiply 10 20 +201 +``` + +But you can always fix it: + +```bash +$ python example.py add 10 20 --offset=0 +30 +$ python example.py multiply 10 20 --offset=0 +200 +``` + +Unlike calling ordinary functions, which can be done both with positional +arguments and named arguments (--flag syntax), arguments to \_\_init\_\_ +functions must be passed with the --flag syntax. See the section on +[calling functions](#calling-functions) for more. + +### Grouping Commands + +Here's an example of how you might make a command line interface with grouped +commands. + +```python +class IngestionStage(object): + + def run(self): + return 'Ingesting! Nom nom nom...' + +class DigestionStage(object): + + def run(self, volume=1): + return ' '.join(['Burp!'] * volume) + + def status(self): + return 'Satiated.' + +class Pipeline(object): + + def __init__(self): + self.ingestion = IngestionStage() + self.digestion = DigestionStage() + + def run(self): + ingestion_output = self.ingestion.run() + digestion_output = self.digestion.run() + return [ingestion_output, digestion_output] + +if __name__ == '__main__': + fire.Fire(Pipeline) +``` + +Here's how this looks at the command line: + +```bash +$ python example.py run +Ingesting! Nom nom nom... +Burp! +$ python example.py ingestion run +Ingesting! Nom nom nom... +$ python example.py digestion run +Burp! +$ python example.py digestion status +Satiated. +``` + +You can nest your commands in arbitrarily complex ways, if you're feeling grumpy +or adventurous. + + +### Accessing Properties + +In the examples we've looked at so far, our invocations of `python example.py` +have all run some function from the example program. In this example, we simply +access a property. + +```python +from airports import airports + +import fire + +class Airport(object): + + def __init__(self, code): + self.code = code + self.name = dict(airports).get(self.code) + self.city = self.name.split(',')[0] if self.name else None + +if __name__ == '__main__': + fire.Fire(Airport) +``` + +Now we can use this program to learn about airport codes! + +```bash +$ python example.py --code=JFK code +JFK +$ python example.py --code=SJC name +San Jose-Sunnyvale-Santa Clara, CA - Norman Y. Mineta San Jose International (SJC) +$ python example.py --code=ALB city +Albany-Schenectady-Troy +``` + +By the way, you can find this +[airports module here](https://github.com/trendct-data/airports.py). + +### Chaining Function Calls + +When you run a Fire CLI, you can take all the same actions on the _result_ of +the call to Fire that you can take on the original object passed in. + +For example, we can use our Airport CLI from the previous example like this: + +```bash +$ python example.py --code=ALB city upper +ALBANY-SCHENECTADY-TROY +``` + +This works since `upper` is a method on all strings. + +So, if you want to set up your functions to chain nicely, all you have to do is +have a class whose methods return self. Here's an example. + +```python +import fire + +class BinaryCanvas(object): + """A canvas with which to make binary art, one bit at a time.""" + + def __init__(self, size=10): + self.pixels = [[0] * size for _ in range(size)] + self._size = size + self._row = 0 # The row of the cursor. + self._col = 0 # The column of the cursor. + + def __str__(self): + return '\n'.join(' '.join(str(pixel) for pixel in row) for row in self.pixels) + + def show(self): + print(self) + return self + + def move(self, row, col): + self._row = row % self._size + self._col = col % self._size + return self + + def on(self): + return self.set(1) + + def off(self): + return self.set(0) + + def set(self, value): + self.pixels[self._row][self._col] = value + return self + +if __name__ == '__main__': + fire.Fire(BinaryCanvas) +``` + +Now we can draw stuff :). + +```bash +$ python example.py move 3 3 on move 3 6 on move 6 3 on move 6 6 on move 7 4 on move 7 5 on +0 0 0 0 0 0 0 0 0 0 +0 0 0 0 0 0 0 0 0 0 +0 0 0 0 0 0 0 0 0 0 +0 0 0 1 0 0 1 0 0 0 +0 0 0 0 0 0 0 0 0 0 +0 0 0 0 0 0 0 0 0 0 +0 0 0 1 0 0 1 0 0 0 +0 0 0 0 1 1 0 0 0 0 +0 0 0 0 0 0 0 0 0 0 +0 0 0 0 0 0 0 0 0 0 +``` + +It's supposed to be a smiley face. + +### Custom Serialization + +You'll notice in the BinaryCanvas example, the canvas with the smiley face was +printed to the screen. You can determine how a component will be serialized by +defining its `__str__` method. + +If a custom `__str__` method is present on the final component, the object is +serialized and printed. If there's no custom `__str__` method, then the help +screen for the object is shown instead. + +### Can we make an even simpler example than Hello World? + +Yes, this program is even simpler than our original Hello World example. + +```python +import fire +english = 'Hello World' +spanish = 'Hola Mundo' +fire.Fire() +``` + +You can use it like this: + +```bash +$ python example.py english +Hello World +$ python example.py spanish +Hola Mundo +``` + +### Calling Functions + +Arguments to a constructor are passed by name using flag syntax `--name=value`. + +For example, consider this simple class: + +```python +import fire + +class Building(object): + + def __init__(self, name, stories=1): + self.name = name + self.stories = stories + + def climb_stairs(self, stairs_per_story=10): + for story in range(self.stories): + for stair in range(1, stairs_per_story): + yield stair + yield 'Phew!' + yield 'Done!' + +if __name__ == '__main__': + fire.Fire(Building) +``` + +We can instantiate it as follows: `python example.py --name="Sherrerd Hall"` + +Arguments to other functions may be passed positionally or by name using flag +syntax. + +To instantiate a `Building` and then run the `climb_stairs` function, the +following commands are all valid: + +```bash +$ python example.py --name="Sherrerd Hall" --stories=3 climb_stairs 10 +$ python example.py --name="Sherrerd Hall" climb_stairs --stairs_per_story=10 +$ python example.py --name="Sherrerd Hall" climb_stairs --stairs-per-story 10 +$ python example.py climb-stairs --stairs-per-story 10 --name="Sherrerd Hall" +``` + +You'll notice that hyphens and underscores (`-` and `_`) are interchangeable in +member names and flag names. + +You'll also notice that the constructor's arguments can come after the +function's arguments or before the function. + +You'll also notice that the equal sign between the flag name and its value is +optional. + +##### Functions with `*varargs` and `**kwargs` + +Fire supports functions that take \*varargs or \*\*kwargs. Here's an example: + +```python +import fire + +def order_by_length(*items): + """Orders items by length, breaking ties alphabetically.""" + sorted_items = sorted(items, key=lambda item: (len(str(item)), str(item))) + return ' '.join(sorted_items) + +if __name__ == '__main__': + fire.Fire(order_by_length) +``` + +To use it, we run: + +```bash +$ python example.py dog cat elephant +cat dog elephant +``` + +You can use a separator to indicate that you're done providing arguments to a +function. All arguments after the separator will be used to process the result +of the function, rather than being passed to the function itself. The default +separator is the hyphen `-`. + +Here's an example where we use a separator. + +```bash +$ python example.py dog cat elephant - upper +CAT DOG ELEPHANT +``` + +Without the separator, upper would have been treated as another argument. + +```bash +$ python example.py dog cat elephant upper +cat dog upper elephant +``` + +You can change the separator with the `--separator` flag. Flags are always +separated from your Fire command by an isolated `--`. Here's an example where we +change the separator. + +```bash +$ python example.py dog cat elephant X upper -- --separator=X +CAT DOG ELEPHANT +``` + +Separators can be useful when a function accepts \*varargs, \*\*kwargs, or +default values that you don't want to specify. It is also important to remember +to change the separator if you want to pass `-` as an argument. + + +##### Async Functions + +Fire supports calling async functions too. Here's a simple example. + +```python +import asyncio + +async def count_to_ten(): + for i in range(1, 11): + await asyncio.sleep(1) + print(i) + +if __name__ == '__main__': + fire.Fire(count_to_ten) +``` + +Whenever fire encounters a coroutine function, it runs it, blocking until it completes. + + +### Argument Parsing + +The types of the arguments are determined by their values, rather than by the +function signature where they're used. You can pass any Python literal from the +command line: numbers, strings, tuples, lists, dictionaries, (sets are only +supported in some versions of Python). You can also nest the collections +arbitrarily as long as they only contain literals. + +To demonstrate this, we'll make a small example program that tells us the type +of any argument we give it: + +```python +import fire +fire.Fire(lambda obj: type(obj).__name__) +``` + +And we'll use it like so: + +```bash +$ python example.py 10 +int +$ python example.py 10.0 +float +$ python example.py hello +str +$ python example.py '(1,2)' +tuple +$ python example.py [1,2] +list +$ python example.py True +bool +$ python example.py {name:David} +dict +``` + +You'll notice in that last example that bare-words are automatically replaced +with strings. + +Be careful with your quotes! If you want to pass the string `"10"`, rather than +the int `10`, you'll need to either escape or quote your quotes. Otherwise Bash +will eat your quotes and pass an unquoted `10` to your Python program, where +Fire will interpret it as a number. + + +```bash +$ python example.py 10 +int +$ python example.py "10" +int +$ python example.py '"10"' +str +$ python example.py "'10'" +str +$ python example.py \"10\" +str +``` + +Be careful with your quotes! Remember that Bash processes your arguments first, +and then Fire parses the result of that. +If you wanted to pass the dict `{"name": "David Bieber"}` to your program, you +might try this: + +```bash +$ python example.py '{"name": "David Bieber"}' # Good! Do this. +dict +$ python example.py {"name":'"David Bieber"'} # Okay. +dict +$ python example.py {"name":"David Bieber"} # Wrong. This is parsed as a string. +str +$ python example.py {"name": "David Bieber"} # Wrong. This isn't even treated as a single argument. + +$ python example.py '{"name": "Justin Bieber"}' # Wrong. This is not the Bieber you're looking for. (The syntax is fine though :)) +dict +``` + +##### Boolean Arguments + +The tokens `True` and `False` are parsed as boolean values. + +You may also specify booleans via flag syntax `--name` and `--noname`, which set +`name` to `True` and `False` respectively. + +Continuing the previous example, we could run any of the following: + +```bash +$ python example.py --obj=True +bool +$ python example.py --obj=False +bool +$ python example.py --obj +bool +$ python example.py --noobj +bool +``` + +Be careful with boolean flags! If a token other than another flag immediately +follows a flag that's supposed to be a boolean, the flag will take on the value +of the token rather than the boolean value. You can resolve this: by putting a +separator after your last flag, by explicitly stating the value of the boolean +flag (as in `--obj=True`), or by making sure there's another flag after any +boolean flag argument. + + +### Using Fire Flags + +Fire CLIs all come with a number of flags. These flags should be separated from +the Fire command by an isolated `--`. If there is at least one isolated `--` +argument, then arguments after the final isolated `--` are treated as flags, +whereas all arguments before the final isolated `--` are considered part of the +Fire command. + +One useful flag is the `--interactive` flag. Use the `--interactive` flag on any +CLI to enter a Python REPL with all the modules and variables used in the +context where `Fire` was called already available to you for use. Other useful +variables, such as the result of the Fire command will also be available. Use +this feature like this: `python example.py -- --interactive`. + +You can add the help flag to any command to see help and usage information. Fire +incorporates your docstrings into the help and usage information that it +generates. Fire will try to provide help even if you omit the isolated `--` +separating the flags from the Fire command, but may not always be able to, since +`help` is a valid argument name. Use this feature like this: `python +example.py -- --help` or `python example.py --help` (or even `python example.py +-h`). + +The complete set of flags available is shown below, in the reference section. + + +### Reference + +| Setup | Command | Notes +| :------ | :------------------ | :--------- +| install | `pip install fire` | + +##### Creating a CLI + +| Creating a CLI | Command | Notes +| :--------------| :--------------------- | :--------- +| import | `import fire` | +| Call | `fire.Fire()` | Turns the current module into a Fire CLI. +| Call | `fire.Fire(component)` | Turns `component` into a Fire CLI. + +##### Flags + +| Using a CLI | Command | Notes +| :------------- | :------------------------- | :--------- +| [Help](using-cli.md#help-flag) | `command -- --help` | Show help and usage information for the command. +| [REPL](using-cli.md#interactive-flag) | `command -- --interactive` | Enter interactive mode. +| [Separator](using-cli.md#separator-flag) | `command -- --separator=X` | This sets the separator to `X`. The default separator is `-`. +| [Completion](using-cli.md#completion-flag) | `command -- --completion [shell]` | Generate a completion script for the CLI. +| [Trace](using-cli.md#trace-flag) | `command -- --trace` | Gets a Fire trace for the command. +| [Verbose](using-cli.md#verbose-flag) | `command -- --verbose` | Include private members in the output. + +_Note that flags are separated from the Fire command by an isolated `--` arg. +Help is an exception; the isolated `--` is optional for getting help._ + + +##### Arguments for Calling fire.Fire() + +| Argument | Usage | Notes | +| --------- | ------------------------- | ------------------------------------ | +| component | `fire.Fire(component)` | If omitted, defaults to a dict of all locals and globals. | +| command | `fire.Fire(command='hello --name=5')` | Either a string or a list of arguments. If a string is provided, it is split to determine the arguments. If a list or tuple is provided, they are the arguments. If `command` is omitted, then `sys.argv[1:]` (the arguments from the command line) are used by default. | +| name | `fire.Fire(name='tool')` | The name of the CLI, ideally the name users will enter to run the CLI. This name will be used in the CLI's help screens. If the argument is omitted, it will be inferred automatically.| +| serialize | `fire.Fire(serialize=custom_serializer)` | If omitted, simple types are serialized via their builtin str method, and any objects that define a custom `__str__` method are serialized with that. If specified, all objects are serialized to text via the provided method. | + + +### Disclaimer + +Python Fire is not an official Google product. diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 00000000..8dcc5db6 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,119 @@ +# Python Fire [![PyPI](https://img.shields.io/pypi/pyversions/fire.svg?style=plastic)](https://github.com/google/python-fire) + +_Python Fire is a library for automatically generating command line interfaces +(CLIs) from absolutely any Python object._ + +- Python Fire is a simple way to create a CLI in Python. + [[1]](benefits.md#simple-cli) +- Python Fire is a helpful tool for developing and debugging Python code. + [[2]](benefits.md#debugging) +- Python Fire helps with exploring existing code or turning other people's + code into a CLI. [[3]](benefits.md#exploring) +- Python Fire makes transitioning between Bash and Python easier. + [[4]](benefits.md#bash) +- Python Fire makes using a Python REPL easier by setting up the REPL with the + modules and variables you'll need already imported and created. + [[5]](benefits.md#repl) + +## Installation + +To install Python Fire with pip, run: `pip install fire` + +To install Python Fire with conda, run: `conda install fire -c conda-forge` + +To install Python Fire from source, first clone the repository and then run: +`python setup.py install` + +## Basic Usage + +You can call `Fire` on any Python object:
+functions, classes, modules, objects, dictionaries, lists, tuples, etc. +They all work! + +Here's an example of calling Fire on a function. + +```python +import fire + +def hello(name="World"): + return "Hello %s!" % name + +if __name__ == '__main__': + fire.Fire(hello) +``` + +Then, from the command line, you can run: + +```bash +python hello.py # Hello World! +python hello.py --name=David # Hello David! +python hello.py --help # Shows usage information. +``` + +Here's an example of calling Fire on a class. + +```python +import fire + +class Calculator(object): + """A simple calculator class.""" + + def double(self, number): + return 2 * number + +if __name__ == '__main__': + fire.Fire(Calculator) +``` + +Then, from the command line, you can run: + +```bash +python calculator.py double 10 # 20 +python calculator.py double --number=15 # 30 +``` + +To learn how Fire behaves on functions, objects, dicts, lists, etc, and to learn +about Fire's other features, see the [Using a Fire CLI page](using-cli.md). + +For additional examples, see [The Python Fire Guide](guide.md). + +## Why is it called Fire? + +When you call `Fire`, it fires off (executes) your command. + +## Where can I learn more? + +Please see [The Python Fire Guide](guide.md). + +## Reference + +| Setup | Command | Notes +| :------ | :------------------ | :--------- +| install | `pip install fire` | + +| Creating a CLI | Command | Notes +| :--------------| :--------------------- | :--------- +| import | `import fire` | +| Call | `fire.Fire()` | Turns the current module into a Fire CLI. +| Call | `fire.Fire(component)` | Turns `component` into a Fire CLI. + +| Using a CLI | Command | Notes +| :---------------------------------------------- | :-------------------------------------- | :---- +| [Help](using-cli.md#help-flag) | `command --help` or `command -- --help` | +| [REPL](using-cli.md#interactive-flag) | `command -- --interactive` | Enters interactive mode. +| [Separator](using-cli.md#separator-flag) | `command -- --separator=X` | Sets the separator to `X`. The default separator is `-`. +| [Completion](using-cli.md#completion-flag) | `command -- --completion [shell]` | Generates a completion script for the CLI. +| [Trace](using-cli.md#trace-flag) | `command -- --trace` | Gets a Fire trace for the command. +| [Verbose](using-cli.md#verbose-flag) | `command -- --verbose` | + +_Note that flags are separated from the Fire command by an isolated `--` arg. +Help is an exception; the isolated `--` is optional for getting help._ + +## License + +Licensed under the +[Apache 2.0](https://github.com/google/python-fire/blob/master/LICENSE) License. + +## Disclaimer + +This is not an official Google product. diff --git a/docs/installation.md b/docs/installation.md new file mode 100644 index 00000000..7e4cccb8 --- /dev/null +++ b/docs/installation.md @@ -0,0 +1,8 @@ +# Installation + +To install Python Fire with pip, run: `pip install fire` + +To install Python Fire with conda, run: `conda install fire -c conda-forge` + +To install Python Fire from source, first clone the repository and then run +`python setup.py install`. To install from source for development, instead run `python setup.py develop`. diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md new file mode 100644 index 00000000..3ef6b548 --- /dev/null +++ b/docs/troubleshooting.md @@ -0,0 +1,13 @@ +# Troubleshooting + +This page describes known issues that users of Python Fire have run into. If you +have an issue not resolved here, consider opening a +[GitHub Issue](https://github.com/google/python-fire/issues). + +### Issue [#19](https://github.com/google/python-fire/issues/19): Don't name your module "cmd" + +If you have a module name that conflicts with the name of a builtin module, then +when Fire goes to import the builtin module, it will import your module instead. +This will result in an error, possibly an `AttributeError`. Specifically, do not +name your module any of the following: +sys, linecache, cmd, bdb, repr, os, re, pprint, traceback diff --git a/doc/using-cli.md b/docs/using-cli.md similarity index 87% rename from doc/using-cli.md rename to docs/using-cli.md index ab4050bc..bdfcb7db 100644 --- a/doc/using-cli.md +++ b/docs/using-cli.md @@ -1,7 +1,5 @@ # Using a Fire CLI - - ## Basic usage Every Fire command corresponds to a Python component. @@ -11,10 +9,13 @@ arguments. This command corresponds to the Python component you called the `Fire` function on. If you did not supply an object in the call to `Fire`, then the context in which `Fire` was called will be used as the Python component. -You can append `-- --help` to any command to see what Python component it +You can append `--help` or `-h` to a command to see what Python component it corresponds to, as well as the various ways in which you can extend the command. -Flags are always separated from the Fire command by an isolated `--` in order -to distinguish between flags and named arguments. + +Flags to Fire should be separated from the Fire command by an isolated `--` in +order to distinguish between flags and named arguments. So, for example, to +enter interactive mode append `-- -i` or `-- --interactive` to any command. To +use Fire in verbose mode, append `-- --verbose`. Given a Fire command that corresponds to a Python object, you can extend that command to access a member of that object, call it with arguments if it is a @@ -56,7 +57,7 @@ If your command corresponds to a list or tuple, you can extend your command by adding the index of an element of the component to your command as an argument. For example, `widget function-that-returns-list 2` will correspond to item 2 of -the result of function_that_returns_list. +the result of `function_that_returns_list`. ### Calling a function @@ -89,10 +90,12 @@ See also the section on [Changing the Separator](#separator-flag). ### Instantiating a class If your command corresponds to a class, you can extend your command by adding -the arguments of the class's \_\_init\_\_ function. Arguments must be specified +the arguments of the class's `__init__` function. Arguments must be specified by name, using the flags syntax. See the section on [calling a function](#calling-a-function) for more details. +Similarly, when passing arguments to a callable object (an object with a custom +`__call__` function), those arguments must be passed using flags syntax. ## Using Flags with Fire CLIs @@ -102,8 +105,8 @@ after the final standalone `--` argument. (If there is no `--` argument, then no arguments are used for flags.) For example, to set the alsologtostderr flag, you could run the command: -`widget bang --noise=boom -- --alsologtostderr`. The --noise argument is -consumed by Fire, but the --alsologtostderr argument is treated as a normal +`widget bang --noise=boom -- --alsologtostderr`. The `--noise` argument is +consumed by Fire, but the `--alsologtostderr` argument is treated as a normal Flag. All CLIs built with Python Fire share some flags, as described in the next @@ -134,13 +137,19 @@ will put you in an IPython REPL, with the variable `widget` already defined. You can then explore the Python object that `widget` corresponds to interactively using Python. +Note: if you want fire to start the IPython REPL instead of the regular Python one, +the `ipython` package needs to be installed in your environment. + ### `--completion`: Generating a completion script Call `widget -- --completion` to generate a completion script for the Fire CLI `widget`. To save the completion script to your home directory, you could e.g. run `widget -- --completion > ~/.widget-completion`. You should then source this -file; to get permanent completion, source this file from your .bashrc file. +file; to get permanent completion, source this file from your `.bashrc` file. + +Call `widget -- --completion fish` to generate a completion script for the Fish +shell. Source this file from your fish.config. If the commands available in the Fire CLI change, you'll have to regenerate the completion script and source it again. @@ -168,7 +177,7 @@ corresponds to, as well as usage information for how to extend that command. ### `--trace`: Getting a Fire trace In order to understand what is happening when you call Python Fire, it can be -useful to request a trace. This is done via the --trace flag, e.g. +useful to request a trace. This is done via the `--trace` flag, e.g. `widget whack 5 -- --trace`. A trace provides step by step information about how the Fire command was diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/cipher/__init__.py b/examples/cipher/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/cipher/cipher.py b/examples/cipher/cipher.py index 26a07756..83610a5d 100644 --- a/examples/cipher/cipher.py +++ b/examples/cipher/cipher.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/cipher/cipher_test.py b/examples/cipher/cipher_test.py index 904eef14..d2fb5c5f 100644 --- a/examples/cipher/cipher_test.py +++ b/examples/cipher/cipher_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from fire.examples.cipher import cipher +"""Tests for the cipher module.""" -import unittest +from fire import testutils +from examples.cipher import cipher -class CipherTest(unittest.TestCase): + +class CipherTest(testutils.BaseTestCase): def testCipher(self): self.assertEqual(cipher.rot13('Hello world!'), 'Uryyb jbeyq!') @@ -29,4 +31,4 @@ def testCipher(self): if __name__ == '__main__': - unittest.main() + testutils.main() diff --git a/examples/diff/__init__.py b/examples/diff/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/diff/diff.py b/examples/diff/diff.py index be5b8dab..f99e525e 100644 --- a/examples/diff/diff.py +++ b/examples/diff/diff.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -74,8 +74,10 @@ def __init__(self, fromfile, tofile): self.fromdate = time.ctime(os.stat(fromfile).st_mtime) self.todate = time.ctime(os.stat(tofile).st_mtime) - self.fromlines = open(fromfile, 'U').readlines() - self.tolines = open(tofile, 'U').readlines() + with open(fromfile) as f: + self.fromlines = f.readlines() + with open(tofile) as f: + self.tolines = f.readlines() def unified_diff(self, lines=3): return difflib.unified_diff( diff --git a/examples/diff/diff_test.py b/examples/diff/diff_test.py index f6b980ee..81a513c3 100644 --- a/examples/diff/diff_test.py +++ b/examples/diff/diff_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests for the diff and difffull modules.""" + import tempfile -from fire.examples.diff import diff -from fire.examples.diff import difffull +from fire import testutils -import unittest +from examples.diff import diff +from examples.diff import difffull -class DiffTest(unittest.TestCase): +class DiffTest(testutils.BaseTestCase): """The purpose of these tests is to ensure the difflib wrappers works. It is not the goal of these tests to exhaustively test difflib functionality. @@ -30,8 +32,8 @@ def setUp(self): self.file1 = file1 = tempfile.NamedTemporaryFile() self.file2 = file2 = tempfile.NamedTemporaryFile() - file1.write('test\ntest1\n') - file2.write('test\ntest2\nextraline\n') + file1.write(b'test\ntest1\n') + file2.write(b'test\ntest2\nextraline\n') file1.flush() file2.flush() @@ -90,4 +92,4 @@ def testDiffFull(self): if __name__ == '__main__': - unittest.main() + testutils.main() diff --git a/examples/diff/difffull.py b/examples/diff/difffull.py index 38918dee..b765e3f2 100644 --- a/examples/diff/difffull.py +++ b/examples/diff/difffull.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/identity/__init__.py b/examples/identity/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/identity/identity.py b/examples/identity/identity.py index 1dec9032..45143883 100644 --- a/examples/identity/identity.py +++ b/examples/identity/identity.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/widget/__init__.py b/examples/widget/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/widget/collector.py b/examples/widget/collector.py index 8886bb9d..f37ecc7a 100644 --- a/examples/widget/collector.py +++ b/examples/widget/collector.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ import fire -from fire.examples.widget import widget +from examples.widget import widget class Collector(object): @@ -28,7 +28,7 @@ def __init__(self): def collect_widgets(self): """Returns all the widgets the Collector wants.""" - return [widget.Widget() for _ in xrange(self.desired_widget_count)] + return [widget.Widget() for _ in range(self.desired_widget_count)] def main(): diff --git a/examples/widget/collector_test.py b/examples/widget/collector_test.py index 1232eb00..274cf382 100644 --- a/examples/widget/collector_test.py +++ b/examples/widget/collector_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from fire.examples.widget import collector -from fire.examples.widget import widget +"""Tests for the collector module.""" -import unittest +from fire import testutils +from examples.widget import collector +from examples.widget import widget -class CollectorTest(unittest.TestCase): + +class CollectorTest(testutils.BaseTestCase): def testCollectorHasWidget(self): col = collector.Collector() @@ -26,12 +28,12 @@ def testCollectorHasWidget(self): def testCollectorWantsMoreWidgets(self): col = collector.Collector() - self.assertEquals(col.desired_widget_count, 10) + self.assertEqual(col.desired_widget_count, 10) def testCollectorGetsWantedWidgets(self): col = collector.Collector() - self.assertEquals(len(col.collect_widgets()), 10) + self.assertEqual(len(col.collect_widgets()), 10) if __name__ == '__main__': - unittest.main() + testutils.main() diff --git a/examples/widget/widget.py b/examples/widget/widget.py index 4fd28fc5..9092ad75 100644 --- a/examples/widget/widget.py +++ b/examples/widget/widget.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,11 +21,11 @@ class Widget(object): def whack(self, n=1): """Prints "whack!" n times.""" - return ' '.join('whack!' for _ in xrange(n)) + return ' '.join('whack!' for _ in range(n)) def bang(self, noise='bang'): """Makes a loud noise.""" - return '{noise} bang!'.format(noise=noise) + return f'{noise} bang!' def main(): diff --git a/examples/widget/widget_test.py b/examples/widget/widget_test.py index d92f3993..a5cd7188 100644 --- a/examples/widget/widget_test.py +++ b/examples/widget/widget_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from fire.examples.widget import widget +"""Tests for the widget module.""" -import unittest +from fire import testutils +from examples.widget import widget -class WidgetTest(unittest.TestCase): + +class WidgetTest(testutils.BaseTestCase): def testWidgetWhack(self): toy = widget.Widget() @@ -31,4 +33,4 @@ def testWidgetBang(self): if __name__ == '__main__': - unittest.main() + testutils.main() diff --git a/fire/__init__.py b/fire/__init__.py index 39b2f8bb..9ff696d3 100644 --- a/fire/__init__.py +++ b/fire/__init__.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Python Fire module for third_party.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +"""The Python Fire module.""" from fire.core import Fire __all__ = ['Fire'] +__version__ = '0.7.0' diff --git a/fire/__main__.py b/fire/__main__.py new file mode 100644 index 00000000..140b4a76 --- /dev/null +++ b/fire/__main__.py @@ -0,0 +1,126 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=invalid-name +"""Enables use of Python Fire as a "main" function (i.e. "python -m fire"). + +This allows using Fire with third-party libraries without modifying their code. +""" + +import importlib +from importlib import util +import os +import sys + +import fire + +cli_string = """usage: python -m fire [module] [arg] ..." + +Python Fire is a library for creating CLIs from absolutely any Python +object or program. To run Python Fire from the command line on an +existing Python file, it can be invoked with "python -m fire [module]" +and passed a Python module using module notation: + +"python -m fire packageA.packageB.module" + +or with a file path: + +"python -m fire packageA/packageB/module.py" """ + + +def import_from_file_path(path): + """Performs a module import given the filename. + + Args: + path (str): the path to the file to be imported. + + Raises: + IOError: if the given file does not exist or importlib fails to load it. + + Returns: + Tuple[ModuleType, str]: returns the imported module and the module name, + usually extracted from the path itself. + """ + + if not os.path.exists(path): + raise OSError('Given file path does not exist.') + + module_name = os.path.basename(path) + + spec = util.spec_from_file_location(module_name, path) + + if spec is None: + raise OSError('Unable to load module from specified path.') + + module = util.module_from_spec(spec) # pylint: disable=no-member + spec.loader.exec_module(module) # pytype: disable=attribute-error + + return module, module_name + + +def import_from_module_name(module_name): + """Imports a module and returns it and its name.""" + module = importlib.import_module(module_name) + return module, module_name + + +def import_module(module_or_filename): + """Imports a given module or filename. + + If the module_or_filename exists in the file system and ends with .py, we + attempt to import it. If that import fails, try to import it as a module. + + Args: + module_or_filename (str): string name of path or module. + + Raises: + ValueError: if the given file is invalid. + IOError: if the file or module can not be found or imported. + + Returns: + Tuple[ModuleType, str]: returns the imported module and the module name, + usually extracted from the path itself. + """ + + if os.path.exists(module_or_filename): + # importlib.util.spec_from_file_location requires .py + if not module_or_filename.endswith('.py'): + try: # try as module instead + return import_from_module_name(module_or_filename) + except ImportError: + raise ValueError('Fire can only be called on .py files.') + + return import_from_file_path(module_or_filename) + + if os.path.sep in module_or_filename: # Use / to detect if it was a filename. + raise OSError('Fire was passed a filename which could not be found.') + + return import_from_module_name(module_or_filename) # Assume it's a module. + + +def main(args): + """Entrypoint for fire when invoked as a module with python -m fire.""" + + if len(args) < 2: + print(cli_string) + sys.exit(1) + + module_or_filename = args[1] + module, module_name = import_module(module_or_filename) + + fire.Fire(module, name=module_name, command=args[2:]) + + +if __name__ == '__main__': + main(sys.argv) diff --git a/fire/completion.py b/fire/completion.py index 07815e08..1597d464 100644 --- a/fire/completion.py +++ b/fire/completion.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,19 +18,20 @@ from __future__ import division from __future__ import print_function -from collections import defaultdict -from copy import copy +import collections +import copy import inspect from fire import inspectutils -import six -def Script(name, component, default_options=None): - return _Script(name, _Commands(component), default_options) +def Script(name, component, default_options=None, shell='bash'): + if shell == 'fish': + return _FishScript(name, _Commands(component), default_options) + return _BashScript(name, _Commands(component), default_options) -def _Script(name, commands, default_options=None): +def _BashScript(name, commands, default_options=None): """Returns a Bash script registering a completion function for the commands. Args: @@ -45,12 +46,9 @@ def _Script(name, commands, default_options=None): completion in Bash. """ default_options = default_options or set() - options_map = defaultdict(lambda: copy(default_options)) - for command in commands: - start = (name + ' ' + ' '.join(command[:-1])).strip() - completion = _FormatForCommand(command[-1]) - options_map[start].add(completion) - options_map[start.replace('_', '-')].add(completion) + global_options, options_map, subcommands_map = _GetMaps( + name, commands, default_options + ) bash_completion_template = """# bash completion support for {name} # DO NOT EDIT. @@ -58,54 +56,289 @@ def _Script(name, commands, default_options=None): _complete-{identifier}() {{ - local start cur opts + local cur prev opts lastcommand COMPREPLY=() - start="${{COMP_WORDS[@]:0:COMP_CWORD}}" + prev="${{COMP_WORDS[COMP_CWORD-1]}}" cur="${{COMP_WORDS[COMP_CWORD]}}" + lastcommand=$(get_lastcommand) opts="{default_options}" + GLOBAL_OPTIONS="{global_options}" -{start_checks} +{checks} COMPREPLY=( $(compgen -W "${{opts}}" -- ${{cur}}) ) return 0 }} +get_lastcommand() +{{ + local lastcommand i + + lastcommand= + for ((i=0; i < ${{#COMP_WORDS[@]}}; ++i)); do + if [[ ${{COMP_WORDS[i]}} != -* ]] && [[ -n ${{COMP_WORDS[i]}} ]] && [[ + ${{COMP_WORDS[i]}} != $cur ]]; then + lastcommand=${{COMP_WORDS[i]}} + fi + done + + echo $lastcommand +}} + +filter_options() +{{ + local opts + opts="" + for opt in "$@" + do + if ! option_already_entered $opt; then + opts="$opts $opt" + fi + done + + echo $opts +}} + +option_already_entered() +{{ + local opt + for opt in ${{COMP_WORDS[@]:0:$COMP_CWORD}} + do + if [ $1 == $opt ]; then + return 0 + fi + done + return 1 +}} + +is_prev_global() +{{ + local opt + for opt in $GLOBAL_OPTIONS + do + if [ $opt == $prev ]; then + return 0 + fi + done + return 1 +}} + complete -F _complete-{identifier} {command} """ - start_check_template = """ - if [[ "$start" == "{start}" ]] ; then - opts="{completions}" - fi""" - - start_checks = '\n'.join( - start_check_template.format( - start=start, - completions=' '.join(sorted(options_map[start])) - ) - for start in options_map + + check_wrapper = """ + case "${{lastcommand}}" in + {lastcommand_checks} + esac""" + + lastcommand_check_template = """ + {command}) + {opts_assignment} + opts=$(filter_options $opts) + ;;""" + + opts_assignment_subcommand_template = """ + if is_prev_global; then + opts="${{GLOBAL_OPTIONS}}" + else + opts="{options} ${{GLOBAL_OPTIONS}}" + fi""" + + opts_assignment_main_command_template = """ + opts="{options} ${{GLOBAL_OPTIONS}}" """ + + def _GetOptsAssignmentTemplate(command): + if command == name: + return opts_assignment_main_command_template + else: + return opts_assignment_subcommand_template + + lines = [] + commands_set = set() + commands_set.add(name) + commands_set = commands_set.union(set(subcommands_map.keys())) + commands_set = commands_set.union(set(options_map.keys())) + for command in commands_set: + opts_assignment = _GetOptsAssignmentTemplate(command).format( + options=' '.join( + sorted(options_map[command].union(subcommands_map[command])) + ), + ) + lines.append( + lastcommand_check_template.format( + command=command, + opts_assignment=opts_assignment) + ) + lastcommand_checks = '\n'.join(lines) + + checks = check_wrapper.format( + lastcommand_checks=lastcommand_checks, ) return ( bash_completion_template.format( name=name, command=name, - start_checks=start_checks, + checks=checks, default_options=' '.join(default_options), - identifier=name.replace('/', '').replace('.', '').replace(',', '') + identifier=name.replace('/', '').replace('.', '').replace(',', ''), + global_options=' '.join(global_options), + ) + ) + + +def _FishScript(name, commands, default_options=None): + """Returns a Fish script registering a completion function for the commands. + + Args: + name: The first token in the commands, also the name of the command. + commands: A list of all possible commands that tab completion can complete + to. Each command is a list or tuple of the string tokens that make up + that command. + default_options: A dict of options that can be used with any command. Use + this if there are flags that can always be appended to a command. + Returns: + A string which is the Fish script. Source the fish script to enable tab + completion in Fish. + """ + default_options = default_options or set() + global_options, options_map, subcommands_map = _GetMaps( + name, commands, default_options + ) + + fish_source = """function __fish_using_command + set cmd (commandline -opc) + for i in (seq (count $cmd) 1) + switch $cmd[$i] + case "-*" + case "*" + if [ $cmd[$i] = $argv[1] ] + return 0 + else + return 1 + end + end + end + return 1 +end + +function __option_entered_check + set cmd (commandline -opc) + for i in (seq (count $cmd)) + switch $cmd[$i] + case "-*" + if [ $cmd[$i] = $argv[1] ] + return 1 + end + end + end + return 0 +end + +function __is_prev_global + set cmd (commandline -opc) + set global_options {global_options} + set prev (count $cmd) + + for opt in $global_options + if [ "--$opt" = $cmd[$prev] ] + echo $prev + return 0 + end + end + return 1 +end + +""" + + subcommand_template = ("complete -c {name} -n '__fish_using_command " + "{command}' -f -a {subcommand}\n") + flag_template = ("complete -c {name} -n " + "'__fish_using_command {command};{prev_global_check} and " + "__option_entered_check --{option}' -l {option}\n") + + prev_global_check = ' and __is_prev_global;' + for command in set(subcommands_map.keys()).union(set(options_map.keys())): + for subcommand in subcommands_map[command]: + fish_source += subcommand_template.format( + name=name, + command=command, + subcommand=subcommand, ) + + for option in options_map[command].union(global_options): + check_needed = command != name + fish_source += flag_template.format( + name=name, + command=command, + prev_global_check=prev_global_check if check_needed else '', + option=option.lstrip('--'), + ) + + return fish_source.format( + global_options=' '.join(f'"{option}"' for option in global_options) ) -def _IncludeMember(name, verbose): +def MemberVisible(component, name, member, class_attrs=None, verbose=False): + """Returns whether a member should be included in auto-completion or help. + + Determines whether a member of an object with the specified name should be + included in auto-completion or help text(both usage and detailed help). + + If the member name starts with '__', it will always be excluded. If it + starts with only one '_', it will be included for all non-string types. If + verbose is True, the members, including the private members, are included. + + When not in verbose mode, some modules and functions are excluded as well. + + Args: + component: The component containing the member. + name: The name of the member. + member: The member itself. + class_attrs: (optional) If component is a class, provide this as: + inspectutils.GetClassAttrsDict(component). If not provided, it will be + computed. + verbose: Whether to include private members. + Returns + A boolean value indicating whether the member should be included. + """ + if isinstance(name, str) and name.startswith('__'): + return False if verbose: return True - if isinstance(name, six.string_types): - return name and name[0] != '_' + if (member is absolute_import + or member is division + or member is print_function): + return False + if isinstance(member, type(absolute_import)): + return False + # TODO(dbieber): Determine more generally which modules to hide. + modules_to_hide = [] + if inspect.ismodule(member) and member in modules_to_hide: + return False + if inspect.isclass(component): + # If class_attrs has not been provided, compute it. + if class_attrs is None: + class_attrs = inspectutils.GetClassAttrsDict(component) or {} + class_attr = class_attrs.get(name) + if class_attr: + # Methods and properties should only be accessible on instantiated + # objects, not on uninstantiated classes. + if class_attr.kind in ('method', 'property'): + return False + # Backward compatibility notes: Before Python 3.8, namedtuple attributes + # were properties. In Python 3.8, they have type tuplegetter. + tuplegetter = getattr(collections, '_tuplegetter', type(None)) + if isinstance(class_attr.object, tuplegetter): + return False + if isinstance(name, str): + return not name.startswith('_') return True # Default to including the member -def _Members(component, verbose=False): +def VisibleMembers(component, class_attrs=None, verbose=False): """Returns a list of the members of the given component. If verbose is True, then members starting with _ (normally ignored) are @@ -113,6 +346,12 @@ def _Members(component, verbose=False): Args: component: The component whose members to list. + class_attrs: (optional) If component is a class, you may provide this as: + inspectutils.GetClassAttrsDict(component). If not provided, it will be + computed. If provided, this determines how class members will be treated + for visibility. In particular, methods are generally hidden for + non-instantiated classes, but if you wish them to be shown (e.g. for + completion scripts) then pass in a different class_attr for them. verbose: Whether to include private members. Returns: A list of tuples (member_name, member) of all members of the component. @@ -122,10 +361,13 @@ def _Members(component, verbose=False): else: members = inspect.getmembers(component) + # If class_attrs has not been provided, compute it. + if class_attrs is None: + class_attrs = inspectutils.GetClassAttrsDict(component) return [ - (member_name, member) - for member_name, member in members - if _IncludeMember(member_name, verbose) + (member_name, member) for member_name, member in members + if MemberVisible(component, member_name, member, class_attrs=class_attrs, + verbose=verbose) ] @@ -140,7 +382,7 @@ def _CompletionsFromArgs(fn_args): completions = [] for arg in fn_args: arg = arg.replace('_', '-') - completions.append('--{arg}'.format(arg=arg)) + completions.append(f'--{arg}') return completions @@ -157,21 +399,20 @@ def Completions(component, verbose=False): A list of completions for a command that would so far return the component. """ if inspect.isroutine(component) or inspect.isclass(component): - fn_args = inspectutils.GetArgSpec(component).args - return _CompletionsFromArgs(fn_args) + spec = inspectutils.GetFullArgSpec(component) + return _CompletionsFromArgs(spec.args + spec.kwonlyargs) - elif isinstance(component, (tuple, list)): + if isinstance(component, (tuple, list)): return [str(index) for index in range(len(component))] - elif inspect.isgenerator(component): - # TODO: There are currently no commands available for generators. + if inspect.isgenerator(component): + # TODO(dbieber): There are currently no commands available for generators. return [] - else: - return [ - _FormatForCommand(member_name) - for member_name, unused_member in _Members(component, verbose) - ] + return [ + _FormatForCommand(member_name) + for member_name, _ in VisibleMembers(component, verbose=verbose) + ] def _FormatForCommand(token): @@ -187,13 +428,13 @@ def _FormatForCommand(token): Returns: The transformed token. """ - if not isinstance(token, six.string_types): + if not isinstance(token, str): token = str(token) if token.startswith('_'): return token - else: - return token.replace('_', '-') + + return token.replace('_', '-') def _Commands(component, depth=3): @@ -210,20 +451,68 @@ def _Commands(component, depth=3): Tuples, each tuple representing one possible command for this CLI. Only traverses the member DAG up to a depth of depth. """ + if inspect.isroutine(component) or inspect.isclass(component): + for completion in Completions(component, verbose=False): + yield (completion,) + if inspect.isroutine(component): + return # Don't descend into routines. + if depth < 1: return - for member_name, member in _Members(component): - # TODO: Also skip components we've already seen. + # By setting class_attrs={} we don't hide methods in completion. + for member_name, member in VisibleMembers(component, class_attrs={}, + verbose=False): + # TODO(dbieber): Also skip components we've already seen. member_name = _FormatForCommand(member_name) yield (member_name,) - if inspect.isroutine(member) or inspect.isclass(member): - for completion in Completions(member): - yield (member_name, completion) - continue # Don't descend into routines. - for command in _Commands(member, depth - 1): yield (member_name,) + command + +def _IsOption(arg): + return arg.startswith('-') + + +def _GetMaps(name, commands, default_options): + """Returns sets of subcommands and options for each command. + + Args: + name: The first token in the commands, also the name of the command. + commands: A list of all possible commands that tab completion can complete + to. Each command is a list or tuple of the string tokens that make up + that command. + default_options: A dict of options that can be used with any command. Use + this if there are flags that can always be appended to a command. + Returns: + global_options: A set of all options of the first token of the command. + subcommands_map: A dict storing set of subcommands for each + command/subcommand. + options_map: A dict storing set of options for each subcommand. + """ + global_options = copy.copy(default_options) + options_map = collections.defaultdict(lambda: copy.copy(default_options)) + subcommands_map = collections.defaultdict(set) + + for command in commands: + if len(command) == 1: + if _IsOption(command[0]): + global_options.add(command[0]) + else: + subcommands_map[name].add(command[0]) + + elif command: + subcommand = command[-2] + arg = _FormatForCommand(command[-1]) + + if _IsOption(arg): + args_map = options_map + else: + args_map = subcommands_map + + args_map[subcommand].add(arg) + args_map[subcommand.replace('_', '-')].add(arg) + + return global_options, options_map, subcommands_map diff --git a/fire/completion_test.py b/fire/completion_test.py index 129b1470..c0d5d24f 100644 --- a/fire/completion_test.py +++ b/fire/completion_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,32 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - +"""Tests for the completion module.""" from fire import completion from fire import test_components as tc +from fire import testutils -import unittest +class TabCompletionTest(testutils.BaseTestCase): -class TabCompletionTest(unittest.TestCase): + def testCompletionBashScript(self): + # A sanity check test to make sure the bash completion script satisfies + # some basic assumptions. + commands = [ + ['run'], + ['halt'], + ['halt', '--now'], + ] + script = completion._BashScript(name='command', commands=commands) # pylint: disable=protected-access + self.assertIn('command', script) + self.assertIn('halt', script) - def testCompletionScript(self): - # A sanity check test to make sure the completion script satisfies some - # basic assumptions. + for last_command in ['command', 'halt']: + self.assertIn(f'{last_command})', script) + + def testCompletionFishScript(self): + # A sanity check test to make sure the fish completion script satisfies + # some basic assumptions. commands = [ ['run'], ['halt'], ['halt', '--now'], ] - script = completion._Script(name='command', commands=commands) + script = completion._FishScript(name='command', commands=commands) # pylint: disable=protected-access self.assertIn('command', script) self.assertIn('halt', script) - self.assertIn('"$start" == "command"', script) + self.assertIn('-l now', script) def testFnCompletions(self): def example(one, two, three): @@ -99,6 +109,44 @@ def testDeepDictScript(self): self.assertIn('level3', script) self.assertNotIn('level4', script) # The default depth is 3. + def testFnScript(self): + script = completion.Script('identity', tc.identity) + self.assertIn('--arg1', script) + self.assertIn('--arg2', script) + self.assertIn('--arg3', script) + self.assertIn('--arg4', script) + + def testClassScript(self): + script = completion.Script('', tc.MixedDefaults) + self.assertIn('ten', script) + self.assertIn('sum', script) + self.assertIn('identity', script) + self.assertIn('--alpha', script) + self.assertIn('--beta', script) + + def testDeepDictFishScript(self): + deepdict = {'level1': {'level2': {'level3': {'level4': {}}}}} + script = completion.Script('deepdict', deepdict, shell='fish') + self.assertIn('level1', script) + self.assertIn('level2', script) + self.assertIn('level3', script) + self.assertNotIn('level4', script) # The default depth is 3. + + def testFnFishScript(self): + script = completion.Script('identity', tc.identity, shell='fish') + self.assertIn('arg1', script) + self.assertIn('arg2', script) + self.assertIn('arg3', script) + self.assertIn('arg4', script) + + def testClassFishScript(self): + script = completion.Script('', tc.MixedDefaults, shell='fish') + self.assertIn('ten', script) + self.assertIn('sum', script) + self.assertIn('identity', script) + self.assertIn('alpha', script) + self.assertIn('beta', script) + def testNonStringDictCompletions(self): completions = completion.Completions({ 10: 'green', @@ -137,4 +185,4 @@ def testMethodCompletions(self): if __name__ == '__main__': - unittest.main() + testutils.main() diff --git a/fire/console/README.md b/fire/console/README.md new file mode 100644 index 00000000..b23401f7 --- /dev/null +++ b/fire/console/README.md @@ -0,0 +1,3 @@ +This is the console package from googlecloudsdk, as used by Python Fire. +Python Fire does not accept pull requests modifying the console package; rather, +changes to console should go through the upstream project googlecloudsdk. diff --git a/fire/console/__init__.py b/fire/console/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fire/console/console_attr.py b/fire/console/console_attr.py new file mode 100644 index 00000000..c0a3d784 --- /dev/null +++ b/fire/console/console_attr.py @@ -0,0 +1,811 @@ +# -*- coding: utf-8 -*- # + +# Copyright 2015 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""A module for console attributes, special characters and functions. + +The target architectures {linux, macos, windows} support inline encoding for +all attributes except color. Windows requires win32 calls to manipulate the +console color state. + +Usage: + + # Get the console attribute state. + out = log.out + con = console_attr.GetConsoleAttr(out=out) + + # Get the ISO 8879:1986//ENTITIES Box and Line Drawing characters. + box = con.GetBoxLineCharacters() + # Print an X inside a box. + out.write(box.dr) + out.write(box.h) + out.write(box.dl) + out.write('\n') + out.write(box.v) + out.write('X') + out.write(box.v) + out.write('\n') + out.write(box.ur) + out.write(box.h) + out.write(box.ul) + out.write('\n') + + # Print the bullet characters. + for c in con.GetBullets(): + out.write(c) + out.write('\n') + + # Print FAIL in red. + out.write('Epic ') + con.Colorize('FAIL', 'red') + out.write(', my first.') + + # Print italic and bold text. + bold = con.GetFontCode(bold=True) + italic = con.GetFontCode(italic=True) + normal = con.GetFontCode() + out.write('This is {bold}bold{normal}, this is {italic}italic{normal},' + ' and this is normal.\n'.format(bold=bold, italic=italic, + normal=normal)) + + # Read one character from stdin with echo disabled. + c = con.GetRawKey() + if c is None: + print 'EOF\n' + + # Return the display width of a string that may contain FontCode() chars. + display_width = con.DisplayWidth(string) + + # Reset the memoized state. + con = console_attr.ResetConsoleAttr() + + # Print the console width and height in characters. + width, height = con.GetTermSize() + print 'width={width}, height={height}'.format(width=width, height=height) + + # Colorize table data cells. + fail = console_attr.Colorizer('FAIL', 'red') + pass = console_attr.Colorizer('PASS', 'green') + cells = ['label', fail, 'more text', pass, 'end'] + for cell in cells; + if isinstance(cell, console_attr.Colorizer): + cell.Render() + else: + out.write(cell) +""" + + +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals + +import os +import sys +import unicodedata + +# from fire.console import properties +from fire.console import console_attr_os +from fire.console import encoding as encoding_util +from fire.console import text + + +# TODO: Unify this logic with console.style.mappings +class BoxLineCharacters(object): + """Box/line drawing characters. + + The element names are from ISO 8879:1986//ENTITIES Box and Line Drawing//EN: + http://www.w3.org/2003/entities/iso8879doc/isobox.html + """ + + +class BoxLineCharactersUnicode(BoxLineCharacters): + """unicode Box/line drawing characters (cp437 compatible unicode).""" + dl = '┐' + dr = '┌' + h = '─' + hd = '┬' + hu = '┴' + ul = '┘' + ur = '└' + v = '│' + vh = '┼' + vl = '┤' + vr = '├' + d_dl = '╗' + d_dr = '╔' + d_h = '═' + d_hd = '╦' + d_hu = '╩' + d_ul = '╝' + d_ur = '╚' + d_v = '║' + d_vh = '╬' + d_vl = '╣' + d_vr = '╠' + + +class BoxLineCharactersAscii(BoxLineCharacters): + """ASCII Box/line drawing characters.""" + dl = '+' + dr = '+' + h = '-' + hd = '+' + hu = '+' + ul = '+' + ur = '+' + v = '|' + vh = '+' + vl = '+' + vr = '+' + d_dl = '#' + d_dr = '#' + d_h = '=' + d_hd = '#' + d_hu = '#' + d_ul = '#' + d_ur = '#' + d_v = '#' + d_vh = '#' + d_vl = '#' + d_vr = '#' + + +class BoxLineCharactersScreenReader(BoxLineCharactersAscii): + dl = ' ' + dr = ' ' + hd = ' ' + hu = ' ' + ul = ' ' + ur = ' ' + vh = ' ' + vl = ' ' + vr = ' ' + + +class ProgressTrackerSymbols(object): + """Characters used by progress trackers.""" + + +class ProgressTrackerSymbolsUnicode(ProgressTrackerSymbols): + """Characters used by progress trackers.""" + + @property + def spin_marks(self): + return ['⠏', '⠛', '⠹', '⠼', '⠶', '⠧'] + + success = text.TypedText(['✓'], text_type=text.TextTypes.PT_SUCCESS) + failed = text.TypedText(['X'], text_type=text.TextTypes.PT_FAILURE) + interrupted = '-' + not_started = '.' + prefix_length = 2 + + +class ProgressTrackerSymbolsAscii(ProgressTrackerSymbols): + """Characters used by progress trackers.""" + + @property + def spin_marks(self): + return ['|', '/', '-', '\\',] + + success = 'OK' + failed = 'X' + interrupted = '-' + not_started = '.' + prefix_length = 3 + + +class ConsoleAttr(object): + """Console attribute and special drawing characters and functions accessor. + + Use GetConsoleAttr() to get a global ConsoleAttr object shared by all callers. + Use ConsoleAttr() for abstracting multiple consoles. + + If _out is not associated with a console, or if the console properties cannot + be determined, the default behavior is ASCII art with no attributes. + + Attributes: + _ANSI_COLOR: The ANSI color control sequence dict. + _ANSI_COLOR_RESET: The ANSI color reset control sequence string. + _csi: The ANSI Control Sequence indicator string, '' if not supported. + _encoding: The character encoding. + ascii: ASCII art. This is the default. + utf8: UTF-8 unicode. + win: Windows code page 437. + _font_bold: The ANSI bold font embellishment code string. + _font_italic: The ANSI italic font embellishment code string. + _get_raw_key: A function that reads one keypress from stdin with no echo. + _out: The console output file stream. + _term: TERM environment variable value. + _term_size: The terminal (x, y) dimensions in characters. + """ + + _CONSOLE_ATTR_STATE = None + + _ANSI_COLOR = { + 'red': '31;1m', + 'yellow': '33;1m', + 'green': '32m', + 'blue': '34;1m' + } + _ANSI_COLOR_RESET = '39;0m' + + _BULLETS_UNICODE = ('▪', '◆', '▸', '▫', '◇', '▹') + _BULLETS_WINDOWS = ('■', '≡', '∞', 'Φ', '·') # cp437 compatible unicode + _BULLETS_ASCII = ('o', '*', '+', '-') + + def __init__(self, encoding=None, suppress_output=False): + """Constructor. + + Args: + encoding: Encoding override. + ascii -- ASCII art. This is the default. + utf8 -- UTF-8 unicode. + win -- Windows code page 437. + suppress_output: True to create a ConsoleAttr that doesn't want to output + anything. + """ + # Normalize the encoding name. + if not encoding: + encoding = self._GetConsoleEncoding() + elif encoding == 'win': + encoding = 'cp437' + self._encoding = encoding or 'ascii' + self._term = '' if suppress_output else os.getenv('TERM', '').lower() + + # ANSI "standard" attributes. + if self.SupportsAnsi(): + # Select Graphic Rendition parameters from + # http://en.wikipedia.org/wiki/ANSI_escape_code#graphics + # Italic '3' would be nice here but its not widely supported. + self._csi = '\x1b[' + self._font_bold = '1' + self._font_italic = '4' + else: + self._csi = None + self._font_bold = '' + self._font_italic = '' + + # Encoded character attributes. + is_screen_reader = False + if self._encoding == 'utf8' and not is_screen_reader: + self._box_line_characters = BoxLineCharactersUnicode() + self._bullets = self._BULLETS_UNICODE + self._progress_tracker_symbols = ProgressTrackerSymbolsUnicode() + elif self._encoding == 'cp437' and not is_screen_reader: + self._box_line_characters = BoxLineCharactersUnicode() + self._bullets = self._BULLETS_WINDOWS + # Windows does not support the unicode characters used for the spinner. + self._progress_tracker_symbols = ProgressTrackerSymbolsAscii() + else: + self._box_line_characters = BoxLineCharactersAscii() + if is_screen_reader: + self._box_line_characters = BoxLineCharactersScreenReader() + self._bullets = self._BULLETS_ASCII + self._progress_tracker_symbols = ProgressTrackerSymbolsAscii() + + # OS specific attributes. + self._get_raw_key = [console_attr_os.GetRawKeyFunction()] + self._term_size = ( + (0, 0) if suppress_output else console_attr_os.GetTermSize()) + + self._display_width_cache = {} + + def _GetConsoleEncoding(self): + """Gets the encoding as declared by the stdout stream. + + Returns: + str, The encoding name or None if it could not be determined. + """ + console_encoding = getattr(sys.stdout, 'encoding', None) + if not console_encoding: + return None + console_encoding = console_encoding.lower() + if 'utf-8' in console_encoding: + return 'utf8' + elif 'cp437' in console_encoding: + return 'cp437' + return None + + def Colorize(self, string, color, justify=None): + """Generates a colorized string, optionally justified. + + Args: + string: The string to write. + color: The color name -- must be in _ANSI_COLOR. + justify: The justification function, no justification if None. For + example, justify=lambda s: s.center(10) + + Returns: + str, The colorized string that can be printed to the console. + """ + if justify: + string = justify(string) + if self._csi and color in self._ANSI_COLOR: + return '{csi}{color_code}{string}{csi}{reset_code}'.format( + csi=self._csi, + color_code=self._ANSI_COLOR[color], + reset_code=self._ANSI_COLOR_RESET, + string=string) + # TODO: Add elif self._encoding == 'cp437': code here. + return string + + def ConvertOutputToUnicode(self, buf): + """Converts a console output string buf to unicode. + + Mainly used for testing. Allows test comparisons in unicode while ensuring + that unicode => encoding => unicode works. + + Args: + buf: The console output string to convert. + + Returns: + The console output string buf converted to unicode. + """ + if isinstance(buf, str): + buf = buf.encode(self._encoding) + return str(buf, self._encoding, 'replace') + + def GetBoxLineCharacters(self): + """Returns the box/line drawing characters object. + + The element names are from ISO 8879:1986//ENTITIES Box and Line Drawing//EN: + http://www.w3.org/2003/entities/iso8879doc/isobox.html + + Returns: + A BoxLineCharacters object for the console output device. + """ + return self._box_line_characters + + def GetBullets(self): + """Returns the bullet characters list. + + Use the list elements in order for best appearance in nested bullet lists, + wrapping back to the first element for deep nesting. The list size depends + on the console implementation. + + Returns: + A tuple of bullet characters. + """ + return self._bullets + + def GetProgressTrackerSymbols(self): + """Returns the progress tracker characters object. + + Returns: + A ProgressTrackerSymbols object for the console output device. + """ + return self._progress_tracker_symbols + + def GetControlSequenceIndicator(self): + """Returns the control sequence indicator string. + + Returns: + The control sequence indicator string or None if control sequences are not + supported. + """ + return self._csi + + def GetControlSequenceLen(self, buf): + """Returns the control sequence length at the beginning of buf. + + Used in display width computations. Control sequences have display width 0. + + Args: + buf: The string to check for a control sequence. + + Returns: + The control sequence length at the beginning of buf or 0 if buf does not + start with a control sequence. + """ + if not self._csi or not buf.startswith(self._csi): + return 0 + n = 0 + for c in buf: + n += 1 + if c.isalpha(): + break + return n + + def GetEncoding(self): + """Returns the current encoding.""" + return self._encoding + + def GetFontCode(self, bold=False, italic=False): + """Returns a font code string for 0 or more embellishments. + + GetFontCode() with no args returns the default font code string. + + Args: + bold: True for bold embellishment. + italic: True for italic embellishment. + + Returns: + The font code string for the requested embellishments. Write this string + to the console output to control the font settings. + """ + if not self._csi: + return '' + codes = [] + if bold: + codes.append(self._font_bold) + if italic: + codes.append(self._font_italic) + return '{csi}{codes}m'.format(csi=self._csi, codes=';'.join(codes)) + + def GetRawKey(self): + """Reads one key press from stdin with no echo. + + Returns: + The key name, None for EOF, for function keys, otherwise a + character. + """ + return self._get_raw_key[0]() + + def GetTermIdentifier(self): + """Returns the TERM environment variable for the console. + + Returns: + str: A str that describes the console's text capabilities + """ + return self._term + + def GetTermSize(self): + """Returns the terminal (x, y) dimensions in characters. + + Returns: + (x, y): A tuple of the terminal x and y dimensions. + """ + return self._term_size + + def DisplayWidth(self, buf): + """Returns the display width of buf, handling unicode and ANSI controls. + + Args: + buf: The string to count from. + + Returns: + The display width of buf, handling unicode and ANSI controls. + """ + if not isinstance(buf, str): + # Handle non-string objects like Colorizer(). + return len(buf) + + cached = self._display_width_cache.get(buf, None) + if cached is not None: + return cached + + width = 0 + max_width = 0 + i = 0 + while i < len(buf): + if self._csi and buf[i:].startswith(self._csi): + i += self.GetControlSequenceLen(buf[i:]) + elif buf[i] == '\n': + # A newline incidates the start of a new line. + # Newline characters have 0 width. + max_width = max(width, max_width) + width = 0 + i += 1 + else: + width += GetCharacterDisplayWidth(buf[i]) + i += 1 + max_width = max(width, max_width) + + self._display_width_cache[buf] = max_width + return max_width + + def SplitIntoNormalAndControl(self, buf): + """Returns a list of (normal_string, control_sequence) tuples from buf. + + Args: + buf: The input string containing one or more control sequences + interspersed with normal strings. + + Returns: + A list of (normal_string, control_sequence) tuples. + """ + if not self._csi or not buf: + return [(buf, '')] + seq = [] + i = 0 + while i < len(buf): + c = buf.find(self._csi, i) + if c < 0: + seq.append((buf[i:], '')) + break + normal = buf[i:c] + i = c + self.GetControlSequenceLen(buf[c:]) + seq.append((normal, buf[c:i])) + return seq + + def SplitLine(self, line, width): + """Splits line into width length chunks. + + Args: + line: The line to split. + width: The width of each chunk except the last which could be smaller than + width. + + Returns: + A list of chunks, all but the last with display width == width. + """ + lines = [] + chunk = '' + w = 0 + keep = False + for normal, control in self.SplitIntoNormalAndControl(line): + keep = True + while True: + n = width - w + w += len(normal) + if w <= width: + break + lines.append(chunk + normal[:n]) + chunk = '' + keep = False + w = 0 + normal = normal[n:] + chunk += normal + control + if chunk or keep: + lines.append(chunk) + return lines + + def SupportsAnsi(self): + return (self._encoding != 'ascii' and + ('screen' in self._term or 'xterm' in self._term)) + + +class Colorizer(object): + """Resource string colorizer. + + Attributes: + _con: ConsoleAttr object. + _color: Color name. + _string: The string to colorize. + _justify: The justification function, no justification if None. For example, + justify=lambda s: s.center(10) + """ + + def __init__(self, string, color, justify=None): + """Constructor. + + Args: + string: The string to colorize. + color: Color name used to index ConsoleAttr._ANSI_COLOR. + justify: The justification function, no justification if None. For + example, justify=lambda s: s.center(10) + """ + self._con = GetConsoleAttr() + self._color = color + self._string = string + self._justify = justify + + def __eq__(self, other): + return self._string == str(other) + + def __ne__(self, other): + return not self == other + + def __gt__(self, other): + return self._string > str(other) + + def __lt__(self, other): + return self._string < str(other) + + def __ge__(self, other): + return not self < other + + def __le__(self, other): + return not self > other + + def __len__(self): + return self._con.DisplayWidth(self._string) + + def __str__(self): + return self._string + + def Render(self, stream, justify=None): + """Renders the string as self._color on the console. + + Args: + stream: The stream to render the string to. The stream given here *must* + have the same encoding as sys.stdout for this to work properly. + justify: The justification function, self._justify if None. + """ + stream.write( + self._con.Colorize(self._string, self._color, justify or self._justify)) + + +def GetConsoleAttr(encoding=None, reset=False): + """Gets the console attribute state. + + If this is the first call or reset is True or encoding is not None and does + not match the current encoding or out is not None and does not match the + current out then the state is (re)initialized. Otherwise the current state + is returned. + + This call associates the out file stream with the console. All console related + output should go to the same stream. + + Args: + encoding: Encoding override. + ascii -- ASCII. This is the default. + utf8 -- UTF-8 unicode. + win -- Windows code page 437. + reset: Force re-initialization if True. + + Returns: + The global ConsoleAttr state object. + """ + attr = ConsoleAttr._CONSOLE_ATTR_STATE # pylint: disable=protected-access + if not reset: + if not attr: + reset = True + elif encoding and encoding != attr.GetEncoding(): + reset = True + if reset: + attr = ConsoleAttr(encoding=encoding) + ConsoleAttr._CONSOLE_ATTR_STATE = attr # pylint: disable=protected-access + return attr + + +def ResetConsoleAttr(encoding=None): + """Resets the console attribute state to the console default. + + Args: + encoding: Reset to this encoding instead of the default. + ascii -- ASCII. This is the default. + utf8 -- UTF-8 unicode. + win -- Windows code page 437. + + Returns: + The global ConsoleAttr state object. + """ + return GetConsoleAttr(encoding=encoding, reset=True) + + +def GetCharacterDisplayWidth(char): + """Returns the monospaced terminal display width of char. + + Assumptions: + - monospaced display + - ambiguous or unknown chars default to width 1 + - ASCII control char width is 1 => don't use this for control chars + + Args: + char: The character to determine the display width of. + + Returns: + The monospaced terminal display width of char: either 0, 1, or 2. + """ + if not isinstance(char, str): + # Non-unicode chars have width 1. Don't use this function on control chars. + return 1 + + # Normalize to avoid special cases. + char = unicodedata.normalize('NFC', char) + + if unicodedata.combining(char) != 0: + # Modifies the previous character and does not move the cursor. + return 0 + elif unicodedata.category(char) == 'Cf': + # Unprintable formatting char. + return 0 + elif unicodedata.east_asian_width(char) in 'FW': + # Fullwidth or Wide chars take 2 character positions. + return 2 + else: + # Don't use this function on control chars. + return 1 + + +def SafeText(data, encoding=None, escape=True): + br"""Converts the data to a text string compatible with the given encoding. + + This works the same way as Decode() below except it guarantees that any + characters in the resulting text string can be re-encoded using the given + encoding (or GetConsoleAttr().GetEncoding() if None is given). This means + that the string will be safe to print to sys.stdout (for example) without + getting codec exceptions if the user's terminal doesn't support the encoding + used by the source of the text. + + Args: + data: Any bytes, string, or object that has str() or unicode() methods. + encoding: The encoding name to ensure compatibility with. Defaults to + GetConsoleAttr().GetEncoding(). + escape: Replace unencodable characters with a \uXXXX or \xXX equivalent if + True. Otherwise replace unencodable characters with an appropriate unknown + character, '?' for ASCII, and the unicode unknown replacement character + \uFFFE for unicode. + + Returns: + A text string representation of the data, but modified to remove any + characters that would result in an encoding exception with the target + encoding. In the worst case, with escape=False, it will contain only ? + characters. + """ + if data is None: + return 'None' + encoding = encoding or GetConsoleAttr().GetEncoding() + string = encoding_util.Decode(data, encoding=encoding) + + try: + # No change needed if the string encodes to the output encoding. + string.encode(encoding) + return string + except UnicodeError: + # The string does not encode to the output encoding. Encode it with error + # handling then convert it back into a text string (which will be + # guaranteed to only contain characters that can be encoded later. + return (string + .encode(encoding, 'backslashreplace' if escape else 'replace') + .decode(encoding)) + + +def EncodeToBytes(data): + r"""Encode data to bytes. + + The primary use case is for base64/mime style 7-bit ascii encoding where the + encoder input must be bytes. "safe" means that the conversion always returns + bytes and will not raise codec exceptions. + + If data is text then an 8-bit ascii encoding is attempted, then the console + encoding, and finally utf-8. + + Args: + data: Any bytes, string, or object that has str() or unicode() methods. + + Returns: + A bytes string representation of the data. + """ + if data is None: + return b'' + if isinstance(data, bytes): + # Already bytes - our work is done. + return data + + # Coerce to text that will be converted to bytes. + s = str(data) + + try: + # Assume the text can be directly converted to bytes (8-bit ascii). + return s.encode('iso-8859-1') + except UnicodeEncodeError: + pass + + try: + # Try the output encoding. + return s.encode(GetConsoleAttr().GetEncoding()) + except UnicodeEncodeError: + pass + + # Punt to utf-8. + return s.encode('utf-8') + + +def Decode(data, encoding=None): + """Converts the given string, bytes, or object to a text string. + + Args: + data: Any bytes, string, or object that has str() or unicode() methods. + encoding: A suggesting encoding used to decode. If this encoding doesn't + work, other defaults are tried. Defaults to + GetConsoleAttr().GetEncoding(). + + Returns: + A text string representation of the data. + """ + encoding = encoding or GetConsoleAttr().GetEncoding() + return encoding_util.Decode(data, encoding=encoding) diff --git a/fire/console/console_attr_os.py b/fire/console/console_attr_os.py new file mode 100644 index 00000000..869c5949 --- /dev/null +++ b/fire/console/console_attr_os.py @@ -0,0 +1,260 @@ +# -*- coding: utf-8 -*- # +# Copyright 2015 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OS specific console_attr helper functions.""" +# This file contains platform specific code which is not currently handled +# by pytype. +# pytype: skip-file + +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals + +import os +import sys + +from fire.console import encoding + + +def GetTermSize(): + """Gets the terminal x and y dimensions in characters. + + _GetTermSize*() helper functions taken from: + http://stackoverflow.com/questions/263890/ + + Returns: + (columns, lines): A tuple containing the terminal x and y dimensions. + """ + xy = None + # Believe the first helper that doesn't bail. + for get_terminal_size in (_GetTermSizePosix, + _GetTermSizeWindows, + _GetTermSizeEnvironment, + _GetTermSizeTput): + try: + xy = get_terminal_size() + if xy: + break + except: # pylint: disable=bare-except + pass + return xy or (80, 24) + + +def _GetTermSizePosix(): + """Returns the Posix terminal x and y dimensions.""" + # pylint: disable=g-import-not-at-top + import fcntl + # pylint: disable=g-import-not-at-top + import struct + # pylint: disable=g-import-not-at-top + import termios + + def _GetXY(fd): + """Returns the terminal (x,y) size for fd. + + Args: + fd: The terminal file descriptor. + + Returns: + The terminal (x,y) size for fd or None on error. + """ + try: + # This magic incantation converts a struct from ioctl(2) containing two + # binary shorts to a (rows, columns) int tuple. + rc = struct.unpack(b'hh', fcntl.ioctl(fd, termios.TIOCGWINSZ, 'junk')) + return (rc[1], rc[0]) if rc else None + except: # pylint: disable=bare-except + return None + + xy = _GetXY(0) or _GetXY(1) or _GetXY(2) + if not xy: + fd = None + try: + fd = os.open(os.ctermid(), os.O_RDONLY) + xy = _GetXY(fd) + except: # pylint: disable=bare-except + xy = None + finally: + if fd is not None: + os.close(fd) + return xy + + +def _GetTermSizeWindows(): + """Returns the Windows terminal x and y dimensions.""" + # pylint:disable=g-import-not-at-top + import struct + # pylint: disable=g-import-not-at-top + from ctypes import create_string_buffer + # pylint:disable=g-import-not-at-top + from ctypes import windll + + # stdin handle is -10 + # stdout handle is -11 + # stderr handle is -12 + + h = windll.kernel32.GetStdHandle(-12) + csbi = create_string_buffer(22) + if not windll.kernel32.GetConsoleScreenBufferInfo(h, csbi): + return None + (unused_bufx, unused_bufy, unused_curx, unused_cury, unused_wattr, + left, top, right, bottom, + unused_maxx, unused_maxy) = struct.unpack(b'hhhhHhhhhhh', csbi.raw) + x = right - left + 1 + y = bottom - top + 1 + return (x, y) + + +def _GetTermSizeEnvironment(): + """Returns the terminal x and y dimensions from the environment.""" + return (int(os.environ['COLUMNS']), int(os.environ['LINES'])) + + +def _GetTermSizeTput(): + """Returns the terminal x and y dimensions from tput(1).""" + import subprocess # pylint: disable=g-import-not-at-top + output = encoding.Decode(subprocess.check_output(['tput', 'cols'], + stderr=subprocess.STDOUT)) + cols = int(output) + output = encoding.Decode(subprocess.check_output(['tput', 'lines'], + stderr=subprocess.STDOUT)) + rows = int(output) + return (cols, rows) + + +_ANSI_CSI = '\x1b' # ANSI control sequence indicator (ESC) +_CONTROL_D = '\x04' # unix EOF (^D) +_CONTROL_Z = '\x1a' # Windows EOF (^Z) +_WINDOWS_CSI_1 = '\x00' # Windows control sequence indicator #1 +_WINDOWS_CSI_2 = '\xe0' # Windows control sequence indicator #2 + + +def GetRawKeyFunction(): + """Returns a function that reads one keypress from stdin with no echo. + + Returns: + A function that reads one keypress from stdin with no echo or a function + that always returns None if stdin does not support it. + """ + # Believe the first helper that doesn't bail. + for get_raw_key_function in (_GetRawKeyFunctionPosix, + _GetRawKeyFunctionWindows): + try: + return get_raw_key_function() + except: # pylint: disable=bare-except + pass + return lambda: None + + +def _GetRawKeyFunctionPosix(): + """_GetRawKeyFunction helper using Posix APIs.""" + # pylint: disable=g-import-not-at-top + import tty + # pylint: disable=g-import-not-at-top + import termios + + def _GetRawKeyPosix(): + """Reads and returns one keypress from stdin, no echo, using Posix APIs. + + Returns: + The key name, None for EOF, <*> for function keys, otherwise a + character. + """ + ansi_to_key = { + 'A': '', + 'B': '', + 'D': '', + 'C': '', + '5': '', + '6': '', + 'H': '', + 'F': '', + 'M': '', + 'S': '', + 'T': '', + } + + # Flush pending output. sys.stdin.read() would do this, but it's explicitly + # bypassed in _GetKeyChar(). + sys.stdout.flush() + + fd = sys.stdin.fileno() + + def _GetKeyChar(): + return encoding.Decode(os.read(fd, 1)) + + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(fd) + c = _GetKeyChar() + if c == _ANSI_CSI: + c = _GetKeyChar() + while True: + if c == _ANSI_CSI: + return c + if c.isalpha(): + break + prev_c = c + c = _GetKeyChar() + if c == '~': + c = prev_c + break + return ansi_to_key.get(c, '') + except: # pylint:disable=bare-except + c = None + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + return None if c in (_CONTROL_D, _CONTROL_Z) else c + + return _GetRawKeyPosix + + +def _GetRawKeyFunctionWindows(): + """_GetRawKeyFunction helper using Windows APIs.""" + # pylint: disable=g-import-not-at-top + import msvcrt + + def _GetRawKeyWindows(): + """Reads and returns one keypress from stdin, no echo, using Windows APIs. + + Returns: + The key name, None for EOF, <*> for function keys, otherwise a + character. + """ + windows_to_key = { + 'H': '', + 'P': '', + 'K': '', + 'M': '', + 'I': '', + 'Q': '', + 'G': '', + 'O': '', + } + + # Flush pending output. sys.stdin.read() would do this it's explicitly + # bypassed in _GetKeyChar(). + sys.stdout.flush() + + def _GetKeyChar(): + return encoding.Decode(msvcrt.getch()) + + c = _GetKeyChar() + # Special function key is a two character sequence; return the second char. + if c in (_WINDOWS_CSI_1, _WINDOWS_CSI_2): + return windows_to_key.get(_GetKeyChar(), '') + return None if c in (_CONTROL_D, _CONTROL_Z) else c + + return _GetRawKeyWindows diff --git a/fire/console/console_io.py b/fire/console/console_io.py new file mode 100644 index 00000000..ec0858d9 --- /dev/null +++ b/fire/console/console_io.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- # +# Copyright 2013 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General console printing utilities used by the Cloud SDK.""" + +import os +import signal +import subprocess +import sys + +from fire.console import console_attr +from fire.console import console_pager +from fire.console import encoding +from fire.console import files + + +def IsInteractive(output=False, error=False, heuristic=False): + """Determines if the current terminal session is interactive. + + sys.stdin must be a terminal input stream. + + Args: + output: If True then sys.stdout must also be a terminal output stream. + error: If True then sys.stderr must also be a terminal output stream. + heuristic: If True then we also do some additional heuristics to check if + we are in an interactive context. Checking home path for example. + + Returns: + True if the current terminal session is interactive. + """ + if not sys.stdin.isatty(): + return False + if output and not sys.stdout.isatty(): + return False + if error and not sys.stderr.isatty(): + return False + + if heuristic: + # Check the home path. Most startup scripts for example are executed by + # users that don't have a home path set. Home is OS dependent though, so + # check everything. + # *NIX OS usually sets the HOME env variable. It is usually '/home/user', + # but can also be '/root'. If it's just '/' we are most likely in an init + # script. + # Windows usually sets HOMEDRIVE and HOMEPATH. If they don't exist we are + # probably being run from a task scheduler context. HOMEPATH can be '\' + # when a user has a network mapped home directory. + # Cygwin has it all! Both Windows and Linux. Checking both is perfect. + home = os.getenv('HOME') + homepath = os.getenv('HOMEPATH') + if not homepath and (not home or home == '/'): + return False + return True + + +def More(contents, out, prompt=None, check_pager=True): + """Run a user specified pager or fall back to the internal pager. + + Args: + contents: The entire contents of the text lines to page. + out: The output stream. + prompt: The page break prompt. + check_pager: Checks the PAGER env var and uses it if True. + """ + if not IsInteractive(output=True): + out.write(contents) + return + if check_pager: + pager = encoding.GetEncodedValue(os.environ, 'PAGER', None) + if pager == '-': + # Use the fallback Pager. + pager = None + elif not pager: + # Search for a pager that handles ANSI escapes. + for command in ('less', 'pager'): + if files.FindExecutableOnPath(command): + pager = command + break + if pager: + # If the pager is less(1) then instruct it to display raw ANSI escape + # sequences to enable colors and font embellishments. + less_orig = encoding.GetEncodedValue(os.environ, 'LESS', None) + less = '-R' + (less_orig or '') + encoding.SetEncodedValue(os.environ, 'LESS', less) + # Ignore SIGINT while the pager is running. + # We don't want to terminate the parent while the child is still alive. + signal.signal(signal.SIGINT, signal.SIG_IGN) + p = subprocess.Popen(pager, stdin=subprocess.PIPE, shell=True) + enc = console_attr.GetConsoleAttr().GetEncoding() + p.communicate(input=contents.encode(enc)) + p.wait() + # Start using default signal handling for SIGINT again. + signal.signal(signal.SIGINT, signal.SIG_DFL) + if less_orig is None: + encoding.SetEncodedValue(os.environ, 'LESS', None) + return + # Fall back to the internal pager. + console_pager.Pager(contents, out, prompt).Run() diff --git a/fire/console/console_pager.py b/fire/console/console_pager.py new file mode 100644 index 00000000..565c7e1e --- /dev/null +++ b/fire/console/console_pager.py @@ -0,0 +1,299 @@ +# -*- coding: utf-8 -*- # +# Copyright 2015 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple console pager.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals + +import re +import sys + +from fire.console import console_attr + + +class Pager(object): + """A simple console text pager. + + This pager requires the entire contents to be available. The contents are + written one page of lines at a time. The prompt is written after each page of + lines. A one character response is expected. See HELP_TEXT below for more + info. + + The contents are written as is. For example, ANSI control codes will be in + effect. This is different from pagers like more(1) which is ANSI control code + agnostic and miscalculates line lengths, and less(1) which displays control + character names by default. + + Attributes: + _attr: The current ConsoleAttr handle. + _clear: A string that clears the prompt when written to _out. + _contents: The entire contents of the text lines to page. + _height: The terminal height in characters. + _out: The output stream, log.out (effectively) if None. + _prompt: The page break prompt. + _search_direction: The search direction command, n:forward, N:reverse. + _search_pattern: The current forward/reverse search compiled RE. + _width: The termonal width in characters. + """ + + HELP_TEXT = """ + Simple pager commands: + + b, ^B, , + Back one page. + f, ^F, , , + Forward one page. Does not quit if there are no more lines. + g, + Back to the first page. + g + Go to lines from the top. + G, + Forward to the last page. + G + Go to lines from the bottom. + h + Print pager command help. + j, +, + Forward one line. + k, -, + Back one line. + /pattern + Forward search for pattern. + ?pattern + Backward search for pattern. + n + Repeat current search. + N + Repeat current search in the opposite direction. + q, Q, ^C, ^D, ^Z + Quit return to the caller. + any other character + Prompt again. + + Hit any key to continue:""" + + PREV_POS_NXT_REPRINT = -1, -1 + + def __init__(self, contents, out=None, prompt=None): + """Constructor. + + Args: + contents: The entire contents of the text lines to page. + out: The output stream, log.out (effectively) if None. + prompt: The page break prompt, a default prompt is used if None.. + """ + self._contents = contents + self._out = out or sys.stdout + self._search_pattern = None + self._search_direction = None + + # prev_pos, prev_next values to force reprint + self.prev_pos, self.prev_nxt = self.PREV_POS_NXT_REPRINT + # Initialize the console attributes. + self._attr = console_attr.GetConsoleAttr() + self._width, self._height = self._attr.GetTermSize() + + # Initialize the prompt and the prompt clear string. + if not prompt: + prompt = '{bold}--({{percent}}%)--{normal}'.format( + bold=self._attr.GetFontCode(bold=True), + normal=self._attr.GetFontCode()) + self._clear = '\r{0}\r'.format(' ' * (self._attr.DisplayWidth(prompt) - 6)) + self._prompt = prompt + + # Initialize a list of lines with long lines split into separate display + # lines. + self._lines = [] + for line in contents.splitlines(): + self._lines += self._attr.SplitLine(line, self._width) + + def _Write(self, s): + """Mockable helper that writes s to self._out.""" + self._out.write(s) + + def _GetSearchCommand(self, c): + """Consumes a search command and returns the equivalent pager command. + + The search pattern is an RE that is pre-compiled and cached for subsequent + /, ?, n, or N commands. + + Args: + c: The search command char. + + Returns: + The pager command char. + """ + self._Write(c) + buf = '' + while True: + p = self._attr.GetRawKey() + if p in (None, '\n', '\r') or len(p) != 1: + break + self._Write(p) + buf += p + self._Write('\r' + ' ' * len(buf) + '\r') + if buf: + try: + self._search_pattern = re.compile(buf) + except re.error: + # Silently ignore pattern errors. + self._search_pattern = None + return '' + self._search_direction = 'n' if c == '/' else 'N' + return 'n' + + def _Help(self): + """Print command help and wait for any character to continue.""" + clear = self._height - (len(self.HELP_TEXT) - + len(self.HELP_TEXT.replace('\n', ''))) + if clear > 0: + self._Write('\n' * clear) + self._Write(self.HELP_TEXT) + self._attr.GetRawKey() + self._Write('\n') + + def Run(self): + """Run the pager.""" + # No paging if the contents are small enough. + if len(self._lines) <= self._height: + self._Write(self._contents) + return + + # We will not always reset previous values. + reset_prev_values = True + # Save room for the prompt at the bottom of the page. + self._height -= 1 + + # Loop over all the pages. + pos = 0 + while pos < len(self._lines): + # Write a page of lines. + nxt = pos + self._height + if nxt > len(self._lines): + nxt = len(self._lines) + pos = nxt - self._height + # Checks if the starting position is in between the current printed lines + # so we don't need to reprint all the lines. + if self.prev_pos < pos < self.prev_nxt: + # we start where the previous page ended. + self._Write('\n'.join(self._lines[self.prev_nxt:nxt]) + '\n') + elif pos != self.prev_pos and nxt != self.prev_nxt: + self._Write('\n'.join(self._lines[pos:nxt]) + '\n') + + # Handle the prompt response. + percent = self._prompt.format(percent=100 * nxt // len(self._lines)) + digits = '' + while True: + # We want to reset prev values if we just exited out of the while loop + if reset_prev_values: + self.prev_pos, self.prev_nxt = pos, nxt + reset_prev_values = False + self._Write(percent) + c = self._attr.GetRawKey() + self._Write(self._clear) + + # Parse the command. + if c in (None, # EOF. + 'q', # Quit. + 'Q', # Quit. + '\x03', # ^C (unix & windows terminal interrupt) + '\x1b', # ESC. + ): + # Quit. + return + elif c in ('/', '?'): + c = self._GetSearchCommand(c) + elif c.isdigit(): + # Collect digits for operation count. + digits += c + continue + + # Set the optional command count. + if digits: + count = int(digits) + digits = '' + else: + count = 0 + + # Finally commit to command c. + if c in ('', '', 'b', '\x02'): + # Previous page. + nxt = pos - self._height + if nxt < 0: + nxt = 0 + elif c in ('', '', 'f', '\x06', ' '): + # Next page. + if nxt >= len(self._lines): + continue + nxt = pos + self._height + if nxt >= len(self._lines): + nxt = pos + elif c in ('', 'g'): + # First page. + nxt = count - 1 + if nxt > len(self._lines) - self._height: + nxt = len(self._lines) - self._height + if nxt < 0: + nxt = 0 + elif c in ('', 'G'): + # Last page. + nxt = len(self._lines) - count + if nxt > len(self._lines) - self._height: + nxt = len(self._lines) - self._height + if nxt < 0: + nxt = 0 + elif c == 'h': + self._Help() + # Special case when we want to reprint the previous display. + self.prev_pos, self.prev_nxt = self.PREV_POS_NXT_REPRINT + nxt = pos + break + elif c in ('', 'j', '+', '\n', '\r'): + # Next line. + if nxt >= len(self._lines): + continue + nxt = pos + 1 + if nxt >= len(self._lines): + nxt = pos + elif c in ('', 'k', '-'): + # Previous line. + nxt = pos - 1 + if nxt < 0: + nxt = 0 + elif c in ('n', 'N'): + # Next pattern match search. + if not self._search_pattern: + continue + nxt = pos + i = pos + direction = 1 if c == self._search_direction else -1 + while True: + i += direction + if i < 0 or i >= len(self._lines): + break + if self._search_pattern.search(self._lines[i]): + nxt = i + break + else: + # Silently ignore everything else. + continue + if nxt != pos: + # We will exit the while loop because position changed so we can reset + # prev values. + reset_prev_values = True + break + pos = nxt diff --git a/fire/console/encoding.py b/fire/console/encoding.py new file mode 100644 index 00000000..3ce30cb5 --- /dev/null +++ b/fire/console/encoding.py @@ -0,0 +1,189 @@ +# -*- coding: utf-8 -*- # + +# Copyright 2015 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A module for dealing with unknown string and environment encodings.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals + +import sys + + +def Encode(string, encoding=None): + """Encode the text string to a byte string. + + Args: + string: str, The text string to encode. + encoding: The suggested encoding if known. + + Returns: + str, The binary string. + """ + del encoding # Unused. + return string + + +def Decode(data, encoding=None): + """Returns string with non-ascii characters decoded to UNICODE. + + UTF-8, the suggested encoding, and the usual suspects will be attempted in + order. + + Args: + data: A string or object that has str() and unicode() methods that may + contain an encoding incompatible with the standard output encoding. + encoding: The suggested encoding if known. + + Returns: + A text string representing the decoded byte string. + """ + if data is None: + return None + + # First we are going to get the data object to be a text string. + if isinstance(data, str) or isinstance(data, bytes): + string = data + else: + # Some non-string type of object. + string = str(data) + + if isinstance(string, str): + # Our work is done here. + return string + + try: + # Just return the string if its pure ASCII. + return string.decode('ascii') # pytype: disable=attribute-error + except UnicodeError: + # The string is not ASCII encoded. + pass + + # Try the suggested encoding if specified. + if encoding: + try: + return string.decode(encoding) # pytype: disable=attribute-error + except UnicodeError: + # Bad suggestion. + pass + + # Try UTF-8 because the other encodings could be extended ASCII. It would + # be exceptional if a valid extended ascii encoding with extended chars + # were also a valid UITF-8 encoding. + try: + return string.decode('utf8') # pytype: disable=attribute-error + except UnicodeError: + # Not a UTF-8 encoding. + pass + + # Try the filesystem encoding. + try: + return string.decode(sys.getfilesystemencoding()) # pytype: disable=attribute-error + except UnicodeError: + # string is not encoded for filesystem paths. + pass + + # Try the system default encoding. + try: + return string.decode(sys.getdefaultencoding()) # pytype: disable=attribute-error + except UnicodeError: + # string is not encoded using the default encoding. + pass + + # We don't know the string encoding. + # This works around a Python str.encode() "feature" that throws + # an ASCII *decode* exception on str strings that contain 8th bit set + # bytes. For example, this sequence throws an exception: + # string = '\xdc' # iso-8859-1 'Ü' + # string = string.encode('ascii', 'backslashreplace') + # even though 'backslashreplace' is documented to handle encoding + # errors. We work around the problem by first decoding the str string + # from an 8-bit encoding to unicode, selecting any 8-bit encoding that + # uses all 256 bytes (such as ISO-8559-1): + # string = string.decode('iso-8859-1') + # Using this produces a sequence that works: + # string = '\xdc' + # string = string.decode('iso-8859-1') + # string = string.encode('ascii', 'backslashreplace') + return string.decode('iso-8859-1') # pytype: disable=attribute-error + + +def GetEncodedValue(env, name, default=None): + """Returns the decoded value of the env var name. + + Args: + env: {str: str}, The env dict. + name: str, The env var name. + default: The value to return if name is not in env. + + Returns: + The decoded value of the env var name. + """ + name = Encode(name) + value = env.get(name) + if value is None: + return default + # In Python 3, the environment sets and gets accept and return text strings + # only, and it handles the encoding itself so this is not necessary. + return Decode(value) + + +def SetEncodedValue(env, name, value, encoding=None): + """Sets the value of name in env to an encoded value. + + Args: + env: {str: str}, The env dict. + name: str, The env var name. + value: str or unicode, The value for name. If None then name is removed from + env. + encoding: str, The encoding to use or None to try to infer it. + """ + # Python 2 *and* 3 unicode support falls apart at filesystem/argv/environment + # boundaries. The encoding used for filesystem paths and environment variable + # names/values is under user control on most systems. With one of those values + # in hand there is no way to tell exactly how the value was encoded. We get + # some reasonable hints from sys.getfilesystemencoding() or + # sys.getdefaultencoding() and use them to encode values that the receiving + # process will have a chance at decoding. Leaving the values as unicode + # strings will cause os module Unicode exceptions. What good is a language + # unicode model when the module support could care less? + name = Encode(name, encoding=encoding) + if value is None: + env.pop(name, None) + return + env[name] = Encode(value, encoding=encoding) + + +def EncodeEnv(env, encoding=None): + """Encodes all the key value pairs in env in preparation for subprocess. + + Args: + env: {str: str}, The environment you are going to pass to subprocess. + encoding: str, The encoding to use or None to use the default. + + Returns: + {bytes: bytes}, The environment to pass to subprocess. + """ + encoding = encoding or _GetEncoding() + return { + Encode(k, encoding=encoding): Encode(v, encoding=encoding) + for k, v in env.items() + } + + +def _GetEncoding(): + """Gets the default encoding to use.""" + return sys.getfilesystemencoding() or sys.getdefaultencoding() diff --git a/fire/console/files.py b/fire/console/files.py new file mode 100644 index 00000000..97222c3d --- /dev/null +++ b/fire/console/files.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- # +# Copyright 2013 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Some general file utilities used that can be used by the Cloud SDK.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals + +import os + +from fire.console import encoding as encoding_util +from fire.console import platforms + + +def _GetSystemPath(): + """Returns properly encoded system PATH variable string.""" + return encoding_util.GetEncodedValue(os.environ, 'PATH') + + +def _FindExecutableOnPath(executable, path, pathext): + """Internal function to a find an executable. + + Args: + executable: The name of the executable to find. + path: A list of directories to search separated by 'os.pathsep'. + pathext: An iterable of file name extensions to use. + + Returns: + str, the path to a file on `path` with name `executable` + `p` for + `p` in `pathext`. + + Raises: + ValueError: invalid input. + """ + + if isinstance(pathext, str): + raise ValueError('_FindExecutableOnPath(..., pathext=\'{0}\') failed ' + 'because pathext must be an iterable of strings, but got ' + 'a string.'.format(pathext)) + + # Prioritize preferred extension over earlier in path. + for ext in pathext: + for directory in path.split(os.pathsep): + # Windows can have paths quoted. + directory = directory.strip('"') + full = os.path.normpath(os.path.join(directory, executable) + ext) + # On Windows os.access(full, os.X_OK) is always True. + if os.path.isfile(full) and os.access(full, os.X_OK): + return full + return None + + +def _PlatformExecutableExtensions(platform): + if platform == platforms.OperatingSystem.WINDOWS: + return ('.exe', '.cmd', '.bat', '.com', '.ps1') + else: + return ('', '.sh') + + +def FindExecutableOnPath(executable, path=None, pathext=None, + allow_extensions=False): + """Searches for `executable` in the directories listed in `path` or $PATH. + + Executable must not contain a directory or an extension. + + Args: + executable: The name of the executable to find. + path: A list of directories to search separated by 'os.pathsep'. If None + then the system PATH is used. + pathext: An iterable of file name extensions to use. If None then + platform specific extensions are used. + allow_extensions: A boolean flag indicating whether extensions in the + executable are allowed. + + Returns: + The path of 'executable' (possibly with a platform-specific extension) if + found and executable, None if not found. + + Raises: + ValueError: if executable has a path or an extension, and extensions are + not allowed, or if there's an internal error. + """ + + if not allow_extensions and os.path.splitext(executable)[1]: + raise ValueError('FindExecutableOnPath({0},...) failed because first ' + 'argument must not have an extension.'.format(executable)) + + if os.path.dirname(executable): + raise ValueError('FindExecutableOnPath({0},...) failed because first ' + 'argument must not have a path.'.format(executable)) + + if path is None: + effective_path = _GetSystemPath() + else: + effective_path = path + effective_pathext = (pathext if pathext is not None + else _PlatformExecutableExtensions( + platforms.OperatingSystem.Current())) + + return _FindExecutableOnPath(executable, effective_path, + effective_pathext) diff --git a/fire/console/platforms.py b/fire/console/platforms.py new file mode 100644 index 00000000..13fd8204 --- /dev/null +++ b/fire/console/platforms.py @@ -0,0 +1,483 @@ +# -*- coding: utf-8 -*- # +# Copyright 2013 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for determining the current platform and architecture.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals + +import os +import platform +import subprocess +import sys + + +class Error(Exception): + """Base class for exceptions in the platforms module.""" + pass + + +class InvalidEnumValue(Error): # pylint: disable=g-bad-exception-name + """Exception for when a string could not be parsed to a valid enum value.""" + + def __init__(self, given, enum_type, options): + """Constructs a new exception. + + Args: + given: str, The given string that could not be parsed. + enum_type: str, The human readable name of the enum you were trying to + parse. + options: list(str), The valid values for this enum. + """ + super(InvalidEnumValue, self).__init__( + 'Could not parse [{0}] into a valid {1}. Valid values are [{2}]' + .format(given, enum_type, ', '.join(options))) + + +class OperatingSystem(object): + """An enum representing the operating system you are running on.""" + + class _OS(object): + """A single operating system.""" + + # pylint: disable=redefined-builtin + def __init__(self, id, name, file_name): + self.id = id + self.name = name + self.file_name = file_name + + def __str__(self): + return self.id + + def __eq__(self, other): + return (isinstance(other, type(self)) and + self.id == other.id and + self.name == other.name and + self.file_name == other.file_name) + + def __hash__(self): + return hash(self.id) + hash(self.name) + hash(self.file_name) + + def __ne__(self, other): + return not self == other + + @classmethod + def _CmpHelper(cls, x, y): + """Just a helper equivalent to the cmp() function in Python 2.""" + return (x > y) - (x < y) + + def __lt__(self, other): + return self._CmpHelper( + (self.id, self.name, self.file_name), + (other.id, other.name, other.file_name)) < 0 + + def __gt__(self, other): + return self._CmpHelper( + (self.id, self.name, self.file_name), + (other.id, other.name, other.file_name)) > 0 + + def __le__(self, other): + return not self.__gt__(other) + + def __ge__(self, other): + return not self.__lt__(other) + + WINDOWS = _OS('WINDOWS', 'Windows', 'windows') + MACOSX = _OS('MACOSX', 'Mac OS X', 'darwin') + LINUX = _OS('LINUX', 'Linux', 'linux') + CYGWIN = _OS('CYGWIN', 'Cygwin', 'cygwin') + MSYS = _OS('MSYS', 'Msys', 'msys') + _ALL = [WINDOWS, MACOSX, LINUX, CYGWIN, MSYS] + + @staticmethod + def AllValues(): + """Gets all possible enum values. + + Returns: + list, All the enum values. + """ + return list(OperatingSystem._ALL) + + @staticmethod + def FromId(os_id, error_on_unknown=True): + """Gets the enum corresponding to the given operating system id. + + Args: + os_id: str, The operating system id to parse + error_on_unknown: bool, True to raise an exception if the id is unknown, + False to just return None. + + Raises: + InvalidEnumValue: If the given value cannot be parsed. + + Returns: + OperatingSystemTuple, One of the OperatingSystem constants or None if the + input is None. + """ + if not os_id: + return None + for operating_system in OperatingSystem._ALL: + if operating_system.id == os_id: + return operating_system + if error_on_unknown: + raise InvalidEnumValue(os_id, 'Operating System', + [value.id for value in OperatingSystem._ALL]) + return None + + @staticmethod + def Current(): + """Determines the current operating system. + + Returns: + OperatingSystemTuple, One of the OperatingSystem constants or None if it + cannot be determined. + """ + if os.name == 'nt': + return OperatingSystem.WINDOWS + elif 'linux' in sys.platform: + return OperatingSystem.LINUX + elif 'darwin' in sys.platform: + return OperatingSystem.MACOSX + elif 'cygwin' in sys.platform: + return OperatingSystem.CYGWIN + elif 'msys' in sys.platform: + return OperatingSystem.MSYS + return None + + @staticmethod + def IsWindows(): + """Returns True if the current operating system is Windows.""" + return OperatingSystem.Current() is OperatingSystem.WINDOWS + + +class Architecture(object): + """An enum representing the system architecture you are running on.""" + + class _ARCH(object): + """A single architecture.""" + + # pylint: disable=redefined-builtin + def __init__(self, id, name, file_name): + self.id = id + self.name = name + self.file_name = file_name + + def __str__(self): + return self.id + + def __eq__(self, other): + return (isinstance(other, type(self)) and + self.id == other.id and + self.name == other.name and + self.file_name == other.file_name) + + def __hash__(self): + return hash(self.id) + hash(self.name) + hash(self.file_name) + + def __ne__(self, other): + return not self == other + + @classmethod + def _CmpHelper(cls, x, y): + """Just a helper equivalent to the cmp() function in Python 2.""" + return (x > y) - (x < y) + + def __lt__(self, other): + return self._CmpHelper( + (self.id, self.name, self.file_name), + (other.id, other.name, other.file_name)) < 0 + + def __gt__(self, other): + return self._CmpHelper( + (self.id, self.name, self.file_name), + (other.id, other.name, other.file_name)) > 0 + + def __le__(self, other): + return not self.__gt__(other) + + def __ge__(self, other): + return not self.__lt__(other) + + x86 = _ARCH('x86', 'x86', 'x86') + x86_64 = _ARCH('x86_64', 'x86_64', 'x86_64') + ppc = _ARCH('PPC', 'PPC', 'ppc') + arm = _ARCH('arm', 'arm', 'arm') + _ALL = [x86, x86_64, ppc, arm] + + # Possible values for `uname -m` and what arch they map to. + # Examples of possible values: https://en.wikipedia.org/wiki/Uname + _MACHINE_TO_ARCHITECTURE = { + 'amd64': x86_64, 'x86_64': x86_64, 'i686-64': x86_64, + 'i386': x86, 'i686': x86, 'x86': x86, + 'ia64': x86, # Itanium is different x64 arch, treat it as the common x86. + 'powerpc': ppc, 'power macintosh': ppc, 'ppc64': ppc, + 'armv6': arm, 'armv6l': arm, 'arm64': arm, 'armv7': arm, 'armv7l': arm} + + @staticmethod + def AllValues(): + """Gets all possible enum values. + + Returns: + list, All the enum values. + """ + return list(Architecture._ALL) + + @staticmethod + def FromId(architecture_id, error_on_unknown=True): + """Gets the enum corresponding to the given architecture id. + + Args: + architecture_id: str, The architecture id to parse + error_on_unknown: bool, True to raise an exception if the id is unknown, + False to just return None. + + Raises: + InvalidEnumValue: If the given value cannot be parsed. + + Returns: + ArchitectureTuple, One of the Architecture constants or None if the input + is None. + """ + if not architecture_id: + return None + for arch in Architecture._ALL: + if arch.id == architecture_id: + return arch + if error_on_unknown: + raise InvalidEnumValue(architecture_id, 'Architecture', + [value.id for value in Architecture._ALL]) + return None + + @staticmethod + def Current(): + """Determines the current system architecture. + + Returns: + ArchitectureTuple, One of the Architecture constants or None if it cannot + be determined. + """ + return Architecture._MACHINE_TO_ARCHITECTURE.get(platform.machine().lower()) + + +class Platform(object): + """Holds an operating system and architecture.""" + + def __init__(self, operating_system, architecture): + """Constructs a new platform. + + Args: + operating_system: OperatingSystem, The OS + architecture: Architecture, The machine architecture. + """ + self.operating_system = operating_system + self.architecture = architecture + + def __str__(self): + return '{}-{}'.format(self.operating_system, self.architecture) + + @staticmethod + def Current(os_override=None, arch_override=None): + """Determines the current platform you are running on. + + Args: + os_override: OperatingSystem, A value to use instead of the current. + arch_override: Architecture, A value to use instead of the current. + + Returns: + Platform, The platform tuple of operating system and architecture. Either + can be None if it could not be determined. + """ + return Platform( + os_override if os_override else OperatingSystem.Current(), + arch_override if arch_override else Architecture.Current()) + + def UserAgentFragment(self): + """Generates the fragment of the User-Agent that represents the OS. + + Examples: + (Linux 3.2.5-gg1236) + (Windows NT 6.1.7601) + (Macintosh; PPC Mac OS X 12.4.0) + (Macintosh; Intel Mac OS X 12.4.0) + + Returns: + str, The fragment of the User-Agent string. + """ + # Below, there are examples of the value of platform.uname() per platform. + # platform.release() is uname[2], platform.version() is uname[3]. + if self.operating_system == OperatingSystem.LINUX: + # ('Linux', '', '3.2.5-gg1236', + # '#1 SMP Tue May 21 02:35:06 PDT 2013', 'x86_64', 'x86_64') + return '({name} {version})'.format( + name=self.operating_system.name, version=platform.release()) + elif self.operating_system == OperatingSystem.WINDOWS: + # ('Windows', '', '7', '6.1.7601', 'AMD64', + # 'Intel64 Family 6 Model 45 Stepping 7, GenuineIntel') + return '({name} NT {version})'.format( + name=self.operating_system.name, version=platform.version()) + elif self.operating_system == OperatingSystem.MACOSX: + # ('Darwin', '', '12.4.0', + # 'Darwin Kernel Version 12.4.0: Wed May 1 17:57:12 PDT 2013; + # root:xnu-2050.24.15~1/RELEASE_X86_64', 'x86_64', 'i386') + format_string = '(Macintosh; {name} Mac OS X {version})' + arch_string = (self.architecture.name + if self.architecture == Architecture.ppc else 'Intel') + return format_string.format( + name=arch_string, version=platform.release()) + else: + return '()' + + def AsyncPopenArgs(self): + """Returns the args for spawning an async process using Popen on this OS. + + Make sure the main process does not wait for the new process. On windows + this means setting the 0x8 creation flag to detach the process. + + Killing a group leader kills the whole group. Setting creation flag 0x200 on + Windows or running setsid on *nix makes sure the new process is in a new + session with the new process the group leader. This means it can't be killed + if the parent is killed. + + Finally, all file descriptors (FD) need to be closed so that waiting for the + output of the main process does not inadvertently wait for the output of the + new process, which means waiting for the termination of the new process. + If the new process wants to write to a file, it can open new FDs. + + Returns: + {str:}, The args for spawning an async process using Popen on this OS. + """ + args = {} + if self.operating_system == OperatingSystem.WINDOWS: + args['close_fds'] = True # This is enough to close _all_ FDs on windows. + detached_process = 0x00000008 + create_new_process_group = 0x00000200 + # 0x008 | 0x200 == 0x208 + args['creationflags'] = detached_process | create_new_process_group + else: + # Killing a group leader kills the whole group. + # Create a new session with the new process the group leader. + args['preexec_fn'] = os.setsid + args['close_fds'] = True # This closes all FDs _except_ 0, 1, 2 on *nix. + args['stdin'] = subprocess.PIPE + args['stdout'] = subprocess.PIPE + args['stderr'] = subprocess.PIPE + return args + + +class PythonVersion(object): + """Class to validate the Python version we are using. + + The Cloud SDK officially supports Python 2.7. + + However, many commands do work with Python 2.6, so we don't error out when + users are using this (we consider it sometimes "compatible" but not + "supported"). + """ + + # See class docstring for descriptions of what these mean + MIN_REQUIRED_PY2_VERSION = (2, 6) + MIN_SUPPORTED_PY2_VERSION = (2, 7) + MIN_SUPPORTED_PY3_VERSION = (3, 4) + ENV_VAR_MESSAGE = """\ + +If you have a compatible Python interpreter installed, you can use it by setting +the CLOUDSDK_PYTHON environment variable to point to it. + +""" + + def __init__(self, version=None): + if version: + self.version = version + elif hasattr(sys, 'version_info'): + self.version = sys.version_info[:2] + else: + self.version = None + + def SupportedVersionMessage(self, allow_py3): + if allow_py3: + return 'Please use Python version {0}.{1}.x or {2}.{3} and up.'.format( + PythonVersion.MIN_SUPPORTED_PY2_VERSION[0], + PythonVersion.MIN_SUPPORTED_PY2_VERSION[1], + PythonVersion.MIN_SUPPORTED_PY3_VERSION[0], + PythonVersion.MIN_SUPPORTED_PY3_VERSION[1]) + else: + return 'Please use Python version {0}.{1}.x.'.format( + PythonVersion.MIN_SUPPORTED_PY2_VERSION[0], + PythonVersion.MIN_SUPPORTED_PY2_VERSION[1]) + + def IsCompatible(self, allow_py3=False, raise_exception=False): + """Ensure that the Python version we are using is compatible. + + This will print an error message if not compatible. + + Compatible versions are 2.6 and 2.7 and > 3.4 if allow_py3 is True. + We don't guarantee support for 2.6 so we want to warn about it. + + Args: + allow_py3: bool, True if we should allow a Python 3 interpreter to run + gcloud. If False, this returns an error for Python 3. + raise_exception: bool, True to raise an exception rather than printing + the error and exiting. + + Raises: + Error: If not compatible and raise_exception is True. + + Returns: + bool, True if the version is valid, False otherwise. + """ + error = None + if not self.version: + # We don't know the version, not a good sign. + error = ('ERROR: Your current version of Python is not compatible with ' + 'the Google Cloud SDK. {0}\n' + .format(self.SupportedVersionMessage(allow_py3))) + else: + if self.version[0] < 3: + # Python 2 Mode + if self.version < PythonVersion.MIN_REQUIRED_PY2_VERSION: + error = ('ERROR: Python {0}.{1} is not compatible with the Google ' + 'Cloud SDK. {2}\n' + .format(self.version[0], self.version[1], + self.SupportedVersionMessage(allow_py3))) + else: + # Python 3 Mode + if not allow_py3: + error = ('ERROR: Python 3 and later is not compatible with the ' + 'Google Cloud SDK. {0}\n' + .format(self.SupportedVersionMessage(allow_py3))) + elif self.version < PythonVersion.MIN_SUPPORTED_PY3_VERSION: + error = ('ERROR: Python {0}.{1} is not compatible with the Google ' + 'Cloud SDK. {2}\n' + .format(self.version[0], self.version[1], + self.SupportedVersionMessage(allow_py3))) + + if error: + if raise_exception: + raise Error(error) + sys.stderr.write(error) + sys.stderr.write(PythonVersion.ENV_VAR_MESSAGE) + return False + + # Warn that 2.6 might not work. + if (self.version >= self.MIN_REQUIRED_PY2_VERSION and + self.version < self.MIN_SUPPORTED_PY2_VERSION): + sys.stderr.write("""\ +WARNING: Python 2.6.x is no longer officially supported by the Google Cloud SDK +and may not function correctly. {0} +{1}""".format(self.SupportedVersionMessage(allow_py3), + PythonVersion.ENV_VAR_MESSAGE)) + + return True diff --git a/fire/console/text.py b/fire/console/text.py new file mode 100644 index 00000000..73e68488 --- /dev/null +++ b/fire/console/text.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- # +# Copyright 2018 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Semantic text objects that are used for styled outputting.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals + +import enum + + +class TextAttributes(object): + """Attributes to use to style text with.""" + + def __init__(self, format_str=None, color=None, attrs=None): + """Defines a set of attributes for a piece of text. + + Args: + format_str: (str), string that will be used to format the text + with. For example '[{}]', to enclose text in brackets. + color: (Colors), the color the text should be formatted with. + attrs: (Attrs), the attributes to apply to text. + """ + self._format_str = format_str + self._color = color + self._attrs = attrs or [] + + @property + def format_str(self): + return self._format_str + + @property + def color(self): + return self._color + + @property + def attrs(self): + return self._attrs + + +class TypedText(object): + """Text with a semantic type that will be used for styling.""" + + def __init__(self, texts, text_type=None): + """String of text and a corresponding type to use to style that text. + + Args: + texts: (list[str]), list of strs or TypedText objects + that should be styled using text_type. + text_type: (TextTypes), the semantic type of the text that + will be used to style text. + """ + self.texts = texts + self.text_type = text_type + + def __len__(self): + length = 0 + for text in self.texts: + length += len(text) + return length + + def __add__(self, other): + texts = [self, other] + return TypedText(texts) + + def __radd__(self, other): + texts = [other, self] + return TypedText(texts) + + +class _TextTypes(enum.Enum): + """Text types base class that defines base functionality.""" + + def __call__(self, *args): + """Returns a TypedText object using this style.""" + return TypedText(list(args), self) + + +# TODO: Add more types. +class TextTypes(_TextTypes): + """Defines text types that can be used for styling text.""" + RESOURCE_NAME = 1 + URL = 2 + USER_INPUT = 3 + COMMAND = 4 + INFO = 5 + URI = 6 + OUTPUT = 7 + PT_SUCCESS = 8 + PT_FAILURE = 9 + diff --git a/fire/core.py b/fire/core.py index 66746547..26a25753 100644 --- a/fire/core.py +++ b/fire/core.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ def main(argv): access a member of current component, call the current component (if it's a function), or instantiate the current component (if it's a class). The target component begins as Component, and at each operation the component becomes the -result of the preceeding operation. +result of the preceding operation. For example "command fn arg1 arg2" might access the "fn" property of the initial target component, and then call that function with arguments 'arg1' and 'arg2'. @@ -44,32 +44,34 @@ def main(argv): -h --help: Provide help and usage information for the command. -i --interactive: Drop into a Python REPL after running the command. --completion: Write the Bash completion script for the tool to stdout. + --completion fish: Write the Fish completion script for the tool to stdout. --separator SEPARATOR: Use SEPARATOR in place of the default separator, '-'. --trace: Get the Fire Trace for the command. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - +import asyncio import inspect import json -import pipes +import os +import re import shlex import sys +import types from fire import completion from fire import decorators -from fire import helputils +from fire import formatting +from fire import helptext from fire import inspectutils from fire import interact from fire import parser from fire import trace -import six +from fire import value_types +from fire.console import console_io -def Fire(component=None, command=None, name=None): - """This function, Fire, is the main entrypoint for Fire. +def Fire(component=None, command=None, name=None, serialize=None): + """This function, Fire, is the main entrypoint for Python Fire. Executes a command either from the `command` argument or from sys.argv by recursively traversing the target object `component`'s members consuming @@ -80,9 +82,12 @@ def Fire(component=None, command=None, name=None): Args: component: The initial target component. command: Optional. If supplied, this is the command executed. If not - supplied, then the command is taken from sys.argv instead. + supplied, then the command is taken from sys.argv instead. This can be + a string or a list of strings; a list of strings is preferred. name: Optional. The name of the command as entered at the command line. Used in interactive mode and for generating the completion script. + serialize: Optional. If supplied, all objects are serialized to text via + the provided callable. Returns: The result of executing the Fire command. Execution begins with the initial target component. The component is updated by using the command arguments @@ -91,65 +96,82 @@ def Fire(component=None, command=None, name=None): it's a class). When all arguments are consumed and there's no function left to call or class left to instantiate, the resulting current component is the final result. - If a Fire error is encountered, the Fire Trace is displayed to stdout and - None is returned. - If the trace command line argument is supplied, the FireTrace is returned. + Raises: + ValueError: If the command argument is supplied, but not a string or a + sequence of arguments. + FireExit: When Fire encounters a FireError, Fire will raise a FireExit with + code 2. When used with the help or trace flags, Fire will raise a + FireExit with code 0 if successful. """ + name = name or os.path.basename(sys.argv[0]) + # Get args as a list. - if command is None: + if isinstance(command, str): + args = shlex.split(command) + elif isinstance(command, (list, tuple)): + args = command + elif command is None: # Use the command line args by default if no command is specified. - name = name or sys.argv[0] args = sys.argv[1:] else: - # Otherwise use the specified command. - args = shlex.split(command) + raise ValueError('The command argument must be a string or a sequence of ' + 'arguments.') + + args, flag_args = parser.SeparateFlagArgs(args) + + argparser = parser.CreateParser() + parsed_flag_args, unused_args = argparser.parse_known_args(flag_args) - # Determine the calling context. - caller = inspect.stack()[1] - caller_frame = caller[0] - caller_globals = caller_frame.f_globals - caller_locals = caller_frame.f_locals context = {} - context.update(caller_globals) - context.update(caller_locals) + if parsed_flag_args.interactive or component is None: + # Determine the calling context. + caller = inspect.stack()[1] + caller_frame = caller[0] + caller_globals = caller_frame.f_globals + caller_locals = caller_frame.f_locals + context.update(caller_globals) + context.update(caller_locals) - component_trace = _Fire(component, args, context, name) + component_trace = _Fire(component, args, parsed_flag_args, context, name) if component_trace.HasError(): - for help_flag in ['-h', '--help']: - if help_flag in component_trace.elements[-1].args: - command = '{cmd} -- --help'.format(cmd=component_trace.GetCommand()) - print(('WARNING: The proper way to show help is {cmd}.\n' - 'Showing help anyway.\n').format(cmd=pipes.quote(command))) - - print('Fire trace:\n{trace}\n'.format(trace=component_trace)) + _DisplayError(component_trace) + raise FireExit(2, component_trace) + if component_trace.show_trace and component_trace.show_help: + output = [f'Fire trace:\n{component_trace}\n'] result = component_trace.GetResult() - print( - helputils.HelpString(result, component_trace, component_trace.verbose)) - return None - elif component_trace.show_trace and component_trace.show_help: - print('Fire trace:\n{trace}\n'.format(trace=component_trace)) + help_text = helptext.HelpText( + result, trace=component_trace, verbose=component_trace.verbose) + output.append(help_text) + Display(output, out=sys.stderr) + raise FireExit(0, component_trace) + if component_trace.show_trace: + output = [f'Fire trace:\n{component_trace}'] + Display(output, out=sys.stderr) + raise FireExit(0, component_trace) + if component_trace.show_help: result = component_trace.GetResult() - print( - helputils.HelpString(result, component_trace, component_trace.verbose)) - return component_trace - elif component_trace.show_trace: - print('Fire trace:\n{trace}'.format(trace=component_trace)) - return component_trace - elif component_trace.show_help: - result = component_trace.GetResult() - print( - helputils.HelpString(result, component_trace, component_trace.verbose)) - return None - else: - _PrintResult(component_trace, verbose=component_trace.verbose) - result = component_trace.GetResult() - return result + help_text = helptext.HelpText( + result, trace=component_trace, verbose=component_trace.verbose) + output = [help_text] + Display(output, out=sys.stderr) + raise FireExit(0, component_trace) + + # The command succeeded normally; print the result. + _PrintResult( + component_trace, verbose=component_trace.verbose, serialize=serialize) + result = component_trace.GetResult() + return result + + +def Display(lines, out): + text = '\n'.join(lines) + '\n' + console_io.More(text, out=out) -def CompletionScript(name, component): - """Returns the text of the Bash completion script for a Fire CLI.""" - return completion.Script(name, component) +def CompletionScript(name, component, shell): + """Returns the text of the completion script for a Fire CLI.""" + return completion.Script(name, component, shell=shell) class FireError(Exception): @@ -160,32 +182,124 @@ class FireError(Exception): """ -def _PrintResult(component_trace, verbose=False): +class FireExit(SystemExit): # pylint: disable=g-bad-exception-name + """An exception raised by Fire to the client in the case of a FireError. + + The trace of the Fire program is available on the `trace` property. + + This exception inherits from SystemExit, so clients may explicitly catch it + with `except SystemExit` or `except FireExit`. If not caught, this exception + will cause the client program to exit without a stacktrace. + """ + + def __init__(self, code, component_trace): + """Constructs a FireExit exception. + + Args: + code: (int) Exit code for the Fire CLI. + component_trace: (FireTrace) The trace for the Fire command. + """ + super().__init__(code) + self.trace = component_trace + + +def _IsHelpShortcut(component_trace, remaining_args): + """Determines if the user is trying to access help without '--' separator. + + For example, mycmd.py --help instead of mycmd.py -- --help. + + Args: + component_trace: (FireTrace) The trace for the Fire command. + remaining_args: List of remaining args that haven't been consumed yet. + Returns: + True if help is requested, False otherwise. + """ + show_help = False + if remaining_args: + target = remaining_args[0] + if target in ('-h', '--help'): + # Check if --help would be consumed as a keyword argument, or is a member. + component = component_trace.GetResult() + if inspect.isclass(component) or inspect.isroutine(component): + fn_spec = inspectutils.GetFullArgSpec(component) + _, remaining_kwargs, _ = _ParseKeywordArgs(remaining_args, fn_spec) + show_help = target in remaining_kwargs + else: + members = dict(inspect.getmembers(component)) + show_help = target not in members + + if show_help: + component_trace.show_help = True + command = f'{component_trace.GetCommand()} -- --help' + print(f'INFO: Showing help with the command {shlex.quote(command)}.\n', + file=sys.stderr) + return show_help + + +def _PrintResult(component_trace, verbose=False, serialize=None): """Prints the result of the Fire call to stdout in a human readable way.""" - # TODO: Design human readable deserializable serialization method - # and move serialization to it's own module. + # TODO(dbieber): Design human readable deserializable serialization method + # and move serialization to its own module. result = component_trace.GetResult() - if isinstance(result, list): - for i in result: - print(_OneLineResult(i)) - elif isinstance(result, set): - for i in result: - print(_OneLineResult(i)) - elif inspect.isgenerator(result): + # Allow users to modify the return value of the component and provide + # custom formatting. + if serialize: + if not callable(serialize): + raise FireError( + 'The argument `serialize` must be empty or callable:', serialize) + result = serialize(result) + + if value_types.HasCustomStr(result): + # If the object has a custom __str__ method, rather than one inherited from + # object, then we use that to serialize the object. + print(str(result)) + return + + if isinstance(result, (list, set, frozenset, types.GeneratorType)): for i in result: print(_OneLineResult(i)) elif inspect.isgeneratorfunction(result): raise NotImplementedError - elif isinstance(result, dict): + elif isinstance(result, dict) and value_types.IsSimpleGroup(result): print(_DictAsString(result, verbose)) elif isinstance(result, tuple): print(_OneLineResult(result)) - elif isinstance(result, - (bool, six.string_types, six.integer_types, float, complex)): - print(result) - elif result is not None: - print(helputils.HelpString(result, component_trace, verbose)) + elif isinstance(result, value_types.VALUE_TYPES): + if result is not None: + print(result) + else: + help_text = helptext.HelpText( + result, trace=component_trace, verbose=verbose) + output = [help_text] + Display(output, out=sys.stdout) + + +def _DisplayError(component_trace): + """Prints the Fire trace and the error to stdout.""" + result = component_trace.GetResult() + + output = [] + show_help = False + for help_flag in ('-h', '--help'): + if help_flag in component_trace.elements[-1].args: + show_help = True + + if show_help: + command = f'{component_trace.GetCommand()} -- --help' + print(f'INFO: Showing help with the command {shlex.quote(command)}.\n', + file=sys.stderr) + help_text = helptext.HelpText(result, trace=component_trace, + verbose=component_trace.verbose) + output.append(help_text) + Display(output, out=sys.stderr) + else: + print(formatting.Error('ERROR: ') + + component_trace.elements[-1].ErrorAsStr(), + file=sys.stderr) + error_text = helptext.UsageText(result, trace=component_trace, + verbose=component_trace.verbose) + print(error_text, file=sys.stderr) def _DictAsString(result, verbose=False): @@ -197,42 +311,54 @@ def _DictAsString(result, verbose=False): Returns: A string representing the dict """ - longest_key = max( - len(str(key)) for key in result.keys() - if _ComponentVisible(key, verbose) - ) - format_string = '{{key:{padding}s}} {{value}}'.format(padding=longest_key + 1) + + # We need to do 2 iterations over the items in the result dict + # 1) Getting visible items and the longest key for output formatting + # 2) Actually construct the output lines + class_attrs = inspectutils.GetClassAttrsDict(result) + result_visible = { + key: value for key, value in result.items() + if completion.MemberVisible(result, key, value, + class_attrs=class_attrs, verbose=verbose) + } + + if not result_visible: + return '{}' + + longest_key = max(len(str(key)) for key in result_visible.keys()) + format_string = f'{{key:{longest_key + 1}s}} {{value}}' lines = [] for key, value in result.items(): - if _ComponentVisible(key, verbose): - line = format_string.format( - key=str(key) + ':', value=_OneLineResult(value)) + if completion.MemberVisible(result, key, value, class_attrs=class_attrs, + verbose=verbose): + line = format_string.format(key=f'{key}:', value=_OneLineResult(value)) lines.append(line) return '\n'.join(lines) -def _ComponentVisible(component, verbose=False): - """Returns whether a component should be visible in the output.""" - return ( - verbose - or not isinstance(component, six.string_types) - or not component.startswith('_')) - - def _OneLineResult(result): """Returns result serialized to a single line string.""" - # TODO: Ensure line is fewer than eg 120 characters. - if isinstance(result, six.string_types): + # TODO(dbieber): Ensure line is fewer than eg 120 characters. + if isinstance(result, str): return str(result).replace('\n', ' ') + # TODO(dbieber): Show a small amount of usage information about the function + # or module if it fits cleanly on the line. + if inspect.isfunction(result): + return f'' + + if inspect.ismodule(result): + return f'' + try: - return json.dumps(result) - except TypeError: + # Don't force conversion to ascii. + return json.dumps(result, ensure_ascii=False) + except (TypeError, ValueError): return str(result).replace('\n', ' ') -def _Fire(component, args, context, name=None): +def _Fire(component, args, parsed_flag_args, context, name=None): """Execute a Fire command on a target component using the args supplied. Arguments that come after a final isolated '--' are treated as Flags, eg for @@ -248,9 +374,15 @@ def _Fire(component, args, context, name=None): 2. Start with component as the current component. 2a. If the current component is a class, instantiate it using args from args. - 2b. If the current component is a routine, call it using args from args. - 2c. Otherwise access a member from component using an arg from args. - 2d. Repeat 2a-2c until no args remain. + 2b. If the component is a routine, call it using args from args. + 2c. If the component is a sequence, index into it using an arg from + args. + 2d. If possible, access a member from the component using an arg from args. + 2e. If the component is a callable object, call it using args from args. + 2f. Repeat 2a-2e until no args remain. + Note: Only the first applicable rule from 2a-2e is applied in each iteration. + After each iteration of step 2a-2e, the current component is updated to be the + result of the applied rule. 3a. Embed into ipython REPL if interactive mode is selected. 3b. Generate a completion script if that flag is provided. @@ -264,6 +396,8 @@ def _Fire(component, args, context, name=None): component: The target component for Fire. args: A list of args to consume in Firing on the component, usually from the command line. + parsed_flag_args: The values of the flag args (e.g. --verbose, --separator) + that are part of every Fire CLI. context: A dict with the local and global variables available at the call to Fire. name: Optional. The name of the command. Used in interactive mode and in @@ -275,10 +409,6 @@ def _Fire(component, args, context, name=None): ValueError: If there are arguments that cannot be consumed. ValueError: If --completion is specified but no name available. """ - args, flag_args = parser.SeparateFlagArgs(args) - - argparser = parser.CreateParser() - parsed_flag_args, unused_args = argparser.parse_known_args(flag_args) verbose = parsed_flag_args.verbose interactive = parsed_flag_args.interactive separator = parsed_flag_args.separator @@ -302,11 +432,15 @@ def _Fire(component, args, context, name=None): initial_args = remaining_args if not remaining_args and (show_help or interactive or show_trace - or show_completion): + or show_completion is not None): # Don't initialize the final class or call the final function unless # there's a separator after it, and instead process the current component. break + if _IsHelpShortcut(component_trace, remaining_args): + remaining_args = [] + break + saved_args = [] used_separator = False if separator in remaining_args: @@ -317,101 +451,124 @@ def _Fire(component, args, context, name=None): used_separator = True assert separator not in remaining_args - if inspect.isclass(component) or inspect.isroutine(component): + handled = False + candidate_errors = [] + + is_callable = inspect.isclass(component) or inspect.isroutine(component) + is_callable_object = callable(component) and not is_callable + is_sequence = isinstance(component, (list, tuple)) + is_map = isinstance(component, dict) or inspectutils.IsNamedTuple(component) + + if not handled and is_callable: # The component is a class or a routine; we'll try to initialize it or # call it. - isclass = inspect.isclass(component) + is_class = inspect.isclass(component) try: - target = component.__name__ - filename, lineno = _GetFileAndLine(component) - - component, consumed_args, remaining_args, capacity = _CallCallable( - component, remaining_args) - - # Update the trace. - if isclass: - component_trace.AddInstantiatedClass( - component, target, consumed_args, filename, lineno, capacity) - else: - component_trace.AddCalledRoutine( - component, target, consumed_args, filename, lineno, capacity) - + component, remaining_args = _CallAndUpdateTrace( + component, + remaining_args, + component_trace, + treatment='class' if is_class else 'routine', + target=component.__name__) + handled = True except FireError as error: - component_trace.AddError(error, initial_args) - return component_trace + candidate_errors.append((error, initial_args)) - if last_component == initial_component: + if handled and last_component is initial_component: # If the initial component is a class, keep an instance for use with -i. instance = component - elif isinstance(component, (list, tuple)) and remaining_args: + if not handled and is_sequence and remaining_args: # The component is a tuple or list; we'll try to access a member. arg = remaining_args[0] try: index = int(arg) component = component[index] + handled = True except (ValueError, IndexError): error = FireError( 'Unable to index into component with argument:', arg) - component_trace.AddError(error, initial_args) - return component_trace + candidate_errors.append((error, initial_args)) - remaining_args = remaining_args[1:] - filename = None - lineno = None - component_trace.AddAccessedProperty( - component, index, [arg], filename, lineno) + if handled: + remaining_args = remaining_args[1:] + filename = None + lineno = None + component_trace.AddAccessedProperty( + component, index, [arg], filename, lineno) - elif isinstance(component, dict) and remaining_args: - # The component is a dict; we'll try to access a member. + if not handled and is_map and remaining_args: + # The component is a dict or other key-value map; try to access a member. target = remaining_args[0] - if target in component: - component = component[target] - elif target.replace('-', '_') in component: - component = component[target.replace('-', '_')] + + # Treat namedtuples as dicts when handling them as a map. + if inspectutils.IsNamedTuple(component): + component_dict = component._asdict() # pytype: disable=attribute-error else: - # The target isn't present in the dict as a string, but maybe it is as - # another type. - # TODO: Consider alternatives for accessing non-string keys. - found_target = False - for key, value in component.items(): + component_dict = component + + if target in component_dict: + component = component_dict[target] + handled = True + elif target.replace('-', '_') in component_dict: + component = component_dict[target.replace('-', '_')] + handled = True + else: + # The target isn't present in the dict as a string key, but maybe it is + # a key as another type. + # TODO(dbieber): Consider alternatives for accessing non-string keys. + for key, value in ( + component_dict.items()): # pytype: disable=attribute-error if target == str(key): component = value - found_target = True + handled = True break - if not found_target: - error = FireError( - 'Cannot find target in dict', target, component) - component_trace.AddError(error, initial_args) - return component_trace - - remaining_args = remaining_args[1:] - filename = None - lineno = None - component_trace.AddAccessedProperty( - component, target, [target], filename, lineno) - - elif remaining_args: - # We'll try to access a member of the component. + + if handled: + remaining_args = remaining_args[1:] + filename = None + lineno = None + component_trace.AddAccessedProperty( + component, target, [target], filename, lineno) + else: + error = FireError('Cannot find key:', target) + candidate_errors.append((error, initial_args)) + + if not handled and remaining_args: + # Object handler. We'll try to access a member of the component. try: target = remaining_args[0] component, consumed_args, remaining_args = _GetMember( component, remaining_args) + handled = True - try: - filename, lineno = _GetFileAndLine(component) - except TypeError: - filename = None - lineno = None + filename, lineno = inspectutils.GetFileAndLine(component) component_trace.AddAccessedProperty( component, target, consumed_args, filename, lineno) except FireError as error: - component_trace.AddError(error, initial_args) - return component_trace + # Couldn't access member. + candidate_errors.append((error, initial_args)) + + if not handled and is_callable_object: + # The component is a callable object; we'll try to call it. + try: + component, remaining_args = _CallAndUpdateTrace( + component, + remaining_args, + component_trace, + treatment='callable') + handled = True + except FireError as error: + candidate_errors.append((error, initial_args)) + + if not handled and candidate_errors: + error, initial_args = candidate_errors[0] + component_trace.AddError(error, initial_args) + return component_trace if used_separator: # Add back in the arguments from after the separator. @@ -421,26 +578,26 @@ def _Fire(component, args, context, name=None): or inspect.isroutine(last_component)): remaining_args = saved_args component_trace.AddSeparator() - elif component != last_component: + elif component is not last_component: remaining_args = [separator] + saved_args else: # It was an unnecessary separator. remaining_args = saved_args - if component == last_component and remaining_args == initial_args: + if component is last_component and remaining_args == initial_args: # We're making no progress. break if remaining_args: component_trace.AddError( - FireError('Could not consume arguments', remaining_args), + FireError('Could not consume arguments:', remaining_args), initial_args) return component_trace - if show_completion: + if show_completion is not None: if name is None: raise ValueError('Cannot make completion script without command name') - script = CompletionScript(name, initial_component) + script = CompletionScript(name, initial_component, shell=show_completion) component_trace.AddCompletionScript(script) if interactive: @@ -462,32 +619,6 @@ def _Fire(component, args, context, name=None): return component_trace -def _GetFileAndLine(component): - """Returns the filename and line number of component. - - Args: - component: A component to find the source information for, usually a class - or routine. - Returns: - filename: The name of the file where component is defined. - lineno: The line number where component is defined. - Raises: - TypeError: If component is not a module, class, method, function, traceback, - frame, or code object then the inspect module will raise this error. - """ - if inspect.isbuiltin(component): - return None, None - - filename = inspect.getsourcefile(component) - try: - unused_code, lineindex = inspect.findsource(component) - lineno = lineindex + 1 - except IOError: - lineno = None - - return filename, lineno - - def _GetMember(component, args): """Returns a subcomponent of component by consuming an arg from args. @@ -504,7 +635,7 @@ def _GetMember(component, args): Raises: FireError: If we cannot consume an argument to get a member. """ - members = dict(inspect.getmembers(component)) + members = dir(component) arg = args[0] arg_names = [ arg, @@ -513,66 +644,100 @@ def _GetMember(component, args): for arg_name in arg_names: if arg_name in members: - return members[arg_name], [arg], args[1:] + return getattr(component, arg_name), [arg], args[1:] - raise FireError('Could not consume arg', arg) + raise FireError('Could not consume arg:', arg) -def _CallCallable(fn, args): - """Calls the function fn by consuming args from args. +def _CallAndUpdateTrace(component, args, component_trace, treatment='class', + target=None): + """Call the component by consuming args from args, and update the FireTrace. + + The component could be a class, a routine, or a callable object. This function + calls the component and adds the appropriate action to component_trace. Args: - fn: The function to call or class to instantiate. - args: Args from which to consume for calling the function. + component: The component to call + args: Args for calling the component + component_trace: FireTrace object that contains action trace + treatment: Type of treatment used. Indicating whether we treat the component + as a class, a routine, or a callable. + target: Target in FireTrace element, default is None. If the value is None, + the component itself will be used as target. Returns: - component: The object that is the result of the function call. - consumed_args: The args that were consumed for the function call. + component: The object that is the result of the callable call. remaining_args: The remaining args that haven't been consumed yet. - capacity: Whether the call could have taken additional args. """ - parse = _MakeParseFn(fn) + if not target: + target = component + filename, lineno = inspectutils.GetFileAndLine(component) + metadata = decorators.GetMetadata(component) + fn = component.__call__ if treatment == 'callable' else component + parse = _MakeParseFn(fn, metadata) (varargs, kwargs), consumed_args, remaining_args, capacity = parse(args) - result = fn(*varargs, **kwargs) - return result, consumed_args, remaining_args, capacity + # Call the function. + if inspectutils.IsCoroutineFunction(fn): + loop = asyncio.get_event_loop() + component = loop.run_until_complete(fn(*varargs, **kwargs)) + else: + component = fn(*varargs, **kwargs) + + if treatment == 'class': + action = trace.INSTANTIATED_CLASS + elif treatment == 'routine': + action = trace.CALLED_ROUTINE + else: + action = trace.CALLED_CALLABLE + component_trace.AddCalledComponent( + component, target, consumed_args, filename, lineno, capacity, + action=action) + + return component, remaining_args -def _MakeParseFn(fn): +def _MakeParseFn(fn, metadata): """Creates a parse function for fn. Args: fn: The function or class to create the parse function for. + metadata: Additional metadata about the component the parse function is for. Returns: A parse function for fn. The parse function accepts a list of arguments and returns (varargs, kwargs), remaining_args. The original function fn can then be called with fn(*varargs, **kwargs). The remaining_args are the leftover args from the arguments to the parse function. """ - fn_args, fn_varargs, fn_keywords, fn_defaults = inspectutils.GetArgSpec(fn) - metadata = decorators.GetMetadata(fn) + fn_spec = inspectutils.GetFullArgSpec(fn) - # Note: num_required_args is the number of arguments without default values. - # All of these arguments are required. - num_required_args = len(fn_args) - len(fn_defaults) + # Note: num_required_args is the number of positional arguments without + # default values. All of these arguments are required. + num_required_args = len(fn_spec.args) - len(fn_spec.defaults) + required_kwonly = set(fn_spec.kwonlyargs) - set(fn_spec.kwonlydefaults) def _ParseFn(args): """Parses the list of `args` into (varargs, kwargs), remaining_args.""" - kwargs, remaining_args = _ParseKeywordArgs(args, fn_args, fn_keywords) + kwargs, remaining_kwargs, remaining_args = _ParseKeywordArgs(args, fn_spec) # Note: _ParseArgs modifies kwargs. parsed_args, kwargs, remaining_args, capacity = _ParseArgs( - fn_args, fn_defaults, num_required_args, kwargs, remaining_args, - metadata) + fn_spec.args, fn_spec.defaults, num_required_args, kwargs, + remaining_args, metadata) - if fn_varargs or fn_keywords: + if fn_spec.varargs or fn_spec.varkw: # If we're allowed *varargs or **kwargs, there's always capacity. capacity = True - if fn_keywords is None and kwargs: - raise FireError('Unexpected kwargs present', kwargs) + extra_kw = set(kwargs) - set(fn_spec.kwonlyargs) + if fn_spec.varkw is None and extra_kw: + raise FireError('Unexpected kwargs present:', extra_kw) + + missing_kwonly = set(required_kwonly) - set(kwargs) + if missing_kwonly: + raise FireError('Missing required flags:', missing_kwonly) # If we accept *varargs, then use all remaining arguments for *varargs. - if fn_varargs is not None: + if fn_spec.varargs is not None: varargs, remaining_args = remaining_args, [] else: varargs = [] @@ -581,8 +746,9 @@ def _ParseFn(args): varargs[index] = _ParseValue(value, None, None, metadata) varargs = parsed_args + varargs + remaining_args += remaining_kwargs - consumed_args = args[:len(args)-len(remaining_args)] + consumed_args = args[:len(args) - len(remaining_args)] return (varargs, kwargs), consumed_args, remaining_args, capacity return _ParseFn @@ -612,7 +778,7 @@ def _ParseArgs(fn_args, fn_defaults, num_required_args, kwargs, remaining_args: A list of the supplied args that have not been used yet. capacity: Whether the call could have taken args in place of defaults. Raises: - FireError: if additional positional arguments are expected, but none are + FireError: If additional positional arguments are expected, but none are available. """ accepts_positional_args = metadata.get(decorators.ACCEPTS_POSITIONAL_ARGS) @@ -647,10 +813,10 @@ def _ParseArgs(fn_args, fn_defaults, num_required_args, kwargs, return parsed_args, kwargs, remaining_args, capacity -def _ParseKeywordArgs(args, fn_args, fn_keywords): +def _ParseKeywordArgs(args, fn_spec): """Parses the supplied arguments for keyword arguments. - Given a list of arguments, finds occurences of --name value, and uses 'name' + Given a list of arguments, finds occurrences of --name value, and uses 'name' as the keyword and 'value' as the value. Constructs and returns a dictionary of these keyword arguments, and returns a list of the remaining arguments. @@ -661,71 +827,133 @@ def _ParseKeywordArgs(args, fn_args, fn_keywords): _ParseArgs, which converts them to the appropriate type. Args: - args: A list of arguments - fn_args: A list of argument names that the target function accepts, - including positional and named arguments, but not the varargs or kwargs - names. - fn_keywords: The argument name for **kwargs, or None if **kwargs not used + args: A list of arguments. + fn_spec: The inspectutils.FullArgSpec describing the given callable. Returns: kwargs: A dictionary mapping keywords to values. + remaining_kwargs: A list of the unused kwargs from the original args. remaining_args: A list of the unused arguments from the original args. + Raises: + FireError: If a single-character flag is passed that could refer to multiple + possible args. """ kwargs = {} - if args: - remaining_args = [] - skip_argument = False - - for index, argument in enumerate(args): - if skip_argument: - skip_argument = False - continue - - arg_consumed = False - if argument.startswith('--'): - # This is a named argument; get its value from this arg or the next. + remaining_kwargs = [] + remaining_args = [] + fn_keywords = fn_spec.varkw + fn_args = fn_spec.args + fn_spec.kwonlyargs + + if not args: + return kwargs, remaining_kwargs, remaining_args + + skip_argument = False + + for index, argument in enumerate(args): + if skip_argument: + skip_argument = False + continue + + if _IsFlag(argument): + # This is a named argument. We get its value from this arg or the next. + + # Terminology: + # argument: A full token from the command line, e.g. '--alpha=10' + # stripped_argument: An argument without leading hyphens. + # key: The contents of the stripped argument up to the first equal sign. + # "shortcut flag": refers to an argument where the key is just the first + # letter of a longer keyword. + # keyword: The Python function argument being set by this argument. + # value: The unparsed value for that Python function argument. + contains_equals = '=' in argument + stripped_argument = argument.lstrip('-') + if contains_equals: + key, value = stripped_argument.split('=', 1) + else: + key = stripped_argument + value = None # value will be set later on. + + key = key.replace('-', '_') + is_bool_syntax = (not contains_equals and + (index + 1 == len(args) or _IsFlag(args[index + 1]))) + + # Determine the keyword. + keyword = '' # Indicates no valid keyword has been found yet. + if (key in fn_args + or (is_bool_syntax and key.startswith('no') and key[2:] in fn_args) + or fn_keywords): + keyword = key + elif len(key) == 1: + # This may be a shortcut flag. + matching_fn_args = [arg for arg in fn_args if arg[0] == key] + if len(matching_fn_args) == 1: + keyword = matching_fn_args[0] + elif len(matching_fn_args) > 1: + raise FireError( + f"The argument '{argument}' is ambiguous as it could " + f"refer to any of the following arguments: {matching_fn_args}" + ) + + # Determine the value. + if not keyword: got_argument = False - - keyword = argument[2:] - contains_equals = '=' in keyword - is_bool_syntax = ( - not contains_equals and - (index + 1 == len(args) or args[index + 1].startswith('--'))) - if contains_equals: - keyword, value = keyword.split('=', 1) - got_argument = True - elif is_bool_syntax: - # Since there's no next arg or the next arg is a Flag, we consider - # this flag to be a boolean. - got_argument = True - if keyword in fn_args: - value = 'True' - elif keyword.startswith('no'): - keyword = keyword[2:] - value = 'False' - else: - value = 'True' + elif contains_equals: + # Already got the value above. + got_argument = True + elif is_bool_syntax: + # There's no next arg or the next arg is a Flag, so we consider this + # flag to be a boolean. + got_argument = True + if keyword in fn_args: + value = 'True' + elif keyword.startswith('no'): + keyword = keyword[2:] + value = 'False' else: - if index + 1 < len(args): - value = args[index + 1] - got_argument = True - - keyword = keyword.replace('-', '_') - - # In order for us to consume the argument as a keyword arg, we either: - # Need to be explicitly expecting the keyword, or we need to be - # accepting **kwargs. - if got_argument and (keyword in fn_args or fn_keywords): - kwargs[keyword] = value - skip_argument = not contains_equals and not is_bool_syntax - arg_consumed = True - - if not arg_consumed: - # The argument was not consumed, so it is still a remaining argument. - remaining_args.append(argument) - else: - remaining_args = args + value = 'True' + else: + # The assert should pass. Otherwise either contains_equals or + # is_bool_syntax would have been True. + assert index + 1 < len(args) + value = args[index + 1] + got_argument = True + + # In order for us to consume the argument as a keyword arg, we either: + # Need to be explicitly expecting the keyword, or we need to be + # accepting **kwargs. + skip_argument = not contains_equals and not is_bool_syntax + if got_argument: + kwargs[keyword] = value + else: + remaining_kwargs.append(argument) + if skip_argument: + remaining_kwargs.append(args[index + 1]) + else: # not _IsFlag(argument) + remaining_args.append(argument) + + return kwargs, remaining_kwargs, remaining_args + + +def _IsFlag(argument): + """Determines if the argument is a flag argument. + + If it starts with a hyphen and isn't a negative number, it's a flag. + + Args: + argument: A command line argument that may or may not be a flag. + Returns: + A boolean indicating whether the argument is a flag. + """ + return _IsSingleCharFlag(argument) or _IsMultiCharFlag(argument) + + +def _IsSingleCharFlag(argument): + """Determines if the argument is a single char flag (e.g. '-a').""" + return re.match('^-[a-zA-Z]$', argument) or re.match('^-[a-zA-Z]=', argument) + - return kwargs, remaining_args +def _IsMultiCharFlag(argument): + """Determines if the argument is a multi char flag (e.g. '--alpha').""" + return argument.startswith('--') or re.match('^-[a-zA-Z]', argument) def _ParseValue(value, index, arg, metadata): diff --git a/fire/core_test.py b/fire/core_test.py index 210abb97..90b7f466 100644 --- a/fire/core_test.py +++ b/fire/core_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,41 +12,44 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +"""Tests for the core module.""" + +from unittest import mock from fire import core from fire import test_components as tc +from fire import testutils from fire import trace -import mock - -import unittest -class CoreTest(unittest.TestCase): +class CoreTest(testutils.BaseTestCase): def testOneLineResult(self): - self.assertEqual(core._OneLineResult(1), '1') - self.assertEqual(core._OneLineResult('hello'), 'hello') - self.assertEqual(core._OneLineResult({}), '{}') - self.assertEqual(core._OneLineResult({'x': 'y'}), '{"x": "y"}') + self.assertEqual(core._OneLineResult(1), '1') # pylint: disable=protected-access + self.assertEqual(core._OneLineResult('hello'), 'hello') # pylint: disable=protected-access + self.assertEqual(core._OneLineResult({}), '{}') # pylint: disable=protected-access + self.assertEqual(core._OneLineResult({'x': 'y'}), '{"x": "y"}') # pylint: disable=protected-access + + def testOneLineResultCircularRef(self): + circular_reference = tc.CircularReference() + self.assertEqual(core._OneLineResult(circular_reference.create()), # pylint: disable=protected-access + "{'y': {...}}") @mock.patch('fire.interact.Embed') def testInteractiveMode(self, mock_embed): - core.Fire(tc.TypedProperties, 'alpha') + core.Fire(tc.TypedProperties, command=['alpha']) self.assertFalse(mock_embed.called) - core.Fire(tc.TypedProperties, 'alpha -- -i') + core.Fire(tc.TypedProperties, command=['alpha', '--', '-i']) self.assertTrue(mock_embed.called) @mock.patch('fire.interact.Embed') def testInteractiveModeFullArgument(self, mock_embed): - core.Fire(tc.TypedProperties, 'alpha -- --interactive') + core.Fire(tc.TypedProperties, command=['alpha', '--', '--interactive']) self.assertTrue(mock_embed.called) @mock.patch('fire.interact.Embed') def testInteractiveModeVariables(self, mock_embed): - core.Fire(tc.WithDefaults, 'double 2 -- -i') + core.Fire(tc.WithDefaults, command=['double', '2', '--', '-i']) self.assertTrue(mock_embed.called) (variables, verbose), unused_kwargs = mock_embed.call_args self.assertFalse(verbose) @@ -56,7 +59,8 @@ def testInteractiveModeVariables(self, mock_embed): @mock.patch('fire.interact.Embed') def testInteractiveModeVariablesWithName(self, mock_embed): - core.Fire(tc.WithDefaults, 'double 2 -- -i -v', name='D') + core.Fire(tc.WithDefaults, + command=['double', '2', '--', '-i', '-v'], name='D') self.assertTrue(mock_embed.called) (variables, verbose), unused_kwargs = mock_embed.call_args self.assertTrue(verbose) @@ -65,14 +69,53 @@ def testInteractiveModeVariablesWithName(self, mock_embed): self.assertEqual(variables['D'], tc.WithDefaults) self.assertIsInstance(variables['trace'], trace.FireTrace) - def testImproperUseOfHelp(self): - # This should produce a warning and return None. - self.assertIsNone(core.Fire(tc.TypedProperties, 'alpha --help')) + # TODO(dbieber): Use parameterized tests to break up repetitive tests. + def testHelpWithClass(self): + with self.assertRaisesFireExit(0, 'SYNOPSIS.*ARG1'): + core.Fire(tc.InstanceVars, command=['--', '--help']) + with self.assertRaisesFireExit(0, 'INFO:.*SYNOPSIS.*ARG1'): + core.Fire(tc.InstanceVars, command=['--help']) + with self.assertRaisesFireExit(0, 'INFO:.*SYNOPSIS.*ARG1'): + core.Fire(tc.InstanceVars, command=['-h']) + + def testHelpWithMember(self): + with self.assertRaisesFireExit(0, 'SYNOPSIS.*capitalize'): + core.Fire(tc.TypedProperties, command=['gamma', '--', '--help']) + with self.assertRaisesFireExit(0, 'INFO:.*SYNOPSIS.*capitalize'): + core.Fire(tc.TypedProperties, command=['gamma', '--help']) + with self.assertRaisesFireExit(0, 'INFO:.*SYNOPSIS.*capitalize'): + core.Fire(tc.TypedProperties, command=['gamma', '-h']) + with self.assertRaisesFireExit(0, 'INFO:.*SYNOPSIS.*delta'): + core.Fire(tc.TypedProperties, command=['delta', '--help']) + with self.assertRaisesFireExit(0, 'INFO:.*SYNOPSIS.*echo'): + core.Fire(tc.TypedProperties, command=['echo', '--help']) + + def testHelpOnErrorInConstructor(self): + with self.assertRaisesFireExit(0, 'SYNOPSIS.*VALUE'): + core.Fire(tc.ErrorInConstructor, command=['--', '--help']) + with self.assertRaisesFireExit(0, 'INFO:.*SYNOPSIS.*VALUE'): + core.Fire(tc.ErrorInConstructor, command=['--help']) + + def testHelpWithNamespaceCollision(self): + # Tests cases when calling the help shortcut should not show help. + with self.assertOutputMatches(stdout='DESCRIPTION.*', stderr=None): + core.Fire(tc.WithHelpArg, command=['--help', 'False']) + with self.assertOutputMatches(stdout='help in a dict', stderr=None): + core.Fire(tc.WithHelpArg, command=['dictionary', '__help']) + with self.assertOutputMatches(stdout='{}', stderr=None): + core.Fire(tc.WithHelpArg, command=['dictionary', '--help']) + with self.assertOutputMatches(stdout='False', stderr=None): + core.Fire(tc.function_with_help, command=['False']) + + def testInvalidParameterRaisesFireExit(self): + with self.assertRaisesFireExit(2, 'runmisspelled'): + core.Fire(tc.Kwargs, command=['props', '--a=1', '--b=2', 'runmisspelled']) def testErrorRaising(self): # Errors in user code should not be caught; they should surface as normal. + # This will lead to exit status code 1 for the client program. with self.assertRaises(ValueError): - core.Fire(tc.ErrorRaiser, 'fail') + core.Fire(tc.ErrorRaiser, command=['fail']) def testFireError(self): error = core.FireError('Example error') @@ -82,5 +125,104 @@ def testFireErrorMultipleValues(self): error = core.FireError('Example error', 'value') self.assertIsNotNone(error) + def testPrintEmptyDict(self): + with self.assertOutputMatches(stdout='{}', stderr=None): + core.Fire(tc.EmptyDictOutput, command=['totally_empty']) + with self.assertOutputMatches(stdout='{}', stderr=None): + core.Fire(tc.EmptyDictOutput, command=['nothing_printable']) + + def testPrintOrderedDict(self): + with self.assertOutputMatches(stdout=r'A:\s+A\s+2:\s+2\s+', stderr=None): + core.Fire(tc.OrderedDictionary, command=['non_empty']) + with self.assertOutputMatches(stdout='{}'): + core.Fire(tc.OrderedDictionary, command=['empty']) + + def testPrintNamedTupleField(self): + with self.assertOutputMatches(stdout='11', stderr=None): + core.Fire(tc.NamedTuple, command=['point', 'x']) + + def testPrintNamedTupleFieldNameEqualsValue(self): + with self.assertOutputMatches(stdout='x', stderr=None): + core.Fire(tc.NamedTuple, command=['matching_names', 'x']) + + def testPrintNamedTupleIndex(self): + with self.assertOutputMatches(stdout='22', stderr=None): + core.Fire(tc.NamedTuple, command=['point', '1']) + + def testPrintSet(self): + with self.assertOutputMatches(stdout='.*three.*', stderr=None): + core.Fire(tc.simple_set(), command=[]) + + def testPrintFrozenSet(self): + with self.assertOutputMatches(stdout='.*three.*', stderr=None): + core.Fire(tc.simple_frozenset(), command=[]) + + def testPrintNamedTupleNegativeIndex(self): + with self.assertOutputMatches(stdout='11', stderr=None): + core.Fire(tc.NamedTuple, command=['point', '-2']) + + def testCallable(self): + with self.assertOutputMatches(stdout=r'foo:\s+foo\s+', stderr=None): + core.Fire(tc.CallableWithKeywordArgument(), command=['--foo=foo']) + with self.assertOutputMatches(stdout=r'foo\s+', stderr=None): + core.Fire(tc.CallableWithKeywordArgument(), command=['print_msg', 'foo']) + with self.assertOutputMatches(stdout=r'', stderr=None): + core.Fire(tc.CallableWithKeywordArgument(), command=[]) + + def testCallableWithPositionalArgs(self): + with self.assertRaisesFireExit(2, ''): + # This does not give 7 since positional args are disallowed for callable + # objects. + core.Fire(tc.CallableWithPositionalArgs(), command=['3', '4']) + + def testStaticMethod(self): + self.assertEqual( + core.Fire(tc.HasStaticAndClassMethods, + command=['static_fn', 'alpha']), + 'alpha', + ) + + def testClassMethod(self): + self.assertEqual( + core.Fire(tc.HasStaticAndClassMethods, + command=['class_fn', '6']), + 7, + ) + + def testCustomSerialize(self): + def serialize(x): + if isinstance(x, list): + return ', '.join(str(xi) for xi in x) + if isinstance(x, dict): + return ', '.join('{}={!r}'.format(k, v) for k, v in sorted(x.items())) + if x == 'special': + return ['SURPRISE!!', "I'm a list!"] + return x + + ident = lambda x: x + + with self.assertOutputMatches(stdout='a, b', stderr=None): + _ = core.Fire(ident, command=['[a,b]'], serialize=serialize) + with self.assertOutputMatches(stdout='a=5, b=6', stderr=None): + _ = core.Fire(ident, command=['{a:5,b:6}'], serialize=serialize) + with self.assertOutputMatches(stdout='asdf', stderr=None): + _ = core.Fire(ident, command=['asdf'], serialize=serialize) + with self.assertOutputMatches( + stdout="SURPRISE!!\nI'm a list!\n", stderr=None): + _ = core.Fire(ident, command=['special'], serialize=serialize) + with self.assertRaises(core.FireError): + core.Fire(ident, command=['asdf'], serialize=55) + + def testLruCacheDecoratorBoundArg(self): + self.assertEqual( + core.Fire(tc.py3.LruCacheDecoratedMethod, # pytype: disable=module-attr + command=['lru_cache_in_class', 'foo']), 'foo') + + def testLruCacheDecorator(self): + self.assertEqual( + core.Fire(tc.py3.lru_cache_decorated, # pytype: disable=module-attr + command=['foo']), 'foo') + + if __name__ == '__main__': - unittest.main() + testutils.main() diff --git a/fire/custom_descriptions.py b/fire/custom_descriptions.py new file mode 100644 index 00000000..768f0e23 --- /dev/null +++ b/fire/custom_descriptions.py @@ -0,0 +1,144 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom descriptions and summaries for the builtin types. + +The docstrings for objects of primitive types reflect the type of the object, +rather than the object itself. For example, the docstring for any dict is this: + +> print({'key': 'value'}.__doc__) +dict() -> new empty dictionary +dict(mapping) -> new dictionary initialized from a mapping object's + (key, value) pairs +dict(iterable) -> new dictionary initialized as if via: + d = {} + for k, v in iterable: + d[k] = v +dict(**kwargs) -> new dictionary initialized with the name=value pairs + in the keyword argument list. For example: dict(one=1, two=2) + +As you can see, this docstring is more pertinent to the function `dict` and +would be suitable as the result of `dict.__doc__`, but is wholely unsuitable +as a description for the dict `{'key': 'value'}`. + +This modules aims to resolve that problem, providing custom summaries and +descriptions for primitive typed values. +""" + +from fire import formatting + +TWO_DOUBLE_QUOTES = '""' +STRING_DESC_PREFIX = 'The string ' + + +def NeedsCustomDescription(component): + """Whether the component should use a custom description and summary. + + Components of primitive type, such as ints, floats, dicts, lists, and others + have messy builtin docstrings. These are inappropriate for display as + descriptions and summaries in a CLI. This function determines whether the + provided component has one of these docstrings. + + Note that an object such as `int` has the same docstring as an int like `3`. + The docstring is OK for `int`, but is inappropriate as a docstring for `3`. + + Args: + component: The component of interest. + Returns: + Whether the component should use a custom description and summary. + """ + type_ = type(component) + if ( + type_ in (str, int, bytes) + or type_ in (float, complex, bool) + or type_ in (dict, tuple, list, set, frozenset) + ): + return True + return False + + +def GetStringTypeSummary(obj, available_space, line_length): + """Returns a custom summary for string type objects. + + This function constructs a summary for string type objects by double quoting + the string value. The double quoted string value will be potentially truncated + with ellipsis depending on whether it has enough space available to show the + full string value. + + Args: + obj: The object to generate summary for. + available_space: Number of character spaces available. + line_length: The full width of the terminal, default is 80. + + Returns: + A summary for the input object. + """ + if len(obj) + len(TWO_DOUBLE_QUOTES) <= available_space: + content = obj + else: + additional_len_needed = len(TWO_DOUBLE_QUOTES) + len(formatting.ELLIPSIS) + if available_space < additional_len_needed: + available_space = line_length + content = formatting.EllipsisTruncate( + obj, available_space - len(TWO_DOUBLE_QUOTES), line_length) + return formatting.DoubleQuote(content) + + +def GetStringTypeDescription(obj, available_space, line_length): + """Returns the predefined description for string obj. + + This function constructs a description for string type objects in the format + of 'The string ""'. could be potentially + truncated depending on whether it has enough space available to show the full + string value. + + Args: + obj: The object to generate description for. + available_space: Number of character spaces available. + line_length: The full width of the terminal, default if 80. + + Returns: + A description for input object. + """ + additional_len_needed = len(STRING_DESC_PREFIX) + len( + TWO_DOUBLE_QUOTES) + len(formatting.ELLIPSIS) + if available_space < additional_len_needed: + available_space = line_length + + return STRING_DESC_PREFIX + formatting.DoubleQuote( + formatting.EllipsisTruncate( + obj, available_space - len(STRING_DESC_PREFIX) - + len(TWO_DOUBLE_QUOTES), line_length)) + + +CUSTOM_DESC_SUM_FN_DICT = { + 'str': (GetStringTypeSummary, GetStringTypeDescription), + 'unicode': (GetStringTypeSummary, GetStringTypeDescription), +} + + +def GetSummary(obj, available_space, line_length): + obj_type_name = type(obj).__name__ + if obj_type_name in CUSTOM_DESC_SUM_FN_DICT: + return CUSTOM_DESC_SUM_FN_DICT.get(obj_type_name)[0](obj, available_space, + line_length) + return None + + +def GetDescription(obj, available_space, line_length): + obj_type_name = type(obj).__name__ + if obj_type_name in CUSTOM_DESC_SUM_FN_DICT: + return CUSTOM_DESC_SUM_FN_DICT.get(obj_type_name)[1](obj, available_space, + line_length) + return None diff --git a/fire/custom_descriptions_test.py b/fire/custom_descriptions_test.py new file mode 100644 index 00000000..6cff2d5d --- /dev/null +++ b/fire/custom_descriptions_test.py @@ -0,0 +1,69 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for custom description module.""" + +from fire import custom_descriptions +from fire import testutils + +LINE_LENGTH = 80 + + +class CustomDescriptionTest(testutils.BaseTestCase): + + def test_string_type_summary_enough_space(self): + component = 'Test' + summary = custom_descriptions.GetSummary( + obj=component, available_space=80, line_length=LINE_LENGTH) + self.assertEqual(summary, '"Test"') + + def test_string_type_summary_not_enough_space_truncated(self): + component = 'Test' + summary = custom_descriptions.GetSummary( + obj=component, available_space=5, line_length=LINE_LENGTH) + self.assertEqual(summary, '"..."') + + def test_string_type_summary_not_enough_space_new_line(self): + component = 'Test' + summary = custom_descriptions.GetSummary( + obj=component, available_space=4, line_length=LINE_LENGTH) + self.assertEqual(summary, '"Test"') + + def test_string_type_summary_not_enough_space_long_truncated(self): + component = 'Lorem ipsum dolor sit amet' + summary = custom_descriptions.GetSummary( + obj=component, available_space=10, line_length=LINE_LENGTH) + self.assertEqual(summary, '"Lorem..."') + + def test_string_type_description_enough_space(self): + component = 'Test' + description = custom_descriptions.GetDescription( + obj=component, available_space=80, line_length=LINE_LENGTH) + self.assertEqual(description, 'The string "Test"') + + def test_string_type_description_not_enough_space_truncated(self): + component = 'Lorem ipsum dolor sit amet' + description = custom_descriptions.GetDescription( + obj=component, available_space=20, line_length=LINE_LENGTH) + self.assertEqual(description, 'The string "Lore..."') + + def test_string_type_description_not_enough_space_new_line(self): + component = 'Lorem ipsum dolor sit amet' + description = custom_descriptions.GetDescription( + obj=component, available_space=10, line_length=LINE_LENGTH) + self.assertEqual(description, 'The string "Lorem ipsum dolor sit amet"') + + +if __name__ == '__main__': + testutils.main() diff --git a/fire/decorators.py b/fire/decorators.py index 168312d4..914b1de6 100644 --- a/fire/decorators.py +++ b/fire/decorators.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,10 +18,7 @@ command line arguments to client code. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - +from typing import Any, Dict import inspect FIRE_METADATA = 'FIRE_METADATA' @@ -60,7 +57,7 @@ def SetParseFns(*positional, **named): Python arguments with which to call the function. A parse function should accept a single string argument and return a value to - be used in it's place when calling the decorated function. + be used in its place when calling the decorated function. Args: *positional: The functions to be used for parsing positional arguments. @@ -71,7 +68,7 @@ def SetParseFns(*positional, **named): def _Decorator(fn): parse_fns = GetParseFns(fn) parse_fns['positional'] = positional - parse_fns['named'].update(named) + parse_fns['named'].update(named) # pytype: disable=attribute-error _SetMetadata(fn, FIRE_PARSE_FNS, parse_fns) return fn @@ -84,14 +81,30 @@ def _SetMetadata(fn, attribute, value): setattr(fn, FIRE_METADATA, metadata) -def GetMetadata(fn): +def GetMetadata(fn) -> Dict[str, Any]: + """Gets metadata attached to the function `fn` as an attribute. + + Args: + fn: The function from which to retrieve the function metadata. + Returns: + A dictionary mapping property strings to their value. + """ + # Class __init__ functions and object __call__ functions require flag style + # arguments. Other methods and functions may accept positional args. default = { - ACCEPTS_POSITIONAL_ARGS: not inspect.isclass(fn), + ACCEPTS_POSITIONAL_ARGS: inspect.isroutine(fn), } - return getattr(fn, FIRE_METADATA, default) + try: + metadata = getattr(fn, FIRE_METADATA, default) + if ACCEPTS_POSITIONAL_ARGS in metadata: + return metadata + else: + return default + except: # pylint: disable=bare-except + return default -def GetParseFns(fn): +def GetParseFns(fn) -> Dict[str, Any]: metadata = GetMetadata(fn) - default = dict(default=None, positional=[], named={}) + default = {'default': None, 'positional': [], 'named': {}} return metadata.get(FIRE_PARSE_FNS, default) diff --git a/fire/decorators_test.py b/fire/decorators_test.py index 15ec1ddd..9988743c 100644 --- a/fire/decorators_test.py +++ b/fire/decorators_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,17 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +"""Tests for the decorators module.""" from fire import core from fire import decorators +from fire import testutils -import unittest - -class A(object): +class NoDefaults: + """A class for testing decorated functions without default values.""" @decorators.SetParseFns(count=int) def double(self, count): @@ -42,7 +40,7 @@ def double(count): return 2 * count -class B(object): +class WithDefaults: @decorators.SetParseFns(float) def example1(self, arg1=10): @@ -53,14 +51,14 @@ def example2(self, arg1=10): return arg1, type(arg1) -class C(object): +class MixedArguments: @decorators.SetParseFns(float, arg2=str) def example3(self, arg1, arg2): return arg1, arg2 -class D(object): +class PartialParseFn: @decorators.SetParseFns(arg1=str) def example4(self, arg1, arg2): @@ -71,7 +69,7 @@ def example5(self, arg1, arg2): return arg1, arg2 -class E(object): +class WithKwargs: @decorators.SetParseFns(mode=str, count=int) def example6(self, **kwargs): @@ -81,70 +79,92 @@ def example6(self, **kwargs): ) -class F(object): +class WithVarArgs: @decorators.SetParseFn(str) - def example7(self, arg1, arg2=None, *varargs, **kwargs): + def example7(self, arg1, arg2=None, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg return arg1, arg2, varargs, kwargs -class FireDecoratorsTest(unittest.TestCase): +class FireDecoratorsTest(testutils.BaseTestCase): def testSetParseFnsNamedArgs(self): - self.assertEqual(core.Fire(A, 'double 2'), 4) - self.assertEqual(core.Fire(A, 'triple 4'), 12.0) + self.assertEqual(core.Fire(NoDefaults, command=['double', '2']), 4) + self.assertEqual(core.Fire(NoDefaults, command=['triple', '4']), 12.0) def testSetParseFnsPositionalArgs(self): - self.assertEqual(core.Fire(A, 'quadruple 5'), 20) + self.assertEqual(core.Fire(NoDefaults, command=['quadruple', '5']), 20) def testSetParseFnsFnWithPositionalArgs(self): - self.assertEqual(core.Fire(double, '5'), 10) + self.assertEqual(core.Fire(double, command=['5']), 10) def testSetParseFnsDefaultsFromPython(self): # When called from Python, function should behave normally. - self.assertTupleEqual(B().example1(), (10, int)) - self.assertEqual(B().example1(5), (5, int)) - self.assertEqual(B().example1(12.0), (12, float)) + self.assertTupleEqual(WithDefaults().example1(), (10, int)) + self.assertEqual(WithDefaults().example1(5), (5, int)) + self.assertEqual(WithDefaults().example1(12.0), (12, float)) def testSetParseFnsDefaultsFromFire(self): # Fire should use the decorator to know how to parse string arguments. - self.assertEqual(core.Fire(B, 'example1'), (10, int)) - self.assertEqual(core.Fire(B, 'example1 10'), (10, float)) - self.assertEqual(core.Fire(B, 'example1 13'), (13, float)) - self.assertEqual(core.Fire(B, 'example1 14.0'), (14, float)) + self.assertEqual(core.Fire(WithDefaults, command=['example1']), (10, int)) + self.assertEqual(core.Fire(WithDefaults, command=['example1', '10']), + (10, float)) + self.assertEqual(core.Fire(WithDefaults, command=['example1', '13']), + (13, float)) + self.assertEqual(core.Fire(WithDefaults, command=['example1', '14.0']), + (14, float)) def testSetParseFnsNamedDefaultsFromPython(self): # When called from Python, function should behave normally. - self.assertTupleEqual(B().example2(), (10, int)) - self.assertEqual(B().example2(5), (5, int)) - self.assertEqual(B().example2(12.0), (12, float)) + self.assertTupleEqual(WithDefaults().example2(), (10, int)) + self.assertEqual(WithDefaults().example2(5), (5, int)) + self.assertEqual(WithDefaults().example2(12.0), (12, float)) def testSetParseFnsNamedDefaultsFromFire(self): # Fire should use the decorator to know how to parse string arguments. - self.assertEqual(core.Fire(B, 'example2'), (10, int)) - self.assertEqual(core.Fire(B, 'example2 10'), (10, float)) - self.assertEqual(core.Fire(B, 'example2 13'), (13, float)) - self.assertEqual(core.Fire(B, 'example2 14.0'), (14, float)) + self.assertEqual(core.Fire(WithDefaults, command=['example2']), (10, int)) + self.assertEqual(core.Fire(WithDefaults, command=['example2', '10']), + (10, float)) + self.assertEqual(core.Fire(WithDefaults, command=['example2', '13']), + (13, float)) + self.assertEqual(core.Fire(WithDefaults, command=['example2', '14.0']), + (14, float)) def testSetParseFnsPositionalAndNamed(self): - self.assertEqual(core.Fire(C, 'example3 10 10'), (10, '10')) + self.assertEqual(core.Fire(MixedArguments, ['example3', '10', '10']), + (10, '10')) def testSetParseFnsOnlySomeTypes(self): - self.assertEqual(core.Fire(D, 'example4 10 10'), ('10', 10)) - self.assertEqual(core.Fire(D, 'example5 10 10'), (10, '10')) + self.assertEqual( + core.Fire(PartialParseFn, command=['example4', '10', '10']), ('10', 10)) + self.assertEqual( + core.Fire(PartialParseFn, command=['example5', '10', '10']), (10, '10')) def testSetParseFnsForKeywordArgs(self): - self.assertEqual(core.Fire(E, 'example6'), ('default', 0)) - self.assertEqual(core.Fire(E, 'example6 --herring "red"'), ('default', 0)) - self.assertEqual(core.Fire(E, 'example6 --mode train'), ('train', 0)) - self.assertEqual(core.Fire(E, 'example6 --mode 3'), ('3', 0)) - self.assertEqual(core.Fire(E, 'example6 --mode -1 --count 10'), ('-1', 10)) - self.assertEqual(core.Fire(E, 'example6 --count -2'), ('default', -2)) + self.assertEqual( + core.Fire(WithKwargs, command=['example6']), ('default', 0)) + self.assertEqual( + core.Fire(WithKwargs, command=['example6', '--herring', '"red"']), + ('default', 0)) + self.assertEqual( + core.Fire(WithKwargs, command=['example6', '--mode', 'train']), + ('train', 0)) + self.assertEqual(core.Fire(WithKwargs, command=['example6', '--mode', '3']), + ('3', 0)) + self.assertEqual( + core.Fire(WithKwargs, + command=['example6', '--mode', '-1', '--count', '10']), + ('-1', 10)) + self.assertEqual( + core.Fire(WithKwargs, command=['example6', '--count', '-2']), + ('default', -2)) def testSetParseFn(self): - self.assertEqual(core.Fire(F, 'example7 1 --arg2=2 3 4 --kwarg=5'), - ('1', '2', ('3', '4'), {'kwarg': '5'})) + self.assertEqual( + core.Fire(WithVarArgs, + command=['example7', '1', '--arg2=2', '3', '4', '--kwarg=5']), + ('1', '2', ('3', '4'), {'kwarg': '5'})) if __name__ == '__main__': - unittest.main() + testutils.main() diff --git a/fire/docstrings.py b/fire/docstrings.py new file mode 100644 index 00000000..2d7c7e63 --- /dev/null +++ b/fire/docstrings.py @@ -0,0 +1,774 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Docstring parsing module for Python Fire. + +The following features of docstrings are not supported. +TODO(dbieber): Support these features. +- numpy docstrings may begin with the function signature. +- whitespace may be important for proper structuring of a docstring +- I've seen `argname` (with single backticks) as a style of documenting + arguments. The `argname` appears on one line, and the description on the next. +- .. Sphinx directives such as .. note:: are not understood. +- After a section ends, future contents may be included in the section. E.g. + :returns: This is what is returned. + Example: An example goes here. +- @param is sometimes used. E.g. + @param argname (type) Description + @return (type) Description +- The true signature of a function is not used by the docstring parser. It could + be useful for determining whether something is a section header or an argument + for example. +- This example confuses types as part of the docstrings. + Parameters + argname : argtype + Arg description +- If there's no blank line after the summary, the description will be slurped + up into the summary. +- "Examples" should be its own section type. aka "Usage". +- "Notes" should be a section type. +- Some people put parenthesis around their types in RST format, e.g. + :param (type) paramname: +- :rtype: directive (return type) +- Also ":rtype str" with no closing ":" has come up. +- Return types are not supported. +- "# Returns" as a section title style +- ":raises ExceptionType: Description" ignores the ExceptionType currently. +- "Defaults to X" occurs sometimes. +- "True | False" indicates bool type. +""" + +import collections +import enum +import re +import textwrap + + +class DocstringInfo( + collections.namedtuple( + 'DocstringInfo', + ('summary', 'description', 'args', 'returns', 'yields', 'raises'))): + pass +DocstringInfo.__new__.__defaults__ = (None,) * len(DocstringInfo._fields) + + +class ArgInfo( + collections.namedtuple( + 'ArgInfo', + ('name', 'type', 'description'))): + pass +ArgInfo.__new__.__defaults__ = (None,) * len(ArgInfo._fields) + + +class KwargInfo(ArgInfo): + pass +KwargInfo.__new__.__defaults__ = (None,) * len(KwargInfo._fields) + + +class Namespace(dict): + """A dict with attribute (dot-notation) access enabled.""" + + def __getattr__(self, key): + if key not in self: + self[key] = Namespace() + return self[key] + + def __setattr__(self, key, value): + self[key] = value + + def __delattr__(self, key): + if key in self: + del self[key] + + +class Sections(enum.Enum): + ARGS = 0 + RETURNS = 1 + YIELDS = 2 + RAISES = 3 + TYPE = 4 + + +class Formats(enum.Enum): + GOOGLE = 0 + NUMPY = 1 + RST = 2 + + +SECTION_TITLES = { + Sections.ARGS: ('argument', 'arg', 'parameter', 'param', 'key'), + Sections.RETURNS: ('return',), + Sections.YIELDS: ('yield',), + Sections.RAISES: ('raise', 'except', 'exception', 'throw', 'error', 'warn'), + Sections.TYPE: ('type',), # rst-only +} + + +def parse(docstring): + """Returns DocstringInfo about the given docstring. + + This parser aims to parse Google, numpy, and rst formatted docstrings. These + are the three most common docstring styles at the time of this writing. + + This parser aims to be permissive, working even when the docstring deviates + from the strict recommendations of these styles. + + This parser does not aim to fully extract all structured information from a + docstring, since there are simply too many ways to structure information in a + docstring. Sometimes content will remain as unstructured text and simply gets + included in the description. + + The Google docstring style guide is available at: + https://github.com/google/styleguide/blob/gh-pages/pyguide.md + + The numpy docstring style guide is available at: + https://numpydoc.readthedocs.io/en/latest/format.html + + Information about the rST docstring format is available at: + https://www.python.org/dev/peps/pep-0287/ + The full set of directives such as param and type for rST docstrings are at: + http://www.sphinx-doc.org/en/master/usage/restructuredtext/domains.html + + Note: This function does not claim to handle all docstrings well. A list of + limitations is available at the top of the file. It does aim to run without + crashing in O(n) time on all strings on length n. If you find a string that + causes this to crash or run unacceptably slowly, please consider submitting + a pull request. + + Args: + docstring: The docstring to parse. + + Returns: + A DocstringInfo containing information about the docstring. + """ + if docstring is None: + return DocstringInfo() + + lines = docstring.strip().split('\n') + lines_len = len(lines) + state = Namespace() # TODO(dbieber): Switch to an explicit class. + + # Variables in state include: + state.section.title = None + state.section.indentation = None + state.section.line1_indentation = None + state.section.format = None + state.summary.permitted = True + state.summary.lines = [] + state.description.lines = [] + state.args = [] + state.kwargs = [] + state.current_arg = None + state.returns.lines = [] + state.yields.lines = [] + state.raises.lines = [] + + for index, line in enumerate(lines): + has_next = index + 1 < lines_len + previous_line = lines[index - 1] if index > 0 else None + next_line = lines[index + 1] if has_next else None + line_info = _create_line_info(line, next_line, previous_line) + _consume_line(line_info, state) + + summary = ' '.join(state.summary.lines) if state.summary.lines else None + state.description.lines = _strip_blank_lines(state.description.lines) + description = textwrap.dedent('\n'.join(state.description.lines)) + if not description: + description = None + returns = _join_lines(state.returns.lines) + yields = _join_lines(state.yields.lines) + raises = _join_lines(state.raises.lines) + + args = [ArgInfo( + name=arg.name, type=_cast_to_known_type(_join_lines(arg.type.lines)), + description=_join_lines(arg.description.lines)) for arg in state.args] + + args.extend([KwargInfo( + name=arg.name, type=_cast_to_known_type(_join_lines(arg.type.lines)), + description=_join_lines(arg.description.lines)) for arg in state.kwargs]) + + return DocstringInfo( + summary=summary, + description=description, + args=args or None, + returns=returns, + raises=raises, + yields=yields, + ) + + +def _strip_blank_lines(lines): + """Removes lines containing only blank characters before and after the text. + + Args: + lines: A list of lines. + Returns: + A list of lines without trailing or leading blank lines. + """ + # Find the first non-blank line. + start = 0 + num_lines = len(lines) + while lines and start < num_lines and _is_blank(lines[start]): + start += 1 + + lines = lines[start:] + + # Remove trailing blank lines. + while lines and _is_blank(lines[-1]): + lines.pop() + + return lines + + +def _is_blank(line): + return not line or line.isspace() + + +def _join_lines(lines): + """Joins lines with the appropriate connective whitespace. + + This puts a single space between consecutive lines, unless there's a blank + line, in which case a full blank line is included. + + Args: + lines: A list of lines to join. + Returns: + A string, the lines joined together. + """ + # TODO(dbieber): Add parameters for variations in whitespace handling. + if not lines: + return None + + started = False + group_texts = [] # Full text of each section. + group_lines = [] # Lines within the current section. + for line in lines: + stripped_line = line.strip() + if stripped_line: + started = True + group_lines.append(stripped_line) + else: + if started: + group_text = ' '.join(group_lines) + group_texts.append(group_text) + group_lines = [] + + if group_lines: # Process the final group. + group_text = ' '.join(group_lines) + group_texts.append(group_text) + + return '\n\n'.join(group_texts) + + +def _get_or_create_arg_by_name(state, name, is_kwarg=False): + """Gets or creates a new Arg. + + These Arg objects (Namespaces) are turned into the ArgInfo namedtuples + returned by parse. Each Arg object is used to collect the name, type, and + description of a single argument to the docstring's function. + + Args: + state: The state of the parser. + name: The name of the arg to create. + is_kwarg: A boolean representing whether the argument is a keyword arg. + Returns: + The new Arg. + """ + for arg in state.args + state.kwargs: + if arg.name == name: + return arg + arg = Namespace() # TODO(dbieber): Switch to an explicit class. + arg.name = name + arg.type.lines = [] + arg.description.lines = [] + if is_kwarg: + state.kwargs.append(arg) + else: + state.args.append(arg) + return arg + + +def _is_arg_name(name): + """Returns whether name is a valid arg name. + + This is used to prevent multiple words (plaintext) from being misinterpreted + as an argument name. Any line that doesn't match the pattern for a valid + argument is treated as not being an argument. + + Args: + name: The name of the potential arg. + Returns: + True if name looks like an arg name, False otherwise. + """ + name = name.strip() + # arg_pattern is a letter or underscore followed by + # zero or more letters, numbers, or underscores. + arg_pattern = r'^[a-zA-Z_]\w*$' + re.match(arg_pattern, name) + return re.match(arg_pattern, name) is not None + + +def _as_arg_name_and_type(text): + """Returns text as a name and type, if text looks like an arg name and type. + + Example: + _as_arg_name_and_type("foo (int)") == "foo", "int" + + Args: + text: The text, which may or may not be an arg name and type. + Returns: + The arg name and type, if text looks like an arg name and type. + None otherwise. + """ + tokens = text.split() + if len(tokens) < 2: + return None + if _is_arg_name(tokens[0]): + type_token = ' '.join(tokens[1:]) + type_token = type_token.lstrip('{([').rstrip('])}') + return tokens[0], type_token + else: + return None + + +def _as_arg_names(names_str): + """Converts names_str to a list of arg names. + + Example: + _as_arg_names("a, b, c") == ["a", "b", "c"] + + Args: + names_str: A string with multiple space or comma separated arg names. + Returns: + A list of arg names, or None if names_str doesn't look like a list of arg + names. + """ + names = re.split(',| ', names_str) + names = [name.strip() for name in names if name.strip()] + for name in names: + if not _is_arg_name(name): + return None + if not names: + return None + return names + + +def _cast_to_known_type(name): + """Canonicalizes a string representing a type if possible. + + # TODO(dbieber): Support additional canonicalization, such as string/str, and + # boolean/bool. + + Example: + _cast_to_known_type("str.") == "str" + + Args: + name: A string representing a type, or None. + Returns: + A canonicalized version of the type string. + """ + if name is None: + return None + return name.rstrip('.') + + +def _consume_google_args_line(line_info, state): + """Consume a single line from a Google args section.""" + split_line = line_info.remaining.split(':', 1) + if len(split_line) > 1: + first, second = split_line # first is either the "arg" or "arg (type)" + if _is_arg_name(first.strip()): + arg = _get_or_create_arg_by_name(state, first.strip()) + arg.description.lines.append(second.strip()) + state.current_arg = arg + else: + arg_name_and_type = _as_arg_name_and_type(first) + if arg_name_and_type: + arg_name, type_str = arg_name_and_type + arg = _get_or_create_arg_by_name(state, arg_name) + arg.type.lines.append(type_str) + arg.description.lines.append(second.strip()) + state.current_arg = arg + else: + if state.current_arg: + state.current_arg.description.lines.append(split_line[0]) + else: + if state.current_arg: + state.current_arg.description.lines.append(split_line[0]) + + +def _consume_line(line_info, state): + """Consumes one line of text, updating the state accordingly. + + When _consume_line is called, part of the line may already have been processed + for header information. + + Args: + line_info: Information about the current and next line of the docstring. + state: The state of the docstring parser. + """ + _update_section_state(line_info, state) + + if state.section.title is None: + if state.summary.permitted: + if line_info.remaining: + state.summary.lines.append(line_info.remaining) + elif state.summary.lines: + state.summary.permitted = False + else: + # We're past the end of the summary. + # Additions now contribute to the description. + state.description.lines.append(line_info.remaining_raw) + else: + state.summary.permitted = False + + if state.section.new and state.section.format == Formats.RST: + # The current line starts with an RST directive, e.g. ":param arg:". + directive = _get_directive(line_info) + directive_tokens = directive.split() # pytype: disable=attribute-error + if state.section.title == Sections.ARGS: + name = directive_tokens[-1] + arg = _get_or_create_arg_by_name( + state, + name, + is_kwarg=directive_tokens[0] == 'key' + ) + if len(directive_tokens) == 3: + # A param directive of the form ":param type arg:". + arg.type.lines.append(directive_tokens[1]) + state.current_arg = arg + elif state.section.title == Sections.TYPE: + name = directive_tokens[-1] + arg = _get_or_create_arg_by_name(state, name) + state.current_arg = arg + + if (state.section.format == Formats.NUMPY and + _line_is_hyphens(line_info.remaining)): + # Skip this all-hyphens line, which is part of the numpy section header. + return + + if state.section.title == Sections.ARGS: + if state.section.format == Formats.GOOGLE: + _consume_google_args_line(line_info, state) + elif state.section.format == Formats.RST: + state.current_arg.description.lines.append(line_info.remaining.strip()) + elif state.section.format == Formats.NUMPY: + line_stripped = line_info.remaining.strip() + if _is_arg_name(line_stripped): + # Token on its own line can either be the last word of the description + # of the previous arg, or a new arg. TODO: Whitespace can distinguish. + arg = _get_or_create_arg_by_name(state, line_stripped) + state.current_arg = arg + elif _line_is_numpy_parameter_type(line_info): + possible_args, type_data = line_stripped.split(':', 1) + arg_names = _as_arg_names(possible_args) # re.split(' |,', s) + if arg_names: + for arg_name in arg_names: + arg = _get_or_create_arg_by_name(state, arg_name) + arg.type.lines.append(type_data) + state.current_arg = arg # TODO(dbieber): Multiple current args. + else: # Just an ordinary line. + if state.current_arg: + state.current_arg.description.lines.append( + line_info.remaining.strip()) + else: + # TODO(dbieber): If not a blank line, add it to the description. + pass + else: # Just an ordinary line. + if state.current_arg: + state.current_arg.description.lines.append( + line_info.remaining.strip()) + else: + # TODO(dbieber): If not a blank line, add it to the description. + pass + + elif state.section.title == Sections.RETURNS: + state.returns.lines.append(line_info.remaining.strip()) + elif state.section.title == Sections.YIELDS: + state.yields.lines.append(line_info.remaining.strip()) + elif state.section.title == Sections.RAISES: + state.raises.lines.append(line_info.remaining.strip()) + elif state.section.title == Sections.TYPE: + if state.section.format == Formats.RST: + assert state.current_arg is not None + state.current_arg.type.lines.append(line_info.remaining.strip()) + else: + pass + + +def _create_line_info(line, next_line, previous_line): + """Returns information about the current line and surrounding lines.""" + line_info = Namespace() # TODO(dbieber): Switch to an explicit class. + line_info.line = line + line_info.stripped = line.strip() + line_info.remaining_raw = line_info.line + line_info.remaining = line_info.stripped + line_info.indentation = len(line) - len(line.lstrip()) + # TODO(dbieber): If next_line is blank, use the next non-blank line. + line_info.next.line = next_line + next_line_exists = next_line is not None + line_info.next.stripped = next_line.strip() if next_line_exists else None + line_info.next.indentation = ( + len(next_line) - len(next_line.lstrip()) if next_line_exists else None) + line_info.previous.line = previous_line + previous_line_exists = previous_line is not None + line_info.previous.indentation = ( + len(previous_line) - + len(previous_line.lstrip()) if previous_line_exists else None) + # Note: This counts all whitespace equally. + return line_info + + +def _update_section_state(line_info, state): + """Uses line_info to determine the current section of the docstring. + + Updates state and line_info.remaining. + + Args: + line_info: Information about the current line. + state: The state of the parser. + """ + section_updated = False + + google_section_permitted = _google_section_permitted(line_info, state) + google_section = google_section_permitted and _google_section(line_info) + if google_section: + state.section.format = Formats.GOOGLE + state.section.title = google_section + line_info.remaining = _get_after_google_header(line_info) + line_info.remaining_raw = line_info.remaining + section_updated = True + + rst_section = _rst_section(line_info) + if rst_section: + state.section.format = Formats.RST + state.section.title = rst_section + line_info.remaining = _get_after_directive(line_info) + line_info.remaining_raw = line_info.remaining + section_updated = True + + numpy_section = _numpy_section(line_info) + if numpy_section: + state.section.format = Formats.NUMPY + state.section.title = numpy_section + line_info.remaining = '' + line_info.remaining_raw = line_info.remaining + section_updated = True + + if section_updated: + state.section.new = True + state.section.indentation = line_info.indentation + state.section.line1_indentation = line_info.next.indentation + else: + state.section.new = False + + +def _google_section_permitted(line_info, state): + """Returns whether a new google section is permitted to start here. + + Q: Why might a new Google section not be allowed? + A: If we're in the middle of a Google "Args" section, then lines that start + "param:" will usually be a new arg, rather than a new section. + We use whitespace to determine when the Args section has actually ended. + + A Google section ends when either: + - A new google section begins at either + - indentation less than indentation of line 1 of the previous section + - or <= indentation of the previous section + - Or the docstring terminates. + + Args: + line_info: Information about the current line. + state: The state of the parser. + Returns: + True or False, indicating whether a new Google section is permitted at the + current line. + """ + if state.section.indentation is None: # We're not in a section yet. + return True + return (line_info.indentation <= state.section.indentation + or line_info.indentation < state.section.line1_indentation) + + +def _matches_section_title(title, section_title): + """Returns whether title is a match for a specific section_title. + + Example: + _matches_section_title('Yields', 'yield') == True + + Args: + title: The title to check for matching. + section_title: A specific known section title to check against. + """ + title = title.lower() + section_title = section_title.lower() + return section_title in (title, title[:-1]) # Supports plurals / some typos. + + +def _matches_section(title, section): + """Returns whether title is a match any known title for a specific section. + + Example: + _matches_section_title('Yields', Sections.YIELDS) == True + _matches_section_title('param', Sections.Args) == True + + Args: + title: The title to check for matching. + section: A specific section to check all possible titles for. + Returns: + True or False, indicating whether title is a match for the specified + section. + """ + for section_title in SECTION_TITLES[section]: + if _matches_section_title(title, section_title): + return True + return False + + +def _section_from_possible_title(possible_title): + """Returns a section matched by the possible title, or None if none match. + + Args: + possible_title: A string that may be the title of a new section. + Returns: + A Section type if one matches, or None if no section type matches. + """ + for section in SECTION_TITLES: + if _matches_section(possible_title, section): + return section + return None + + +def _google_section(line_info): + """Checks whether the current line is the start of a new Google-style section. + + This docstring is a Google-style docstring. Google-style sections look like + this: + + Section Name: + section body goes here + + Args: + line_info: Information about the current line. + Returns: + A Section type if one matches, or None if no section type matches. + """ + colon_index = line_info.remaining.find(':') + possible_title = line_info.remaining[:colon_index] + return _section_from_possible_title(possible_title) + + +def _get_after_google_header(line_info): + """Gets the remainder of the line, after a Google header.""" + colon_index = line_info.remaining.find(':') + return line_info.remaining[colon_index + 1:] + + +def _get_directive(line_info): + """Gets a directive from the start of the line. + + If the line is ":param str foo: Description of foo", then + _get_directive(line_info) returns "param str foo". + + Args: + line_info: Information about the current line. + Returns: + The contents of a directive, or None if the line doesn't start with a + directive. + """ + if line_info.stripped.startswith(':'): + return line_info.stripped.split(':', 2)[1] + else: + return None + + +def _get_after_directive(line_info): + """Gets the remainder of the line, after a directive.""" + sections = line_info.stripped.split(':', 2) + if len(sections) > 2: + return sections[-1] + else: + return '' + + +def _rst_section(line_info): + """Checks whether the current line is the start of a new RST-style section. + + RST uses directives to specify information. An RST directive, which we refer + to as a section here, are surrounded with colons. For example, :param name:. + + Args: + line_info: Information about the current line. + Returns: + A Section type if one matches, or None if no section type matches. + """ + directive = _get_directive(line_info) + if directive: + possible_title = directive.split()[0] + return _section_from_possible_title(possible_title) + else: + return None + + +def _line_is_hyphens(line): + """Returns whether the line is entirely hyphens (and not blank).""" + return line and not line.strip('-') + + +def _numpy_section(line_info): + """Checks whether the current line is the start of a new numpy-style section. + + Numpy style sections are followed by a full line of hyphens, for example: + + Section Name + ------------ + Section body goes here. + + Args: + line_info: Information about the current line. + Returns: + A Section type if one matches, or None if no section type matches. + """ + next_line_is_hyphens = _line_is_hyphens(line_info.next.stripped) + if next_line_is_hyphens: + possible_title = line_info.remaining + return _section_from_possible_title(possible_title) + else: + return None + + +def _line_is_numpy_parameter_type(line_info): + """Returns whether the line contains a numpy style parameter type definition. + + We look for a line of the form: + x : type + + And we have to exclude false positives on argument descriptions containing a + colon by checking the indentation of the line above. + + Args: + line_info: Information about the current line. + Returns: + True if the line is a numpy parameter type definition, False otherwise. + """ + line_stripped = line_info.remaining.strip() + if ':' in line_stripped: + previous_indent = line_info.previous.indentation + current_indent = line_info.indentation + if ':' in line_info.previous.line and current_indent > previous_indent: + # The parameter type was the previous line; this is the description. + return False + else: + return True + return False diff --git a/fire/docstrings_fuzz_test.py b/fire/docstrings_fuzz_test.py new file mode 100644 index 00000000..66be8006 --- /dev/null +++ b/fire/docstrings_fuzz_test.py @@ -0,0 +1,36 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fuzz tests for the docstring parser module.""" + +from fire import docstrings +from fire import testutils + +from hypothesis import example +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + + +class DocstringsFuzzTest(testutils.BaseTestCase): + + @settings(max_examples=1000, deadline=1000) + @given(st.text(min_size=1)) + @example('This is a one-line docstring.') + def test_fuzz_parse(self, value): + docstrings.parse(value) + + +if __name__ == '__main__': + testutils.main() diff --git a/fire/docstrings_test.py b/fire/docstrings_test.py new file mode 100644 index 00000000..ce516944 --- /dev/null +++ b/fire/docstrings_test.py @@ -0,0 +1,361 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for fire docstrings module.""" + +from fire import docstrings +from fire import testutils + +# pylint: disable=invalid-name +DocstringInfo = docstrings.DocstringInfo +ArgInfo = docstrings.ArgInfo +KwargInfo = docstrings.KwargInfo +# pylint: enable=invalid-name + + +class DocstringsTest(testutils.BaseTestCase): + + def test_one_line_simple(self): + docstring = """A simple one line docstring.""" + docstring_info = docstrings.parse(docstring) + expected_docstring_info = DocstringInfo( + summary='A simple one line docstring.', + ) + self.assertEqual(expected_docstring_info, docstring_info) + + def test_one_line_simple_whitespace(self): + docstring = """ + A simple one line docstring. + """ + docstring_info = docstrings.parse(docstring) + expected_docstring_info = DocstringInfo( + summary='A simple one line docstring.', + ) + self.assertEqual(expected_docstring_info, docstring_info) + + def test_one_line_too_long(self): + # pylint: disable=line-too-long + docstring = """A one line docstring that is both a little too verbose and a little too long so it keeps going well beyond a reasonable length for a one-liner. + """ + # pylint: enable=line-too-long + docstring_info = docstrings.parse(docstring) + expected_docstring_info = DocstringInfo( + summary='A one line docstring that is both a little too verbose and ' + 'a little too long so it keeps going well beyond a reasonable length ' + 'for a one-liner.', + ) + self.assertEqual(expected_docstring_info, docstring_info) + + def test_one_line_runs_over(self): + # pylint: disable=line-too-long + docstring = """A one line docstring that is both a little too verbose and a little too long + so it runs onto a second line. + """ + # pylint: enable=line-too-long + docstring_info = docstrings.parse(docstring) + expected_docstring_info = DocstringInfo( + summary='A one line docstring that is both a little too verbose and ' + 'a little too long so it runs onto a second line.', + ) + self.assertEqual(expected_docstring_info, docstring_info) + + def test_one_line_runs_over_whitespace(self): + docstring = """ + A one line docstring that is both a little too verbose and a little too long + so it runs onto a second line. + """ + docstring_info = docstrings.parse(docstring) + expected_docstring_info = DocstringInfo( + summary='A one line docstring that is both a little too verbose and ' + 'a little too long so it runs onto a second line.', + ) + self.assertEqual(expected_docstring_info, docstring_info) + + def test_google_format_args_only(self): + docstring = """One line description. + + Args: + arg1: arg1_description + arg2: arg2_description + """ + docstring_info = docstrings.parse(docstring) + expected_docstring_info = DocstringInfo( + summary='One line description.', + args=[ + ArgInfo(name='arg1', description='arg1_description'), + ArgInfo(name='arg2', description='arg2_description'), + ] + ) + self.assertEqual(expected_docstring_info, docstring_info) + + def test_google_format_arg_named_args(self): + docstring = """ + Args: + args: arg_description + """ + docstring_info = docstrings.parse(docstring) + expected_docstring_info = DocstringInfo( + args=[ + ArgInfo(name='args', description='arg_description'), + ] + ) + self.assertEqual(expected_docstring_info, docstring_info) + + def test_google_format_typed_args_and_returns(self): + docstring = """Docstring summary. + + This is a longer description of the docstring. It spans multiple lines, as + is allowed. + + Args: + param1 (int): The first parameter. + param2 (str): The second parameter. + + Returns: + bool: The return value. True for success, False otherwise. + """ + docstring_info = docstrings.parse(docstring) + expected_docstring_info = DocstringInfo( + summary='Docstring summary.', + description='This is a longer description of the docstring. It spans ' + 'multiple lines, as\nis allowed.', + args=[ + ArgInfo(name='param1', type='int', + description='The first parameter.'), + ArgInfo(name='param2', type='str', + description='The second parameter.'), + ], + returns='bool: The return value. True for success, False otherwise.' + ) + self.assertEqual(expected_docstring_info, docstring_info) + + def test_google_format_multiline_arg_description(self): + docstring = """Docstring summary. + + This is a longer description of the docstring. It spans multiple lines, as + is allowed. + + Args: + param1 (int): The first parameter. + param2 (str): The second parameter. This has a lot of text, enough to + cover two lines. + """ + docstring_info = docstrings.parse(docstring) + expected_docstring_info = DocstringInfo( + summary='Docstring summary.', + description='This is a longer description of the docstring. It spans ' + 'multiple lines, as\nis allowed.', + args=[ + ArgInfo(name='param1', type='int', + description='The first parameter.'), + ArgInfo(name='param2', type='str', + description='The second parameter. This has a lot of text, ' + 'enough to cover two lines.'), + ], + ) + self.assertEqual(expected_docstring_info, docstring_info) + + def test_rst_format_typed_args_and_returns(self): + docstring = """Docstring summary. + + This is a longer description of the docstring. It spans across multiple + lines. + + :param arg1: Description of arg1. + :type arg1: str. + :param arg2: Description of arg2. + :type arg2: bool. + :returns: int -- description of the return value. + :raises: AttributeError, KeyError + """ + docstring_info = docstrings.parse(docstring) + expected_docstring_info = DocstringInfo( + summary='Docstring summary.', + description='This is a longer description of the docstring. It spans ' + 'across multiple\nlines.', + args=[ + ArgInfo(name='arg1', type='str', + description='Description of arg1.'), + ArgInfo(name='arg2', type='bool', + description='Description of arg2.'), + ], + returns='int -- description of the return value.', + raises='AttributeError, KeyError', + ) + self.assertEqual(expected_docstring_info, docstring_info) + + def test_numpy_format_typed_args_and_returns(self): + docstring = """Docstring summary. + + This is a longer description of the docstring. It spans across multiple + lines. + + Parameters + ---------- + param1 : int + The first parameter. + param2 : str + The second parameter. + + Returns + ------- + bool + True if successful, False otherwise. + """ + docstring_info = docstrings.parse(docstring) + expected_docstring_info = DocstringInfo( + summary='Docstring summary.', + description='This is a longer description of the docstring. It spans ' + 'across multiple\nlines.', + args=[ + ArgInfo(name='param1', type='int', + description='The first parameter.'), + ArgInfo(name='param2', type='str', + description='The second parameter.'), + ], + # TODO(dbieber): Support return type. + returns='bool True if successful, False otherwise.', + ) + self.assertEqual(expected_docstring_info, docstring_info) + + def test_numpy_format_multiline_arg_description(self): + docstring = """Docstring summary. + + This is a longer description of the docstring. It spans across multiple + lines. + + Parameters + ---------- + param1 : int + The first parameter. + param2 : str + The second parameter. This has a lot of text, enough to cover two + lines. + """ + docstring_info = docstrings.parse(docstring) + expected_docstring_info = DocstringInfo( + summary='Docstring summary.', + description='This is a longer description of the docstring. It spans ' + 'across multiple\nlines.', + args=[ + ArgInfo(name='param1', type='int', + description='The first parameter.'), + ArgInfo(name='param2', type='str', + description='The second parameter. This has a lot of text, ' + 'enough to cover two lines.'), + ], + ) + self.assertEqual(expected_docstring_info, docstring_info) + + def test_multisection_docstring(self): + docstring = """Docstring summary. + + This is the first section of a docstring description. + + This is the second section of a docstring description. This docstring + description has just two sections. + """ + docstring_info = docstrings.parse(docstring) + expected_docstring_info = DocstringInfo( + summary='Docstring summary.', + description='This is the first section of a docstring description.' + '\n\n' + 'This is the second section of a docstring description. This docstring' + '\n' + 'description has just two sections.', + ) + self.assertEqual(expected_docstring_info, docstring_info) + + def test_google_section_with_blank_first_line(self): + docstring = """Inspired by requests HTTPAdapter docstring. + + :param x: Simple param. + + Usage: + + >>> import requests + """ + docstring_info = docstrings.parse(docstring) + self.assertEqual('Inspired by requests HTTPAdapter docstring.', + docstring_info.summary) + + def test_ill_formed_docstring(self): + docstring = """Docstring summary. + + args: raises :: + : + pathological docstrings should not fail, and ideally should behave + reasonably. + """ + docstrings.parse(docstring) + + def test_strip_blank_lines(self): + lines = [' ', ' foo ', ' '] + expected_output = [' foo '] + + self.assertEqual(expected_output, docstrings._strip_blank_lines(lines)) # pylint: disable=protected-access + + def test_numpy_colon_in_description(self): + docstring = """ + Greets name. + + Arguments + --------- + name : str + name, default : World + arg2 : int + arg2, default:None + arg3 : bool + """ + docstring_info = docstrings.parse(docstring) + expected_docstring_info = DocstringInfo( + summary='Greets name.', + description=None, + args=[ + ArgInfo(name='name', type='str', + description='name, default : World'), + ArgInfo(name='arg2', type='int', + description='arg2, default:None'), + ArgInfo(name='arg3', type='bool', description=None), + ] + ) + self.assertEqual(expected_docstring_info, docstring_info) + + def test_rst_format_typed_args_and_kwargs(self): + docstring = """Docstring summary. + + :param arg1: Description of arg1. + :type arg1: str. + :key arg2: Description of arg2. + :type arg2: bool. + :key arg3: Description of arg3. + :type arg3: str. + """ + docstring_info = docstrings.parse(docstring) + expected_docstring_info = DocstringInfo( + summary='Docstring summary.', + args=[ + ArgInfo(name='arg1', type='str', + description='Description of arg1.'), + KwargInfo(name='arg2', type='bool', + description='Description of arg2.'), + KwargInfo(name='arg3', type='str', + description='Description of arg3.'), + ], + ) + self.assertEqual(expected_docstring_info, docstring_info) + + +if __name__ == '__main__': + testutils.main() diff --git a/fire/fire_import_test.py b/fire/fire_import_test.py index f2e64237..a6b4acc3 100644 --- a/fire/fire_import_test.py +++ b/fire/fire_import_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,16 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import fire +"""Tests importing the fire module.""" + +import sys +from unittest import mock -import unittest +import fire +from fire import testutils -class FireImportTest(unittest.TestCase): +class FireImportTest(testutils.BaseTestCase): """Tests importing Fire.""" def testFire(self): - fire.Fire() + with mock.patch.object(sys, 'argv', ['commandname']): + fire.Fire() def testFireMethods(self): self.assertIsNotNone(fire.Fire) @@ -32,4 +37,4 @@ def testNoPrivateMethods(self): if __name__ == '__main__': - unittest.main() + testutils.main() diff --git a/fire/fire_test.py b/fire/fire_test.py index f70aaedf..99b4a7c6 100644 --- a/fire/fire_test.py +++ b/fire/fire_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,360 +12,710 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +"""Tests for the fire module.""" + +import os +import sys +from unittest import mock import fire from fire import test_components as tc -from fire import trace - -import unittest +from fire import testutils -class FireTest(unittest.TestCase): +class FireTest(testutils.BaseTestCase): def testFire(self): - fire.Fire(tc.Empty) - fire.Fire(tc.OldStyleEmpty) - fire.Fire(tc.WithInit) + with mock.patch.object(sys, 'argv', ['progname']): + fire.Fire(tc.Empty) + fire.Fire(tc.OldStyleEmpty) + fire.Fire(tc.WithInit) + # Test both passing command as a sequence and as a string. + self.assertEqual(fire.Fire(tc.NoDefaults, command='triple 4'), 12) + self.assertEqual(fire.Fire(tc.WithDefaults, command=('double', '2')), 4) + self.assertEqual(fire.Fire(tc.WithDefaults, command=['triple', '4']), 12) + self.assertEqual(fire.Fire(tc.OldStyleWithDefaults, + command=['double', '2']), 4) + self.assertEqual(fire.Fire(tc.OldStyleWithDefaults, + command=['triple', '4']), 12) + + def testFirePositionalCommand(self): + # Test passing command as a positional argument. self.assertEqual(fire.Fire(tc.NoDefaults, 'double 2'), 4) - self.assertEqual(fire.Fire(tc.NoDefaults, 'triple 4'), 12) - self.assertEqual(fire.Fire(tc.WithDefaults, 'double 2'), 4) - self.assertEqual(fire.Fire(tc.WithDefaults, 'triple 4'), 12) - self.assertEqual(fire.Fire(tc.OldStyleWithDefaults, 'double 2'), 4) - self.assertEqual(fire.Fire(tc.OldStyleWithDefaults, 'triple 4'), 12) + self.assertEqual(fire.Fire(tc.NoDefaults, ['double', '2']), 4) + + def testFireInvalidCommandArg(self): + with self.assertRaises(ValueError): + # This is not a valid command. + fire.Fire(tc.WithDefaults, command=10) + + def testFireDefaultName(self): + with mock.patch.object(sys, 'argv', + [os.path.join('python-fire', 'fire', + 'base_filename.py')]): + with self.assertOutputMatches(stdout='SYNOPSIS.*base_filename.py', + stderr=None): + fire.Fire(tc.Empty) def testFireNoArgs(self): - self.assertEqual(fire.Fire(tc.MixedDefaults, 'ten'), 10) + self.assertEqual(fire.Fire(tc.MixedDefaults, command=['ten']), 10) def testFireExceptions(self): - # Exceptions of Fire are printed to stderr and None is returned. - self.assertIsNone(fire.Fire(tc.Empty, 'nomethod')) # Member doesn't exist. - self.assertIsNone(fire.Fire(tc.NoDefaults, 'double')) # Missing argument. - self.assertIsNone(fire.Fire(tc.TypedProperties, 'delta x')) # Missing key. + # Exceptions of Fire are printed to stderr and a FireExit is raised. + with self.assertRaisesFireExit(2): + fire.Fire(tc.Empty, command=['nomethod']) # Member doesn't exist. + with self.assertRaisesFireExit(2): + fire.Fire(tc.NoDefaults, command=['double']) # Missing argument. + with self.assertRaisesFireExit(2): + fire.Fire(tc.TypedProperties, command=['delta', 'x']) # Missing key. # Exceptions of the target components are still raised. with self.assertRaises(ZeroDivisionError): - fire.Fire(tc.NumberDefaults, 'reciprocal 0.0') + fire.Fire(tc.NumberDefaults, command=['reciprocal', '0.0']) def testFireNamedArgs(self): - self.assertEqual(fire.Fire(tc.WithDefaults, 'double --count 5'), 10) - self.assertEqual(fire.Fire(tc.WithDefaults, 'triple --count 5'), 15) - self.assertEqual(fire.Fire(tc.OldStyleWithDefaults, 'double --count 5'), 10) - self.assertEqual(fire.Fire(tc.OldStyleWithDefaults, 'triple --count 5'), 15) + self.assertEqual(fire.Fire(tc.WithDefaults, + command=['double', '--count', '5']), 10) + self.assertEqual(fire.Fire(tc.WithDefaults, + command=['triple', '--count', '5']), 15) + self.assertEqual( + fire.Fire(tc.OldStyleWithDefaults, command=['double', '--count', '5']), + 10) + self.assertEqual( + fire.Fire(tc.OldStyleWithDefaults, command=['triple', '--count', '5']), + 15) + + def testFireNamedArgsSingleHyphen(self): + self.assertEqual(fire.Fire(tc.WithDefaults, + command=['double', '-count', '5']), 10) + self.assertEqual(fire.Fire(tc.WithDefaults, + command=['triple', '-count', '5']), 15) + self.assertEqual( + fire.Fire(tc.OldStyleWithDefaults, command=['double', '-count', '5']), + 10) + self.assertEqual( + fire.Fire(tc.OldStyleWithDefaults, command=['triple', '-count', '5']), + 15) def testFireNamedArgsWithEquals(self): - self.assertEqual(fire.Fire(tc.WithDefaults, 'double --count=5'), 10) - self.assertEqual(fire.Fire(tc.WithDefaults, 'triple --count=5'), 15) + self.assertEqual(fire.Fire(tc.WithDefaults, + command=['double', '--count=5']), 10) + self.assertEqual(fire.Fire(tc.WithDefaults, + command=['triple', '--count=5']), 15) + + def testFireNamedArgsWithEqualsSingleHyphen(self): + self.assertEqual(fire.Fire(tc.WithDefaults, + command=['double', '-count=5']), 10) + self.assertEqual(fire.Fire(tc.WithDefaults, + command=['triple', '-count=5']), 15) def testFireAllNamedArgs(self): - self.assertEqual(fire.Fire(tc.MixedDefaults, 'sum 1 2'), 5) - self.assertEqual(fire.Fire(tc.MixedDefaults, 'sum --alpha 1 2'), 5) - self.assertEqual(fire.Fire(tc.MixedDefaults, 'sum --beta 1 2'), 4) - self.assertEqual(fire.Fire(tc.MixedDefaults, 'sum 1 --alpha 2'), 4) - self.assertEqual(fire.Fire(tc.MixedDefaults, 'sum 1 --beta 2'), 5) - self.assertEqual(fire.Fire(tc.MixedDefaults, 'sum --alpha 1 --beta 2'), 5) - self.assertEqual(fire.Fire(tc.MixedDefaults, 'sum --beta 1 --alpha 2'), 4) + self.assertEqual(fire.Fire(tc.MixedDefaults, command=['sum', '1', '2']), 5) + self.assertEqual(fire.Fire(tc.MixedDefaults, + command=['sum', '--alpha', '1', '2']), 5) + self.assertEqual(fire.Fire(tc.MixedDefaults, + command=['sum', '--beta', '1', '2']), 4) + self.assertEqual(fire.Fire(tc.MixedDefaults, + command=['sum', '1', '--alpha', '2']), 4) + self.assertEqual(fire.Fire(tc.MixedDefaults, + command=['sum', '1', '--beta', '2']), 5) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['sum', '--alpha', '1', '--beta', '2']), 5) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['sum', '--beta', '1', '--alpha', '2']), 4) def testFireAllNamedArgsOneMissing(self): - self.assertEqual(fire.Fire(tc.MixedDefaults, 'sum'), 0) - self.assertEqual(fire.Fire(tc.MixedDefaults, 'sum 1'), 1) - self.assertEqual(fire.Fire(tc.MixedDefaults, 'sum --alpha 1'), 1) - self.assertEqual(fire.Fire(tc.MixedDefaults, 'sum --beta 2'), 4) + self.assertEqual(fire.Fire(tc.MixedDefaults, command=['sum']), 0) + self.assertEqual(fire.Fire(tc.MixedDefaults, command=['sum', '1']), 1) + self.assertEqual(fire.Fire(tc.MixedDefaults, + command=['sum', '--alpha', '1']), 1) + self.assertEqual(fire.Fire(tc.MixedDefaults, + command=['sum', '--beta', '2']), 4) def testFirePartialNamedArgs(self): - self.assertEqual(fire.Fire(tc.MixedDefaults, 'identity 1 2'), (1, 2)) self.assertEqual( - fire.Fire(tc.MixedDefaults, 'identity --alpha 1 2'), (1, 2)) - self.assertEqual(fire.Fire(tc.MixedDefaults, 'identity --beta 1 2'), (2, 1)) + fire.Fire(tc.MixedDefaults, command=['identity', '1', '2']), (1, 2)) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '--alpha', '1', '2']), (1, 2)) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '--beta', '1', '2']), (2, 1)) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '1', '--alpha', '2']), (2, 1)) self.assertEqual( - fire.Fire(tc.MixedDefaults, 'identity 1 --alpha 2'), (2, 1)) - self.assertEqual(fire.Fire(tc.MixedDefaults, 'identity 1 --beta 2'), (1, 2)) + fire.Fire(tc.MixedDefaults, + command=['identity', '1', '--beta', '2']), (1, 2)) self.assertEqual( - fire.Fire(tc.MixedDefaults, 'identity --alpha 1 --beta 2'), (1, 2)) + fire.Fire( + tc.MixedDefaults, + command=['identity', '--alpha', '1', '--beta', '2']), (1, 2)) self.assertEqual( - fire.Fire(tc.MixedDefaults, 'identity --beta 1 --alpha 2'), (2, 1)) + fire.Fire( + tc.MixedDefaults, + command=['identity', '--beta', '1', '--alpha', '2']), (2, 1)) def testFirePartialNamedArgsOneMissing(self): - # By default, errors are written to standard out and None is returned. - self.assertIsNone( # Identity needs an arg. - fire.Fire(tc.MixedDefaults, 'identity')) - self.assertIsNone( # Identity needs a value for alpha. - fire.Fire(tc.MixedDefaults, 'identity --beta 2')) + # Errors are written to standard out and a FireExit is raised. + with self.assertRaisesFireExit(2): + fire.Fire(tc.MixedDefaults, + command=['identity']) # Identity needs an arg. - self.assertEqual(fire.Fire(tc.MixedDefaults, 'identity 1'), (1, '0')) + with self.assertRaisesFireExit(2): + # Identity needs a value for alpha. + fire.Fire(tc.MixedDefaults, command=['identity', '--beta', '2']) + + self.assertEqual( + fire.Fire(tc.MixedDefaults, command=['identity', '1']), (1, '0')) + self.assertEqual( + fire.Fire(tc.MixedDefaults, command=['identity', '--alpha', '1']), + (1, '0')) + + def testFireAnnotatedArgs(self): + self.assertEqual(fire.Fire(tc.Annotations, command=['double', '5']), 10) + self.assertEqual(fire.Fire(tc.Annotations, command=['triple', '5']), 15) + + def testFireKeywordOnlyArgs(self): + with self.assertRaisesFireExit(2): + # Keyword arguments must be passed with flag syntax. + fire.Fire(tc.py3.KeywordOnly, command=['double', '5']) + + self.assertEqual( + fire.Fire(tc.py3.KeywordOnly, command=['double', '--count', '5']), 10) self.assertEqual( - fire.Fire(tc.MixedDefaults, 'identity --alpha 1'), (1, '0')) + fire.Fire(tc.py3.KeywordOnly, command=['triple', '--count', '5']), 15) def testFireProperties(self): - self.assertEqual(fire.Fire(tc.TypedProperties, 'alpha'), True) - self.assertEqual(fire.Fire(tc.TypedProperties, 'beta'), (1, 2, 3)) + self.assertEqual(fire.Fire(tc.TypedProperties, command=['alpha']), True) + self.assertEqual(fire.Fire(tc.TypedProperties, command=['beta']), (1, 2, 3)) def testFireRecursion(self): self.assertEqual( - fire.Fire(tc.TypedProperties, 'charlie double hello'), 'hellohello') - self.assertEqual(fire.Fire(tc.TypedProperties, 'charlie triple w'), 'www') + fire.Fire(tc.TypedProperties, + command=['charlie', 'double', 'hello']), 'hellohello') + self.assertEqual(fire.Fire(tc.TypedProperties, + command=['charlie', 'triple', 'w']), 'www') def testFireVarArgs(self): self.assertEqual( - fire.Fire(tc.VarArgs, 'cumsums a b c d'), ['a', 'ab', 'abc', 'abcd']) - self.assertEqual(fire.Fire(tc.VarArgs, 'cumsums 1 2 3 4'), [1, 3, 6, 10]) + fire.Fire(tc.VarArgs, + command=['cumsums', 'a', 'b', 'c', 'd']), + ['a', 'ab', 'abc', 'abcd']) + self.assertEqual( + fire.Fire(tc.VarArgs, command=['cumsums', '1', '2', '3', '4']), + [1, 3, 6, 10]) def testFireVarArgsWithNamedArgs(self): - self.assertEqual(fire.Fire(tc.VarArgs, 'varchars 1 2 c d'), (1, 2, 'cd')) - self.assertEqual(fire.Fire(tc.VarArgs, 'varchars 3 4 c d e'), (3, 4, 'cde')) + self.assertEqual( + fire.Fire(tc.VarArgs, command=['varchars', '1', '2', 'c', 'd']), + (1, 2, 'cd')) + self.assertEqual( + fire.Fire(tc.VarArgs, command=['varchars', '3', '4', 'c', 'd', 'e']), + (3, 4, 'cde')) def testFireKeywordArgs(self): - self.assertEqual(fire.Fire(tc.Kwargs, 'props --name David --age 24'), - {'name': 'David', 'age': 24}) self.assertEqual( - fire.Fire(tc.Kwargs, - 'props --message "This is a message it has -- in it"'), + fire.Fire( + tc.Kwargs, + command=['props', '--name', 'David', '--age', '24']), + {'name': 'David', 'age': 24}) + # Run this test both with a list command and a string command. + self.assertEqual( + fire.Fire( + tc.Kwargs, + command=['props', '--message', + '"This is a message it has -- in it"']), # Quotes stripped + {'message': 'This is a message it has -- in it'}) + self.assertEqual( + fire.Fire( + tc.Kwargs, + command=['props', '--message', + 'This is a message it has -- in it']), {'message': 'This is a message it has -- in it'}) - self.assertEqual(fire.Fire(tc.Kwargs, 'upper --alpha A --beta B'), - 'ALPHA BETA') - self.assertEqual(fire.Fire(tc.Kwargs, 'upper --alpha A --beta B - lower'), - 'alpha beta') + self.assertEqual( + fire.Fire( + tc.Kwargs, + command='props --message "This is a message it has -- in it"'), + {'message': 'This is a message it has -- in it'}) + self.assertEqual( + fire.Fire(tc.Kwargs, + command=['upper', '--alpha', 'A', '--beta', 'B']), + 'ALPHA BETA') + self.assertEqual( + fire.Fire( + tc.Kwargs, + command=['upper', '--alpha', 'A', '--beta', 'B', '-', 'lower']), + 'alpha beta') def testFireKeywordArgsWithMissingPositionalArgs(self): - self.assertEqual(fire.Fire(tc.Kwargs, 'run Hello World --cell is'), - ('Hello', 'World', {'cell': 'is'})) - self.assertEqual(fire.Fire(tc.Kwargs, 'run Hello --cell ok'), - ('Hello', None, {'cell': 'ok'})) + self.assertEqual( + fire.Fire(tc.Kwargs, command=['run', 'Hello', 'World', '--cell', 'is']), + ('Hello', 'World', {'cell': 'is'})) + self.assertEqual( + fire.Fire(tc.Kwargs, command=['run', 'Hello', '--cell', 'ok']), + ('Hello', None, {'cell': 'ok'})) def testFireObject(self): - self.assertEqual(fire.Fire(tc.WithDefaults(), 'double --count 5'), 10) - self.assertEqual(fire.Fire(tc.WithDefaults(), 'triple --count 5'), 15) + self.assertEqual( + fire.Fire(tc.WithDefaults(), command=['double', '--count', '5']), 10) + self.assertEqual( + fire.Fire(tc.WithDefaults(), command=['triple', '--count', '5']), 15) def testFireDict(self): component = { 'double': lambda x=0: 2 * x, 'cheese': 'swiss', } - self.assertEqual(fire.Fire(component, 'double 5'), 10) - self.assertEqual(fire.Fire(component, 'cheese'), 'swiss') + self.assertEqual(fire.Fire(component, command=['double', '5']), 10) + self.assertEqual(fire.Fire(component, command=['cheese']), 'swiss') def testFireObjectWithDict(self): - self.assertEqual(fire.Fire(tc.TypedProperties, 'delta echo'), 'E') - self.assertEqual(fire.Fire(tc.TypedProperties, 'delta echo lower'), 'e') - self.assertIsInstance(fire.Fire(tc.TypedProperties, 'delta nest'), dict) - self.assertEqual(fire.Fire(tc.TypedProperties, 'delta nest 0'), 'a') + self.assertEqual( + fire.Fire(tc.TypedProperties, command=['delta', 'echo']), 'E') + self.assertEqual( + fire.Fire(tc.TypedProperties, command=['delta', 'echo', 'lower']), 'e') + self.assertIsInstance( + fire.Fire(tc.TypedProperties, command=['delta', 'nest']), dict) + self.assertEqual( + fire.Fire(tc.TypedProperties, command=['delta', 'nest', '0']), 'a') + + def testFireSet(self): + component = tc.simple_set() + result = fire.Fire(component, command=[]) + self.assertEqual(len(result), 3) + + def testFireFrozenset(self): + component = tc.simple_frozenset() + result = fire.Fire(component, command=[]) + self.assertEqual(len(result), 3) def testFireList(self): component = ['zero', 'one', 'two', 'three'] - self.assertEqual(fire.Fire(component, '2'), 'two') - self.assertEqual(fire.Fire(component, '3'), 'three') - self.assertEqual(fire.Fire(component, '-1'), 'three') + self.assertEqual(fire.Fire(component, command=['2']), 'two') + self.assertEqual(fire.Fire(component, command=['3']), 'three') + self.assertEqual(fire.Fire(component, command=['-1']), 'three') def testFireObjectWithList(self): - self.assertEqual(fire.Fire(tc.TypedProperties, 'echo 0'), 'alex') - self.assertEqual(fire.Fire(tc.TypedProperties, 'echo 1'), 'bethany') + self.assertEqual(fire.Fire(tc.TypedProperties, command=['echo', '0']), + 'alex') + self.assertEqual(fire.Fire(tc.TypedProperties, command=['echo', '1']), + 'bethany') def testFireObjectWithTuple(self): - self.assertEqual(fire.Fire(tc.TypedProperties, 'fox 0'), 'carry') - self.assertEqual(fire.Fire(tc.TypedProperties, 'fox 1'), 'divide') + self.assertEqual(fire.Fire(tc.TypedProperties, command=['fox', '0']), + 'carry') + self.assertEqual(fire.Fire(tc.TypedProperties, command=['fox', '1']), + 'divide') + + def testFireObjectWithListAsObject(self): + self.assertEqual( + fire.Fire(tc.TypedProperties, command=['echo', 'count', 'bethany']), + 1) + + def testFireObjectWithTupleAsObject(self): + self.assertEqual( + fire.Fire(tc.TypedProperties, command=['fox', 'count', 'divide']), + 1) def testFireNoComponent(self): - self.assertEqual(fire.Fire(command='tc WithDefaults double 10'), 20) + self.assertEqual(fire.Fire(command=['tc', 'WithDefaults', 'double', '10']), + 20) last_char = lambda text: text[-1] # pylint: disable=unused-variable - self.assertEqual(fire.Fire(command='last_char "Hello"'), 'o') - self.assertEqual(fire.Fire(command='last-char "World"'), 'd') + self.assertEqual(fire.Fire(command=['last_char', '"Hello"']), 'o') + self.assertEqual(fire.Fire(command=['last-char', '"World"']), 'd') rset = lambda count=0: set(range(count)) # pylint: disable=unused-variable - self.assertEqual(fire.Fire(command='rset 5'), {0, 1, 2, 3, 4}) + self.assertEqual(fire.Fire(command=['rset', '5']), {0, 1, 2, 3, 4}) def testFireUnderscores(self): self.assertEqual( - fire.Fire(tc.Underscores, 'underscore-example'), 'fish fingers') + fire.Fire(tc.Underscores, + command=['underscore-example']), 'fish fingers') self.assertEqual( - fire.Fire(tc.Underscores, 'underscore_example'), 'fish fingers') + fire.Fire(tc.Underscores, + command=['underscore_example']), 'fish fingers') def testFireUnderscoresInArg(self): self.assertEqual( - fire.Fire(tc.Underscores, 'underscore-function example'), 'example') + fire.Fire(tc.Underscores, + command=['underscore-function', 'example']), 'example') self.assertEqual( - fire.Fire(tc.Underscores, 'underscore_function --underscore-arg=score'), + fire.Fire(tc.Underscores, + command=['underscore_function', '--underscore-arg=score']), 'score') self.assertEqual( - fire.Fire(tc.Underscores, 'underscore_function --underscore_arg=score'), + fire.Fire(tc.Underscores, + command=['underscore_function', '--underscore_arg=score']), 'score') def testBoolParsing(self): - self.assertEqual(fire.Fire(tc.BoolConverter, 'as-bool True'), True) - self.assertEqual(fire.Fire(tc.BoolConverter, 'as-bool False'), False) - self.assertEqual(fire.Fire(tc.BoolConverter, 'as-bool --arg=True'), True) - self.assertEqual(fire.Fire(tc.BoolConverter, 'as-bool --arg=False'), False) - self.assertEqual(fire.Fire(tc.BoolConverter, 'as-bool --arg'), True) - self.assertEqual(fire.Fire(tc.BoolConverter, 'as-bool --noarg'), False) + self.assertEqual(fire.Fire(tc.BoolConverter, command=['as-bool', 'True']), + True) + self.assertEqual( + fire.Fire(tc.BoolConverter, command=['as-bool', 'False']), False) + self.assertEqual( + fire.Fire(tc.BoolConverter, command=['as-bool', '--arg=True']), True) + self.assertEqual( + fire.Fire(tc.BoolConverter, command=['as-bool', '--arg=False']), False) + self.assertEqual(fire.Fire(tc.BoolConverter, command=['as-bool', '--arg']), + True) + self.assertEqual( + fire.Fire(tc.BoolConverter, command=['as-bool', '--noarg']), False) def testBoolParsingContinued(self): self.assertEqual( - fire.Fire(tc.MixedDefaults, 'identity True False'), (True, False)) + fire.Fire(tc.MixedDefaults, + command=['identity', 'True', 'False']), (True, False)) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '--alpha=False', '10']), (False, 10)) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '--alpha', '--beta', '10']), (True, 10)) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '--alpha', '--beta=10']), (True, 10)) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '--noalpha', '--beta']), (False, True)) self.assertEqual( - fire.Fire(tc.MixedDefaults, 'identity --alpha=False 10'), (False, 10)) + fire.Fire(tc.MixedDefaults, command=['identity', '10', '--beta']), + (10, True)) + + def testBoolParsingSingleHyphen(self): + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '-alpha=False', '10']), (False, 10)) self.assertEqual( - fire.Fire(tc.MixedDefaults, 'identity --alpha --beta 10'), (True, 10)) + fire.Fire(tc.MixedDefaults, + command=['identity', '-alpha', '-beta', '10']), (True, 10)) self.assertEqual( - fire.Fire(tc.MixedDefaults, 'identity --alpha --beta=10'), (True, 10)) + fire.Fire(tc.MixedDefaults, + command=['identity', '-alpha', '-beta=10']), (True, 10)) self.assertEqual( - fire.Fire(tc.MixedDefaults, 'identity --noalpha --beta'), (False, True)) + fire.Fire(tc.MixedDefaults, + command=['identity', '-noalpha', '-beta']), (False, True)) self.assertEqual( - fire.Fire(tc.MixedDefaults, 'identity 10 --beta'), (10, True)) + fire.Fire(tc.MixedDefaults, + command=['identity', '-alpha', '-10', '-beta']), (-10, True)) def testBoolParsingLessExpectedCases(self): # Note: Does not return (True, 10). self.assertEqual( - fire.Fire(tc.MixedDefaults, 'identity --alpha 10'), (10, '0')) + fire.Fire(tc.MixedDefaults, + command=['identity', '--alpha', '10']), (10, '0')) # To get (True, 10), use one of the following: self.assertEqual( - fire.Fire(tc.MixedDefaults, 'identity --alpha --beta=10'), (True, 10)) + fire.Fire(tc.MixedDefaults, + command=['identity', '--alpha', '--beta=10']), + (True, 10)) self.assertEqual( - fire.Fire(tc.MixedDefaults, 'identity True 10'), (True, 10)) + fire.Fire(tc.MixedDefaults, + command=['identity', 'True', '10']), (True, 10)) - # Note: Does not return ('--test', '0'). - self.assertEqual(fire.Fire(tc.MixedDefaults, 'identity --alpha --test'), - (True, '--test')) + # Note: Does not return (True, '--test') or ('--test', 0). + with self.assertRaisesFireExit(2): + fire.Fire(tc.MixedDefaults, command=['identity', '--alpha', '--test']) + + self.assertEqual( + fire.Fire( + tc.MixedDefaults, + command=['identity', '--alpha', 'True', '"--test"']), + (True, '--test')) # To get ('--test', '0'), use one of the following: - self.assertEqual(fire.Fire(tc.MixedDefaults, 'identity --alpha=--test'), + self.assertEqual(fire.Fire(tc.MixedDefaults, + command=['identity', '--alpha=--test']), ('--test', '0')) self.assertEqual( - fire.Fire(tc.MixedDefaults, r'identity --alpha \"--test\"'), + fire.Fire(tc.MixedDefaults, command=r'identity --alpha \"--test\"'), ('--test', '0')) + def testSingleCharFlagParsing(self): + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '-a']), (True, '0')) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '-a', '--beta=10']), (True, 10)) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '-a', '-b']), (True, True)) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '-a', '42', '-b']), (42, True)) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '-a', '42', '-b', '10']), (42, 10)) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '--alpha', 'True', '-b', '10']), + (True, 10)) + with self.assertRaisesFireExit(2): + # This test attempts to use an ambiguous shortcut flag on a function with + # a naming conflict for the shortcut, triggering a FireError. + fire.Fire(tc.SimilarArgNames, command=['identity', '-b']) + + def testSingleCharFlagParsingEqualSign(self): + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '-a=True']), (True, '0')) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '-a=3', '--beta=10']), (3, 10)) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '-a=False', '-b=15']), (False, 15)) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '-a', '42', '-b=12']), (42, 12)) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '-a=42', '-b', '10']), (42, 10)) + + def testSingleCharFlagParsingExactMatch(self): + self.assertEqual( + fire.Fire(tc.SimilarArgNames, + command=['identity2', '-a']), (True, None)) + self.assertEqual( + fire.Fire(tc.SimilarArgNames, + command=['identity2', '-a=10']), (10, None)) + self.assertEqual( + fire.Fire(tc.SimilarArgNames, + command=['identity2', '--a']), (True, None)) + self.assertEqual( + fire.Fire(tc.SimilarArgNames, + command=['identity2', '-alpha']), (None, True)) + self.assertEqual( + fire.Fire(tc.SimilarArgNames, + command=['identity2', '-a', '-alpha']), (True, True)) + + def testSingleCharFlagParsingCapitalLetter(self): + self.assertEqual( + fire.Fire(tc.CapitalizedArgNames, + command=['sum', '-D', '5', '-G', '10']), 15) + def testBoolParsingWithNo(self): # In these examples --nothing always refers to the nothing argument: def fn1(thing, nothing): return thing, nothing - self.assertEqual(fire.Fire(fn1, '--thing --nothing'), (True, True)) - self.assertEqual(fire.Fire(fn1, '--thing --nonothing'), (True, False)) + self.assertEqual(fire.Fire(fn1, command=['--thing', '--nothing']), + (True, True)) + self.assertEqual(fire.Fire(fn1, command=['--thing', '--nonothing']), + (True, False)) - # In the next example nothing=False (since rightmost setting of a flag gets - # precedence), but it errors because thing has no value. - self.assertEqual(fire.Fire(fn1, '--nothing --nonothing'), None) + with self.assertRaisesFireExit(2): + # In this case nothing=False (since rightmost setting of a flag gets + # precedence), but it errors because thing has no value. + fire.Fire(fn1, command=['--nothing', '--nonothing']) # In these examples, --nothing sets thing=False: def fn2(thing, **kwargs): return thing, kwargs - self.assertEqual(fire.Fire(fn2, '--thing'), (True, {})) - self.assertEqual(fire.Fire(fn2, '--nothing'), (False, {})) - # In the next one, nothing=True, but it errors because thing has no value. - self.assertEqual(fire.Fire(fn2, '--nothing=True'), None) - self.assertEqual(fire.Fire(fn2, '--nothing --nothing=True'), + self.assertEqual(fire.Fire(fn2, command=['--thing']), (True, {})) + self.assertEqual(fire.Fire(fn2, command=['--nothing']), (False, {})) + with self.assertRaisesFireExit(2): + # In this case, nothing=True, but it errors because thing has no value. + fire.Fire(fn2, command=['--nothing=True']) + self.assertEqual(fire.Fire(fn2, command=['--nothing', '--nothing=True']), (False, {'nothing': True})) def fn3(arg, **kwargs): return arg, kwargs - self.assertEqual(fire.Fire(fn3, '--arg=value --thing'), + self.assertEqual(fire.Fire(fn3, command=['--arg=value', '--thing']), ('value', {'thing': True})) - self.assertEqual(fire.Fire(fn3, '--arg=value --nothing'), + self.assertEqual(fire.Fire(fn3, command=['--arg=value', '--nothing']), ('value', {'thing': False})) - self.assertEqual(fire.Fire(fn3, '--arg=value --nonothing'), + self.assertEqual(fire.Fire(fn3, command=['--arg=value', '--nonothing']), ('value', {'nothing': False})) def testTraceFlag(self): - self.assertIsInstance( - fire.Fire(tc.BoolConverter, 'as-bool True -- --trace'), trace.FireTrace) - self.assertIsInstance( - fire.Fire(tc.BoolConverter, 'as-bool True -- -t'), trace.FireTrace) - self.assertIsInstance( - fire.Fire(tc.BoolConverter, '-- --trace'), trace.FireTrace) + with self.assertRaisesFireExit(0, 'Fire trace:\n'): + fire.Fire(tc.BoolConverter, command=['as-bool', 'True', '--', '--trace']) + with self.assertRaisesFireExit(0, 'Fire trace:\n'): + fire.Fire(tc.BoolConverter, command=['as-bool', 'True', '--', '-t']) + with self.assertRaisesFireExit(0, 'Fire trace:\n'): + fire.Fire(tc.BoolConverter, command=['--', '--trace']) def testHelpFlag(self): - self.assertIsNone(fire.Fire(tc.BoolConverter, 'as-bool True -- --help')) - self.assertIsNone(fire.Fire(tc.BoolConverter, 'as-bool True -- -h')) - self.assertIsNone(fire.Fire(tc.BoolConverter, '-- --help')) + with self.assertRaisesFireExit(0): + fire.Fire(tc.BoolConverter, command=['as-bool', 'True', '--', '--help']) + with self.assertRaisesFireExit(0): + fire.Fire(tc.BoolConverter, command=['as-bool', 'True', '--', '-h']) + with self.assertRaisesFireExit(0): + fire.Fire(tc.BoolConverter, command=['--', '--help']) def testHelpFlagAndTraceFlag(self): - self.assertIsInstance( - fire.Fire(tc.BoolConverter, 'as-bool True -- --help --trace'), - trace.FireTrace) - self.assertIsInstance( - fire.Fire(tc.BoolConverter, 'as-bool True -- -h -t'), trace.FireTrace) - self.assertIsInstance( - fire.Fire(tc.BoolConverter, '-- -h --trace'), trace.FireTrace) + with self.assertRaisesFireExit(0, 'Fire trace:\n.*SYNOPSIS'): + fire.Fire(tc.BoolConverter, + command=['as-bool', 'True', '--', '--help', '--trace']) + with self.assertRaisesFireExit(0, 'Fire trace:\n.*SYNOPSIS'): + fire.Fire(tc.BoolConverter, command=['as-bool', 'True', '--', '-h', '-t']) + with self.assertRaisesFireExit(0, 'Fire trace:\n.*SYNOPSIS'): + fire.Fire(tc.BoolConverter, command=['--', '-h', '--trace']) def testTabCompletionNoName(self): - with self.assertRaises(ValueError): - fire.Fire(tc.NoDefaults, '-- --completion') + completion_script = fire.Fire(tc.NoDefaults, command=['--', '--completion']) + self.assertIn('double', completion_script) + self.assertIn('triple', completion_script) def testTabCompletion(self): - completion_script = fire.Fire(tc.NoDefaults, '-- --completion', name='c') + completion_script = fire.Fire( + tc.NoDefaults, command=['--', '--completion'], name='c') self.assertIn('double', completion_script) self.assertIn('triple', completion_script) def testTabCompletionWithDict(self): actions = {'multiply': lambda a, b: a * b} - completion_script = fire.Fire(actions, '-- --completion', name='actCLI') + completion_script = fire.Fire( + actions, command=['--', '--completion'], name='actCLI') self.assertIn('actCLI', completion_script) self.assertIn('multiply', completion_script) def testBasicSeparator(self): # '-' is the default separator. - self.assertEqual(fire.Fire(tc.MixedDefaults, 'identity + _'), ('+', '_')) - self.assertEqual(fire.Fire(tc.MixedDefaults, 'identity _ + -'), ('_', '+')) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '+', '_']), ('+', '_')) + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['identity', '_', '+', '-']), ('_', '+')) # If we change the separator we can use '-' as an argument. self.assertEqual( - fire.Fire(tc.MixedDefaults, 'identity - _ -- --separator &'), + fire.Fire(tc.MixedDefaults, + command=['identity', '-', '_', '--', '--separator', '&']), ('-', '_')) # The separator triggers a function call, but there aren't enough arguments. - self.assertEqual(fire.Fire(tc.MixedDefaults, 'identity - _ +'), None) + with self.assertRaisesFireExit(2): + fire.Fire(tc.MixedDefaults, command=['identity', '-', '_', '+']) + + def testNonComparable(self): + """Fire should work with classes that disallow comparisons.""" + # Make sure this test passes both with a string command or a list command. + self.assertIsInstance( + fire.Fire(tc.NonComparable, command=''), tc.NonComparable) + self.assertIsInstance( + fire.Fire(tc.NonComparable, command=[]), tc.NonComparable) + + # The first separator instantiates the NonComparable object. + # The second separator causes Fire to check if the separator was necessary. + self.assertIsInstance( + fire.Fire(tc.NonComparable, command=['-', '-']), tc.NonComparable) def testExtraSeparators(self): self.assertEqual( - fire.Fire(tc.ReturnsObj, 'get-obj arg1 arg2 - - as-bool True'), True) + fire.Fire( + tc.ReturnsObj, + command=['get-obj', 'arg1', 'arg2', '-', '-', 'as-bool', 'True']), + True) self.assertEqual( - fire.Fire(tc.ReturnsObj, 'get-obj arg1 arg2 - - - as-bool True'), True) + fire.Fire( + tc.ReturnsObj, + command=['get-obj', 'arg1', 'arg2', '-', '-', '-', 'as-bool', + 'True']), + True) def testSeparatorForChaining(self): # Without a separator all args are consumed by get_obj. self.assertIsInstance( - fire.Fire(tc.ReturnsObj, 'get-obj arg1 arg2 as-bool True'), + fire.Fire(tc.ReturnsObj, + command=['get-obj', 'arg1', 'arg2', 'as-bool', 'True']), tc.BoolConverter) - # With a separator only the preceeding args are consumed by get_obj. + # With a separator only the preceding args are consumed by get_obj. self.assertEqual( - fire.Fire(tc.ReturnsObj, 'get-obj arg1 arg2 - as-bool True'), True) + fire.Fire( + tc.ReturnsObj, + command=['get-obj', 'arg1', 'arg2', '-', 'as-bool', 'True']), True) self.assertEqual( fire.Fire(tc.ReturnsObj, - 'get-obj arg1 arg2 & as-bool True -- --separator &'), + command=['get-obj', 'arg1', 'arg2', '&', 'as-bool', 'True', + '--', '--separator', '&']), True) self.assertEqual( fire.Fire(tc.ReturnsObj, - 'get-obj arg1 $$ as-bool True -- --separator $$'), + command=['get-obj', 'arg1', '$$', 'as-bool', 'True', '--', + '--separator', '$$']), True) + def testNegativeNumbers(self): + self.assertEqual( + fire.Fire(tc.MixedDefaults, + command=['sum', '--alpha', '-3', '--beta', '-4']), -11) + def testFloatForExpectedInt(self): self.assertEqual( - fire.Fire(tc.MixedDefaults, 'sum --alpha 2.2 --beta 3.0'), 8.2) + fire.Fire(tc.MixedDefaults, + command=['sum', '--alpha', '2.2', '--beta', '3.0']), 8.2) self.assertEqual( - fire.Fire(tc.NumberDefaults, 'integer_reciprocal --divisor 5.0'), 0.2) + fire.Fire( + tc.NumberDefaults, + command=['integer_reciprocal', '--divisor', '5.0']), 0.2) self.assertEqual( - fire.Fire(tc.NumberDefaults, 'integer_reciprocal 4.0'), 0.25) + fire.Fire(tc.NumberDefaults, command=['integer_reciprocal', '4.0']), + 0.25) def testClassInstantiation(self): - self.assertIsInstance(fire.Fire(tc.InstanceVars, '--arg1=a1 --arg2=a2'), + self.assertIsInstance(fire.Fire(tc.InstanceVars, + command=['--arg1=a1', '--arg2=a2']), tc.InstanceVars) - # Cannot instantiate a class with positional args by default. - self.assertIsNone(fire.Fire(tc.InstanceVars, 'a1 a2')) + with self.assertRaisesFireExit(2): + # Cannot instantiate a class with positional args. + fire.Fire(tc.InstanceVars, command=['a1', 'a2']) def testTraceErrors(self): # Class needs additional value but runs out of args. - self.assertIsNone(fire.Fire(tc.InstanceVars, 'a1')) - self.assertIsNone(fire.Fire(tc.InstanceVars, '--arg1=a1')) + with self.assertRaisesFireExit(2): + fire.Fire(tc.InstanceVars, command=['a1']) + with self.assertRaisesFireExit(2): + fire.Fire(tc.InstanceVars, command=['--arg1=a1']) + # Routine needs additional value but runs out of args. - self.assertIsNone(fire.Fire(tc.InstanceVars, 'a1 a2 - run b1')) - self.assertIsNone( - fire.Fire(tc.InstanceVars, '--arg1=a1 --arg2=a2 - run b1')) + with self.assertRaisesFireExit(2): + fire.Fire(tc.InstanceVars, command=['a1', 'a2', '-', 'run', 'b1']) + with self.assertRaisesFireExit(2): + fire.Fire(tc.InstanceVars, + command=['--arg1=a1', '--arg2=a2', '-', 'run b1']) + # Extra args cannot be consumed. - self.assertIsNone(fire.Fire(tc.InstanceVars, 'a1 a2 - run b1 b2 b3')) - self.assertIsNone( - fire.Fire(tc.InstanceVars, '--arg1=a1 --arg2=a2 - run b1 b2 b3')) + with self.assertRaisesFireExit(2): + fire.Fire(tc.InstanceVars, + command=['a1', 'a2', '-', 'run', 'b1', 'b2', 'b3']) + with self.assertRaisesFireExit(2): + fire.Fire( + tc.InstanceVars, + command=['--arg1=a1', '--arg2=a2', '-', 'run', 'b1', 'b2', 'b3']) + # Cannot find member to access. - self.assertIsNone(fire.Fire(tc.InstanceVars, 'a1 a2 - jog')) - self.assertIsNone(fire.Fire(tc.InstanceVars, '--arg1=a1 --arg2=a2 - jog')) + with self.assertRaisesFireExit(2): + fire.Fire(tc.InstanceVars, command=['a1', 'a2', '-', 'jog']) + with self.assertRaisesFireExit(2): + fire.Fire(tc.InstanceVars, command=['--arg1=a1', '--arg2=a2', '-', 'jog']) + + def testClassWithDefaultMethod(self): + self.assertEqual( + fire.Fire(tc.DefaultMethod, command=['double', '10']), 20 + ) + + def testClassWithInvalidProperty(self): + self.assertEqual( + fire.Fire(tc.InvalidProperty, command=['double', '10']), 20 + ) + + def testHelpKwargsDecorator(self): + # Issue #190, follow the wrapped method instead of crashing. + with self.assertRaisesFireExit(0): + fire.Fire(tc.decorated_method, command=['-h']) + with self.assertRaisesFireExit(0): + fire.Fire(tc.decorated_method, command=['--help']) + + def testFireAsyncio(self): + self.assertEqual(fire.Fire(tc.py3.WithAsyncio, + command=['double', '--count', '10']), 20) if __name__ == '__main__': - unittest.main() + testutils.main() diff --git a/fire/formatting.py b/fire/formatting.py new file mode 100644 index 00000000..68484c27 --- /dev/null +++ b/fire/formatting.py @@ -0,0 +1,93 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Formatting utilities for use in creating help text.""" + +from fire import formatting_windows # pylint: disable=unused-import +import termcolor + + +ELLIPSIS = '...' + + +def Indent(text, spaces=2): + lines = text.split('\n') + return '\n'.join( + ' ' * spaces + line if line else line + for line in lines) + + +def Bold(text): + return termcolor.colored(text, attrs=['bold']) + + +def Underline(text): + return termcolor.colored(text, attrs=['underline']) + + +def BoldUnderline(text): + return Bold(Underline(text)) + + +def WrappedJoin(items, separator=' | ', width=80): + """Joins the items by the separator, wrapping lines at the given width.""" + lines = [] + current_line = '' + for index, item in enumerate(items): + is_final_item = index == len(items) - 1 + if is_final_item: + if len(current_line) + len(item) <= width: + current_line += item + else: + lines.append(current_line.rstrip()) + current_line = item + else: + if len(current_line) + len(item) + len(separator) <= width: + current_line += item + separator + else: + lines.append(current_line.rstrip()) + current_line = item + separator + + lines.append(current_line) + return lines + + +def Error(text): + return termcolor.colored(text, color='red', attrs=['bold']) + + +def EllipsisTruncate(text, available_space, line_length): + """Truncate text from the end with ellipsis.""" + if available_space < len(ELLIPSIS): + available_space = line_length + # No need to truncate + if len(text) <= available_space: + return text + return text[:available_space - len(ELLIPSIS)] + ELLIPSIS + + +def EllipsisMiddleTruncate(text, available_space, line_length): + """Truncates text from the middle with ellipsis.""" + if available_space < len(ELLIPSIS): + available_space = line_length + if len(text) < available_space: + return text + available_string_len = available_space - len(ELLIPSIS) + first_half_len = int(available_string_len / 2) # start from middle + second_half_len = available_string_len - first_half_len + return text[:first_half_len] + ELLIPSIS + text[-second_half_len:] + + +def DoubleQuote(text): + return '"%s"' % text diff --git a/fire/formatting_test.py b/fire/formatting_test.py new file mode 100644 index 00000000..e0f6699d --- /dev/null +++ b/fire/formatting_test.py @@ -0,0 +1,78 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for formatting.py.""" + +from fire import formatting +from fire import testutils + +LINE_LENGTH = 80 + + +class FormattingTest(testutils.BaseTestCase): + + def test_bold(self): + text = formatting.Bold('hello') + self.assertIn(text, ['hello', '\x1b[1mhello\x1b[0m']) + + def test_underline(self): + text = formatting.Underline('hello') + self.assertIn(text, ['hello', '\x1b[4mhello\x1b[0m']) + + def test_indent(self): + text = formatting.Indent('hello', spaces=2) + self.assertEqual(' hello', text) + + def test_indent_multiple_lines(self): + text = formatting.Indent('hello\nworld', spaces=2) + self.assertEqual(' hello\n world', text) + + def test_wrap_one_item(self): + lines = formatting.WrappedJoin(['rice']) + self.assertEqual(['rice'], lines) + + def test_wrap_multiple_items(self): + lines = formatting.WrappedJoin(['rice', 'beans', 'chicken', 'cheese'], + width=15) + self.assertEqual(['rice | beans |', + 'chicken |', + 'cheese'], lines) + + def test_ellipsis_truncate(self): + text = 'This is a string' + truncated_text = formatting.EllipsisTruncate( + text=text, available_space=10, line_length=LINE_LENGTH) + self.assertEqual('This is...', truncated_text) + + def test_ellipsis_truncate_not_enough_space(self): + text = 'This is a string' + truncated_text = formatting.EllipsisTruncate( + text=text, available_space=2, line_length=LINE_LENGTH) + self.assertEqual('This is a string', truncated_text) + + def test_ellipsis_middle_truncate(self): + text = '1000000000L' + truncated_text = formatting.EllipsisMiddleTruncate( + text=text, available_space=7, line_length=LINE_LENGTH) + self.assertEqual('10...0L', truncated_text) + + def test_ellipsis_middle_truncate_not_enough_space(self): + text = '1000000000L' + truncated_text = formatting.EllipsisMiddleTruncate( + text=text, available_space=2, line_length=LINE_LENGTH) + self.assertEqual('1000000000L', truncated_text) + + +if __name__ == '__main__': + testutils.main() diff --git a/fire/formatting_windows.py b/fire/formatting_windows.py new file mode 100644 index 00000000..cee6f393 --- /dev/null +++ b/fire/formatting_windows.py @@ -0,0 +1,58 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module is used for enabling formatting on Windows.""" + +import ctypes +import os +import platform +import subprocess +import sys + +try: + import colorama # pylint: disable=g-import-not-at-top, # pytype: disable=import-error + HAS_COLORAMA = True +except ImportError: + HAS_COLORAMA = False + + +def initialize_or_disable(): + """Enables ANSI processing on Windows or disables it as needed.""" + if HAS_COLORAMA: + wrap = True + if (hasattr(sys.stdout, 'isatty') + and sys.stdout.isatty() + and platform.release() == '10'): + # Enables native ANSI sequences in console. + # Windows 10, 2016, and 2019 only. + + wrap = False + kernel32 = ctypes.windll.kernel32 # pytype: disable=module-attr + enable_virtual_terminal_processing = 0x04 + out_handle = kernel32.GetStdHandle(subprocess.STD_OUTPUT_HANDLE) # pylint: disable=line-too-long, # pytype: disable=module-attr + # GetConsoleMode fails if the terminal isn't native. + mode = ctypes.wintypes.DWORD() + if kernel32.GetConsoleMode(out_handle, ctypes.byref(mode)) == 0: + wrap = True + if not mode.value & enable_virtual_terminal_processing: + if kernel32.SetConsoleMode( + out_handle, mode.value | enable_virtual_terminal_processing) == 0: + # kernel32.SetConsoleMode to enable ANSI sequences failed + wrap = True + colorama.init(wrap=wrap) + else: + os.environ['ANSI_COLORS_DISABLED'] = '1' + +if sys.platform.startswith('win'): + initialize_or_disable() diff --git a/fire/helptext.py b/fire/helptext.py new file mode 100644 index 00000000..318d6276 --- /dev/null +++ b/fire/helptext.py @@ -0,0 +1,784 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for producing help strings for use in Fire CLIs. + +Can produce help strings suitable for display in Fire CLIs for any type of +Python object, module, class, or function. + +There are two types of informative strings: Usage and Help screens. + +Usage screens are shown when the user accesses a group or accesses a command +without calling it. A Usage screen shows information about how to use that group +or command. Usage screens are typically short and show the minimal information +necessary for the user to determine how to proceed. + +Help screens are shown when the user requests help with the help flag (--help). +Help screens are shown in a less-style console view, and contain detailed help +information. +""" + +import collections +import itertools + +from fire import completion +from fire import custom_descriptions +from fire import decorators +from fire import docstrings +from fire import formatting +from fire import inspectutils +from fire import value_types + +LINE_LENGTH = 80 +SECTION_INDENTATION = 4 +SUBSECTION_INDENTATION = 4 + + +def HelpText(component, trace=None, verbose=False): + """Gets the help string for the current component, suitable for a help screen. + + Args: + component: The component to construct the help string for. + trace: The Fire trace of the command so far. The command executed so far + can be extracted from this trace. + verbose: Whether to include private members in the help screen. + + Returns: + The full help screen as a string. + """ + # Preprocessing needed to create the sections: + info = inspectutils.Info(component) + actions_grouped_by_kind = _GetActionsGroupedByKind(component, verbose=verbose) + spec = inspectutils.GetFullArgSpec(component) + metadata = decorators.GetMetadata(component) + + # Sections: + name_section = _NameSection(component, info, trace=trace, verbose=verbose) + synopsis_section = _SynopsisSection( + component, actions_grouped_by_kind, spec, metadata, trace=trace) + description_section = _DescriptionSection(component, info) + # TODO(dbieber): Add returns and raises sections for functions. + + if callable(component): + args_and_flags_sections, notes_sections = _ArgsAndFlagsSections( + info, spec, metadata) + else: + args_and_flags_sections = [] + notes_sections = [] + usage_details_sections = _UsageDetailsSections(component, + actions_grouped_by_kind) + + sections = ( + [name_section, synopsis_section, description_section] + + args_and_flags_sections + + usage_details_sections + + notes_sections + ) + return '\n\n'.join( + _CreateOutputSection(*section) + for section in sections if section is not None + ) + + +def _NameSection(component, info, trace=None, verbose=False): + """The "Name" section of the help string.""" + + # Only include separators in the name in verbose mode. + current_command = _GetCurrentCommand(trace, include_separators=verbose) + summary = _GetSummary(info) + + # If the docstring is one of the messy builtin docstrings, show custom one. + if custom_descriptions.NeedsCustomDescription(component): + available_space = LINE_LENGTH - SECTION_INDENTATION - len(current_command + + ' - ') + summary = custom_descriptions.GetSummary(component, available_space, + LINE_LENGTH) + + if summary: + text = f'{current_command} - {summary}' + else: + text = current_command + return ('NAME', text) + + +def _SynopsisSection(component, actions_grouped_by_kind, spec, metadata, + trace=None): + """The "Synopsis" section of the help string.""" + current_command = _GetCurrentCommand(trace=trace, include_separators=True) + + possible_actions = _GetPossibleActions(actions_grouped_by_kind) + + continuations = [] + if possible_actions: + continuations.append(_GetPossibleActionsString(possible_actions)) + if callable(component): + callable_continuation = _GetArgsAndFlagsString(spec, metadata) + if callable_continuation: + continuations.append(callable_continuation) + elif trace: + # This continuation might be blank if no args are needed. + # In this case, show a separator. + continuations.append(trace.separator) + continuation = ' | '.join(continuations) + + text = f'{current_command} {continuation}' + return ('SYNOPSIS', text) + + +def _DescriptionSection(component, info): + """The "Description" sections of the help string. + + Args: + component: The component to produce the description section for. + info: The info dict for the component of interest. + + Returns: + Returns the description if available. If not, returns the summary. + If neither are available, returns None. + """ + if custom_descriptions.NeedsCustomDescription(component): + available_space = LINE_LENGTH - SECTION_INDENTATION + description = custom_descriptions.GetDescription(component, available_space, + LINE_LENGTH) + summary = custom_descriptions.GetSummary(component, available_space, + LINE_LENGTH) + else: + description = _GetDescription(info) + summary = _GetSummary(info) + # Fall back to summary if description is not available. + text = description or summary or None + if text: + return ('DESCRIPTION', text) + else: + return None + + +def _CreateKeywordOnlyFlagItem(flag, docstring_info, spec, short_arg): + return _CreateFlagItem( + flag, docstring_info, spec, required=flag not in spec.kwonlydefaults, + short_arg=short_arg) + + +def _GetShortFlags(flags): + """Gets a list of single-character flags that uniquely identify a flag. + + Args: + flags: list of strings representing flags + + Returns: + List of single character short flags, + where the character occurred at the start of a flag once. + """ + short_flags = [f[0] for f in flags] + short_flag_counts = collections.Counter(short_flags) + return [v for v in short_flags if short_flag_counts[v] == 1] + + +def _ArgsAndFlagsSections(info, spec, metadata): + """The "Args and Flags" sections of the help string.""" + args_with_no_defaults = spec.args[:len(spec.args) - len(spec.defaults)] + args_with_defaults = spec.args[len(spec.args) - len(spec.defaults):] + + # Check if positional args are allowed. If not, require flag syntax for args. + accepts_positional_args = metadata.get(decorators.ACCEPTS_POSITIONAL_ARGS) + + args_and_flags_sections = [] + notes_sections = [] + + docstring_info = info['docstring_info'] + + arg_items = [ + _CreateArgItem(arg, docstring_info, spec) + for arg in args_with_no_defaults + ] + + if spec.varargs: + arg_items.append( + _CreateArgItem(spec.varargs, docstring_info, spec) + ) + + if arg_items: + title = 'POSITIONAL ARGUMENTS' if accepts_positional_args else 'ARGUMENTS' + arguments_section = (title, '\n'.join(arg_items).rstrip('\n')) + args_and_flags_sections.append(arguments_section) + if args_with_no_defaults and accepts_positional_args: + notes_sections.append( + ('NOTES', 'You can also use flags syntax for POSITIONAL ARGUMENTS') + ) + + unique_short_args = _GetShortFlags(args_with_defaults) + positional_flag_items = [ + _CreateFlagItem( + flag, docstring_info, spec, required=False, + short_arg=flag[0] in unique_short_args + ) + for flag in args_with_defaults + ] + + unique_short_kwonly_flags = _GetShortFlags(spec.kwonlyargs) + kwonly_flag_items = [ + _CreateKeywordOnlyFlagItem( + flag, docstring_info, spec, + short_arg=flag[0] in unique_short_kwonly_flags + ) + for flag in spec.kwonlyargs + ] + flag_items = positional_flag_items + kwonly_flag_items + + if spec.varkw: + # Include kwargs documented via :key param: + documented_kwargs = [] + + # add short flags if possible + flags = docstring_info.args or [] + flag_names = [f.name for f in flags] + unique_short_flags = _GetShortFlags(flag_names) + for flag in flags: + if isinstance(flag, docstrings.KwargInfo): + if flag.name[0] in unique_short_flags: + short_name = flag.name[0] + flag_string = f'-{short_name}, --{flag.name}' + else: + flag_string = f'--{flag.name}' + + flag_item = _CreateFlagItem( + flag.name, docstring_info, spec, + flag_string=flag_string) + documented_kwargs.append(flag_item) + if documented_kwargs: + # Separate documented kwargs from other flags using a message + if flag_items: + message = 'The following flags are also accepted.' + item = _CreateItem(message, None, indent=4) + flag_items.append(item) + flag_items.extend(documented_kwargs) + + description = _GetArgDescription(spec.varkw, docstring_info) + if documented_kwargs: + message = 'Additional undocumented flags may also be accepted.' + elif flag_items: + message = 'Additional flags are accepted.' + else: + message = 'Flags are accepted.' + item = _CreateItem(message, description, indent=4) + flag_items.append(item) + + if flag_items: + flags_section = ('FLAGS', '\n'.join(flag_items)) + args_and_flags_sections.append(flags_section) + + return args_and_flags_sections, notes_sections + + +def _UsageDetailsSections(component, actions_grouped_by_kind): + """The usage details sections of the help string.""" + groups, commands, values, indexes = actions_grouped_by_kind + + sections = [] + if groups.members: + sections.append(_MakeUsageDetailsSection(groups)) + if commands.members: + sections.append(_MakeUsageDetailsSection(commands)) + if values.members: + sections.append(_ValuesUsageDetailsSection(component, values)) + if indexes.members: + sections.append(('INDEXES', _NewChoicesSection('INDEX', indexes.names))) + + return sections + + +def _GetSummary(info): + docstring_info = info['docstring_info'] + return docstring_info.summary if docstring_info.summary else None + + +def _GetDescription(info): + docstring_info = info['docstring_info'] + return docstring_info.description if docstring_info.description else None + + +def _GetArgsAndFlagsString(spec, metadata): + """The args and flags string for showing how to call a function. + + If positional arguments are accepted, the args will be shown as positional. + E.g. "ARG1 ARG2 [--flag=FLAG]" + + If positional arguments are disallowed, the args will be shown with flags + syntax. + E.g. "--arg1=ARG1 [--flag=FLAG]" + + Args: + spec: The full arg spec for the component to construct the args and flags + string for. + metadata: Metadata for the component, including whether it accepts + positional arguments. + + Returns: + The constructed args and flags string. + """ + args_with_no_defaults = spec.args[:len(spec.args) - len(spec.defaults)] + args_with_defaults = spec.args[len(spec.args) - len(spec.defaults):] + + # Check if positional args are allowed. If not, require flag syntax for args. + accepts_positional_args = metadata.get(decorators.ACCEPTS_POSITIONAL_ARGS) + + arg_and_flag_strings = [] + if args_with_no_defaults: + if accepts_positional_args: + arg_strings = [formatting.Underline(arg.upper()) + for arg in args_with_no_defaults] + else: + arg_strings = [ + f'--{arg}={formatting.Underline(arg.upper())}' + for arg in args_with_no_defaults + ] + arg_and_flag_strings.extend(arg_strings) + + # If there are any arguments that are treated as flags: + if args_with_defaults or spec.kwonlyargs or spec.varkw: + arg_and_flag_strings.append('') + + if spec.varargs: + varargs_underlined = formatting.Underline(spec.varargs.upper()) + varargs_string = f'[{varargs_underlined}]...' + arg_and_flag_strings.append(varargs_string) + + return ' '.join(arg_and_flag_strings) + + +def _GetPossibleActions(actions_grouped_by_kind): + """The list of possible action kinds.""" + possible_actions = [] + for action_group in actions_grouped_by_kind: + if action_group.members: + possible_actions.append(action_group.name) + return possible_actions + + +def _GetPossibleActionsString(possible_actions): + """A help screen string listing the possible action kinds available.""" + return ' | '.join(formatting.Underline(action.upper()) + for action in possible_actions) + + +def _GetActionsGroupedByKind(component, verbose=False): + """Gets lists of available actions, grouped by action kind.""" + groups = ActionGroup(name='group', plural='groups') + commands = ActionGroup(name='command', plural='commands') + values = ActionGroup(name='value', plural='values') + indexes = ActionGroup(name='index', plural='indexes') + + members = completion.VisibleMembers(component, verbose=verbose) + for member_name, member in members: + member_name = str(member_name) + if value_types.IsGroup(member): + groups.Add(name=member_name, member=member) + if value_types.IsCommand(member): + commands.Add(name=member_name, member=member) + if value_types.IsValue(member): + values.Add(name=member_name, member=member) + + if isinstance(component, (list, tuple)) and component: + component_len = len(component) + if component_len < 10: + indexes.Add(name=', '.join(str(x) for x in range(component_len))) + else: + indexes.Add(name=f'0..{component_len-1}') + + return [groups, commands, values, indexes] + + +def _GetCurrentCommand(trace=None, include_separators=True): + """Returns current command for the purpose of generating help text.""" + if trace: + current_command = trace.GetCommand(include_separators=include_separators) + else: + current_command = '' + return current_command + + +def _CreateOutputSection(name, content): + return f"""{formatting.Bold(name)} +{formatting.Indent(content, SECTION_INDENTATION)}""" + + +def _CreateArgItem(arg, docstring_info, spec): + """Returns a string describing a positional argument. + + Args: + arg: The name of the positional argument. + docstring_info: A docstrings.DocstringInfo namedtuple with information about + the containing function's docstring. + spec: An instance of fire.inspectutils.FullArgSpec, containing type and + default information about the arguments to a callable. + + Returns: + A string to be used in constructing the help screen for the function. + """ + + # The help string is indented, so calculate the maximum permitted length + # before indentation to avoid exceeding the maximum line length. + max_str_length = LINE_LENGTH - SECTION_INDENTATION - SUBSECTION_INDENTATION + + description = _GetArgDescription(arg, docstring_info) + + arg_string = formatting.BoldUnderline(arg.upper()) + + arg_type = _GetArgType(arg, spec) + arg_type = f'Type: {arg_type}' if arg_type else '' + available_space = max_str_length - len(arg_type) + arg_type = ( + formatting.EllipsisTruncate(arg_type, available_space, max_str_length)) + + description = '\n'.join(part for part in (arg_type, description) if part) + + return _CreateItem(arg_string, description, indent=SUBSECTION_INDENTATION) + + +def _CreateFlagItem(flag, docstring_info, spec, required=False, + flag_string=None, short_arg=False): + """Returns a string describing a flag using docstring and FullArgSpec info. + + Args: + flag: The name of the flag. + docstring_info: A docstrings.DocstringInfo namedtuple with information about + the containing function's docstring. + spec: An instance of fire.inspectutils.FullArgSpec, containing type and + default information about the arguments to a callable. + required: Whether the flag is required. + flag_string: If provided, use this string for the flag, rather than + constructing one from the flag name. + short_arg: Whether the flag has a short variation or not. + Returns: + A string to be used in constructing the help screen for the function. + """ + # pylint: disable=g-bad-todo + # TODO(MichaelCG8): Get type and default information from docstrings if it is + # not available in FullArgSpec. This will require updating + # fire.docstrings.parser(). + + # The help string is indented, so calculate the maximum permitted length + # before indentation to avoid exceeding the maximum line length. + max_str_length = LINE_LENGTH - SECTION_INDENTATION - SUBSECTION_INDENTATION + + description = _GetArgDescription(flag, docstring_info) + + if not flag_string: + flag_name_upper = formatting.Underline(flag.upper()) + flag_string = f'--{flag}={flag_name_upper}' + if required: + flag_string += ' (required)' + if short_arg: + short_flag = flag[0] + flag_string = f'-{short_flag}, {flag_string}' + + arg_type = _GetArgType(flag, spec) + arg_default = _GetArgDefault(flag, spec) + + # We need to handle the case where there is a default of None, but otherwise + # the argument has another type. + if arg_default == 'None': + arg_type = f'Optional[{arg_type}]' + + arg_type = f'Type: {arg_type}' if arg_type else '' + available_space = max_str_length - len(arg_type) + arg_type = ( + formatting.EllipsisTruncate(arg_type, available_space, max_str_length)) + + arg_default = f'Default: {arg_default}' if arg_default else '' + available_space = max_str_length - len(arg_default) + arg_default = ( + formatting.EllipsisTruncate(arg_default, available_space, max_str_length)) + + description = '\n'.join( + part for part in (arg_type, arg_default, description) if part + ) + + return _CreateItem(flag_string, description, indent=SUBSECTION_INDENTATION) + + +def _GetArgType(arg, spec): + """Returns a string describing the type of an argument. + + Args: + arg: The name of the argument. + spec: An instance of fire.inspectutils.FullArgSpec, containing type and + default information about the arguments to a callable. + Returns: + A string to be used in constructing the help screen for the function, the + empty string if the argument type is not available. + """ + if arg in spec.annotations: + arg_type = spec.annotations[arg] + try: + return arg_type.__qualname__ + except AttributeError: + # Some typing objects, such as typing.Union do not have either a __name__ + # or __qualname__ attribute. + # repr(typing.Union[int, str]) will return ': typing.Union[int, str]' + return repr(arg_type) + return '' + + +def _GetArgDefault(flag, spec): + """Returns a string describing a flag's default value. + + Args: + flag: The name of the flag. + spec: An instance of fire.inspectutils.FullArgSpec, containing type and + default information about the arguments to a callable. + Returns: + A string to be used in constructing the help screen for the function, the + empty string if the flag does not have a default or the default is not + available. + """ + num_defaults = len(spec.defaults) + args_with_defaults = spec.args[-num_defaults:] + + for arg, default in zip(args_with_defaults, spec.defaults): + if arg == flag: + return repr(default) + if flag in spec.kwonlydefaults: + return repr(spec.kwonlydefaults[flag]) + return '' + + +def _CreateItem(name, description, indent=2): + if not description: + return name + description = formatting.Indent(description, indent) + return f"""{name} +{description}""" + + +def _GetArgDescription(name, docstring_info): + if docstring_info.args: + for arg_in_docstring in docstring_info.args: + if arg_in_docstring.name in (name, f'*{name}', f'**{name}'): + return arg_in_docstring.description + return None + + +def _MakeUsageDetailsSection(action_group): + """Creates a usage details section for the provided action group.""" + item_strings = [] + for name, member in action_group.GetItems(): + info = inspectutils.Info(member) + item = name + docstring_info = info.get('docstring_info') + if (docstring_info + and not custom_descriptions.NeedsCustomDescription(member)): + summary = docstring_info.summary + elif custom_descriptions.NeedsCustomDescription(member): + summary = custom_descriptions.GetSummary( + member, LINE_LENGTH - SECTION_INDENTATION, LINE_LENGTH) + else: + summary = None + item = _CreateItem(name, summary) + item_strings.append(item) + return (action_group.plural.upper(), + _NewChoicesSection(action_group.name.upper(), item_strings)) + + +def _ValuesUsageDetailsSection(component, values): + """Creates a section tuple for the values section of the usage details.""" + value_item_strings = [] + for value_name, value in values.GetItems(): + del value + init_info = inspectutils.Info(component.__class__.__init__) + value_item = None + if 'docstring_info' in init_info: + init_docstring_info = init_info['docstring_info'] + if init_docstring_info.args: + for arg_info in init_docstring_info.args: + if arg_info.name == value_name: + value_item = _CreateItem(value_name, arg_info.description) + if value_item is None: + value_item = str(value_name) + value_item_strings.append(value_item) + return ('VALUES', _NewChoicesSection('VALUE', value_item_strings)) + + +def _NewChoicesSection(name, choices): + name_formatted = formatting.Bold(formatting.Underline(name)) + return _CreateItem( + f'{name_formatted} is one of the following:', + '\n' + '\n\n'.join(choices), + indent=1) + + +def UsageText(component, trace=None, verbose=False): + """Returns usage text for the given component. + + Args: + component: The component to determine the usage text for. + trace: The Fire trace object containing all metadata of current execution. + verbose: Whether to display the usage text in verbose mode. + + Returns: + String suitable for display in an error screen. + """ + # Get the command so far: + if trace: + command = trace.GetCommand() + needs_separating_hyphen_hyphen = trace.NeedsSeparatingHyphenHyphen() + else: + command = None + needs_separating_hyphen_hyphen = False + + if not command: + command = '' + + # Build the continuations for the command: + continued_command = command + + spec = inspectutils.GetFullArgSpec(component) + metadata = decorators.GetMetadata(component) + + # Usage for objects. + actions_grouped_by_kind = _GetActionsGroupedByKind(component, verbose=verbose) + possible_actions = _GetPossibleActions(actions_grouped_by_kind) + + continuations = [] + if possible_actions: + continuations.append(_GetPossibleActionsUsageString(possible_actions)) + + availability_lines = _UsageAvailabilityLines(actions_grouped_by_kind) + + if callable(component): + callable_items = _GetCallableUsageItems(spec, metadata) + if callable_items: + continuations.append(' '.join(callable_items)) + elif trace: + continuations.append(trace.separator) + availability_lines.extend(_GetCallableAvailabilityLines(spec)) + + if continuations: + continued_command += ' ' + ' | '.join(continuations) + help_command = ( + command + + (' -- ' if needs_separating_hyphen_hyphen else ' ') + + '--help' + ) + + return f"""Usage: {continued_command} +{''.join(availability_lines)} +For detailed information on this command, run: + {help_command}""" + + +def _GetPossibleActionsUsageString(possible_actions): + if possible_actions: + actions_str = '|'.join(possible_actions) + return f'<{actions_str}>' + return None + + +def _UsageAvailabilityLines(actions_grouped_by_kind): + availability_lines = [] + for action_group in actions_grouped_by_kind: + if action_group.members: + availability_line = _CreateAvailabilityLine( + header=f'available {action_group.plural}:', + items=action_group.names + ) + availability_lines.append(availability_line) + return availability_lines + + +def _GetCallableUsageItems(spec, metadata): + """A list of elements that comprise the usage summary for a callable.""" + args_with_no_defaults = spec.args[:len(spec.args) - len(spec.defaults)] + args_with_defaults = spec.args[len(spec.args) - len(spec.defaults):] + + # Check if positional args are allowed. If not, show flag syntax for args. + accepts_positional_args = metadata.get(decorators.ACCEPTS_POSITIONAL_ARGS) + + if not accepts_positional_args: + items = [f'--{arg}={arg.upper()}' + for arg in args_with_no_defaults] + else: + items = [arg.upper() for arg in args_with_no_defaults] + + # If there are any arguments that are treated as flags: + if args_with_defaults or spec.kwonlyargs or spec.varkw: + items.append('') + + if spec.varargs: + items.append(f'[{spec.varargs.upper()}]...') + + return items + + +def _KeywordOnlyArguments(spec, required=True): + return (flag for flag in spec.kwonlyargs + if required != (flag in spec.kwonlydefaults)) + + +def _GetCallableAvailabilityLines(spec): + """The list of availability lines for a callable for use in a usage string.""" + args_with_defaults = spec.args[len(spec.args) - len(spec.defaults):] + + # TODO(dbieber): Handle args_with_no_defaults if not accepts_positional_args. + optional_flags = [f'--{flag}' for flag in itertools.chain( + args_with_defaults, _KeywordOnlyArguments(spec, required=False))] + required_flags = [ + f'--{flag}' for flag in _KeywordOnlyArguments(spec, required=True) + ] + + # Flags section: + availability_lines = [] + if optional_flags: + availability_lines.append( + _CreateAvailabilityLine(header='optional flags:', items=optional_flags, + header_indent=2)) + if required_flags: + availability_lines.append( + _CreateAvailabilityLine(header='required flags:', items=required_flags, + header_indent=2)) + if spec.varkw: + additional_flags = ('additional flags are accepted' + if optional_flags or required_flags else + 'flags are accepted') + availability_lines.append( + _CreateAvailabilityLine(header=additional_flags, items=[], + header_indent=2)) + return availability_lines + + +def _CreateAvailabilityLine(header, items, + header_indent=2, items_indent=25, + line_length=LINE_LENGTH): + items_width = line_length - items_indent + items_text = '\n'.join(formatting.WrappedJoin(items, width=items_width)) + indented_items_text = formatting.Indent(items_text, spaces=items_indent) + indented_header = formatting.Indent(header, spaces=header_indent) + return indented_header + indented_items_text[len(indented_header):] + '\n' + + +class ActionGroup: + """A group of actions of the same kind.""" + + def __init__(self, name, plural): + self.name = name + self.plural = plural + self.names = [] + self.members = [] + + def Add(self, name, member=None): + self.names.append(name) + self.members.append(member) + + def GetItems(self): + return zip(self.names, self.members) diff --git a/fire/helptext_test.py b/fire/helptext_test.py new file mode 100644 index 00000000..aeff5240 --- /dev/null +++ b/fire/helptext_test.py @@ -0,0 +1,596 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the helptext module.""" + +import os +import textwrap + +from fire import formatting +from fire import helptext +from fire import test_components as tc +from fire import testutils +from fire import trace + + +class HelpTest(testutils.BaseTestCase): + + def setUp(self): + super().setUp() + os.environ['ANSI_COLORS_DISABLED'] = '1' + + def testHelpTextNoDefaults(self): + component = tc.NoDefaults + help_screen = helptext.HelpText( + component=component, + trace=trace.FireTrace(component, name='NoDefaults')) + self.assertIn('NAME\n NoDefaults', help_screen) + self.assertIn('SYNOPSIS\n NoDefaults', help_screen) + self.assertNotIn('DESCRIPTION', help_screen) + self.assertNotIn('NOTES', help_screen) + + def testHelpTextNoDefaultsObject(self): + component = tc.NoDefaults() + help_screen = helptext.HelpText( + component=component, + trace=trace.FireTrace(component, name='NoDefaults')) + self.assertIn('NAME\n NoDefaults', help_screen) + self.assertIn('SYNOPSIS\n NoDefaults COMMAND', help_screen) + self.assertNotIn('DESCRIPTION', help_screen) + self.assertIn('COMMANDS\n COMMAND is one of the following:', + help_screen) + self.assertIn('double', help_screen) + self.assertIn('triple', help_screen) + self.assertNotIn('NOTES', help_screen) + + def testHelpTextFunction(self): + component = tc.NoDefaults().double + help_screen = helptext.HelpText( + component=component, + trace=trace.FireTrace(component, name='double')) + self.assertIn('NAME\n double', help_screen) + self.assertIn('SYNOPSIS\n double COUNT', help_screen) + self.assertNotIn('DESCRIPTION', help_screen) + self.assertIn('POSITIONAL ARGUMENTS\n COUNT', help_screen) + self.assertIn( + 'NOTES\n You can also use flags syntax for POSITIONAL ARGUMENTS', + help_screen) + + def testHelpTextFunctionWithDefaults(self): + component = tc.WithDefaults().triple + help_screen = helptext.HelpText( + component=component, + trace=trace.FireTrace(component, name='triple')) + self.assertIn('NAME\n triple', help_screen) + self.assertIn('SYNOPSIS\n triple ', help_screen) + self.assertNotIn('DESCRIPTION', help_screen) + self.assertIn( + 'FLAGS\n -c, --count=COUNT\n Default: 0', + help_screen) + self.assertNotIn('NOTES', help_screen) + + def testHelpTextFunctionWithLongDefaults(self): + component = tc.WithDefaults().text + help_screen = helptext.HelpText( + component=component, + trace=trace.FireTrace(component, name='text')) + self.assertIn('NAME\n text', help_screen) + self.assertIn('SYNOPSIS\n text ', help_screen) + self.assertNotIn('DESCRIPTION', help_screen) + self.assertIn( + 'FLAGS\n -s, --string=STRING\n' + ' Default: \'0001020304050607080910' + '1112131415161718192021222324252627282...', + help_screen) + self.assertNotIn('NOTES', help_screen) + + def testHelpTextFunctionWithKwargs(self): + component = tc.fn_with_kwarg + help_screen = helptext.HelpText( + component=component, + trace=trace.FireTrace(component, name='text')) + self.assertIn('NAME\n text', help_screen) + self.assertIn('SYNOPSIS\n text ARG1 ARG2 ', help_screen) + self.assertIn('DESCRIPTION\n Function with kwarg', help_screen) + self.assertIn( + 'FLAGS\n --arg3\n Description of arg3.\n ' + 'Additional undocumented flags may also be accepted.', + help_screen) + + def testHelpTextFunctionWithKwargsAndDefaults(self): + component = tc.fn_with_kwarg_and_defaults + help_screen = helptext.HelpText( + component=component, + trace=trace.FireTrace(component, name='text')) + self.assertIn('NAME\n text', help_screen) + self.assertIn('SYNOPSIS\n text ARG1 ARG2 ', help_screen) + self.assertIn('DESCRIPTION\n Function with kwarg', help_screen) + self.assertIn( + 'FLAGS\n -o, --opt=OPT\n Default: True\n' + ' The following flags are also accepted.' + '\n --arg3\n Description of arg3.\n ' + 'Additional undocumented flags may also be accepted.', + help_screen) + + def testHelpTextFunctionWithDefaultsAndTypes(self): + component = ( + tc.py3.WithDefaultsAndTypes().double) # pytype: disable=module-attr + help_screen = helptext.HelpText( + component=component, + trace=trace.FireTrace(component, name='double')) + self.assertIn('NAME\n double', help_screen) + self.assertIn('SYNOPSIS\n double ', help_screen) + self.assertIn('DESCRIPTION', help_screen) + self.assertIn( + 'FLAGS\n -c, --count=COUNT\n Type: float\n Default: 0', + help_screen) + self.assertNotIn('NOTES', help_screen) + + def testHelpTextFunctionWithTypesAndDefaultNone(self): + component = ( + tc.py3.WithDefaultsAndTypes().get_int) # pytype: disable=module-attr + help_screen = helptext.HelpText( + component=component, + trace=trace.FireTrace(component, name='get_int')) + self.assertIn('NAME\n get_int', help_screen) + self.assertIn('SYNOPSIS\n get_int ', help_screen) + self.assertNotIn('DESCRIPTION', help_screen) + self.assertIn( + 'FLAGS\n -v, --value=VALUE\n' + ' Type: Optional[int]\n Default: None', + help_screen) + self.assertNotIn('NOTES', help_screen) + + def testHelpTextFunctionWithTypes(self): + component = tc.py3.WithTypes().double # pytype: disable=module-attr + help_screen = helptext.HelpText( + component=component, + trace=trace.FireTrace(component, name='double')) + self.assertIn('NAME\n double', help_screen) + self.assertIn('SYNOPSIS\n double COUNT', help_screen) + self.assertIn('DESCRIPTION', help_screen) + self.assertIn( + 'POSITIONAL ARGUMENTS\n COUNT\n Type: float', + help_screen) + self.assertIn( + 'NOTES\n You can also use flags syntax for POSITIONAL ARGUMENTS', + help_screen) + + def testHelpTextFunctionWithLongTypes(self): + component = tc.py3.WithTypes().long_type # pytype: disable=module-attr + help_screen = helptext.HelpText( + component=component, + trace=trace.FireTrace(component, name='long_type')) + self.assertIn('NAME\n long_type', help_screen) + self.assertIn('SYNOPSIS\n long_type LONG_OBJ', help_screen) + self.assertNotIn('DESCRIPTION', help_screen) + # TODO(dbieber): Assert type is displayed correctly. Type displayed + # differently in Travis vs in Google. + # self.assertIn( + # 'POSITIONAL ARGUMENTS\n LONG_OBJ\n' + # ' Type: typing.Tuple[typing.Tuple[' + # 'typing.Tuple[typing.Tuple[typing.Tupl...', + # help_screen) + self.assertIn( + 'NOTES\n You can also use flags syntax for POSITIONAL ARGUMENTS', + help_screen) + + def testHelpTextFunctionWithBuiltin(self): + component = 'test'.upper + help_screen = helptext.HelpText( + component=component, + trace=trace.FireTrace(component, 'upper')) + self.assertIn('NAME\n upper', help_screen) + self.assertIn('SYNOPSIS\n upper', help_screen) + # We don't check description content here since the content is python + # version dependent. + self.assertIn('DESCRIPTION\n', help_screen) + self.assertNotIn('NOTES', help_screen) + + def testHelpTextFunctionIntType(self): + component = int + help_screen = helptext.HelpText( + component=component, trace=trace.FireTrace(component, 'int')) + self.assertIn('NAME\n int', help_screen) + self.assertIn('SYNOPSIS\n int', help_screen) + # We don't check description content here since the content is python + # version dependent. + self.assertIn('DESCRIPTION\n', help_screen) + + def testHelpTextEmptyList(self): + component = [] + help_screen = helptext.HelpText( + component=component, + trace=trace.FireTrace(component, 'list')) + self.assertIn('NAME\n list', help_screen) + self.assertIn('SYNOPSIS\n list COMMAND', help_screen) + # TODO(zuhaochen): Change assertion after custom description is + # implemented for list type. + self.assertNotIn('DESCRIPTION', help_screen) + # We don't check the listed commands either since the list API could + # potentially change between Python versions. + self.assertIn('COMMANDS\n COMMAND is one of the following:\n', + help_screen) + + def testHelpTextShortList(self): + component = [10] + help_screen = helptext.HelpText( + component=component, + trace=trace.FireTrace(component, 'list')) + self.assertIn('NAME\n list', help_screen) + self.assertIn('SYNOPSIS\n list COMMAND', help_screen) + # TODO(zuhaochen): Change assertion after custom description is + # implemented for list type. + self.assertNotIn('DESCRIPTION', help_screen) + + # We don't check the listed commands comprehensively since the list API + # could potentially change between Python versions. Check a few + # functions(command) that we're confident likely remain available. + self.assertIn('COMMANDS\n COMMAND is one of the following:\n', + help_screen) + self.assertIn(' append\n', help_screen) + + def testHelpTextInt(self): + component = 7 + help_screen = helptext.HelpText( + component=component, trace=trace.FireTrace(component, '7')) + self.assertIn('NAME\n 7', help_screen) + self.assertIn('SYNOPSIS\n 7 COMMAND | VALUE', help_screen) + # TODO(zuhaochen): Change assertion after implementing custom + # description for int. + self.assertNotIn('DESCRIPTION', help_screen) + self.assertIn('COMMANDS\n COMMAND is one of the following:\n', + help_screen) + self.assertIn('VALUES\n VALUE is one of the following:\n', help_screen) + + def testHelpTextNoInit(self): + component = tc.OldStyleEmpty + help_screen = helptext.HelpText( + component=component, + trace=trace.FireTrace(component, 'OldStyleEmpty')) + self.assertIn('NAME\n OldStyleEmpty', help_screen) + self.assertIn('SYNOPSIS\n OldStyleEmpty', help_screen) + + def testHelpTextKeywordOnlyArgumentsWithDefault(self): + component = tc.py3.KeywordOnly.with_default # pytype: disable=module-attr + output = helptext.HelpText( + component=component, trace=trace.FireTrace(component, 'with_default')) + self.assertIn('NAME\n with_default', output) + self.assertIn('FLAGS\n -x, --x=X', output) + + def testHelpTextKeywordOnlyArgumentsWithoutDefault(self): + component = tc.py3.KeywordOnly.double # pytype: disable=module-attr + output = helptext.HelpText( + component=component, trace=trace.FireTrace(component, 'double')) + self.assertIn('NAME\n double', output) + self.assertIn('FLAGS\n -c, --count=COUNT (required)', output) + + def testHelpTextFunctionMixedDefaults(self): + component = tc.py3.HelpTextComponent().identity + t = trace.FireTrace(component, name='FunctionMixedDefaults') + output = helptext.HelpText(component, trace=t) + self.assertIn('NAME\n FunctionMixedDefaults', output) + self.assertIn('FunctionMixedDefaults ', output) + self.assertIn('--alpha=ALPHA (required)', output) + self.assertIn('--beta=BETA\n Default: \'0\'', output) + + def testHelpScreen(self): + component = tc.ClassWithDocstring() + t = trace.FireTrace(component, name='ClassWithDocstring') + help_output = helptext.HelpText(component, t) + expected_output = """ +NAME + ClassWithDocstring - Test class for testing help text output. + +SYNOPSIS + ClassWithDocstring COMMAND | VALUE + +DESCRIPTION + This is some detail description of this test class. + +COMMANDS + COMMAND is one of the following: + + print_msg + Prints a message. + +VALUES + VALUE is one of the following: + + message + The default message to print.""" + self.assertEqual(textwrap.dedent(expected_output).strip(), + help_output.strip()) + + def testHelpScreenForFunctionDocstringWithLineBreak(self): + component = tc.ClassWithMultilineDocstring.example_generator + t = trace.FireTrace(component, name='example_generator') + help_output = helptext.HelpText(component, t) + expected_output = """ + NAME + example_generator - Generators have a ``Yields`` section instead of a ``Returns`` section. + + SYNOPSIS + example_generator N + + DESCRIPTION + Generators have a ``Yields`` section instead of a ``Returns`` section. + + POSITIONAL ARGUMENTS + N + The upper limit of the range to generate, from 0 to `n` - 1. + + NOTES + You can also use flags syntax for POSITIONAL ARGUMENTS""" + self.assertEqual(textwrap.dedent(expected_output).strip(), + help_output.strip()) + + def testHelpScreenForFunctionFunctionWithDefaultArgs(self): + component = tc.WithDefaults().double + t = trace.FireTrace(component, name='double') + help_output = helptext.HelpText(component, t) + expected_output = """ + NAME + double - Returns the input multiplied by 2. + + SYNOPSIS + double + + DESCRIPTION + Returns the input multiplied by 2. + + FLAGS + -c, --count=COUNT + Default: 0 + Input number that you want to double.""" + self.assertEqual(textwrap.dedent(expected_output).strip(), + help_output.strip()) + + def testHelpTextUnderlineFlag(self): + component = tc.WithDefaults().triple + t = trace.FireTrace(component, name='triple') + help_screen = helptext.HelpText(component, t) + self.assertIn(formatting.Bold('NAME') + '\n triple', help_screen) + self.assertIn( + formatting.Bold('SYNOPSIS') + '\n triple ', + help_screen) + self.assertIn( + formatting.Bold('FLAGS') + '\n -c, --' + + formatting.Underline('count'), + help_screen) + + def testHelpTextBoldCommandName(self): + component = tc.ClassWithDocstring() + t = trace.FireTrace(component, name='ClassWithDocstring') + help_screen = helptext.HelpText(component, t) + self.assertIn( + formatting.Bold('NAME') + '\n ClassWithDocstring', help_screen) + self.assertIn(formatting.Bold('COMMANDS') + '\n', help_screen) + self.assertIn( + formatting.BoldUnderline('COMMAND') + ' is one of the following:\n', + help_screen) + self.assertIn(formatting.Bold('print_msg') + '\n', help_screen) + + def testHelpTextObjectWithGroupAndValues(self): + component = tc.TypedProperties() + t = trace.FireTrace(component, name='TypedProperties') + help_screen = helptext.HelpText( + component=component, trace=t, verbose=True) + print(help_screen) + self.assertIn('GROUPS', help_screen) + self.assertIn('GROUP is one of the following:', help_screen) + self.assertIn( + 'charlie\n Class with functions that have default arguments.', + help_screen) + self.assertIn('VALUES', help_screen) + self.assertIn('VALUE is one of the following:', help_screen) + self.assertIn('alpha', help_screen) + + def testHelpTextNameSectionCommandWithSeparator(self): + component = 9 + t = trace.FireTrace(component, name='int', separator='-') + t.AddSeparator() + help_screen = helptext.HelpText(component=component, trace=t, verbose=False) + self.assertIn('int -', help_screen) + self.assertNotIn('int - -', help_screen) + + def testHelpTextNameSectionCommandWithSeparatorVerbose(self): + component = tc.WithDefaults().double + t = trace.FireTrace(component, name='double', separator='-') + t.AddSeparator() + help_screen = helptext.HelpText(component=component, trace=t, verbose=True) + self.assertIn('double -', help_screen) + self.assertIn('double - -', help_screen) + + def testHelpTextMultipleKeywoardArgumentsWithShortArgs(self): + component = tc.fn_with_multiple_defaults + t = trace.FireTrace(component, name='shortargs') + help_screen = helptext.HelpText(component, t) + self.assertIn(formatting.Bold('NAME') + '\n shortargs', help_screen) + self.assertIn( + formatting.Bold('SYNOPSIS') + '\n shortargs ', + help_screen) + self.assertIn( + formatting.Bold('FLAGS') + '\n -f, --first', + help_screen) + self.assertIn('\n --last', help_screen) + self.assertIn('\n --late', help_screen) + + +class UsageTest(testutils.BaseTestCase): + + def testUsageOutput(self): + component = tc.NoDefaults() + t = trace.FireTrace(component, name='NoDefaults') + usage_output = helptext.UsageText(component, trace=t, verbose=False) + expected_output = """ + Usage: NoDefaults + available commands: double | triple + + For detailed information on this command, run: + NoDefaults --help""" + + self.assertEqual( + usage_output, + textwrap.dedent(expected_output).lstrip('\n')) + + def testUsageOutputVerbose(self): + component = tc.NoDefaults() + t = trace.FireTrace(component, name='NoDefaults') + usage_output = helptext.UsageText(component, trace=t, verbose=True) + expected_output = """ + Usage: NoDefaults + available commands: double | triple + + For detailed information on this command, run: + NoDefaults --help""" + self.assertEqual( + usage_output, + textwrap.dedent(expected_output).lstrip('\n')) + + def testUsageOutputMethod(self): + component = tc.NoDefaults().double + t = trace.FireTrace(component, name='NoDefaults') + t.AddAccessedProperty(component, 'double', ['double'], None, None) + usage_output = helptext.UsageText(component, trace=t, verbose=False) + expected_output = """ + Usage: NoDefaults double COUNT + + For detailed information on this command, run: + NoDefaults double --help""" + self.assertEqual( + usage_output, + textwrap.dedent(expected_output).lstrip('\n')) + + def testUsageOutputFunctionWithHelp(self): + component = tc.function_with_help + t = trace.FireTrace(component, name='function_with_help') + usage_output = helptext.UsageText(component, trace=t, verbose=False) + expected_output = """ + Usage: function_with_help + optional flags: --help + + For detailed information on this command, run: + function_with_help -- --help""" + self.assertEqual( + usage_output, + textwrap.dedent(expected_output).lstrip('\n')) + + def testUsageOutputFunctionWithDocstring(self): + component = tc.multiplier_with_docstring + t = trace.FireTrace(component, name='multiplier_with_docstring') + usage_output = helptext.UsageText(component, trace=t, verbose=False) + expected_output = """ + Usage: multiplier_with_docstring NUM + optional flags: --rate + + For detailed information on this command, run: + multiplier_with_docstring --help""" + self.assertEqual( + textwrap.dedent(expected_output).lstrip('\n'), + usage_output) + + def testUsageOutputFunctionMixedDefaults(self): + component = tc.py3.HelpTextComponent().identity + t = trace.FireTrace(component, name='FunctionMixedDefaults') + usage_output = helptext.UsageText(component, trace=t, verbose=False) + expected_output = """ + Usage: FunctionMixedDefaults + optional flags: --beta + required flags: --alpha + + For detailed information on this command, run: + FunctionMixedDefaults --help""" + expected_output = textwrap.dedent(expected_output).lstrip('\n') + self.assertEqual(expected_output, usage_output) + + def testUsageOutputCallable(self): + # This is both a group and a command. + component = tc.CallableWithKeywordArgument() + t = trace.FireTrace(component, name='CallableWithKeywordArgument', + separator='@') + usage_output = helptext.UsageText(component, trace=t, verbose=False) + expected_output = """ + Usage: CallableWithKeywordArgument | + available commands: print_msg + flags are accepted + + For detailed information on this command, run: + CallableWithKeywordArgument -- --help""" + self.assertEqual( + textwrap.dedent(expected_output).lstrip('\n'), + usage_output) + + def testUsageOutputConstructorWithParameter(self): + component = tc.InstanceVars + t = trace.FireTrace(component, name='InstanceVars') + usage_output = helptext.UsageText(component, trace=t, verbose=False) + expected_output = """ + Usage: InstanceVars --arg1=ARG1 --arg2=ARG2 + + For detailed information on this command, run: + InstanceVars --help""" + self.assertEqual( + textwrap.dedent(expected_output).lstrip('\n'), + usage_output) + + def testUsageOutputConstructorWithParameterVerbose(self): + component = tc.InstanceVars + t = trace.FireTrace(component, name='InstanceVars') + usage_output = helptext.UsageText(component, trace=t, verbose=True) + expected_output = """ + Usage: InstanceVars | --arg1=ARG1 --arg2=ARG2 + available commands: run + + For detailed information on this command, run: + InstanceVars --help""" + self.assertEqual( + textwrap.dedent(expected_output).lstrip('\n'), + usage_output) + + def testUsageOutputEmptyDict(self): + component = {} + t = trace.FireTrace(component, name='EmptyDict') + usage_output = helptext.UsageText(component, trace=t, verbose=True) + expected_output = """ + Usage: EmptyDict + + For detailed information on this command, run: + EmptyDict --help""" + self.assertEqual( + textwrap.dedent(expected_output).lstrip('\n'), + usage_output) + + def testUsageOutputNone(self): + component = None + t = trace.FireTrace(component, name='None') + usage_output = helptext.UsageText(component, trace=t, verbose=True) + expected_output = """ + Usage: None + + For detailed information on this command, run: + None --help""" + self.assertEqual( + textwrap.dedent(expected_output).lstrip('\n'), + usage_output) + + def testInitRequiresFlagSyntaxSubclassNamedTuple(self): + component = tc.SubPoint + t = trace.FireTrace(component, name='SubPoint') + usage_output = helptext.UsageText(component, trace=t, verbose=False) + expected_output = 'Usage: SubPoint --x=X --y=Y' + self.assertIn(expected_output, usage_output) + +if __name__ == '__main__': + testutils.main() diff --git a/fire/helputils.py b/fire/helputils.py deleted file mode 100644 index 43aff84d..00000000 --- a/fire/helputils.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright (C) 2017 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility for producing help strings for use in Fire CLIs. - -Can produce help strings suitable for display in Fire CLIs for any type of -Python object, module, class, or function. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import inspect - -from fire import completion -from fire import inspectutils - - -def _NormalizeField(field): - """Takes a field name and turns it into a human readable name for display. - - Args: - field: The field name, used to index into the inspection dict. - Returns: - The human readable name, suitable for display in a help string. - """ - if field == 'type_name': - field = 'type' - return (field[0].upper() + field[1:]).replace('_', ' ') - - -def _DisplayValue(info, field, padding): - """Gets the value of field from the dict info for display. - - Args: - info: The dict with information about the component. - field: The field to access for display. - padding: Number of spaces to indent text to line up with first-line text. - Returns: - The value of the field for display, or None if no value should be displayed. - """ - value = info.get(field) - - if value is None: - return - - skip_doc_types = ('dict', 'list', 'unicode', 'int', 'float', 'bool') - - if field == 'docstring': - if info.get('type_name') in skip_doc_types: - # Don't show the boring default docstrings for these types. - return None - elif value == '': - return None - - elif field == 'usage': - lines = [] - for index, line in enumerate(value.split('\n')): - if index > 0: - line = ' ' * padding + line - lines.append(line) - return '\n'.join(lines) - - return value - - -def HelpString(component, trace=None, verbose=False): - """Returns a help string for a supplied component. - - The component can be any Python class, object, function, module, etc. - - Args: - component: The component to determine the help string for. - trace: The Fire trace leading to this component. - verbose: Whether to include private members in the help string. - Returns: - String suitable for display giving information about the component. - """ - info = inspectutils.Info(component) - info['usage'] = UsageString(component, trace, verbose) - - fields = [ - 'type_name', - 'string_form', - 'file', - 'line', - - 'docstring', - 'init_docstring', - 'class_docstring', - 'call_docstring', - 'length', - - 'usage', - ] - - max_size = max( - len(_NormalizeField(field)) + 1 - for field in fields - if field in info and info[field]) - format_string = '{{field:{max_size}s}} {{value}}'.format(max_size=max_size) - - lines = [] - for field in fields: - value = _DisplayValue(info, field, padding=max_size + 1) - if value: - if lines and field == 'usage': - lines.append('') # Ensure a blank line before usage. - - lines.append(format_string.format( - field=_NormalizeField(field) + ':', - value=value, - )) - return '\n'.join(lines) - - -def _UsageStringFromFnDetails(command, args, varargs, keywords, defaults): - """Get a usage string from the function details for the given command. - - The strings look like: - command --arg ARG [--opt OPT] [VAR ...] [--KWARGS ...] - - Args: - command: The command leading up to the function. - args: The args accepted by the function. - varargs: If not None, a string naming the *varargs variable used by the fn. - keywords: If not None, a string naming the **kwargs varargs used by the fn. - defaults: The default values for args accepted by the function. - Returns: - The usage string for the function. - """ - num_required_args = len(args) - len(defaults) - - help_flags = [] - help_positional = [] - for index, arg in enumerate(args): - flag = arg.replace('_', '-') - if index < num_required_args: - help_flags.append('--{flag} {value}'.format(flag=flag, value=arg.upper())) - help_positional.append('{value}'.format(value=arg.upper())) - else: - help_flags.append('[--{flag} {value}]'.format( - flag=flag, value=arg.upper())) - help_positional.append('[{value}]'.format(value=arg.upper())) - - if varargs: - help_flags.append('[{var} ...]'.format(var=varargs.upper())) - help_positional.append('[{var} ...]'.format(var=varargs.upper())) - - if keywords: - help_flags.append('[--{kwarg} ...]'.format(kwarg=keywords.upper())) - help_positional.append('[--{kwarg} ...]'.format(kwarg=keywords.upper())) - - commands_flags = command + ' '.join(help_flags) - commands_positional = command + ' '.join(help_positional) - commands = [commands_positional] - - if commands_flags != commands_positional: - commands.append(commands_flags) - - return '\n'.join(commands) - - -def UsageString(component, trace=None, verbose=False): - """Returns a string showing how to use the component as a Fire command.""" - command = trace.GetCommand() + ' ' if trace else '' - - if inspect.isroutine(component) or inspect.isclass(component): - args, varargs, keywords, defaults = inspectutils.GetArgSpec(component) - return _UsageStringFromFnDetails(command, args, varargs, keywords, defaults) - - elif isinstance(component, (list, tuple)): - length = len(component) - if length == 0: - return command - elif length == 1: - return command + '[0]' - else: - return command + '[0..{cap}]'.format(cap=length - 1) - - else: - completions = completion.Completions(component, verbose) - if command: - completions = [''] + completions - return '\n'.join(command + end for end in completions) diff --git a/fire/helputils_test.py b/fire/helputils_test.py deleted file mode 100644 index 4bc8acc9..00000000 --- a/fire/helputils_test.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright (C) 2017 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from fire import helputils -from fire import test_components as tc -import six - -import unittest - - -class HelpUtilsTest(unittest.TestCase): - - def testHelpStringClass(self): - helpstring = helputils.HelpString(tc.NoDefaults) - self.assertIn('Type: type', helpstring) - self.assertIn("String form: ", - helpstring) - self.assertIn('test_components.py', helpstring) - self.assertIn('Line: ', helpstring) - self.assertNotIn('Usage', helpstring) - - def testHelpStringObject(self): - obj = tc.NoDefaults() - helpstring = helputils.HelpString(obj) - self.assertIn('Type: NoDefaults', helpstring) - self.assertIn('String form: ", helpstring) - else: - self.assertIn("String form: ", helpstring) - self.assertNotIn('Usage', helpstring) - - def testHelpStringEmptyList(self): - helpstring = helputils.HelpString([]) - self.assertIn('Type: list', helpstring) - self.assertIn('String form: []', helpstring) - self.assertIn('Length: 0', helpstring) - - def testHelpStringShortList(self): - helpstring = helputils.HelpString([10]) - self.assertIn('Type: list', helpstring) - self.assertIn('String form: [10]', helpstring) - self.assertIn('Length: 1', helpstring) - self.assertIn('Usage: [0]', helpstring) # [] denotes optional. - - def testHelpStringInt(self): - helpstring = helputils.HelpString(7) - self.assertIn('Type: int', helpstring) - self.assertIn('String form: 7', helpstring) - self.assertIn('Usage: bit-length\n' - ' conjugate\n' - ' denominator\n', helpstring) - - def testHelpClassNoInit(self): - helpstring = helputils.HelpString(tc.OldStyleEmpty) - if six.PY2: - self.assertIn('Type: classobj\n', helpstring) - else: - self.assertIn('Type: type\n', helpstring) - self.assertIn('String form: ', helpstring) - self.assertIn('fire.test_components.OldStyleEmpty', helpstring) - self.assertIn('fire/test_components.py\n', helpstring) - self.assertIn('Line: ', helpstring) - - -if __name__ == '__main__': - unittest.main() diff --git a/fire/inspectutils.py b/fire/inspectutils.py index c5e6c226..d9c62ca7 100644 --- a/fire/inspectutils.py +++ b/fire/inspectutils.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,17 +14,40 @@ """Inspection utility functions for Python Fire.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - +import asyncio import inspect +import sys +import types + +from fire import docstrings + + +class FullArgSpec: + """The arguments of a function, as in Python 3's inspect.FullArgSpec.""" + + def __init__(self, args=None, varargs=None, varkw=None, defaults=None, + kwonlyargs=None, kwonlydefaults=None, annotations=None): + """Constructs a FullArgSpec with each provided attribute, or the default. -import IPython -import six + Args: + args: A list of the argument names accepted by the function. + varargs: The name of the *varargs argument or None if there isn't one. + varkw: The name of the **kwargs argument or None if there isn't one. + defaults: A tuple of the defaults for the arguments that accept defaults. + kwonlyargs: A list of argument names that must be passed with a keyword. + kwonlydefaults: A dictionary of keyword only arguments and their defaults. + annotations: A dictionary of arguments and their annotated types. + """ + self.args = args or [] + self.varargs = varargs + self.varkw = varkw + self.defaults = defaults or () + self.kwonlyargs = kwonlyargs or [] + self.kwonlydefaults = kwonlydefaults or {} + self.annotations = annotations or {} -def _GetArgSpecFnInfo(fn): +def _GetArgSpecInfo(fn): """Gives information pertaining to computing the ArgSpec of fn. Determines if the first arg is supplied automatically when fn is called. @@ -39,62 +62,177 @@ class with an __init__ method. fn: The function or class of interest. Returns: A tuple with the following two items: - fn: The function to use for determing the arg spec of this function. + fn: The function to use for determining the arg spec of this function. skip_arg: Whether the first argument will be supplied automatically, and hence should be skipped when supplying args from a Fire command. """ skip_arg = False if inspect.isclass(fn): - # If the function is a class, we try to use it's init method. + # If the function is a class, we try to use its init method. skip_arg = True - if six.PY2 and hasattr(fn, '__init__'): - fn = fn.__init__ - else: + elif inspect.ismethod(fn): # If the function is a bound method, we skip the `self` argument. - is_method = inspect.ismethod(fn) - skip_arg = is_method and fn.__self__ is not None - + skip_arg = fn.__self__ is not None + elif inspect.isbuiltin(fn): + # If the function is a bound builtin, we skip the `self` argument, unless + # the function is from a standard library module in which case its __self__ + # attribute is that module. + if not isinstance(fn.__self__, types.ModuleType): + skip_arg = True + elif not inspect.isfunction(fn): + # The purpose of this else clause is to set skip_arg for callable objects. + skip_arg = True return fn, skip_arg -def GetArgSpec(fn): - """Returns information about the function signature. +def Py3GetFullArgSpec(fn): + """A alternative to the builtin getfullargspec. + + The builtin inspect.getfullargspec uses: + `skip_bound_args=False, follow_wrapped_chains=False` + in order to be backwards compatible. + + This function instead skips bound args (self) and follows wrapped chains. Args: - fn: The function to analyze. + fn: The function or class of interest. Returns: - A named tuple of type inspect.ArgSpec with the following fields: - args: A list of the argument names accepted by the function. - varargs: The name of the *varargs argument or None if there isn't one. - keywords: The name of the **kwargs argument or None if there isn't one. - defaults: A tuple of the defaults for the arguments that accept defaults. + An inspect.FullArgSpec namedtuple with the full arg spec of the function. """ - fn, skip_arg = _GetArgSpecFnInfo(fn) + # pylint: disable=no-member + # pytype: disable=module-attr + try: + sig = inspect._signature_from_callable( # pylint: disable=protected-access + fn, + skip_bound_arg=True, + follow_wrapper_chains=True, + sigcls=inspect.Signature) + except Exception: + # 'signature' can raise ValueError (most common), AttributeError, and + # possibly others. We catch all exceptions here, and reraise a TypeError. + raise TypeError('Unsupported callable.') + + args = [] + varargs = None + varkw = None + kwonlyargs = [] + defaults = () + annotations = {} + defaults = () + kwdefaults = {} + + if sig.return_annotation is not sig.empty: + annotations['return'] = sig.return_annotation + + for param in sig.parameters.values(): + kind = param.kind + name = param.name + + # pylint: disable=protected-access + if kind is inspect._POSITIONAL_ONLY: + args.append(name) + elif kind is inspect._POSITIONAL_OR_KEYWORD: + args.append(name) + if param.default is not param.empty: + defaults += (param.default,) + elif kind is inspect._VAR_POSITIONAL: + varargs = name + elif kind is inspect._KEYWORD_ONLY: + kwonlyargs.append(name) + if param.default is not param.empty: + kwdefaults[name] = param.default + elif kind is inspect._VAR_KEYWORD: + varkw = name + if param.annotation is not param.empty: + annotations[name] = param.annotation + # pylint: enable=protected-access + + if not kwdefaults: + # compatibility with 'func.__kwdefaults__' + kwdefaults = None + + if not defaults: + # compatibility with 'func.__defaults__' + defaults = None + return inspect.FullArgSpec(args, varargs, varkw, defaults, + kwonlyargs, kwdefaults, annotations) + # pylint: enable=no-member + # pytype: enable=module-attr + + +def GetFullArgSpec(fn): + """Returns a FullArgSpec describing the given callable.""" + original_fn = fn + fn, skip_arg = _GetArgSpecInfo(fn) try: - argspec = inspect.getargspec(fn) - args = argspec.args - defaults = argspec.defaults or () - varargs = argspec.varargs - keywords = argspec.keywords + if sys.version_info[0:2] >= (3, 5): + (args, varargs, varkw, defaults, + kwonlyargs, kwonlydefaults, annotations) = Py3GetFullArgSpec(fn) + else: # Specifically Python 3.4. + (args, varargs, varkw, defaults, + kwonlyargs, kwonlydefaults, annotations) = inspect.getfullargspec(fn) # pylint: disable=deprecated-method,no-member + except TypeError: - args = [] - defaults = () # If we can't get the argspec, how do we know if the fn should take args? # 1. If it's a builtin, it can take args. - # 2. If it's an implicit __init__ function (a 'slot wrapper'), take no args. - # Are there other cases? - varargs = 'vars' if inspect.isbuiltin(fn) else None - keywords = 'kwargs' if inspect.isbuiltin(fn) else None + # 2. If it's an implicit __init__ function (a 'slot wrapper'), that comes + # from a namedtuple, use _fields to determine the args. + # 3. If it's another slot wrapper (that comes from not subclassing object in + # Python 2), then there are no args. + # Are there other cases? We just don't know. + + # Case 1: Builtins accept args. + if inspect.isbuiltin(fn): + # TODO(dbieber): Try parsing the docstring, if available. + # TODO(dbieber): Use known argspecs, like set.add and namedtuple.count. + return FullArgSpec(varargs='vars', varkw='kwargs') + + # Case 2: namedtuples store their args in their _fields attribute. + # TODO(dbieber): Determine if there's a way to detect false positives. + # In Python 2, a class that does not subclass anything, does not define + # __init__, and has an attribute named _fields will cause Fire to think it + # expects args for its constructor when in fact it does not. + fields = getattr(original_fn, '_fields', None) + if fields is not None: + return FullArgSpec(args=list(fields)) + + # Case 3: Other known slot wrappers do not accept args. + return FullArgSpec() + + # In Python 3.5+ Py3GetFullArgSpec uses skip_bound_arg=True already. + skip_arg_required = sys.version_info[0:2] == (3, 4) + if skip_arg_required and skip_arg and args: + args.pop(0) # Remove 'self' or 'cls' from the list of arguments. + return FullArgSpec(args, varargs, varkw, defaults, + kwonlyargs, kwonlydefaults, annotations) + + +def GetFileAndLine(component): + """Returns the filename and line number of component. + + Args: + component: A component to find the source information for, usually a class + or routine. + Returns: + filename: The name of the file where component is defined. + lineno: The line number where component is defined. + """ + if inspect.isbuiltin(component): + return None, None + + try: + filename = inspect.getsourcefile(component) + except TypeError: + return None, None - if skip_arg: - args = args[1:] # Remove self. + try: + unused_code, lineindex = inspect.findsource(component) + lineno = lineindex + 1 + except (OSError, IndexError): + lineno = None - return inspect.ArgSpec( - args=args, - varargs=varargs, - keywords=keywords, - defaults=defaults) + return filename, lineno def Info(component): @@ -116,13 +254,98 @@ def Info(component): Returns: A dict with information about the component. """ - inspector = IPython.core.oinspect.Inspector() - info = inspector.info(component) + try: + from IPython.core import oinspect # pylint: disable=import-outside-toplevel,g-import-not-at-top + try: + inspector = oinspect.Inspector(theme_name="neutral") + except TypeError: # Only recent versions of IPython support theme_name. + inspector = oinspect.Inspector() + info = inspector.info(component) + + # IPython's oinspect.Inspector.info may return '' + if info['docstring'] == '': + info['docstring'] = None + except ImportError: + info = _InfoBackup(component) try: unused_code, lineindex = inspect.findsource(component) info['line'] = lineindex + 1 - except (TypeError, IOError): + except (TypeError, OSError): info['line'] = None + if 'docstring' in info: + info['docstring_info'] = docstrings.parse(info['docstring']) + return info + + +def _InfoBackup(component): + """Returns a dict with information about the given component. + + This function is to be called only in the case that IPython's + oinspect module is not available. The info dict it produces may + contain less information that contained in the info dict produced + by oinspect. + + Args: + component: The component to analyze. + Returns: + A dict with information about the component. + """ + info = {} + + info['type_name'] = type(component).__name__ + info['string_form'] = str(component) + + filename, lineno = GetFileAndLine(component) + info['file'] = filename + info['line'] = lineno + info['docstring'] = inspect.getdoc(component) + + try: + info['length'] = str(len(component)) + except (TypeError, AttributeError): + pass + + return info + + +def IsNamedTuple(component): + """Return true if the component is a namedtuple. + + Unfortunately, Python offers no native way to check for a namedtuple type. + Instead, we need to use a simple hack which should suffice for our case. + namedtuples are internally implemented as tuples, therefore we need to: + 1. Check if the component is an instance of tuple. + 2. Check if the component has a _fields attribute which regular tuples do + not have. + + Args: + component: The component to analyze. + Returns: + True if the component is a namedtuple or False otherwise. + """ + if not isinstance(component, tuple): + return False + + has_fields = bool(getattr(component, '_fields', None)) + return has_fields + + +def GetClassAttrsDict(component): + """Gets the attributes of the component class, as a dict with name keys.""" + if not inspect.isclass(component): + return None + class_attrs_list = inspect.classify_class_attrs(component) + return { + class_attr.name: class_attr + for class_attr in class_attrs_list + } + + +def IsCoroutineFunction(fn): + try: + return asyncio.iscoroutinefunction(fn) + except: # pylint: disable=bare-except + return False diff --git a/fire/inspectutils_test.py b/fire/inspectutils_test.py index 5cf61900..47de7e72 100644 --- a/fire/inspectutils_test.py +++ b/fire/inspectutils_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,65 +12,95 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +"""Tests for the inspectutils module.""" + +import os from fire import inspectutils from fire import test_components as tc -import six - -import unittest - - -class InspectUtilsTest(unittest.TestCase): - - def testGetArgSpecReturnType(self): - # Asserts that the named tuple returned by GetArgSpec has the appropriate - # fields. - argspec = inspectutils.GetArgSpec(tc.identity) - args, varargs, keywords, defaults = argspec - self.assertEqual(argspec.args, args) - self.assertEqual(argspec.defaults, defaults) - self.assertEqual(argspec.varargs, varargs) - self.assertEqual(argspec.keywords, keywords) - - def testGetArgSpec(self): - args, varargs, keywords, defaults = inspectutils.GetArgSpec(tc.identity) - self.assertEqual(args, ['arg1', 'arg2']) - self.assertEqual(defaults, (10,)) - self.assertEqual(varargs, 'arg3') - self.assertEqual(keywords, 'arg4') - - def testGetArgSpecBuiltin(self): - args, varargs, keywords, defaults = inspectutils.GetArgSpec('test'.upper) - self.assertEqual(args, []) - self.assertEqual(defaults, ()) - self.assertEqual(varargs, 'vars') - self.assertEqual(keywords, 'kwargs') - - def testGetArgSpecSlotWrapper(self): - args, varargs, keywords, defaults = inspectutils.GetArgSpec(tc.NoDefaults) - self.assertEqual(args, []) - self.assertEqual(defaults, ()) - self.assertEqual(varargs, None) - self.assertEqual(keywords, None) - - def testGetArgSpecClassNoInit(self): - args, varargs, keywords, defaults = inspectutils.GetArgSpec( - tc.OldStyleEmpty) - self.assertEqual(args, []) - self.assertEqual(defaults, ()) - self.assertEqual(varargs, None) - self.assertEqual(keywords, None) - - def testGetArgSpecMethod(self): - args, varargs, keywords, defaults = inspectutils.GetArgSpec( - tc.NoDefaults().double) - self.assertEqual(args, ['count']) - self.assertEqual(defaults, ()) - self.assertEqual(varargs, None) - self.assertEqual(keywords, None) +from fire import testutils + + +class InspectUtilsTest(testutils.BaseTestCase): + + def testGetFullArgSpec(self): + spec = inspectutils.GetFullArgSpec(tc.identity) + self.assertEqual(spec.args, ['arg1', 'arg2', 'arg3', 'arg4']) + self.assertEqual(spec.defaults, (10, 20)) + self.assertEqual(spec.varargs, 'arg5') + self.assertEqual(spec.varkw, 'arg6') + self.assertEqual(spec.kwonlyargs, []) + self.assertEqual(spec.kwonlydefaults, {}) + self.assertEqual(spec.annotations, {'arg2': int, 'arg4': int}) + + def testGetFullArgSpecPy3(self): + spec = inspectutils.GetFullArgSpec(tc.py3.identity) + self.assertEqual(spec.args, ['arg1', 'arg2', 'arg3', 'arg4']) + self.assertEqual(spec.defaults, (10, 20)) + self.assertEqual(spec.varargs, 'arg5') + self.assertEqual(spec.varkw, 'arg10') + self.assertEqual(spec.kwonlyargs, ['arg6', 'arg7', 'arg8', 'arg9']) + self.assertEqual(spec.kwonlydefaults, {'arg8': 30, 'arg9': 40}) + self.assertEqual(spec.annotations, + {'arg2': int, 'arg4': int, 'arg7': int, 'arg9': int}) + + def testGetFullArgSpecFromBuiltin(self): + spec = inspectutils.GetFullArgSpec('test'.upper) + self.assertEqual(spec.args, []) + self.assertEqual(spec.defaults, ()) + self.assertEqual(spec.kwonlyargs, []) + self.assertEqual(spec.kwonlydefaults, {}) + self.assertEqual(spec.annotations, {}) + + def testGetFullArgSpecFromSlotWrapper(self): + spec = inspectutils.GetFullArgSpec(tc.NoDefaults) + self.assertEqual(spec.args, []) + self.assertEqual(spec.defaults, ()) + self.assertEqual(spec.varargs, None) + self.assertEqual(spec.varkw, None) + self.assertEqual(spec.kwonlyargs, []) + self.assertEqual(spec.kwonlydefaults, {}) + self.assertEqual(spec.annotations, {}) + + def testGetFullArgSpecFromNamedTuple(self): + spec = inspectutils.GetFullArgSpec(tc.NamedTuplePoint) + self.assertEqual(spec.args, ['x', 'y']) + self.assertEqual(spec.defaults, ()) + self.assertEqual(spec.varargs, None) + self.assertEqual(spec.varkw, None) + self.assertEqual(spec.kwonlyargs, []) + self.assertEqual(spec.kwonlydefaults, {}) + self.assertEqual(spec.annotations, {}) + + def testGetFullArgSpecFromNamedTupleSubclass(self): + spec = inspectutils.GetFullArgSpec(tc.SubPoint) + self.assertEqual(spec.args, ['x', 'y']) + self.assertEqual(spec.defaults, ()) + self.assertEqual(spec.varargs, None) + self.assertEqual(spec.varkw, None) + self.assertEqual(spec.kwonlyargs, []) + self.assertEqual(spec.kwonlydefaults, {}) + self.assertEqual(spec.annotations, {}) + + def testGetFullArgSpecFromClassNoInit(self): + spec = inspectutils.GetFullArgSpec(tc.OldStyleEmpty) + self.assertEqual(spec.args, []) + self.assertEqual(spec.defaults, ()) + self.assertEqual(spec.varargs, None) + self.assertEqual(spec.varkw, None) + self.assertEqual(spec.kwonlyargs, []) + self.assertEqual(spec.kwonlydefaults, {}) + self.assertEqual(spec.annotations, {}) + + def testGetFullArgSpecFromMethod(self): + spec = inspectutils.GetFullArgSpec(tc.NoDefaults().double) + self.assertEqual(spec.args, ['count']) + self.assertEqual(spec.defaults, ()) + self.assertEqual(spec.varargs, None) + self.assertEqual(spec.varkw, None) + self.assertEqual(spec.kwonlyargs, []) + self.assertEqual(spec.kwonlydefaults, {}) + self.assertEqual(spec.annotations, {}) def testInfoOne(self): info = inspectutils.Info(1) @@ -82,18 +112,19 @@ def testInfoOne(self): def testInfoClass(self): info = inspectutils.Info(tc.NoDefaults) self.assertEqual(info.get('type_name'), 'type') - self.assertIn('fire/test_components.py', info.get('file')) + self.assertIn(os.path.join('fire', 'test_components.py'), info.get('file')) self.assertGreater(info.get('line'), 0) def testInfoClassNoInit(self): info = inspectutils.Info(tc.OldStyleEmpty) - if six.PY2: - self.assertEqual(info.get('type_name'), 'classobj') - else: - self.assertEqual(info.get('type_name'), 'type') - self.assertIn('fire/test_components.py', info.get('file')) + self.assertEqual(info.get('type_name'), 'type') + self.assertIn(os.path.join('fire', 'test_components.py'), info.get('file')) self.assertGreater(info.get('line'), 0) + def testInfoNoDocstring(self): + info = inspectutils.Info(tc.NoDefaults) + self.assertEqual(info['docstring'], None, 'Docstring should be None') + if __name__ == '__main__': - unittest.main() + testutils.main() diff --git a/fire/interact.py b/fire/interact.py index 4f470dca..eccd3990 100644 --- a/fire/interact.py +++ b/fire/interact.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,17 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This module enables interactive mode in Python Fire.""" +"""This module enables interactive mode in Python Fire. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +It uses IPython as an optional dependency. When IPython is installed, the +interactive flag will use IPython's REPL. When IPython is not installed, the +interactive flag will start a Python REPL with the builtin `code` module's +InteractiveConsole class. +""" -import code import inspect -import IPython - def Embed(variables, verbose=False): """Drops into a Python REPL with variables available as local variables. @@ -33,7 +32,11 @@ def Embed(variables, verbose=False): verbose: Whether to include 'hidden' members, those keys starting with _. """ print(_AvailableString(variables, verbose)) - _EmbedIPython(variables) + + try: + _EmbedIPython(variables) + except ImportError: + _EmbedCode(variables) def _AvailableString(variables, verbose=False): @@ -62,16 +65,17 @@ def _AvailableString(variables, verbose=False): lists = [ ('Modules', modules), ('Objects', other)] - liststrs = [] + list_strs = [] for name, varlist in lists: if varlist: - liststrs.append( - '{name}: {items}'.format(name=name, items=', '.join(sorted(varlist)))) + items_str = ', '.join(sorted(varlist)) + list_strs.append(f'{name}: {items_str}') + lists_str = '\n'.join(list_strs) return ( 'Fire is starting a Python REPL with the following objects:\n' - '{liststrs}\n' - ).format(liststrs='\n'.join(liststrs)) + f'{lists_str}\n' + ) def _EmbedIPython(variables, argv=None): @@ -82,9 +86,11 @@ def _EmbedIPython(variables, argv=None): Values are variable values. argv: The argv to use for starting ipython. Defaults to an empty list. """ + import IPython # pylint: disable=import-outside-toplevel,g-import-not-at-top argv = argv or [] IPython.start_ipython(argv=argv, user_ns=variables) def _EmbedCode(variables): + import code # pylint: disable=import-outside-toplevel,g-import-not-at-top code.InteractiveConsole(variables).interact() diff --git a/fire/interact_test.py b/fire/interact_test.py index 9a2871a8..2f286824 100644 --- a/fire/interact_test.py +++ b/fire/interact_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,32 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +"""Tests for the interact module.""" + +from unittest import mock from fire import interact -import mock +from fire import testutils + -import unittest +try: + import IPython # pylint: disable=unused-import, g-import-not-at-top + INTERACT_METHOD = 'IPython.start_ipython' +except ImportError: + INTERACT_METHOD = 'code.InteractiveConsole' -class InteractTest(unittest.TestCase): +class InteractTest(testutils.BaseTestCase): - @mock.patch('IPython.start_ipython') - def testInteract(self, mock_ipython): - self.assertFalse(mock_ipython.called) + @mock.patch(INTERACT_METHOD) + def testInteract(self, mock_interact_method): + self.assertFalse(mock_interact_method.called) interact.Embed({}) - self.assertTrue(mock_ipython.called) + self.assertTrue(mock_interact_method.called) - @mock.patch('IPython.start_ipython') - def testInteractVariables(self, mock_ipython): - self.assertFalse(mock_ipython.called) + @mock.patch(INTERACT_METHOD) + def testInteractVariables(self, mock_interact_method): + self.assertFalse(mock_interact_method.called) interact.Embed({ 'count': 10, 'mock': mock, }) - self.assertTrue(mock_ipython.called) + self.assertTrue(mock_interact_method.called) if __name__ == '__main__': - unittest.main() + testutils.main() diff --git a/fire/main_test.py b/fire/main_test.py new file mode 100644 index 00000000..a2723347 --- /dev/null +++ b/fire/main_test.py @@ -0,0 +1,94 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test using Fire via `python -m fire`.""" + +import os +import tempfile + +from fire import __main__ +from fire import testutils + + +class MainModuleTest(testutils.BaseTestCase): + """Tests to verify the behavior of __main__ (python -m fire).""" + + def testNameSetting(self): + # Confirm one of the usage lines has the gettempdir member. + with self.assertOutputMatches('gettempdir'): + __main__.main(['__main__.py', 'tempfile']) + + def testArgPassing(self): + expected = os.path.join('part1', 'part2', 'part3') + with self.assertOutputMatches('%s\n' % expected): + __main__.main( + ['__main__.py', 'os.path', 'join', 'part1', 'part2', 'part3']) + with self.assertOutputMatches('%s\n' % expected): + __main__.main( + ['__main__.py', 'os', 'path', '-', 'join', 'part1', 'part2', 'part3']) + + +class MainModuleFileTest(testutils.BaseTestCase): + """Tests to verify correct import behavior for file executables.""" + + def setUp(self): + super().setUp() + self.file = tempfile.NamedTemporaryFile(suffix='.py') # pylint: disable=consider-using-with + self.file.write(b'class Foo:\n def double(self, n):\n return 2 * n\n') + self.file.flush() + + self.file2 = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with + + def testFileNameFire(self): + # Confirm that the file is correctly imported and doubles the number. + with self.assertOutputMatches('4'): + __main__.main( + ['__main__.py', self.file.name, 'Foo', 'double', '--n', '2']) + + def testFileNameFailure(self): + # Confirm that an existing file without a .py suffix raises a ValueError. + with self.assertRaises(ValueError): + __main__.main( + ['__main__.py', self.file2.name, 'Foo', 'double', '--n', '2']) + + def testFileNameModuleDuplication(self): + # Confirm that a file that masks a module still loads the module. + with self.assertOutputMatches('gettempdir'): + dirname = os.path.dirname(self.file.name) + with testutils.ChangeDirectory(dirname): + with open('tempfile', 'w'): + __main__.main([ + '__main__.py', + 'tempfile', + ]) + + os.remove('tempfile') + + def testFileNameModuleFileFailure(self): + # Confirm that an invalid file that masks a non-existent module fails. + with self.assertRaisesRegex(ValueError, + r'Fire can only be called on \.py files\.'): # pylint: disable=line-too-long, # pytype: disable=attribute-error + dirname = os.path.dirname(self.file.name) + with testutils.ChangeDirectory(dirname): + with open('foobar', 'w'): + __main__.main([ + '__main__.py', + 'foobar', + ]) + + os.remove('foobar') + + +if __name__ == '__main__': + testutils.main() diff --git a/fire/parser.py b/fire/parser.py index a0cef6d1..d945b8ce 100644 --- a/fire/parser.py +++ b/fire/parser.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,12 +14,14 @@ """Provides parsing functionality used by Python Fire.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import argparse import ast +import sys + +if sys.version_info[0:2] < (3, 8): + _StrNode = ast.Str +else: + _StrNode = ast.Constant def CreateParser(): @@ -27,10 +29,10 @@ def CreateParser(): parser.add_argument('--verbose', '-v', action='store_true') parser.add_argument('--interactive', '-i', action='store_true') parser.add_argument('--separator', default='-') - parser.add_argument('--completion', action='store_true') + parser.add_argument('--completion', nargs='?', const='bash', type=str) parser.add_argument('--help', '-h', action='store_true') parser.add_argument('--trace', '-t', action='store_true') - # TODO: Consider allowing name to be passed as an argument. + # TODO(dbieber): Consider allowing name to be passed as an argument. return parser @@ -94,6 +96,9 @@ def _LiteralEval(value): SyntaxError: If the value string has a syntax error. """ root = ast.parse(value, mode='eval') + if isinstance(root.body, ast.BinOp): # pytype: disable=attribute-error + raise ValueError(value) + for node in ast.walk(root): for field, child in ast.iter_fields(node): if isinstance(child, list): @@ -103,7 +108,7 @@ def _LiteralEval(value): elif isinstance(child, ast.Name): replacement = _Replacement(child) - node.__setattr__(field, replacement) + setattr(node, field, replacement) # ast.literal_eval supports the following types: # strings, bytes, numbers, tuples, lists, dicts, sets, booleans, and None @@ -124,4 +129,4 @@ def _Replacement(node): # These are the only builtin constants supported by literal_eval. if value in ('True', 'False', 'None'): return node - return ast.Str(value) + return _StrNode(value) diff --git a/fire/parser_fuzz_test.py b/fire/parser_fuzz_test.py index de1d4f96..10f497cf 100644 --- a/fire/parser_fuzz_test.py +++ b/fire/parser_fuzz_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,24 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +"""Fuzz tests for the parser module.""" from fire import parser +from fire import testutils from hypothesis import example from hypothesis import given from hypothesis import settings from hypothesis import strategies as st import Levenshtein -import six -import unittest +class ParserFuzzTest(testutils.BaseTestCase): -class ParserFuzzTest(unittest.TestCase): - - @given(st.text(min_size=1), settings=settings.Settings(max_examples=10000)) + @settings(max_examples=10000) + @given(st.text(min_size=1)) @example('True') @example(r'"test\t\t\a\\a"') @example(r' "test\t\t\a\\a" ') @@ -56,7 +53,7 @@ def testDefaultParseValueFuzz(self, value): result = parser.DefaultParseValue(value) except TypeError: # It's OK to get a TypeError if the string has the null character. - if u'\x00' in value: + if '\x00' in value: return raise except MemoryError: @@ -66,8 +63,8 @@ def testDefaultParseValueFuzz(self, value): raise try: - uvalue = unicode(value) - uresult = unicode(result) + uvalue = str(value) + uresult = str(result) except UnicodeDecodeError: # This is not what we're testing. return @@ -84,7 +81,7 @@ def testDefaultParseValueFuzz(self, value): if '#' in value: max_distance += len(value) - value.index('#') - if not isinstance(result, six.string_types): + if not isinstance(result, str): max_distance += value.count('0') # Leading 0s are stripped. # Note: We don't check distance for dicts since item order can be changed. @@ -94,4 +91,4 @@ def testDefaultParseValueFuzz(self, value): if __name__ == '__main__': - unittest.main() + testutils.main() diff --git a/fire/parser_test.py b/fire/parser_test.py index dce4e7cc..a404eea2 100644 --- a/fire/parser_test.py +++ b/fire/parser_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,16 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +"""Tests for the parser module.""" from fire import parser +from fire import testutils -import unittest - -class ParserTest(unittest.TestCase): +class ParserTest(testutils.BaseTestCase): def testCreateParser(self): self.assertIsNotNone(parser.CreateParser()) @@ -68,10 +65,12 @@ def testDefaultParseValueSpecialStrings(self): def testDefaultParseValueNumbers(self): self.assertEqual(parser.DefaultParseValue('23'), 23) + self.assertEqual(parser.DefaultParseValue('-23'), -23) self.assertEqual(parser.DefaultParseValue('23.0'), 23.0) self.assertIsInstance(parser.DefaultParseValue('23'), int) self.assertIsInstance(parser.DefaultParseValue('23.0'), float) self.assertEqual(parser.DefaultParseValue('23.5'), 23.5) + self.assertEqual(parser.DefaultParseValue('-23.5'), -23.5) def testDefaultParseValueStringNumbers(self): self.assertEqual(parser.DefaultParseValue("'23'"), '23') @@ -114,8 +113,9 @@ def testDefaultParseValueBareWordsTuple(self): def testDefaultParseValueNestedContainers(self): self.assertEqual( - parser.DefaultParseValue('[(A, 2, "3"), 5, {alph: 10.2, beta: "cat"}]'), - [('A', 2, '3'), 5, {'alph': 10.2, 'beta': 'cat'}]) + parser.DefaultParseValue( + '[(A, 2, "3"), 5, {alpha: 10.2, beta: "cat"}]'), + [('A', 2, '3'), 5, {'alpha': 10.2, 'beta': 'cat'}]) def testDefaultParseValueComments(self): self.assertEqual(parser.DefaultParseValue('"0#comments"'), '0#comments') @@ -126,13 +126,15 @@ def testDefaultParseValueBadLiteral(self): # If it can't be parsed, we treat it as a string. This behavior may change. self.assertEqual( parser.DefaultParseValue('[(A, 2, "3"), 5'), '[(A, 2, "3"), 5') - self.assertEqual(parser.DefaultParseValue('x=10'), 'x=10') def testDefaultParseValueSyntaxError(self): # If it can't be parsed, we treat it as a string. self.assertEqual(parser.DefaultParseValue('"'), '"') + def testDefaultParseValueIgnoreBinOp(self): + self.assertEqual(parser.DefaultParseValue('2017-10-10'), '2017-10-10') + self.assertEqual(parser.DefaultParseValue('1+1'), '1+1') if __name__ == '__main__': - unittest.main() + testutils.main() diff --git a/fire/test_components.py b/fire/test_components.py index 0218bab7..887a0dc6 100644 --- a/fire/test_components.py +++ b/fire/test_components.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,32 +12,67 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Thie module has componenets that are used for testing Python Fire.""" +"""This module has components that are used for testing Python Fire.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +import collections +import enum +import functools +from fire import test_components_py3 as py3 # pylint: disable=unused-import,no-name-in-module,g-import-not-at-top -def identity(arg1, arg2=10, *arg3, **arg4): - return arg1, arg2, arg3, arg4 +def identity(arg1, arg2, arg3=10, arg4=20, *arg5, **arg6): # pylint: disable=keyword-arg-before-vararg + return arg1, arg2, arg3, arg4, arg5, arg6 -class Empty(object): +identity.__annotations__ = {'arg2': int, 'arg4': int} + + +def multiplier_with_docstring(num, rate=2): + """Multiplies num by rate. + + Args: + num (int): the num you want to multiply + rate (int): the rate for multiplication + Returns: + Multiplication of num by rate + """ + return num * rate + + +def function_with_help(help=True): # pylint: disable=redefined-builtin + return help + + +class Empty: pass -class OldStyleEmpty: # pylint: disable=old-style-class +class OldStyleEmpty: # pylint: disable=old-style-class,no-init pass -class WithInit(object): +class WithInit: def __init__(self): pass -class NoDefaults(object): +class ErrorInConstructor: + + def __init__(self, value='value'): + self.value = value + raise ValueError('Error in constructor') + + +class WithHelpArg: + """Test class for testing when class has a help= arg.""" + + def __init__(self, help=True): # pylint: disable=redefined-builtin + self.has_help = help + self.dictionary = {'__help': 'help in a dict'} + + +class NoDefaults: def double(self, count): return 2 * count @@ -46,16 +81,32 @@ def triple(self, count): return 3 * count -class WithDefaults(object): +class WithDefaults: + """Class with functions that have default arguments.""" def double(self, count=0): + """Returns the input multiplied by 2. + + Args: + count: Input number that you want to double. + + Returns: + A number that is the double of count. + """ return 2 * count def triple(self, count=0): return 3 * count + def text( + self, + string=('0001020304050607080910111213141516171819' + '2021222324252627282930313233343536373839') + ): + return string + -class OldStyleWithDefaults: # pylint: disable=old-style-class +class OldStyleWithDefaults: # pylint: disable=old-style-class,no-init def double(self, count=0): return 2 * count @@ -64,7 +115,7 @@ def triple(self, count=0): return 3 * count -class MixedDefaults(object): +class MixedDefaults: def ten(self): return 10 @@ -76,7 +127,35 @@ def identity(self, alpha, beta='0'): return alpha, beta -class TypedProperties(object): +class SimilarArgNames: + + def identity(self, bool_one=False, bool_two=False): + return bool_one, bool_two + + def identity2(self, a=None, alpha=None): + return a, alpha + + +class CapitalizedArgNames: + + def sum(self, Delta=1.0, Gamma=2.0): # pylint: disable=invalid-name + return Delta + Gamma + + +class Annotations: + + def double(self, count=0): + return 2 * count + + def triple(self, count=0): + return 3 * count + + double.__annotations__ = {'count': float} + triple.__annotations__ = {'count': float} + + +class TypedProperties: + """Test class for testing Python Fire with properties of various types.""" def __init__(self): self.alpha = True @@ -91,10 +170,11 @@ def __init__(self): } self.echo = ['alex', 'bethany'] self.fox = ('carry', 'divide') + self.gamma = 'myexcitingstring' -class VarArgs(object): - """Test class G for testing Python Fire.""" +class VarArgs: + """Test class for testing Python Fire with a property with varargs.""" def cumsums(self, *items): total = None @@ -107,11 +187,11 @@ def cumsums(self, *items): sums.append(total) return sums - def varchars(self, alpha=0, beta=0, *chars): + def varchars(self, alpha=0, beta=0, *chars): # pylint: disable=keyword-arg-before-vararg return alpha, beta, ''.join(chars) -class Underscores(object): +class Underscores: def __init__(self): self.underscore_example = 'fish fingers' @@ -120,20 +200,20 @@ def underscore_function(self, underscore_arg): return underscore_arg -class BoolConverter(object): +class BoolConverter: def as_bool(self, arg=False): - return arg + return bool(arg) -class ReturnsObj(object): +class ReturnsObj: def get_obj(self, *items): del items # Unused return BoolConverter() -class NumberDefaults(object): +class NumberDefaults: def reciprocal(self, divisor=10.0): return 1.0 / divisor @@ -142,7 +222,7 @@ def integer_reciprocal(self, divisor=10): return 1.0 / divisor -class InstanceVars(object): +class InstanceVars: def __init__(self, arg1, arg2): self.arg1 = arg1 @@ -152,7 +232,7 @@ def run(self, arg1, arg2): return (self.arg1, self.arg2, arg1, arg2) -class Kwargs(object): +class Kwargs: def props(self, **kwargs): return kwargs @@ -164,7 +244,325 @@ def run(self, positional, named=None, **kwargs): return (positional, named, kwargs) -class ErrorRaiser(object): +class ErrorRaiser: def fail(self): raise ValueError('This error is part of a test.') + + +class NonComparable: + + def __eq__(self, other): + raise ValueError('Instances of this class cannot be compared.') + + def __ne__(self, other): + raise ValueError('Instances of this class cannot be compared.') + + +class EmptyDictOutput: + + def totally_empty(self): + return {} + + def nothing_printable(self): + return {'__do_not_print_me': 1} + + +class CircularReference: + + def create(self): + x = {} + x['y'] = x + return x + + +class OrderedDictionary: + + def empty(self): + return collections.OrderedDict() + + def non_empty(self): + ordered_dict = collections.OrderedDict() + ordered_dict['A'] = 'A' + ordered_dict[2] = 2 + return ordered_dict + + +class NamedTuple: + """Functions returning named tuples used for testing.""" + + def point(self): + """Point example straight from Python docs.""" + # pylint: disable=invalid-name + Point = collections.namedtuple('Point', ['x', 'y']) + return Point(11, y=22) + + def matching_names(self): + """Field name equals value.""" + # pylint: disable=invalid-name + Point = collections.namedtuple('Point', ['x', 'y']) + return Point(x='x', y='y') + + +class CallableWithPositionalArgs: + """Test class for supporting callable.""" + + TEST = 1 + + def __call__(self, x, y): + return x + y + + def fn(self, x): + return x + 1 + + +NamedTuplePoint = collections.namedtuple('NamedTuplePoint', ['x', 'y']) + + +class SubPoint(NamedTuplePoint): + """Used for verifying subclasses of namedtuples behave as intended.""" + + def coordinate_sum(self): + return self.x + self.y + + +class CallableWithKeywordArgument: + """Test class for supporting callable.""" + + def __call__(self, **kwargs): + for key, value in kwargs.items(): + print('{}: {}'.format(key, value)) + + def print_msg(self, msg): + print(msg) + + +CALLABLE_WITH_KEYWORD_ARGUMENT = CallableWithKeywordArgument() + + +class ClassWithDocstring: + """Test class for testing help text output. + + This is some detail description of this test class. + """ + + def __init__(self, message='Hello!'): + """Constructor of the test class. + + Constructs a new ClassWithDocstring object. + + Args: + message: The default message to print. + """ + self.message = message + + def print_msg(self, msg=None): + """Prints a message.""" + if msg is None: + msg = self.message + print(msg) + + +class ClassWithMultilineDocstring: + """Test class for testing help text output with multiline docstring. + + This is a test class that has a long docstring description that spans across + multiple lines for testing line breaking in help text. + """ + + @staticmethod + def example_generator(n): + """Generators have a ``Yields`` section instead of a ``Returns`` section. + + Args: + n (int): The upper limit of the range to generate, from 0 to `n` - 1. + + Yields: + int: The next number in the range of 0 to `n` - 1. + + Examples: + Examples should be written in doctest format, and should illustrate how + to use the function. + + >>> print([i for i in example_generator(4)]) + [0, 1, 2, 3] + + """ + yield from range(n) + + +def simple_set(): + return {1, 2, 'three'} + + +def simple_frozenset(): + return frozenset({1, 2, 'three'}) + + +class Subdict(dict): + """A subclass of dict, for testing purposes.""" + + +# An example subdict. +SUBDICT = Subdict({1: 2, 'red': 'blue'}) + + +class Color(enum.Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + +class HasStaticAndClassMethods: + """A class with a static method and a class method.""" + + CLASS_STATE = 1 + + def __init__(self, instance_state): + self.instance_state = instance_state + + @staticmethod + def static_fn(args): + return args + + @classmethod + def class_fn(cls, args): + return args + cls.CLASS_STATE + + +def function_with_varargs(arg1, arg2, arg3=1, *varargs): # pylint: disable=keyword-arg-before-vararg + """Function with varargs. + + Args: + arg1: Position arg docstring. + arg2: Position arg docstring. + arg3: Flags docstring. + *varargs: Accepts unlimited positional args. + Returns: + The unlimited positional args. + """ + del arg1, arg2, arg3 # Unused. + return varargs + + +def function_with_keyword_arguments(arg1, arg2=3, **kwargs): + del arg2 # Unused. + return arg1, kwargs + + +def fn_with_code_in_docstring(): + """This has code in the docstring. + + + + Example: + x = fn_with_code_in_docstring() + indentation_matters = True + + + + Returns: + True. + """ + return True + + +class BinaryCanvas: + """A canvas with which to make binary art, one bit at a time.""" + + def __init__(self, size=10): + self.pixels = [[0] * size for _ in range(size)] + self._size = size + self._row = 0 # The row of the cursor. + self._col = 0 # The column of the cursor. + + def __str__(self): + return '\n'.join( + ' '.join(str(pixel) for pixel in row) for row in self.pixels) + + def show(self): + print(self) + return self + + def move(self, row, col): + self._row = row % self._size + self._col = col % self._size + return self + + def on(self): + return self.set(1) + + def off(self): + return self.set(0) + + def set(self, value): + self.pixels[self._row][self._col] = value + return self + + +class DefaultMethod: + + def double(self, number): + return 2 * number + + def __getattr__(self, name): + def _missing(): + return 'Undefined function' + return _missing + + +class InvalidProperty: + + def double(self, number): + return 2 * number + + @property + def prop(self): + raise ValueError('test') + + +def simple_decorator(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + return f(*args, **kwargs) + return wrapper + + +@simple_decorator +def decorated_method(name='World'): + return 'Hello %s' % name + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def fn_with_kwarg(arg1, arg2, **kwargs): + """Function with kwarg. + + :param arg1: Description of arg1. + :param arg2: Description of arg2. + :key arg3: Description of arg3. + """ + del arg1, arg2 + return kwargs.get('arg3') + + +def fn_with_kwarg_and_defaults(arg1, arg2, opt=True, **kwargs): + """Function with kwarg and defaults. + + :param arg1: Description of arg1. + :param arg2: Description of arg2. + :key arg3: Description of arg3. + """ + del arg1, arg2, opt + return kwargs.get('arg3') + + +def fn_with_multiple_defaults(first='first', last='last', late='late'): + """Function with kwarg and defaults. + + :key first: Description of first. + :key last: Description of last. + :key late: Description of late. + """ + del last, late + return first +# pylint: enable=g-doc-args,g-doc-return-or-yield diff --git a/fire/test_components_bin.py b/fire/test_components_bin.py new file mode 100644 index 00000000..62afdf11 --- /dev/null +++ b/fire/test_components_bin.py @@ -0,0 +1,28 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Python Fire test components Fire CLI. + +This file is useful for replicating test results manually. +""" + +import fire +from fire import test_components + + +def main(): + fire.Fire(test_components) + +if __name__ == '__main__': + main() diff --git a/fire/test_components_py3.py b/fire/test_components_py3.py new file mode 100644 index 00000000..192302d3 --- /dev/null +++ b/fire/test_components_py3.py @@ -0,0 +1,101 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module has components that use Python 3 specific syntax.""" + +import asyncio +import functools +from typing import Tuple + + +# pylint: disable=keyword-arg-before-vararg +def identity(arg1, arg2: int, arg3=10, arg4: int = 20, *arg5, + arg6, arg7: int, arg8=30, arg9: int = 40, **arg10): + return arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10 + + +class HelpTextComponent: + + def identity(self, *, alpha, beta='0'): + return alpha, beta + + +class KeywordOnly: + + def double(self, *, count): + return count * 2 + + def triple(self, *, count): + return count * 3 + + def with_default(self, *, x="x"): + print("x: " + x) + + +class LruCacheDecoratedMethod: + + @functools.lru_cache() + def lru_cache_in_class(self, arg1): + return arg1 + + +@functools.lru_cache() +def lru_cache_decorated(arg1): + return arg1 + + +class WithAsyncio: + + async def double(self, count=0): + return 2 * count + + +class WithTypes: + """Class with functions that have default arguments and types.""" + + def double(self, count: float) -> float: + """Returns the input multiplied by 2. + + Args: + count: Input number that you want to double. + + Returns: + A number that is the double of count. + """ + return 2 * count + + def long_type( + self, + long_obj: (Tuple[Tuple[Tuple[Tuple[Tuple[Tuple[Tuple[ + Tuple[Tuple[Tuple[Tuple[Tuple[int]]]]]]]]]]]]) + ): + return long_obj + + +class WithDefaultsAndTypes: + """Class with functions that have default arguments and types.""" + + def double(self, count: float = 0) -> float: + """Returns the input multiplied by 2. + + Args: + count: Input number that you want to double. + + Returns: + A number that is the double of count. + """ + return 2 * count + + def get_int(self, value: int = None): + return 0 if value is None else value diff --git a/fire/test_components_test.py b/fire/test_components_test.py index 3b6edefe..531f882c 100644 --- a/fire/test_components_test.py +++ b/fire/test_components_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,22 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +"""Tests for the test_components module.""" from fire import test_components as tc +from fire import testutils -import unittest - -class TestComponentsTest(unittest.TestCase): +class TestComponentsTest(testutils.BaseTestCase): """Tests to verify that the test components are importable and okay.""" def testTestComponents(self): self.assertIsNotNone(tc.Empty) self.assertIsNotNone(tc.OldStyleEmpty) + def testNonComparable(self): + with self.assertRaises(ValueError): + tc.NonComparable() != 2 # pylint: disable=expression-not-assigned + with self.assertRaises(ValueError): + tc.NonComparable() == 2 # pylint: disable=expression-not-assigned + if __name__ == '__main__': - unittest.main() + testutils.main() diff --git a/fire/testutils.py b/fire/testutils.py new file mode 100644 index 00000000..eca37f43 --- /dev/null +++ b/fire/testutils.py @@ -0,0 +1,112 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for Python Fire's tests.""" + +import contextlib +import io +import os +import re +import sys +import unittest +from unittest import mock + +from fire import core +from fire import trace + + +class BaseTestCase(unittest.TestCase): + """Shared test case for Python Fire tests.""" + + @contextlib.contextmanager + def assertOutputMatches(self, stdout='.*', stderr='.*', capture=True): + """Asserts that the context generates stdout and stderr matching regexps. + + Note: If wrapped code raises an exception, stdout and stderr will not be + checked. + + Args: + stdout: (str) regexp to match against stdout (None will check no stdout) + stderr: (str) regexp to match against stderr (None will check no stderr) + capture: (bool, default True) do not bubble up stdout or stderr + + Yields: + Yields to the wrapped context. + """ + stdout_fp = io.StringIO() + stderr_fp = io.StringIO() + try: + with mock.patch.object(sys, 'stdout', stdout_fp): + with mock.patch.object(sys, 'stderr', stderr_fp): + yield + finally: + if not capture: + sys.stdout.write(stdout_fp.getvalue()) + sys.stderr.write(stderr_fp.getvalue()) + + for name, regexp, fp in [('stdout', stdout, stdout_fp), + ('stderr', stderr, stderr_fp)]: + value = fp.getvalue() + if regexp is None: + if value: + raise AssertionError('%s: Expected no output. Got: %r' % + (name, value)) + else: + if not re.search(regexp, value, re.DOTALL | re.MULTILINE): + raise AssertionError('%s: Expected %r to match %r' % + (name, value, regexp)) + + @contextlib.contextmanager + def assertRaisesFireExit(self, code, regexp='.*'): + """Asserts that a FireExit error is raised in the context. + + Allows tests to check that Fire's wrapper around SystemExit is raised + and that a regexp is matched in the output. + + Args: + code: The status code that the FireExit should contain. + regexp: stdout must match this regex. + + Yields: + Yields to the wrapped context. + """ + with self.assertOutputMatches(stderr=regexp): + with self.assertRaises(core.FireExit): + try: + yield + except core.FireExit as exc: + if exc.code != code: + raise AssertionError('Incorrect exit code: %r != %r' % + (exc.code, code)) + self.assertIsInstance(exc.trace, trace.FireTrace) + raise + + +@contextlib.contextmanager +def ChangeDirectory(directory): + """Context manager to mock a directory change and revert on exit.""" + cwdir = os.getcwd() + os.chdir(directory) + + try: + yield directory + finally: + os.chdir(cwdir) + + +# pylint: disable=invalid-name +main = unittest.main +skip = unittest.skip +skipIf = unittest.skipIf +# pylint: enable=invalid-name diff --git a/fire/testutils_test.py b/fire/testutils_test.py new file mode 100644 index 00000000..4cfc0937 --- /dev/null +++ b/fire/testutils_test.py @@ -0,0 +1,53 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the test utilities for Fire's tests.""" + +import sys + +from fire import testutils + + +class TestTestUtils(testutils.BaseTestCase): + """Let's get meta.""" + + def testNoCheckOnException(self): + with self.assertRaises(ValueError): + with self.assertOutputMatches(stdout='blah'): + raise ValueError() + + def testCheckStdoutOrStderrNone(self): + with self.assertRaisesRegex(AssertionError, 'stdout:'): + with self.assertOutputMatches(stdout=None): + print('blah') + + with self.assertRaisesRegex(AssertionError, 'stderr:'): + with self.assertOutputMatches(stderr=None): + print('blah', file=sys.stderr) + + with self.assertRaisesRegex(AssertionError, 'stderr:'): + with self.assertOutputMatches(stdout='apple', stderr=None): + print('apple') + print('blah', file=sys.stderr) + + def testCorrectOrderingOfAssertRaises(self): + # Check to make sure FireExit tests are correct. + with self.assertOutputMatches(stdout='Yep.*first.*second'): + with self.assertRaises(ValueError): + print('Yep, this is the first line.\nThis is the second.') + raise ValueError() + + +if __name__ == '__main__': + testutils.main() diff --git a/fire/trace.py b/fire/trace.py index a901fbd1..4a6d4776 100644 --- a/fire/trace.py +++ b/fire/trace.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,21 +25,20 @@ component will be None. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +import shlex -import pipes +from fire import inspectutils INITIAL_COMPONENT = 'Initial component' INSTANTIATED_CLASS = 'Instantiated class' CALLED_ROUTINE = 'Called routine' +CALLED_CALLABLE = 'Called callable' ACCESSED_PROPERTY = 'Accessed property' COMPLETION_SCRIPT = 'Generated completion script' INTERACTIVE_MODE = 'Entered interactive mode' -class FireTrace(object): +class FireTrace: """A FireTrace represents the steps taken during a single Fire execution. A FireTrace consists of a sequence of FireTraceElement objects. Each element @@ -63,7 +62,9 @@ def __init__(self, initial_component, name=None, separator='-', verbose=False, def GetResult(self): """Returns the component from the last element of the trace.""" + # pytype: disable=attribute-error return self.GetLastHealthyElement().component + # pytype: enable=attribute-error def GetLastHealthyElement(self): """Returns the last element of the trace that is not an error. @@ -76,6 +77,7 @@ def GetLastHealthyElement(self): for element in reversed(self.elements): if not element.HasError(): return element + return self.elements[0] # The initial element is always healthy. def HasError(self): """Returns whether the Fire execution encountered a Fire usage error.""" @@ -92,44 +94,24 @@ def AddAccessedProperty(self, component, target, args, filename, lineno): ) self.elements.append(element) - def AddCalledRoutine(self, component, target, args, filename, lineno, - capacity): - """Adds an element to the trace indicating that a routine was called. + def AddCalledComponent(self, component, target, args, filename, lineno, + capacity, action=CALLED_CALLABLE): + """Adds an element to the trace indicating that a component was called. - Args: - component: The result of calling the routine. - target: The name of the routine. - args: The args consumed in order to call this routine. - filename: The file in which the routine is defined, or None if N/A. - lineno: The line number on which the routine is defined, or None if N/A. - capacity: (bool) Whether the routine could have accepted additional args. - """ - element = FireTraceElement( - component=component, - action=CALLED_ROUTINE, - target=target, - args=args, - filename=filename, - lineno=lineno, - capacity=capacity, - ) - self.elements.append(element) - - def AddInstantiatedClass(self, component, target, args, filename, lineno, - capacity): - """Adds an element to the trace indicating that a class was instantiated. + Also applies to instantiating a class. Args: - component: The result of instantiating the class. - target: The name of the class. - args: The args consumed in order to instantiate the class. - filename: The file in which the class is defined, or None if N/A. - lineno: The line number on which the class is defined, or None if N/A. - capacity: (bool) Whether cls.__init__ could have accepted additional args. + component: The result of calling the callable. + target: The name of the callable. + args: The args consumed in order to call this callable. + filename: The file in which the callable is defined, or None if N/A. + lineno: The line number on which the callable is defined, or None if N/A. + capacity: (bool) Whether the callable could have accepted additional args. + action: The value to include as the action in the FireTraceElement. """ element = FireTraceElement( component=component, - action=INSTANTIATED_CLASS, + action=action, target=target, args=args, filename=filename, @@ -180,12 +162,15 @@ def display(arg1, arg2='!'): def _Quote(self, arg): if arg.startswith('--') and '=' in arg: prefix, value = arg.split('=', 1) - return pipes.quote(prefix) + '=' + pipes.quote(value) - return pipes.quote(arg) + return shlex.quote(prefix) + '=' + shlex.quote(value) + return shlex.quote(arg) - def GetCommand(self): + def GetCommand(self, include_separators=True): """Returns the command representing the trace up to this point. + Args: + include_separators: Whether or not to include separators in the command. + Returns: A string representing a Fire CLI command that would produce this trace. """ @@ -198,10 +183,10 @@ def GetCommand(self): continue if element.args: args.extend(element.args) - if element.HasSeparator(): + if element.HasSeparator() and include_separators: args.append(self.separator) - if self.NeedsSeparator(): + if self.NeedsSeparator() and include_separators: args.append(self.separator) return ' '.join(self._Quote(arg) for arg in args) @@ -225,16 +210,35 @@ def NeedsSeparator(self): return element.HasCapacity() and not element.HasSeparator() def __str__(self): - return '\n'.join( - '{index}. {trace_string}'.format( - index=index + 1, - trace_string=element, - ) - for index, element in enumerate(self.elements) - ) + lines = [] + for index, element in enumerate(self.elements): + line = f'{index + 1}. {element}' + lines.append(line) + return '\n'.join(lines) + + def NeedsSeparatingHyphenHyphen(self, flag='help'): + """Returns whether a the trace need '--' before '--help'. + '--' is needed when the component takes keyword arguments, when the value of + flag matches one of the argument of the component, or the component takes in + keyword-only arguments(e.g. argument with default value). + + Args: + flag: the flag available for the trace -class FireTraceElement(object): + Returns: + True for needed '--', False otherwise. + + """ + element = self.GetLastHealthyElement() + component = element.component + spec = inspectutils.GetFullArgSpec(component) + return (spec.varkw is not None + or flag in spec.args + or flag in spec.kwonlyargs) + + +class FireTraceElement: """A FireTraceElement represents a single step taken by a Fire execution. Examples of a FireTraceElement are the instantiation of a class or the @@ -254,7 +258,7 @@ def __init__(self, Args: component: The result of this element of the trace. - action: The type of action (eg instantiating a class) taking place. + action: The type of action (e.g. instantiating a class) taking place. target: (string) The name of the component being acted upon. args: The args consumed by the represented action. filename: The file in which the action is defined, or None if N/A. @@ -284,18 +288,21 @@ def HasSeparator(self): def AddSeparator(self): self._separator = True + def ErrorAsStr(self): + return ' '.join(str(arg) for arg in self._error.args) + def __str__(self): - if not self.HasError(): + if self.HasError(): + return self.ErrorAsStr() + else: # Format is: {action} "{target}" ({filename}:{lineno}) string = self._action if self._target is not None: - string += ' "{target}"'.format(target=self._target) + string += f' "{self._target}"' if self._filename is not None: path = self._filename if self._lineno is not None: - path += ':{lineno}'.format(lineno=self._lineno) + path += f':{self._lineno}' - string += ' ({path})'.format(path=path) + string += f' ({path})' return string - else: - return str(self._error) diff --git a/fire/trace_test.py b/fire/trace_test.py index f1f88273..1f858f5e 100644 --- a/fire/trace_test.py +++ b/fire/trace_test.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,16 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +"""Tests for the trace module.""" +from fire import testutils from fire import trace -import unittest - -class FireTraceTest(unittest.TestCase): +class FireTraceTest(testutils.BaseTestCase): def testFireTraceInitialization(self): t = trace.FireTrace(10) @@ -50,10 +47,20 @@ def testAddAccessedProperty(self): str(t), '1. Initial component\n2. Accessed property "prop" (sample.py:12)') + def testAddCalledCallable(self): + t = trace.FireTrace('initial object') + args = ('example', 'args') + t.AddCalledComponent('result', 'cell', args, 'sample.py', 10, False, + action=trace.CALLED_CALLABLE) + self.assertEqual( + str(t), + '1. Initial component\n2. Called callable "cell" (sample.py:10)') + def testAddCalledRoutine(self): t = trace.FireTrace('initial object') args = ('example', 'args') - t.AddCalledRoutine('result', 'run', args, 'sample.py', 12, False) + t.AddCalledComponent('result', 'run', args, 'sample.py', 12, False, + action=trace.CALLED_ROUTINE) self.assertEqual( str(t), '1. Initial component\n2. Called routine "run" (sample.py:12)') @@ -61,8 +68,9 @@ def testAddCalledRoutine(self): def testAddInstantiatedClass(self): t = trace.FireTrace('initial object') args = ('example', 'args') - t.AddInstantiatedClass( - 'Classname', 'classname', args, 'sample.py', 12, False) + t.AddCalledComponent( + 'Classname', 'classname', args, 'sample.py', 12, False, + action=trace.INSTANTIATED_CLASS) target = """1. Initial component 2. Instantiated class "classname" (sample.py:12)""" self.assertEqual(str(t), target) @@ -84,23 +92,26 @@ def testAddInteractiveMode(self): def testGetCommand(self): t = trace.FireTrace('initial object') args = ('example', 'args') - t.AddCalledRoutine('result', 'run', args, 'sample.py', 12, False) + t.AddCalledComponent('result', 'run', args, 'sample.py', 12, False, + action=trace.CALLED_ROUTINE) self.assertEqual(t.GetCommand(), 'example args') def testGetCommandWithQuotes(self): t = trace.FireTrace('initial object') args = ('example', 'spaced arg') - t.AddCalledRoutine('result', 'run', args, 'sample.py', 12, False) + t.AddCalledComponent('result', 'run', args, 'sample.py', 12, False, + action=trace.CALLED_ROUTINE) self.assertEqual(t.GetCommand(), "example 'spaced arg'") def testGetCommandWithFlagQuotes(self): t = trace.FireTrace('initial object') args = ('--example=spaced arg',) - t.AddCalledRoutine('result', 'run', args, 'sample.py', 12, False) + t.AddCalledComponent('result', 'run', args, 'sample.py', 12, False, + action=trace.CALLED_ROUTINE) self.assertEqual(t.GetCommand(), "--example='spaced arg'") -class FireTraceElementTest(unittest.TestCase): +class FireTraceElementTest(testutils.BaseTestCase): def testFireTraceElementHasError(self): el = trace.FireTraceElement() @@ -136,4 +147,4 @@ def testFireTraceElementAsStringWithTargetAndLineNo(self): if __name__ == '__main__': - unittest.main() + testutils.main() diff --git a/fire/value_types.py b/fire/value_types.py new file mode 100644 index 00000000..81308973 --- /dev/null +++ b/fire/value_types.py @@ -0,0 +1,80 @@ +# Copyright (C) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Types of values.""" + +import inspect + +from fire import inspectutils + + +VALUE_TYPES = (bool, str, bytes, int, float, complex, + type(Ellipsis), type(None), type(NotImplemented)) + + +def IsGroup(component): + # TODO(dbieber): Check if there are any subcomponents. + return not IsCommand(component) and not IsValue(component) + + +def IsCommand(component): + return inspect.isroutine(component) or inspect.isclass(component) + + +def IsValue(component): + return isinstance(component, VALUE_TYPES) or HasCustomStr(component) + + +def IsSimpleGroup(component): + """If a group is simple enough, then we treat it as a value in PrintResult. + + Only if a group contains all value types do we consider it simple enough to + print as a value. + + Args: + component: The group to check for value-group status. + Returns: + A boolean indicating if the group should be treated as a value for printing + purposes. + """ + assert isinstance(component, dict) + for unused_key, value in component.items(): + if not IsValue(value) and not isinstance(value, (list, dict)): + return False + return True + + +def HasCustomStr(component): + """Determines if a component has a custom __str__ method. + + Uses inspect.classify_class_attrs to determine the origin of the object's + __str__ method, if one is present. If it defined by `object` itself, then + it is not considered custom. Otherwise it is. This means that the __str__ + methods of primitives like ints and floats are considered custom. + + Objects with custom __str__ methods are treated as values and can be + serialized in places where more complex objects would have their help screen + shown instead. + + Args: + component: The object to check for a custom __str__ method. + Returns: + Whether `component` has a custom __str__ method. + """ + if hasattr(component, '__str__'): + class_attrs = inspectutils.GetClassAttrsDict(type(component)) or {} + str_attr = class_attrs.get('__str__') + if str_attr and str_attr.defining_class is not object: + return True + return False diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 00000000..bbe1e848 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,11 @@ +site_name: Python Fire +theme: readthedocs +markdown_extensions: [fenced_code] +nav: + - Overview: index.md + - Installation: installation.md + - Benefits: benefits.md + - The Python Fire Guide: guide.md + - Using a CLI: using-cli.md + - Troubleshooting: troubleshooting.md + - Reference: api.md diff --git a/pylintrc b/pylintrc new file mode 100644 index 00000000..8896bb5b --- /dev/null +++ b/pylintrc @@ -0,0 +1,210 @@ +[MASTER] + +# Specify a configuration file. +#rcfile= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Add to the black list. It should be a base name, not a +# path. You may set this option multiple times. +ignore= + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + + +[MESSAGES CONTROL] + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time. +enable=indexing-exception,old-raise-syntax + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifier separated by comma (,) or put this option +# multiple time. +disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,file-ignored,wrong-import-order,useless-object-inheritance,no-else-return,super-with-arguments,raise-missing-from,consider-using-f-string,unspecified-encoding,unnecessary-lambda-assignment,wrong-import-position,ungrouped-imports,deprecated-module + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html +output-format=text + +# Tells whether to display a full report or only the messages +reports=yes + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (R0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching names used for dummy variables (i.e. not used). +dummy-variables-rgx=\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + + +[BASIC] + +# Regular expression which should only match correct module names +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Regular expression which should only match correct module level names +const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Regular expression which should only match correct class names +class-rgx=[A-Z_][a-zA-Z0-9]+$ + +# Regular expression which should only match correct function names +function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression which should only match correct method names +method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}(?:test|assert)?[A-Z][a-zA-Z0-9]*)|(?:_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match correct instance attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct list comprehension / +# generator expression variable names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Good variable names which should always be accepted, separated by a comma +good-names=i,j,k,ex,main,Run,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names=map,filter,apply,input,reduce,foo,bar,baz,toto,tutu,tata + +# Regular expression which should only match functions or classes name which do +# not require a docstring +no-docstring-rgx=(__.*__|main|test.*|.*Test) + +# Minimum length for a docstring +docstring-min-length=10 + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME,XXX,TODO + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=80 + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + + +[TYPECHECK] + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of classes names for which member attributes should not be checked +# (useful for classes with attributes dynamically set). +ignored-classes= + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E0201 when accessed. +generated-members= + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=5 + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.* + +# Maximum number of locals for function / method body +max-locals=15 + +# Maximum number of return / yield for function / method body +max-returns=6 + +# Maximum number of branch for function / method body +max-branches=12 + +# Maximum number of statements in function / method body +max-statements=50 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub,string,TERMIOS,Bastion,rexec + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..9c558e35 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +. diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..ed53d83b --- /dev/null +++ b/setup.cfg @@ -0,0 +1,10 @@ +[aliases] +test = pytest + +[tool:pytest] +addopts = --ignore=fire/test_components_py3.py + --ignore=fire/parser_fuzz_test.py + +[pytype] +inputs = . +output = .pytype diff --git a/setup.py b/setup.py index 83dd63a9..8d4a381b 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 Google Inc. +# Copyright (C) 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ from setuptools import setup -long_description = """ +LONG_DESCRIPTION = """ Python Fire is a library for automatically generating command line interfaces (CLIs) with a single line of code. @@ -25,32 +25,30 @@ it fires off your command. """.strip() -short_description = """ -A library for automatically generating commane line interfaces.""".strip() +SHORT_DESCRIPTION = """ +A library for automatically generating command line interfaces.""".strip() -dependencies = [ - 'ipython', - 'six', +DEPENDENCIES = [ + 'termcolor', ] -test_dependencies = [ +TEST_DEPENDENCIES = [ 'hypothesis', - 'mock', - 'python-Levenshtein', + 'levenshtein', ] +VERSION = '0.7.0' +URL = 'https://github.com/google/python-fire' + setup( name='fire', - version='0.1.0', - - description=short_description, - long_description=long_description, - - url='https://github.com/google/python-fire', + version=VERSION, + description=SHORT_DESCRIPTION, + long_description=LONG_DESCRIPTION, + url=URL, author='David Bieber', author_email='dbieber@google.com', - license='Apache Software License', classifiers=[ @@ -62,9 +60,14 @@ 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', 'Operating System :: OS Independent', 'Operating System :: POSIX', @@ -74,8 +77,9 @@ keywords='command line interface cli python fire interactive bash tool', - packages=['fire'], + requires_python='>=3.7', + packages=['fire', 'fire.console'], - install_requires=dependencies, - tests_require=test_dependencies, + install_requires=DEPENDENCIES, + tests_require=TEST_DEPENDENCIES, )