8000 Simpler extension API by gbaydin · Pull Request #311 · DiffSharp/DiffSharp · GitHub
[go: up one dir, main page]

Skip to content

Simpler extension API #311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Apr 2, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
pass precomputed values
  • Loading branch information
gbaydin committed Mar 30, 2021
commit b9bbe7e09c60404b95842d044bfb6d89c0b602f8
38 changes: 19 additions & 19 deletions src/DiffSharp.Core/Tensor.fs
Original file line number Diff line number Diff line change
Expand Up @@ -2804,10 +2804,10 @@ type Tensor =
| AcosT(a) -> push (check(-td / Tensor.Sqrt(1. - a.primal*a.primal), a) :: tt)
| AtanT(a) -> push (check(td / (1. + a.primal*a.primal), a) :: tt)
| NewT -> push tt
| OpUnaryT(a, rev) -> push (check(rev(td, a.primal), a) :: tt)
| OpBinaryTT(a, b, rev) -> let ad, bd = rev(td, a.primal, b.primal) in push (check(ad, a) :: check(bd, b) :: tt)
| OpBinaryTC(a, b, rev) -> let ad = rev(td, a.primal, b) in push (check(ad, a) :: tt)
| OpBinaryCT(a, b, rev) -> let bd = rev(td, a, b.primal) in push (check(bd, b) :: tt)
| OpUnaryT(a, rev) -> push (check(rev(t.primal, td, a.primal), a) :: tt)
8000 | OpBinaryTT(a, b, rev) -> let ad, bd = rev(t.primal, td, a.primal, b.primal) in push (check(ad, a) :: check(bd, b) :: tt)
| OpBinaryTC(a, b, rev) -> let ad = rev(t.primal, td, a.primal, b) in push (check(ad, a) :: tt)
| OpBinaryCT(a, b, rev) -> let bd = rev(t.primal, td, a, b.primal) in push (check(bd, b) :: tt)
else push tt
| _ -> push tt
push [(value, t)]
Expand Down Expand Up @@ -2916,41 +2916,41 @@ and TensorOp =
| AcosT of Tensor
| AtanT of Tensor
| NewT
| OpUnaryT of Tensor*(Tensor*Tensor->Tensor)
| OpBinaryTT of Tensor*Tensor*(Tensor*Tensor*Tensor->Tensor*Tensor)
| OpBinaryTC of Tensor*Tensor*(Tensor*Tensor*Tensor->Tensor)
| OpBinaryCT of Tensor*Tensor*(Tensor*Tensor*Tensor->Tensor)
| OpUnaryT of Tensor*(Tensor*Tensor*Tensor->Tensor)
| OpBinaryTT of Tensor*Tensor*(Tensor*Tensor*Tensor*Tensor->Tensor*Tensor)
| OpBinaryTC of Tensor*Tensor*(Tensor*Tensor*Tensor*Tensor->Tensor)
| OpBinaryCT of Tensor*Tensor*(Tensor*Tensor*Tensor*Tensor->Tensor)


[<AbstractClass>]
type UnaryOp() =
abstract fRaw: RawTensor->RawTensor
abstract df_da: Tensor->Tensor
abstract df_da: Tensor*Tensor->Tensor

[<AbstractClass>]
type BinaryOp() =
abstract fRaw: RawTensor*RawTensor->RawTensor
abstract df_da: Tensor*Tensor->Tensor
abstract df_db: Tensor*Tensor->Tensor
abstract df_da: Tensor*Tensor*Tensor->Tensor
abstract df_db: Tensor*Tensor*Tensor->Tensor


type Tensor with
static member Op(ext: UnaryOp) =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add /// comments

fun a ->
let fRaw = ext.fRaw
let fTensor = Tensor.Op ext
let dfTensorFwd(cp,ap,ad) = ad*ext.df_da(ap)
let dfTensorRev(a) = OpUnaryT(a, (fun (cd,ap) -> cd*ext.df_da(ap)))
let dfTensorFwd(cp,ap,ad) = ad*ext.df_da(ap,cp)
let dfTensorRev(a) = OpUnaryT(a, (fun (cp,cd,ap) -> cd*ext.df_da(ap,cp)))
Tensor.OpUnary(a, fRaw, fTensor, dfTensorFwd, dfTensorRev)

static member Op(ext: BinaryOp) =
fun (a, b) ->
let fRaw = ext.fRaw
let fTensor = Tensor.Op ext
let dfTensorFwdTT(cp,ap,ad,bp,bd) = ad*ext.df_da(ap,bp) + bd*ext.df_db(ap,bp)
let dfTensorFwdTC(cp,ap,ad) = ad*ext.df_da(ap,b)
let dfTensorFwdCT(cp,bp,bd) = bd*ext.df_db(a,bp)
let dfTensorRevTT(a,b) = OpBinaryTT(a, b, (fun (cd,ap,bp) -> (cd*ext.df_da(ap,bp)), (cd*ext.df_db(ap,bp))))
let dfTensorRevTC(a,b) = OpBinaryTC(a, b, (fun (cd,ap,b) -> (cd*ext.df_da(ap,b))))
let dfTensorRevCT(a,b) = OpBinaryCT(a, b, (fun (cd,a,bp) -> (cd*ext.df_db(a,bp))))
let dfTensorFwdTT(cp,ap,ad,bp,bd) = ad*ext.df_da(ap,bp,cp) + bd*ext.df_db(ap,bp,cp)
let dfTensorFwdTC(cp,ap,ad) = ad*ext.df_da(ap,b,cp)
let dfTensorFwdCT(cp,bp,bd) = bd*ext.df_db(a,bp,cp)
let dfTensorRevTT(a,b) = OpBinaryTT(a, b, (fun (cp,cd,ap,bp) -> (cd*ext.df_da(ap,bp,cp)), (cd*ext.df_db(ap,bp,cp))))
let dfTensorRevTC(a,b) = OpBinaryTC(a, b, (fun (cp,cd,ap,b) -> (cd*ext.df_da(ap,b,cp))))
let dfTensorRevCT(a,b) = OpBinaryCT(a, b, (fun (cp,cd,a,bp) -> (cd*ext.df_db(a,bp,cp))))
Tensor.OpBinary(a, b, fRaw, fTensor, dfTensorFwdTT, dfTensorFwdTC, dfTensorFwdCT, dfTensorRevTT, dfTensorRevTC, dfTensorRevCT)
0