8000 [dynamo] support custom __getattr__ on torch.nn.Modules by davidberard98 · Pull Request #94658 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo] support custom __getattr__ on torch.nn.Modules #94658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

davidberard98
Copy link
Contributor
@davidberard98 davidberard98 commented Feb 11, 2023

Stack from ghstack:

Summary: torch.nn.Module implementations previously did not support custom implementations of __getattr__; if a torch.nn.Module subclass implemented __getattr__ and we tried to access an attribute that was expected to be present in __getattr__, dynamo would not check __getattr__ and would error out with an AttributeError. This PR copies the functionality from UserDefinedObjectVariable into torch.nn.Module so that it also supports __getattr__

Example of a module which previously would fail:

class MyMod(torch.nn.Module):
		def __init__(self):
				super().__init__()
				self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
				self.other_attr = torch.rand((2, 2))

		def __getattr__(self, name):
				custom_dict = self.custom_dict
				if name in custom_dict:
						return custom_dict[name]
				return super().__getattr__(name)

		def forward(self, x):
				return x @ self.other_attr + self.queue[-1]

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

**Summary**: torch.nn.Module implementations previously did not support custom implementations of `__getattr__`; if a torch.nn.Module subclass implemented `__getattr__` and we tried to access an attribute that was expected to be present in `__getattr__`, dynamo would not check `__getattr__` and would error out with an AttributeError. This PR copies the functionality from UserDefinedObjectVariable into torch.nn.Module so that it also supports `__getattr__`

Example of a module which previously would fail:

```python
class MyMod(torch.nn.Module):
		def __init__(self):
				super().__init__()
				self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
				self.other_attr = torch.rand((2, 2))

		def __getattr__(self, name):
				custom_dict = self.custom_dict
				if name in custom_dict:
						return custom_dict[name]
				return super().__getattr__(name)

		def forward(self, x):
				return x @ self.other_attr + self.queue[-1]
```

[ghstack-poisoned]
@pytorch-bot
Copy link
pytorch-bot bot commented Feb 11, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94658

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 Failures

As of commit 1c78962:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

davidberard98 added a commit that referenced this pull request Feb 11, 2023
**Summary**: torch.nn.Module implementations previously did not support custom implementations of `__getattr__`; if a torch.nn.Module subclass implemented `__getattr__` and we tried to access an attribute that was expected to be present in `__getattr__`, dynamo would not check `__getattr__` and would error out with an AttributeError. This PR copies the functionality from UserDefinedObjectVariable into torch.nn.Module so that it also supports `__getattr__`

Example of a module which previously would fail:

```python
class MyMod(torch.nn.Module):
		def __init__(self):
				super().__init__()
				self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
				self.other_attr = torch.rand((2, 2))

		def __getattr__(self, name):
				custom_dict = self.custom_dict
				if name in custom_dict:
						return custom_dict[name]
				return super().__getattr__(name)

		def forward(self, x):
				return x @ self.other_attr + self.queue[-1]
```

ghstack-source-id: da0a0d8
Pull Request resolved: #94658
**Summary**: torch.nn.Module implementations previously did not support custom implementations of `__getattr__`; if a torch.nn.Module subclass implemented `__getattr__` and we tried to access an attribute that was expected to be present in `__getattr__`, dynamo would not check `__getattr__` and would error out with an AttributeError. This PR copies the functionality from UserDefinedObjectVariable into torch.nn.Module so that it also supports `__getattr__`

Example of a module which previously would fail:

