-
Notifications
You must be signed in to change notification settings - Fork 70
Simpler extension API #311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## dev #311 +/- ##
==========================================
- Coverage 66.23% 65.29% -0.94%
==========================================
Files 5 29 +24
Lines 2393 6658 +4265
Branches 638 1573 +935
==========================================
+ Hits 1585 4347 +2762
- Misses 512 1468 +956
- Partials 296 843 +547
|
Great to see this, it looks fine to me |
src/DiffSharp.Core/Tensor.fs
Outdated
| OpBinaryCT of Tensor*Tensor*(Tensor*Tensor*Tensor*Tensor->Tensor) | ||
|
||
[<AbstractClass>] | ||
type UnaryOp() = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comments and example
type UnaryOp() = | ||
abstract fRaw: a:RawTensor->RawTensor | ||
abstract ad_df_da: a:Tensor*ad:Tensor*f:Tensor->Tensor | ||
abstract fd_df_da: a:Tensor*f:Tensor*fd:Tensor->Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comments to each of these
|
||
|
||
type Tensor with | ||
static member Op(ext: UnaryOp) = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add ///
comments
Note I will replay #252 over this once it's in. Really glad to have this addressed, it's so crucial in the long term |
@dsyme one thing I need advice with is the following. Currently the list of ops in DiffSharp/src/DiffSharp.Core/Tensor.fs Line 2807 in 4c8df05
is serving an important purpose of keeping track of operation history in the computation graph of reverse mode. This is immensely useful and needed for debugging and things like rendering dependency graphs. For example see the code here DiffSharp/src/DiffSharp.Core/Tensor.fs Line 338 in 4c8df05
It's analogous to the PyTorch case below, for example (see import torch
x = torch.tensor([1.,2,3], requires_grad=True)
y =x.sum()
print(y)
When we introduce the extension api, we lose this and all operations will rely on the generic I'm thinking about adding a string to all cases under |
Yes I understand. TBH I think it's best just to add a string to these cases:
and implement a "ToString()" on the type which picks up either that string (for those cases) or the If we move various things out of Tensor.fs into extensions (e.g. move the maxpool out, like avgpool) then we can add the strings at that point. |
This is ready to go. I'm quite happy about how simple and nice this turned out! I think it will be really useful for ongoing work. I'm merging this now thinking that there are other PRs need to be updated to incorporate this. I added some basic documentation here and I will handle the complete documentation including a tutorial about how to create new ops as the next PR. |
This is work in progress to deliver a much simplified extension api. Based on the prototype in #89 which served as a great pathfinder for the design. I'm really happy with the simplicity of the elementwise case. Still a few things need to be added and tested.