8000 feature/extension-ops: user-defined extensions by dsyme · Pull Request #89 · DiffSharp/DiffSharp · GitHub
[go: up one dir, main page]

Skip to content

feature/extension-ops: user-defined extensions #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 44 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
7667054
prototype extension ops
dsyme Feb 26, 2020
270e7d2
conv1d extension example
dsyme Feb 27, 2020
f125503
conv1d extension example
dsyme Feb 27, 2020
e135663
adjust naming
dsyme Feb 28, 2020
c3c067e
integrate dev
dsyme Mar 2, 2020
aab40f5
merge dev
dsyme May 5, 2020
42465db
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
dsyme May 23, 2020
a03bf73
simplify binary extensions
dsyme May 23, 2020
6592c8e
simplify extension api
dsyme May 24, 2020
b00aec4
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
dsyme May 24, 2020
5ab0e27
cleanup
dsyme May 24, 2020
5c2648b
Merge commit 'f8c432b048584947a3cb672e55723dd2170c0671' of https://gi…
dsyme Jun 1, 2020
154c66c
Merge commit '211c5f6ecdbb122f33afdb33196317bf6fcd8407' of https://gi…
dsyme Jun 1, 2020
c27a9ed
merge dev
dsyme Jun 1, 2020
cfd05d0
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
dsyme Sep 7, 2020
70d34a4
integrate dev
Sep 15, 2020
e5e4e41
merge dev
Oct 6, 2020
b12cb1f
merge dev
Oct 15, 2020
3614412
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Nov 11, 2020
559f8b0
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Nov 11, 2020
184d664
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Nov 13, 2020
184164e
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Nov 13, 2020
be1c23d
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Nov 14, 2020
03c91d0
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Nov 16, 2020
2ece5dd
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Nov 20, 2020
72276fa
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Nov 25, 2020
be40af9
update docs
Nov 25, 2020
6c308a2
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Nov 25, 2020
98cd8c2
integrate dev
Nov 26, 2020
757d104
Merge branch 'tsplit1' into feature/extension-ops
Nov 26, 2020
7c02d48
Merge branch 'tsplit1' into feature/extension-ops
Nov 26, 2020
b3f874e
rename file
Nov 26, 2020
4966630
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Nov 26, 2020
37e90ff
allow reverse mode to have access to computed f(a)
Nov 27, 2020
8608ce7
fix comment
Nov 27, 2020
a38b116
fix comment
Nov 27, 2020
5a8ae8c
tidy up
Nov 28, 2020
57ded3e
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Nov 30, 2020
a383c5f
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Nov 30, 2020
331c14c
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Dec 3, 2020
7014b1f
integrate dev
Mar 4, 2021
e587849
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Mar 11, 2021
64dc91c
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Mar 11, 2021
e66f417
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
Mar 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
conv1d extension example
  • Loading branch information
dsyme committed Feb 27, 2020
commit 270e7d2220a7582fa6e13602b15c0ac9933d2cc9
35 changes: 24 additions & 11 deletions src/DiffSharp.Core/Tensor.fs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ type Tensor =
let cp = Tensor.Extend(ap, shape)
TensorR(cp, ref (a.Zero()), MakeTofT0(a), ref 0u, at)

member internal t.GetSlice(bounds:int[,]) =
member t.GetSlice(bounds:int[,]) =
// printfn "t.GetSlice bounds\n %A" bounds
if t.Dim = 0 then failwith "Cannot slice a scalar Tensor"
let fullBounds = Array2D.init t.Dim 3 (fun i j -> if j=0 then 0 elif j=1 then t.Shape.[i]-1 else 0)
Expand Down Expand Up @@ -3770,30 +3770,43 @@ type Tensor with
let bounds = array2D [[i0min; i0max; i0given]; [i1min; i1max; i1given]; [i2min; i2max; i2given]; [i3min; i3max; i3given]; [i4min; i4max; i4given]; [i5min; i5max; i5given]]
t.GetSlice(bounds)

/// Defines an extension implementing a unary function and its gradients
type UnaryExtension =
abstract Raw: ap: RawTensor -> RawTensor
abstract GradForward: fp: Tensor * a: Tensor * ad: Tensor -> Tensor

/// Compute the function f(a)
abstract Compute: a: RawTensor -> RawTensor

/// Compute the forward gradient of function.
abstract GradForward: fa: Tensor * a: Tensor * da: Tensor -> Tensor

/// Compute the reverse gradient (adjoint) of function.
abstract GradReverse: t: Tensor * a: Tensor -> Tensor

/// Defines an extension implementing a binary function and its gradients
type BinaryExtension =
abstract Raw: a: RawTensor * b: RawTensor -> RawTensor
abstract GradForwardTT: fp: Tensor * a: Tensor * ad: Tensor * b: Tensor * bd: Tensor -> Tensor
abstract GradForwardTC: fp: Tensor * a: Tensor * ad: Tensor * b: Tensor -> Tensor
abstract GradForwardCT: fp: Tensor * a: Tensor * b: Tensor * bd: Tensor -> Tensor
/// Compute the function on raw tensors
abstract Compute: a: RawTensor * b: RawTensor -> RawTensor

/// Compute the forward gradient of function.
abstract GradForwardTT: fab: Tensor * a: Tensor * da: Tensor * b: Tensor * db: Tensor -> Tensor
abstract GradForwardTC: fab: Tensor * a: Tensor * da: Tensor * b: Tensor -> Tensor
abstract GradForwardCT: fab: Tensor * a: Tensor * b: Tensor * db: Tensor -> Tensor

/// Compute the reverse gradient (adjoint) of function.
abstract GradReverseTT: t: Tensor * a: Tensor * b: Tensor -> Tensor * Tensor
abstract GradReverseTC: t: Tensor * a: Tensor * b: Tensor -> Tensor
abstract GradReverseCT: t: Tensor * a: Tensor * b: Tensor -> Tensor

type Tensor with
static member UnaryExtension(ext: UnaryExtension) =
static member Extension(ext: UnaryExtension) =
(fun a ->
Tensor.OpUnary(a, ext.Raw, Tensor.UnaryExtension ext, ext.GradForward,
Tensor.OpUnary(a, ext.Compute, Tensor.Extension ext, ext.GradForward,
(fun a -> OpExtensionT([a], (fun t -> [ext.GradReverse (t,a)])))
))

static member BinaryExtension(ext: BinaryExtension) =
static member Extension(ext: BinaryExtension) =
(fun (a, b) ->
Tensor.OpBinary(a, b, ext.Raw, Tensor.BinaryExtension ext,
Tensor.OpBinary(a, b, ext.Compute, Tensor.Extension ext,
ext.GradForwardTT,
(fun (cp,a,ad) -> ext.GradForwardTC(cp,a,ad,b)),
(fun (cp,b,bd) -> ext.GradForwardCT(cp,a,b,bd)),
Expand Down
Loading
0