```python
class MyMod(torch.nn.Module):
		def __init__(self):
				super().__init__()
				self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
				self.other_attr = torch.rand((2, 2))

		def __getattr__(self, name):
				custom_dict = self.custom_dict
				if name in custom_dict:
						return custom_dict[name]
				return super().__getattr__(name)

		def forward(self, x):
				return x @ self.other_attr + self.queue[-1]
```

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 @EikanWang jgong5 @Guobing-Chen @XiaobingSuper zhuhaozhe blzheng @Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@davidberard98 davidberard98 marked this pull request as ready for review February 13, 2023 21:19
**Summary**: torch.nn.Module implementations previously did not support custom implementations of `__getattr__`; if a torch.nn.Module subclass implemented `__getattr__` and we tried to access an attribute that was expected to be present in `__getattr__`, dynamo would not check `__getattr__` and would error out with an AttributeError. This PR copies the functionality from UserDefinedObjectVariable into torch.nn.Module so that it also supports `__getattr__`

Example of a module which previously would fail:

```python
class MyMod(torch.nn.Module):
		def __init__(self):
				super().__init__()
				self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
				self.other_attr = torch.rand((2, 2))

		def __getattr__(self, name):
				custom_dict = self.custom_dict
				if name in custom_dict:
						return custom_dict[name]
				return super().__getattr__(name)

		def forward(self, x):
				return x @ self.other_attr + self.queue[-1]
```

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 @EikanWang jgong5 @Guobing-Chen @XiaobingSuper zhuhaozhe blzheng @Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Feb 13, 2023
**Summary**: torch.nn.Module implementations previously did not support custom implementations of `__getattr__`; if a torch.nn.Module subclass implemented `__getattr__` and we tried to access an attribute that was expected to be present in `__getattr__`, dynamo would not check `__getattr__` and would error out with an AttributeError. This PR copies the functionality from UserDefinedObjectVariable into torch.nn.Module so that it also supports `__getattr__`

Example of a module which previously would fail:

```python
class MyMod(torch.nn.Module):
		def __init__(self):
				super().__init__()
				self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
				self.other_attr = torch.rand((2, 2))

		def __getattr__(self, name):
				custom_dict = self.custom_dict
				if name in custom_dict:
						return custom_dict[name]
				return super().__getattr__(name)

		def forward(self, x):
				return x @ self.other_attr + self.queue[-1]
```

ghstack-source-id: 56ba12b
Pull Request resolved: #94658
@davidberard98 davidberard98 requested a review from mlazos February 13, 2023 21:27
Copy link
Contributor
@yanboliang yanboliang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let's wait for @jansel's double check.

def _custom_getattr_fallback(self, base, tx, name, options):
"""Check for a __getattr__ and handle it specially if it is implemented"""
if object_has_getattribute(base):
unimplemented("torch.nn.Module with a custom __getattribute__ defined")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest to add a test for this case

**Summary**: torch.nn.Module implementations previously did not support custom implementations of `__getattr__`; if a torch.nn.Module subclass implemented `__getattr__` and we tried to access an attribute that was expected to be present in `__getattr__`, dynamo would not check `__getattr__` and would error out with an AttributeError. This PR copies the functionality from UserDefinedObjectVariable into torch.nn.Module so that it also supports `__getattr__`

Example of a module which previously would fail:

```python
class MyMod(torch.nn.Module):
		def __init__(self):
				super().__init__()
				self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
				self.other_attr = torch.rand((2, 2))

		def __getattr__(self, name):
				custom_dict = self.custom_dict
				if name in custom_dict:
						return custom_dict[name]
				return super().__getattr__(name)

		def forward(self, x):
				return x @ self.other_attr + self.queue[-1]
```

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 @EikanWang jgong5 @Guobing-Chen @XiaobingSuper zhuhaozhe blzheng @Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
**Summary**: torch.nn.Module implementations previously did not support custom implementations of `__getattr__`; if a torch.nn.Module subclass implemented `__getattr__` and we tried to access an attribute that was expected to be present in `__getattr__`, dynamo would not check `__getattr__` and would error out with an AttributeError. This PR copies the functionality from UserDefinedObjectVariable into torch.nn.Module so that it also supports `__getattr__`

Example of a module which previously would fail:

