From b5b4ba3608b5af89aa1898033154bdb15d073c1b Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 13 Nov 2023 18:27:04 +0100 Subject: [PATCH 1/2] Rewrites len(..) == 0 into not .. --- _unittests/ut_light_api/test_light_api.py | 2 +- _unittests/ut_validation/test_f8.py | 2 +- .../ut_xrun_doc/test_documentation_examples.py | 4 ++-- onnx_array_api/ext_test_case.py | 4 ++-- onnx_array_api/light_api/emitter.py | 2 +- onnx_array_api/light_api/translate.py | 2 +- onnx_array_api/npx/npx_graph_builder.py | 2 +- onnx_array_api/npx/npx_helper.py | 4 ++-- onnx_array_api/npx/npx_jit_eager.py | 12 ++++++------ onnx_array_api/npx/npx_numpy_tensors.py | 6 +++--- onnx_array_api/npx/npx_var.py | 8 ++++---- onnx_array_api/plotting/_helper.py | 5 +---- onnx_array_api/plotting/dot_plot.py | 6 +++--- onnx_array_api/plotting/text_plot.py | 14 +++++++------- onnx_array_api/profiling.py | 4 ++-- .../reference/ops/op_constant_of_shape.py | 2 +- onnx_array_api/validation/tools.py | 2 +- setup.py | 4 ++-- 18 files changed, 41 insertions(+), 44 deletions(-) diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py index 773819a..98dd64d 100644 --- a/_unittests/ut_light_api/test_light_api.py +++ b/_unittests/ut_light_api/test_light_api.py @@ -138,7 +138,7 @@ def list_ops_missing(self, n_inputs): methods.append("") new_missing.append(m) text = "\n".join(methods) - if len(new_missing) > 0: + if new_missing: raise AssertionError( f"n_inputs={n_inputs}: missing method for operators " f"{new_missing}\n{text}" diff --git a/_unittests/ut_validation/test_f8.py b/_unittests/ut_validation/test_f8.py index b44683f..85f27aa 100644 --- a/_unittests/ut_validation/test_f8.py +++ b/_unittests/ut_validation/test_f8.py @@ -344,7 +344,7 @@ def test_search_float32_into_fe5m2(self): add = value else: add = v - value - if len(w) > 0: + if w: raise AssertionError( f"A warning was thrown for v={v}, " f"value={value}, w={w[0]}." diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index e3f9206..170e82b 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -26,7 +26,7 @@ def import_source(module_file_path, module_name): class TestDocumentationExamples(ExtTestCase): def run_test(self, fold: str, name: str, verbose=0) -> int: ppath = os.environ.get("PYTHONPATH", "") - if len(ppath) == 0: + if not ppath: os.environ["PYTHONPATH"] = ROOT elif ROOT not in ppath: sep = ";" if is_windows() else ":" @@ -42,7 +42,7 @@ def run_test(self, fold: str, name: str, verbose=0) -> int: res = p.communicate() out, err = res st = err.decode("ascii", errors="ignore") - if len(st) > 0 and "Traceback" in st: + if st and "Traceback" in st: if '"dot" not found in path.' in st: # dot not installed, this part # is tested in onnx framework diff --git a/onnx_array_api/ext_test_case.py b/onnx_array_api/ext_test_case.py index c8aec35..6e412b3 100644 --- a/onnx_array_api/ext_test_case.py +++ b/onnx_array_api/ext_test_case.py @@ -228,7 +228,7 @@ def assertRaise(self, fct: Callable, exc_type: Exception): def assertEmpty(self, value: Any): if value is None: return - if len(value) == 0: + if value: return raise AssertionError(f"value is not empty: {value!r}.") @@ -240,7 +240,7 @@ def assertNotEmpty(self, value: Any): if value is None: raise AssertionError(f"value is empty: {value!r}.") if isinstance(value, (list, dict, tuple, set)): - if len(value) == 0: + if value: raise AssertionError(f"value is empty: {value!r}.") def assertStartsWith(self, prefix: str, full: str): diff --git a/onnx_array_api/light_api/emitter.py b/onnx_array_api/light_api/emitter.py index 4457c55..c52acfc 100644 --- a/onnx_array_api/light_api/emitter.py +++ b/onnx_array_api/light_api/emitter.py @@ -85,7 +85,7 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: if isinstance(v, str): return [], f"{v!r}" if isinstance(v, np.ndarray): - if len(v.shape) == 0: + if not v.shape: return [], str(v) if len(v.shape) == 1: if value[0].type in ( diff --git a/onnx_array_api/light_api/translate.py b/onnx_array_api/light_api/translate.py index 7932693..a61ce24 100644 --- a/onnx_array_api/light_api/translate.py +++ b/onnx_array_api/light_api/translate.py @@ -51,7 +51,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: else: raise ValueError(f"Unexpected type {type(self.proto_)} for proto.") - if len(sparse_initializers) != 0: + if sparse_initializers: raise NotImplementedError("Sparse initializer not supported yet.") rows.extend( diff --git a/onnx_array_api/npx/npx_graph_builder.py b/onnx_array_api/npx/npx_graph_builder.py index 3dd842c..4496d79 100644 --- a/onnx_array_api/npx/npx_graph_builder.py +++ b/onnx_array_api/npx/npx_graph_builder.py @@ -919,7 +919,7 @@ def to_onnx( [(var, i, None) for i in range(var.n_var_outputs)] ) - if len(possible_types) > 0: + if possible_types: # converts possibles types into a dictionary map_types = {} for var, i, dt in possible_types: diff --git a/onnx_array_api/npx/npx_helper.py b/onnx_array_api/npx/npx_helper.py index f86aadc..34d9af3 100644 --- a/onnx_array_api/npx/npx_helper.py +++ b/onnx_array_api/npx/npx_helper.py @@ -47,7 +47,7 @@ def _process_attributes(attributes): nodes = [] modified = False for node in graph.node: - if len(set(node.input) & set_rep) == 0: + if not (set(node.input) & set_rep): modified = True new_inputs = [replacements.get(i, i) for i in node.input] atts = _process_attributes(node.attribute) or node.attribute @@ -66,7 +66,7 @@ def _process_attributes(attributes): if not modified: return None - if len(set(i.name for i in graph.input) & set_rep) == 0: + if not (set(i.name for i in graph.input) & set_rep): return make_graph(nodes, graph.name, graph.input, graph.output) new_inputs = [] diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index ef24af7..20becbd 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -253,7 +253,7 @@ def to_jit(self, *values, **kwargs): """ self.info("+", "to_jit", args=values, kwargs=kwargs) annotations = self.f.__annotations__ - if len(annotations) > 0: + if annotations: input_to_kwargs = {} kwargs_to_input = {} names = list(annotations.keys()) @@ -352,10 +352,10 @@ def to_jit(self, *values, **kwargs): if iname in constraints ] names = [i.name for i in inputs] - if len(new_kwargs) > 0: + if new_kwargs: # An attribute is not named in the numpy API # but is the ONNX definition. - if len(kwargs) == 0: + if not kwargs: kwargs = new_kwargs else: kwargs = kwargs.copy() @@ -375,13 +375,13 @@ def to_jit(self, *values, **kwargs): target_opsets=self.target_opsets, ir_version=self.ir_version, ) - if len(values) > 0 and len(values[0].shape) == 0: + if values and not values[0].shape: inps = onx.graph.input[0] shape = [] for d in inps.type.tensor_type.shape.dim: v = d.dim_value if d.dim_value > 0 else d.dim_param shape.append(v) - if len(shape) != 0: + if shape: raise RuntimeError( f"Shape mismatch, values[0]={values[0]} " f"and inputs={onx.graph.input}." @@ -441,7 +441,7 @@ def move_input_to_kwargs( f"self.input_to_kwargs_ is not initialized for function {self.f} " f"from module {self.f.__module__!r}." ) - if len(self.input_to_kwargs_) == 0: + if not self.input_to_kwargs_: return values, kwargs new_values = [] new_kwargs = kwargs.copy() diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index a106b95..68a4da7 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -220,7 +220,7 @@ def __bool__(self): ) if self.shape == (0,): return False - if len(self.shape) != 0: + if self.shape: warnings.warn( f"Conversion to bool only works for scalar, not for {self!r}, " f"bool(...)={bool(self._tensor)}." @@ -233,7 +233,7 @@ def __bool__(self): def __int__(self): "Implicit conversion to int." - if len(self.shape) != 0: + if self.shape: raise ValueError( f"Conversion to bool only works for scalar, not for {self!r}." ) @@ -255,7 +255,7 @@ def __int__(self): def __float__(self): "Implicit conversion to float." - if len(self.shape) != 0: + if self.shape: raise ValueError( f"Conversion to bool only works for scalar, not for {self!r}." ) diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index 27f5455..ca8af0d 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -174,7 +174,7 @@ def to_onnx( f"Mismatch number of outputs, expecting {len(outputs)}, " f"got ({len(onx.output)})." ) - if len(g.functions_) > 0: + if g.functions_: return [g.functions_, onx] return onx @@ -1020,7 +1020,7 @@ def __getitem__(self, index: Any) -> "Var": if not isinstance(index, tuple): index = (index,) - elif len(index) == 0: + elif not index: # The array contains a scalar and it needs to be returned. return var(self, op="Identity") @@ -1091,7 +1091,7 @@ def __getitem__(self, index: Any) -> "Var": starts = np.array(starts, dtype=np.int64) axes = np.array(axes, dtype=np.int64) - if len(needs_shape) > 0: + if needs_shape: shape = self.shape conc = [] for e in ends: @@ -1116,7 +1116,7 @@ def __getitem__(self, index: Any) -> "Var": sliced_args.append(steps) sliced_args_cst = [v if isinstance(v, Var) else cst(v) for v in sliced_args] sliced = var(self.self_var, *sliced_args_cst, op="Slice") - if len(axis_squeeze) > 0: + if axis_squeeze: return var( sliced, cst(np.array(axis_squeeze, dtype=np.int64)), diff --git a/onnx_array_api/plotting/_helper.py b/onnx_array_api/plotting/_helper.py index 21179ab..3131177 100644 --- a/onnx_array_api/plotting/_helper.py +++ b/onnx_array_api/plotting/_helper.py @@ -120,10 +120,7 @@ def get_tensor_shape(obj): for d in obj.tensor_type.shape.dim: v = d.dim_value if d.dim_value > 0 else d.dim_param shape.append(v) - if len(shape) == 0: - shape = None - else: - shape = list(None if s == 0 else s for s in shape) + shape = None if not shape else list(None if s == 0 else s for s in shape) return shape diff --git a/onnx_array_api/plotting/dot_plot.py b/onnx_array_api/plotting/dot_plot.py index fd23f79..cff93f5 100644 --- a/onnx_array_api/plotting/dot_plot.py +++ b/onnx_array_api/plotting/dot_plot.py @@ -242,7 +242,7 @@ def dot_label(text): for node in nodes: exp.append("") for out in node.output: - if len(out) > 0 and out not in inter_vars: + if out and out not in inter_vars: inter_vars[out] = out sh = shapes.get(out, "") if sh: @@ -318,7 +318,7 @@ def dot_label(text): f"{dot_name(subprefix)}{dot_name(inp2.name)};" ) for out1, out2 in zip(body.output, node.output): - if len(out2) == 0: + if not out2: # Empty output, it cannot be used. continue exp.append( @@ -346,7 +346,7 @@ def dot_label(text): f"{dot_name(prefix)}{dot_name(node.name)};" ) for out in node.output: - if len(out) == 0: + if not out: # Empty output, it cannot be used. continue exp.append( diff --git a/onnx_array_api/plotting/text_plot.py b/onnx_array_api/plotting/text_plot.py index 8736d97..36f9feb 100644 --- a/onnx_array_api/plotting/text_plot.py +++ b/onnx_array_api/plotting/text_plot.py @@ -75,7 +75,7 @@ def append_target(self, tid, weight): def process_node(self): "node to string" if self.nodes_modes == "LEAF": - if len(self.targets) == 0: + if not self.targets: text = f"{self.true_false}f" elif len(self.targets) == 1: t = self.targets[0] @@ -264,7 +264,7 @@ def _append_succ_pred_s( unknown.add(i) for i in n.output: known[i] = n - if len(unknown) > 0: + if unknown: # These inputs are coming from the graph below. for name in unknown: successors[name].append(parent_node_name) @@ -402,7 +402,7 @@ def _find_sequence(node_name, known, done): % (k, ",".join(sequences[k]), list(sequences)) ) - if len(sequences) == 0: + if not sequences: raise RuntimeError( # pragma: no cover "Unexpected empty sequence (len(possibles)=%d, " "len(done)=%d, len(nodes)=%d). This is usually due to " @@ -417,7 +417,7 @@ def _find_sequence(node_name, known, done): # if the sequence of successors is longer best = k elif len(v) == len(sequences[best]): - if len(new_nodes) > 0: + if new_nodes: # then choose the next successor sharing input with # previous output so = set(new_nodes[-1].output) @@ -808,7 +808,7 @@ def str_node(indent, node): val = ".%d" % att.type atts.append(f"{att.name}={val}") inputs = list(node.input) - if len(atts) > 0: + if atts: inputs.extend(atts) if node.domain in ("", "ai.onnx.ml"): domain = "" @@ -917,7 +917,7 @@ def str_node(indent, node): indent = previous_indent else: inds = [indents.get(i, 0) for i in node.input if i not in init_names] - if len(inds) == 0: + if not inds: indent = 0 else: mi = min(inds) @@ -929,7 +929,7 @@ def str_node(indent, node): ) add_break = True if not add_break and previous_out is not None: - if len(set(node.input) & previous_out) == 0: + if not (set(node.input) & previous_out): if verbose: print(f"[onnx_simple_text_plot] break3 {node.op_type}") add_break = True diff --git a/onnx_array_api/profiling.py b/onnx_array_api/profiling.py index 51d5ad7..52c464a 100644 --- a/onnx_array_api/profiling.py +++ b/onnx_array_api/profiling.py @@ -71,7 +71,7 @@ def get_root(self): def _get_root(node, stor=None): if stor is not None: stor.append(node) - if len(node.called_by) == 0: + if not node.called_by: return node if len(node.called_by) == 1: return _get_root(node.called_by[0], stor=stor) @@ -383,7 +383,7 @@ def walk(node, roots_keys, indent=0): continue child[key] = walk(n, roots_key, indent + 1) - if len(child) > 0: + if child: mx = max(_[0] for _ in child) dg = int(math.log(mx) / math.log(10) + 1.5) form = f"%-{dg}d-%s" diff --git a/onnx_array_api/reference/ops/op_constant_of_shape.py b/onnx_array_api/reference/ops/op_constant_of_shape.py index 33308af..00c6989 100644 --- a/onnx_array_api/reference/ops/op_constant_of_shape.py +++ b/onnx_array_api/reference/ops/op_constant_of_shape.py @@ -7,7 +7,7 @@ class ConstantOfShape(OpRun): def _process(value): cst = value[0] if isinstance(value, np.ndarray) and value.size > 0 else value if isinstance(value, np.ndarray): - if len(value.shape) == 0: + if not value.shape: cst = value elif value.size > 0: cst = value.ravel()[0] diff --git a/onnx_array_api/validation/tools.py b/onnx_array_api/validation/tools.py index f4628db..6cd1da3 100644 --- a/onnx_array_api/validation/tools.py +++ b/onnx_array_api/validation/tools.py @@ -49,7 +49,7 @@ def randomize_proto( doc_string=onx.doc_string, opset_imports=list(onx.opset_import), ) - if len(onx.metadata_props) > 0: + if onx.metadata_props: values = {p.key: p.value for p in onx.metadata_props} set_model_props(onnx_model, values) return onnx_model diff --git a/setup.py b/setup.py index 928f93f..bc4e87e 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ requirements = f.read().strip(" \n\r\t").split("\n") except FileNotFoundError: requirements = [] -if len(requirements) == 0 or requirements == [""]: +if not requirements or requirements == [""]: requirements = ["numpy", "scipy", "onnx"] try: @@ -34,7 +34,7 @@ for _ in [_.strip("\r\n ") for _ in f.readlines()] if _.startswith("__version__") ] - if len(line) > 0: + if line: version_str = line[0].split("=")[1].strip('" ') From d78bb39358480ba6f37b2d8074e918b42dccad50 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 13 Nov 2023 18:40:39 +0100 Subject: [PATCH 2/2] fix bug --- onnx_array_api/ext_test_case.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnx_array_api/ext_test_case.py b/onnx_array_api/ext_test_case.py index 6e412b3..1068bda 100644 --- a/onnx_array_api/ext_test_case.py +++ b/onnx_array_api/ext_test_case.py @@ -226,9 +226,7 @@ def assertRaise(self, fct: Callable, exc_type: Exception): raise AssertionError("No exception was raised.") def assertEmpty(self, value: Any): - if value is None: - return - if value: + if not value: return raise AssertionError(f"value is not empty: {value!r}.")