-
Notifications
You must be signed in to change notification settings - Fork 70
Feature: Batch matmul #88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Transpose is missing tests in TestTensor.fs
I would like to implement the derivative tests to convince myself of the correctness. Let's add a todo comment in the TestDerivative.fs
file to remember the missing test for the time being.
Also it would be great to have some Python reference code.
… into feature/batch-matmul
@gbaydin Can we get this in do you think? |
I've merged this with |
Codecov Report
@@ Coverage Diff @@
## dev #88 +/- ##
==========================================
+ Coverage 67.11% 72.47% +5.36%
==========================================
Files 18 18
Lines 5269 5714 +445
Branches 1296 1325 +29
==========================================
+ Hits 3536 4141 +605
+ Misses 1017 837 -180
- Partials 716 736 +20
|
…ture/batch-matmul
@gbaydin I've merged this with dev, it should now be ready |
I think the behavior of the PyTorch transpose behavior is as follows: a = torch.randn([])
b = torch.randn([10])
c = torch.randn([10,20])
d = torch.randn([10,20,30])
e = torch.randn([10,20,30,40])
at = torch.t(a) # Gives shape []
bt = torch.t(b) # Gives shape [10]
ct = torch.t(c) # Gives shape [20, 10]
dt = torch.t(d) # Fails because d.dim > 2
et = torch.t(e) # Fails because d.dim > 2 DiffSharp behavior in this branch is as follows: let a = dsharp.randn([])
let b = dsharp.randn([10])
let c = dsharp.randn([10;20])
let d = dsharp.randn([10;20;30])
let e = dsharp.randn([10;20;30;40])
let at = a.transpose() // Fails because a.dim < 2
let bt = b.transpose() // Fails because b.dim < 2
let ct = c.transpose() // Gives shape [20; 10]
let dt = d.transpose() // Gives shape [10; 30; 20]
let et = e.transpose() // Gives shape [10; 20; 40; 30] I believe there is no need for a "batch-transpose" operation in general and batch transposition can be achieved when we have the general transpose operation I think this batch-transpose behavior was needed mainly for the reverse mode of batch matrix multiplication. We can either:
I would go with the second option because it improves the api. In both cases we can have the regular |
Yup agreed. There's also |
This might be the simplest option for now, to get the batch matmul implementation and tests consolidated? Then deal with the second issue? |
Ok agreed. So then let's go with renaming the current |
…iffSharp into feature/batch-matmul
These are the default settings and I agree they are too strict. I looked into relaxing them before but it was surprisingly difficult to find the information. Let me look a bit better and fix it. :) |
If you like give me admin rights? I can poke around |
I think you already have it. Codecov uses GitHub permissions https://codecov.io/gh/DiffSharp/DiffSharp/ Please let me know if it doesn't work. |
I think we need to set the project/patch targets to custom values using https://github.com/DiffSharp/DiffSharp/blob/dev/codecov.yml. See here: https://docs.codecov.io/docs/commit-status Edit: I did make some changes in the threshold values. It should now allow coverage to fall by 10% in a PR before failing. I don't know if it will work as intended. I guess we will see. |
Inspecting this further, I have the following concerns about the PR:
I think if we can get I had other comments about the change of Another thing that we need to think about is whether to introduce |
Cool, makes sense |
Some relevant discussions here: pytorch/pytorch#18027 and here: tensorflow/tensorflow#5523 |
I have made and tested this adjustment based on the Pytorch docs for
I haven't done these parts. |
Thank you, I will have a look and merge soon. |
@gbaydin ping :) |
Finally merged this (with a lot of delay!) |
Something went wrong with the overall diff for this in Tensor.fs, which is showing 6000+ line changes (nearly all whitespace) This merge resulted in a complete rewrite of the file, I'm not sure why: 3c94672 I'll try to undo and force push a simpler diff |
Adjust MatMul to support batching and broadcasting, in both Tensor and the reference implementation
Builds on #85