```python
class MyMod(torch.nn.Module):
		def __init__(self):
				super().__init__()
				self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
				self.other_attr = torch.rand((2, 2))

		def __getattr__(self, name):
				custom_dict = self.custom_dict
				if name in custom_dict:
						return custom_dict[name]
				return super().__getattr__(name)

		def forward(self, x):
				return x @ self.other_attr + self.queue[-1]
```

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 @EikanWang jgong5 @Guobing-Chen @XiaobingSuper zhuhaozhe blzheng @Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
**Summary**: torch.nn.Module implementations previously did not support custom implementations of `__getattr__`; if a torch.nn.Module subclass implemented `__getattr__` and we tried to access an attribute that was expected to be present in `__getattr__`, dynamo would not check `__getattr__` and would error out with an AttributeError. This PR copies the functionality from UserDefinedObjectVariable into torch.nn.Module so that it also supports `__getattr__`

Example of a module which previously would fail:

```python
class MyMod(torch.nn.Module):
		def __init__(self):
				super().__init__()
				self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
				self.other_attr = torch.rand((2, 2))

		def __getattr__(self, name):
				custom_dict = self.custom_dict
				if name in custom_dict:
						return custom_dict[name]
				return super().__getattr__(name)

		def forward(self, x):
				return x @ self.other_attr + self.queue[-1]
```

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 @EikanWang jgong5 @Guobing-Chen @XiaobingSuper zhuhaozhe blzheng @Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Feb 14, 2023
**Summary**: torch.nn.Module implementations previously did not support custom implementations of `__getattr__`; if a torch.nn.Module subclass implemented `__getattr__` and we tried to access an attribute that was expected to be present in `__getattr__`, dynamo would not check `__getattr__` and would error out with an AttributeError. This PR copies the functionality from UserDefinedObjectVariable into torch.nn.Module so that it also supports `__getattr__`

Example of a module which previously would fail:

```python
class MyMod(torch.nn.Module):
		def __init__(self):
				super().__init__()
				self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
				self.other_attr = torch.rand((2, 2))

		def __getattr__(self, name):
				custom_dict = self.custom_dict
				if name in custom_dict:
						return custom_dict[name]
				return super().__getattr__(name)

		def forward(self, x):
				return x @ self.other_attr + self.queue[-1]
```

ghstack-source-id: e06c7c9
Pull Request resolved: #94658
@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 15, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_8-cuda11_7-with-pypi-cudnn-test / test

Details for Dev Infra team Raised by workflow job

@davidberard98
Copy link
Contributor Author

@pytorchbot merge -f "multipy error reported in #94751 and shows many similar failures in other PRs; rocm failure is disabled in #93045"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pruthvistony added a commit to ROCm/pytorch that referenced this pull request May 2, 2023
@facebook-github-bot facebook-github-bot deleted the gh/davidberard98/168/head branch June 8, 2023 16:01
jhavukainen pushed a commit to kulinseth/pytorch that referenced this pull request Mar 15, 8000 2024
**Summary**: torch.nn.Module implementations previously did not support custom implementations of `__getattr__`; if a torch.nn.Module subclass implemented `__getattr__` and we tried to access an attribute that was expected to be present in `__getattr__`, dynamo would not check `__getattr__` and would error out with an AttributeError. This PR copies the functionality from UserDefinedObjectVariable into torch.nn.Module so that it also supports `__getattr__`

Example of a module which previously would fail:

```python
class MyMod(torch.nn.Module):
		def __init__(self):
				super().__init__()
				self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
				self.other_attr = torch.rand((2, 2))

		def __getattr__(self, name):
				custom_dict = self.custom_dict
				if name in custom_dict:
						return custom_dict[name]
				return super().__getattr__(name)

		def forward(self, x):
				return x @ self.other_attr + self.queue[-1]
```

Pull Request resolved: pytorch#94658
Approved by: https://github.com/yanboliang, https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0