8000 Load PyInfo from rules_python · IBMZ-Linux-OSS-Python/tensorflow@ae59c9f · GitHub
[go: up one dir, main page]

Skip to content

Commit ae59c9f

Browse files
rickeylevtensorflower-gardener
authored andcommitted
Load PyInfo from rules_python
This is to facilitate upgrading of rules_python and other parts of Bazel. More specifically, this makes the code compatible with both PyInfo from rules_python and the Bazel builtin PyInfo. It does this by making rules accept and produce both providers. This allows the code to be compatible with a mixture of Bazel builtin `py_*` rules and rules_python `py_*` rules, which make upgrading rules_python and Bazel easier in the future. While I'd like to avoid this logic, I saw some errors about the wrong PyInfo being accepted/produced. I wasn't able to track down the particular edge, nor figure out if it was one we could even fix if found. Making it compatible for both cases avoids the problem entirely. Additional logic is present to handle two cases that will occur as part of upgrading to Bazel 8 and higher: * The builtin PyInfo can be an alias to rules_python PyInfo. This occurs when Bazel's autoloading is enabled. Because accepting/producing the same provider is an error, additional logic is needed to only use one. * The builtin PyInfo can be None or a stub provider. This occurs occurs when Bazel's autoloading is disabled and builtin providers are also disabled. Because None or a stub provider would be an error, additional logic is needed to use only one. PiperOrigin-RevId: 766268147
1 parent 180f95f commit ae59c9f

File tree

3 files changed

+103
-44
lines changed

3 files changed

+103
-44
lines changed

tensorflow/python/tools/api/generator2/generate_api.bzl

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,34 @@
11
"""Rules to generate the TensorFlow public API from annotated files."""
22

3-
# Placeholder: load PyInfo
43
load("@bazel_skylib//lib:paths.bzl", "paths")
4+
load("@rules_python//python:py_info.bzl", RulesPythonPyInfo = "PyInfo")
5+
load("@rules_python//python/api:api.bzl", "py_common")
56
load("//tensorflow/python/tools/api/generator:api_init_files.bzl", "TENSORFLOW_API_INIT_FILES")
67
load(":apis.bzl", _APIS = "APIS")
78
load(":patterns.bzl", "any_match")
89

10+
def _get_builtin_py_info():
11+
# May be None in Bazel 8+
12+
if PyInfo == None:
13+
return None
14+
15+
# Bazel 8's autoloading may make them the same
16+
if PyInfo == RulesPythonPyInfo:
17+
return None
18+
19+
# Within Google, it is aliased to a stub provider
20+
if "unimplemented" in str(PyInfo):
21+
return None
22+
return PyInfo
23+
24+
_BuiltinPyInfo = _get_builtin_py_info()
25+
_py_info_providers = [[RulesPythonPyInfo]] + (
26+
[[_BuiltinPyInfo]] if _BuiltinPyInfo else []
27+
)
28+
_py_info_provides = [RulesPythonPyInfo] + (
29+
[_BuiltinPyInfo] if _BuiltinPyInfo else []
30+
)
31+
932
APIS = _APIS.keys()
1033

1134
_MODULE_PREFIX = ""
@@ -28,30 +51,15 @@ def _py_files(f):
2851
return f.path
2952
return None
3053

31-
def _merge_py_info(
32-
deps,
33-
direct_sources = None,
34-
direct_imports = None,
35-
has_py2_only_sources = False,
36-
has_py3_only_sources = False,
37-
uses_shared_libraries = False):
38-
transitive_sources = []
39-
transitive_imports = []
40-
for dep in deps:
41-
if PyInfo in dep:
42-
transitive_sources.append(dep[PyInfo].transitive_sources)
43-
transitive_imports.append(dep[PyInfo].imports)
44-
has_py2_only_sources = has_py2_only_sources or dep[PyInfo].has_py2_only_sources
45-
has_py3_only_sources = has_py3_only_sources or dep[PyInfo].has_py3_only_sources
46-
uses_shared_libraries = uses_shared_libraries or dep[PyInfo].uses_shared_libraries
47-
48-
return PyInfo(
49-
transitive_sources = depset(direct = direct_sources, transitive = transitive_sources),
50-
imports = depset(direct = direct_imports, transitive = transitive_imports),
51-
has_py2_only_sources = has_py2_only_sources,
52-
has_py3_only_sources = has_py3_only_sources,
53-
uses_shared_libraries = uses_shared_libraries,
54-
)
54+
def _merge_py_info(ctx, deps):
55+
py_api = py_common.get(ctx)
56+
builder = py_api.PyInfoBuilder()
57+
builder.merge_targets(deps)
58+
infos = [builder.build()]
59+
builtin_info = builder.build_builtin_py_info()
60+
if builtin_info:
61+
infos.append(builtin_info)
62+
return infos
5563

5664
def _merge_api_info(
5765
deps,
@@ -129,8 +137,7 @@ api_extractor = aspect(
129137
def _extract_api_impl(ctx):
130138
return [
131139
_merge_api_info(ctx.attr.deps),
132-
_merge_py_info(ctx.attr.deps),
133-
]
140+
] + _merge_py_info(ctx, ctx.attr.deps)
134141

135142
extract_api = rule(
136143
doc = "Extract Python API for all targets in transitive dependencies.",
@@ -140,16 +147,16 @@ extract_api = rule(
140147
doc = "Targets to extract API from.",
141148
allow_empty = False,
142149
aspects = [api_extractor],
143-
providers = [PyInfo],
150+
providers = _py_info_providers,
144151
mandatory = True,
145152
),
146153
"api": attr.string(
147154
doc = "API to extract from dependencies.",
148155
mandatory = True,
149156
values = APIS,
150157
),
151-
},
152-
provides = [ApiInfo, PyInfo],
158+
} | py_common.API_ATTRS,
159+
provides = [ApiInfo] + _py_info_provides,
153160
)
154161

155162
def _get_module_by_path(dir_path, output_dir):
@@ -239,7 +246,11 @@ generate_api = rule(
239246
"deps": attr.label_list(
240247
doc = "extract_api targets to generate API from.",
241248
allow_empty = True,
242-
providers = [ApiInfo, PyInfo],
249+
providers = [
250+
[ApiInfo, RulesPythonPyInfo],
251+
] + (
252+
[[ApiInfo, _BuiltinPyInfo]] if _BuiltinPyInfo else []
253+
),
243254
mandatory = True,
244255
),
245256
"root_init_template": attr.label(

third_party/xla/third_party/py/python_wheel.bzl

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,29 @@
11
""" Repository and build rules for Python wheels packaging utilities. """
22

3+
load("@rules_python//python:py_info.bzl", RulesPythonPyInfo = "PyInfo")
4+
load("@rules_python//python/api:api.bzl", "py_common")
5+
6+
def _get_builtin_py_info():
7+
# May be None in Bazel 8+
8+
if PyInfo == None:
9+
return None
10+
11+
# Bazel 8's autoloading may make them the same
12+
if PyInfo == RulesPythonPyInfo:
13+
return None
14+
15+
# Within Google, it is aliased to a stub provider
16+
if "unimplemented" in str(PyInfo):
17+
return None
18+
return PyInfo
19+
20+
_BuiltinPyInfo = _get_builtin_py_info()
21+
_py_info_providers = [
22+
[RulesPythonPyInfo],
23+
] + (
24+
[[_BuiltinPyInfo]] if _BuiltinPyInfo else []
25+
)
26+
327
def _get_host_environ(repository_ctx, name, default_value = None):
428
"""Returns the value of an environment variable on the host platform.
529
@@ -133,20 +157,19 @@ Examples:
133157
""" # buildifier: disable=no-effect
134158

135159
def _transitive_py_deps_impl(ctx):
136-
outputs = depset(
137-
[],
138-
transitive = [dep[PyInfo].transitive_sources for dep in ctx.attr.deps],
139-
)
140-
160+
py_api = py_common.get(ctx)
161+
info = py_api.PyInfoBuilder()
162+
info.merge_targets(ctx.attr.deps)
163+
outputs = info.transitive_sources.build()
141164
return DefaultInfo(files = outputs)
142165

143166
_transitive_py_deps = rule(
144167
attrs = {
145168
"deps": attr.label_list(
146169
allow_files = True,
147-
providers = [PyInfo],
170+
providers = _py_info_providers,
148171
),
149-
},
172+
} | py_common.API_ATTRS,
150173
implementation = _transitive_py_deps_impl,
151174
)
152175

@@ -156,7 +179,7 @@ def transitive_py_deps(name, deps = []):
156179

157180
"""Collects python files that a target depends on.
158181
159-
It traverses dependencies of provided targets, collect their direct and
182+
It traverses dependencies of provided targets, collect their direct and
160183
transitive python deps and then return a list of paths to files.
161184
""" # buildifier: disable=no-effect
162185

third_party/xla/third_party/py/rules_pywrap/pywrap.impl.bzl

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,23 @@
11
load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain", "use_cpp_toolchain")
2+
load("@rules_python//python:py_info.bzl", RulesPythonPyInfo = "PyInfo")
23
load("@rules_python//python:py_library.bzl", "py_library")
34

5+
def _get_builtin_py_info():
6+
# May be None in Bazel 8+
7+
if PyInfo == None:
8+
return None
9+
10+
# Bazel 8's autoloading may make them the same
11+
if PyInfo == RulesPythonPyInfo:
12+
return None
13+
14+
# Within Google, it is aliased to a stub provider
15+
if "unimplemented" in str(PyInfo):
16+
return None
17+
return PyInfo
18+
19+
_BuiltinPyInfo = _get_builtin_py_info()
20+
421
PywrapInfo = provider(
522
fields = {
623
"cc_info": "Wrapped CcInfo",
@@ -827,8 +844,8 @@ def _pywrap_info_wrapper_impl(ctx):
827844
ctx.attr.deps[0][DefaultInfo].default_runfiles,
828845
)
829846

830-
return [
831-
PyInfo(transitive_sources = depset()),
847+
providers = [
848+
RulesPythonPyInfo(transitive_sources = depset()),
832849
PywrapInfo(
833850
cc_info = ctx.attr.deps[0][CcInfo],
834851
default_runfiles = default_runfiles,
@@ -839,6 +856,9 @@ def _pywrap_info_wrapper_impl(ctx):
839856
starlark_only = ctx.attr.starlark_only,
840857
),
841858
]
859+
if _BuiltinPyInfo:
860+
providers.append(_BuiltinPyInfo(transitive_sources = depset()))
861+
return providers
842862

843863
_pywrap_info_wrapper = rule(
844864
attrs = {
@@ -859,8 +879,8 @@ def _cc_only_pywrap_info_wrapper_impl(ctx):
859879
ctx.attr.deps[0][DefaultInfo].default_runfiles,
860880
)
861881

862-
return [
863-
PyInfo(transitive_sources = depset()),
882+
providers = [
883+
RulesPythonPyInfo(transitive_sources = depset()),
864884
PywrapInfo(
865885
cc_info = wrapped_dep[CcInfo],
866886
owner = ctx.label,
@@ -871,6 +891,9 @@ def _cc_only_pywrap_info_wrapper_impl(ctx):
871891
starlark_only = False,
872892
),
873893
]
894+
if _BuiltinPyInfo:
895+
providers.append(_BuiltinPyInfo(transitive_sources = depset()))
896+
return providers
874897

875898
_cc_only_pywrap_info_wrapper = rule(
876899
attrs = {
@@ -975,7 +998,9 @@ collected_pywrap_infos = rule(
975998
attrs = {
976999
"deps": attr.label_list(
9771000
aspects = [_pywrap_info_collector_aspect],
978-
providers = [PyInfo],
1001+
providers = [[RulesPythonPyInfo]] + (
1002+
[[_BuiltinPyInfo]] if _BuiltinPyInfo else []
1003+
),
9791004
),
9801005
"pywrap_count": attr.int(mandatory = True, default = 1),
9811006
"starlark_only_pywrap_count": attr.int(mandatory = True, default = 0),

0 commit comments

Comments
 (0)
0