diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml
index 070c7cbd7..9fd34fc49 100644
--- a/.github/workflows/build_and_test.yml
+++ b/.github/workflows/build_and_test.yml
@@ -28,9 +28,9 @@ jobs:
- name: Test CPU version
run: dotnet test --no-build --verbosity normal
- name: uninstall redist cpu for unit tests
- run: dotnet remove helpers/Tensorflow.UnitTest.RedistHolder package SciSharp.TensorFlow.Redist
+ run: dotnet remove tools/Tensorflow.UnitTest.RedistHolder package SciSharp.TensorFlow.Redist
- name: install redist gpu for unit tests
- run: dotnet add helpers/Tensorflow.UnitTest.RedistHolder package SciSharp.TensorFlow.Redist-Windows-GPU
+ run: dotnet add tools/Tensorflow.UnitTest.RedistHolder package SciSharp.TensorFlow.Redist-Windows-GPU
- name: Restore dependencies
run: dotnet restore
- name: Build GPU version
@@ -52,12 +52,12 @@ jobs:
run: dotnet restore
- name: Build CPU version
run: dotnet build --no-restore
- # - name: Test CPU version
- # run: dotnet test --no-build --verbosity normal
+ - name: Test CPU version
+ run: dotnet test --no-build --verbosity normal
- name: uninstall redist cpu for unit tests
- run: dotnet remove helpers/Tensorflow.UnitTest.RedistHolder package SciSharp.TensorFlow.Redist
+ run: dotnet remove tools/Tensorflow.UnitTest.RedistHolder package SciSharp.TensorFlow.Redist
- name: install redist gpu for unit tests
- run: dotnet add helpers/Tensorflow.UnitTest.RedistHolder package SciSharp.TensorFlow.Redist-Linux-GPU
+ run: dotnet add tools/Tensorflow.UnitTest.RedistHolder package SciSharp.TensorFlow.Redist-Linux-GPU
- name: Restore dependencies
run: dotnet restore
- name: Build GPU version
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 8f862e329..02601764c 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -53,7 +53,7 @@ jobs:
}
- name: Upload packages artifacts
- uses: actions/upload-artifact@v1.0.0
+ uses: actions/upload-artifact@v4.0.0
with:
name: "drop-ci-packages"
path: './packages'
diff --git a/README.md b/README.md
index c3ffdbaa5..75cad0aa7 100644
--- a/README.md
+++ b/README.md
@@ -2,16 +2,27 @@
**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in C# which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. TensorFlow.NET has built-in Keras high-level interface and is released as an independent package [TensorFlow.Keras](https://www.nuget.org/packages/TensorFlow.Keras/).
+[](https://discord.gg/qRVm82fKTS)
+[](http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=sN9VVMwbWjs5L0ATpizKKxOcZdEPMrp8&authKey=RLDw41bLTrEyEgZZi%2FzT4pYk%2BwmEFgFcrhs8ZbkiVY7a4JFckzJefaYNW6Lk4yPX&noverify=0&group_code=985366726)
[](https://gitter.im/sci-sharp/community)
[](https://github.com/SciSharp/TensorFlow.NET/actions/workflows/build_and_test.yml)
-[](https://www.nuget.org/packages/TensorFlow.NET)
-[](https://www.myget.org/feed/scisharp/package/nuget/Tensorflow.NET)
[](https://tensorflownet.readthedocs.io/en/latest/?badge=latest)
+[](https://www.nuget.org/packages/TensorFlow.NET)
+[](https://www.nuget.org/packages/TensorFlow.Keras)
+[](https://www.myget.org/feed/scisharp/package/nuget/Tensorflow.NET)
[](https://996.icu/#/en_US)
[](https://mybinder.org/v2/gh/javiercp/BinderTF.NET/master?urlpath=lab)
English | [中文](docs/README-CN.md)
+> [!IMPORTANT]
+> We're happy that our work on tensorflow.net has attracted many users. However, at this time, none of the main maintainers of this repo is available for new features and bug fix. We won't refuse PRs and will help to review them.
+>
+> If you would like to be a contributor or maintainer of tensorflow.net, we'd like to help you to start up.
+>
+> We feel sorry for that and we'll resume the maintaining for this project once one of us has bandwidth for it.
+>
+
*master branch and v0.100.x is corresponding to tensorflow v2.10, v0.6x branch is from tensorflow v2.6, v0.15-tensorflow1.15 is from tensorflow1.15. Please add `https://www.myget.org/F/scisharp/api/v3/index.json` to nuget source to use nightly release.*
@@ -58,9 +69,12 @@ PM> Install-Package TensorFlow.Keras
The second part is the computing support part. Only one of the following packages is needed, depending on your device and system.
```
-### CPU version for Windows, Linux and Mac
+### CPU version for Windows and Linux
PM> Install-Package SciSharp.TensorFlow.Redist
+### CPU version for MacOS
+PM> Install-Package SciSharp.TensorFlow.Redist-OSX
+
### GPU version for Windows (CUDA and cuDNN are required)
PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU
@@ -238,9 +252,9 @@ Buy our book to make open source project be sustainable [TensorFlow.NET实战](h
### Contact
-Follow us on [Twitter](https://twitter.com/ScisharpStack), [Facebook](https://www.facebook.com/scisharp.stack.9), [Medium](https://medium.com/scisharp), [LinkedIn](https://www.linkedin.com/company/scisharp-stack/).
+Join our chat on [Discord](https://discord.gg/qRVm82fKTS) or [Gitter](https://gitter.im/sci-sharp/community).
-Join our chat on [Gitter](https://gitter.im/sci-sharp/community).
+Follow us on [Twitter](https://twitter.com/ScisharpStack), [Facebook](https://www.facebook.com/scisharp.stack.9), [Medium](https://medium.com/scisharp), [LinkedIn](https://www.linkedin.com/company/scisharp-stack/).
TensorFlow.NET is a part of [SciSharp STACK](https://scisharp.github.io/SciSharp/)
diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln
index 0c7d6e3c2..e0c273568 100644
--- a/TensorFlow.NET.sln
+++ b/TensorFlow.NET.sln
@@ -5,12 +5,8 @@ VisualStudioVersion = 17.4.33213.308
MinimumVisualStudioVersion = 10.0.40219.1
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}"
EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Benchmark", "src\TensorFlowNet.Benchmarks\Tensorflow.Benchmark.csproj", "{3A6EB896-604F-4E25-B677-B8103BCF3D2E}"
-EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding.UnitTest", "test\TensorFlowNET.UnitTest\Tensorflow.Binding.UnitTest.csproj", "{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}"
EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Console", "src\TensorFlowNET.Console\Tensorflow.Console.csproj", "{03F06299-3F4B-4449-A709-3A647657BC0C}"
-EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras", "src\TensorFlowNET.Keras\Tensorflow.Keras.csproj", "{49D71826-C03D-4FA7-9BAC-22C1327E65CF}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Text", "src\TensorFlowNET.Text\Tensorflow.Text.csproj", "{1AB8108D-4FFE-4A16-88E7-328EAF686370}"
@@ -31,9 +27,21 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{01A1787F-A9B
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{1B0918B9-65AD-4F34-A287-AF4597B27DBD}"
EndProject
-Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "helpers", "helpers", "{E1A5D2B7-10AF-4876-85C0-7714EF274214}"
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{E1A5D2B7-10AF-4876-85C0-7714EF274214}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.CodeGen", "tools\Tensorflow.CodeGen\Tensorflow.CodeGen.csproj", "{3D92142F-EEDB-469B-B03C-4E38728BFE4C}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Redist.NativeLibrarySplitter", "tools\Tensorflow.Redist.NativeLibrarySplitter\Tensorflow.Redist.NativeLibrarySplitter.csproj", "{AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.UnitTest.RedistHolder", "tools\Tensorflow.UnitTest.RedistHolder\Tensorflow.UnitTest.RedistHolder.csproj", "{D24FCAA5-548C-4251-B226-A1B6535D0845}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Benchmark", "tools\TensorFlowNET.Benchmarks\Tensorflow.Benchmark.csproj", "{C23563DB-FE21-48E7-A411-87A109E4A899}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Console", "tools\TensorFlowNET.Console\Tensorflow.Console.csproj", "{1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlow.Kernel.UnitTest", "test\TensorFlow.Kernel.UnitTest\TensorFlow.Kernel.UnitTest.csproj", "{654A027D-1364-4729-880B-144DFE1FF5BB}"
EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.UnitTest.RedistHolder", "helpers\Tensorflow.UnitTest.RedistHolder\Tensorflow.UnitTest.RedistHolder.csproj", "{62D543A2-8846-45A3-829B-5754B094A8E2}"
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tensorflow.UnitTest", "test\Tensorflow.UnitTest\Tensorflow.UnitTest.csproj", "{A73DF5A6-866E-4AED-9017-AA2EE86368C4}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@@ -66,24 +74,6 @@ Global
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.Build.0 = Release|x64
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x86.ActiveCfg = Release|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x86.Build.0 = Release|Any CPU
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.ActiveCfg = Debug|x64
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.Build.0 = Debug|x64
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x86.ActiveCfg = Debug|Any CPU
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x86.Build.0 = Debug|Any CPU
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.GPU|Any CPU.ActiveCfg = Release|Any CPU
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.GPU|Any CPU.Build.0 = Release|Any CPU
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.GPU|x64.ActiveCfg = Release|x64
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.GPU|x64.Build.0 = Release|x64
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.GPU|x86.ActiveCfg = Release|Any CPU
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.GPU|x86.Build.0 = Release|Any CPU
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.ActiveCfg = Release|Any CPU
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.Build.0 = Release|Any CPU
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.ActiveCfg = Release|x64
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.Build.0 = Release|x64
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x86.ActiveCfg = Release|Any CPU
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x86.Build.0 = Release|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.Build.0 = Debug|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.ActiveCfg = Debug|x64
@@ -102,24 +92,6 @@ Global
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.Build.0 = Release|x64
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.ActiveCfg = Release|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.Build.0 = Release|Any CPU
- {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
- {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.ActiveCfg = Debug|x64
- {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.Build.0 = Debug|x64
- {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x86.ActiveCfg = Debug|Any CPU
- {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x86.Build.0 = Debug|Any CPU
- {03F06299-3F4B-4449-A709-3A647657BC0C}.GPU|Any CPU.ActiveCfg = Release|Any CPU
- {03F06299-3F4B-4449-A709-3A647657BC0C}.GPU|Any CPU.Build.0 = Release|Any CPU
- {03F06299-3F4B-4449-A709-3A647657BC0C}.GPU|x64.ActiveCfg = Release|x64
- {03F06299-3F4B-4449-A709-3A647657BC0C}.GPU|x64.Build.0 = Release|x64
- {03F06299-3F4B-4449-A709-3A647657BC0C}.GPU|x86.ActiveCfg = Release|Any CPU
- {03F06299-3F4B-4449-A709-3A647657BC0C}.GPU|x86.Build.0 = Release|Any CPU
- {03F06299-3F4B-4449-A709-3A647657BC0C}.Release|Any CPU.ActiveCfg = Release|Any CPU
- {03F06299-3F4B-4449-A709-3A647657BC0C}.Release|Any CPU.Build.0 = Release|Any CPU
- {03F06299-3F4B-4449-A709-3A647657BC0C}.Release|x64.ActiveCfg = Release|x64
- {03F06299-3F4B-4449-A709-3A647657BC0C}.Release|x64.Build.0 = Release|x64
- {03F06299-3F4B-4449-A709-3A647657BC0C}.Release|x86.ActiveCfg = Release|Any CPU
- {03F06299-3F4B-4449-A709-3A647657BC0C}.Release|x86.Build.0 = Release|Any CPU
{49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Debug|Any CPU.Build.0 = Debug|Any CPU
{49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Debug|x64.ActiveCfg = Debug|x64
@@ -264,33 +236,139 @@ Global
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.Build.0 = Release|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.ActiveCfg = Release|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.Build.0 = Release|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.Debug|x64.ActiveCfg = Debug|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.Debug|x64.Build.0 = Debug|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.Debug|x86.ActiveCfg = Debug|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.Debug|x86.Build.0 = Debug|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.GPU|Any CPU.ActiveCfg = Debug|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.GPU|Any CPU.Build.0 = Debug|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.GPU|x64.ActiveCfg = Debug|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.GPU|x64.Build.0 = Debug|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.GPU|x86.ActiveCfg = Debug|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.GPU|x86.Build.0 = Debug|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.Release|Any CPU.ActiveCfg = Release|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.Release|Any CPU.Build.0 = Release|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.Release|x64.ActiveCfg = Release|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.Release|x64.Build.0 = Release|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.Release|x86.ActiveCfg = Release|Any CPU
- {62D543A2-8846-45A3-829B-5754B094A8E2}.Release|x86.Build.0 = Release|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Debug|x64.Build.0 = Debug|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Debug|x86.ActiveCfg = Debug|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Debug|x86.Build.0 = Debug|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.GPU|Any CPU.ActiveCfg = Debug|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.GPU|Any CPU.Build.0 = Debug|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.GPU|x64.ActiveCfg = Debug|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.GPU|x64.Build.0 = Debug|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.GPU|x86.ActiveCfg = Debug|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.GPU|x86.Build.0 = Debug|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Release|Any CPU.Build.0 = Release|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Release|x64.ActiveCfg = Release|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Release|x64.Build.0 = Release|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Release|x86.ActiveCfg = Release|Any CPU
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C}.Release|x86.Build.0 = Release|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Debug|x64.Build.0 = Debug|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Debug|x86.ActiveCfg = Debug|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Debug|x86.Build.0 = Debug|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.GPU|Any CPU.ActiveCfg = Debug|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.GPU|Any CPU.Build.0 = Debug|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.GPU|x64.ActiveCfg = Debug|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.GPU|x64.Build.0 = Debug|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.GPU|x86.ActiveCfg = Debug|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.GPU|x86.Build.0 = Debug|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Release|Any CPU.Build.0 = Release|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Release|x64.ActiveCfg = Release|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Release|x64.Build.0 = Release|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Release|x86.ActiveCfg = Release|Any CPU
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C}.Release|x86.Build.0 = Release|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.Debug|x64.Build.0 = Debug|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.Debug|x86.ActiveCfg = Debug|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.Debug|x86.Build.0 = Debug|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.GPU|Any CPU.ActiveCfg = Debug|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.GPU|Any CPU.Build.0 = Debug|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.GPU|x64.ActiveCfg = Debug|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.GPU|x64.Build.0 = Debug|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.GPU|x86.ActiveCfg = Debug|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.GPU|x86.Build.0 = Debug|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.Release|Any CPU.Build.0 = Release|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.Release|x64.ActiveCfg = Release|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.Release|x64.Build.0 = Release|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.Release|x86.ActiveCfg = Release|Any CPU
+ {D24FCAA5-548C-4251-B226-A1B6535D0845}.Release|x86.Build.0 = Release|Any CPU
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.Debug|x64.ActiveCfg = Debug|x64
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.Debug|x64.Build.0 = Debug|x64
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.Debug|x86.ActiveCfg = Debug|Any CPU
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.Debug|x86.Build.0 = Debug|Any CPU
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.GPU|Any CPU.ActiveCfg = Debug|Any CPU
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.GPU|Any CPU.Build.0 = Debug|Any CPU
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.GPU|x64.ActiveCfg = Debug|x64
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.GPU|x64.Build.0 = Debug|x64
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.GPU|x86.ActiveCfg = Debug|Any CPU
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.GPU|x86.Build.0 = Debug|Any CPU
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.Release|Any CPU.Build.0 = Release|Any CPU
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.Release|x64.ActiveCfg = Release|x64
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.Release|x64.Build.0 = Release|x64
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.Release|x86.ActiveCfg = Release|Any CPU
+ {C23563DB-FE21-48E7-A411-87A109E4A899}.Release|x86.Build.0 = Release|Any CPU
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Debug|x64.ActiveCfg = Debug|x64
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Debug|x64.Build.0 = Debug|x64
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Debug|x86.ActiveCfg = Debug|Any CPU
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Debug|x86.Build.0 = Debug|Any CPU
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.GPU|Any CPU.ActiveCfg = Debug|Any CPU
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.GPU|Any CPU.Build.0 = Debug|Any CPU
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.GPU|x64.ActiveCfg = Debug|x64
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.GPU|x64.Build.0 = Debug|x64
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.GPU|x86.ActiveCfg = Debug|Any CPU
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.GPU|x86.Build.0 = Debug|Any CPU
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|Any CPU.Build.0 = Release|Any CPU
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|x64.ActiveCfg = Release|x64
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|x64.Build.0 = Release|x64
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|x86.ActiveCfg = Release|Any CPU
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|x86.Build.0 = Release|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x64.Build.0 = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x86.ActiveCfg = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x86.Build.0 = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|Any CPU.ActiveCfg = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|Any CPU.Build.0 = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x64.ActiveCfg = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x64.Build.0 = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x86.ActiveCfg = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x86.Build.0 = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|Any CPU.Build.0 = Release|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x64.ActiveCfg = Release|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x64.Build.0 = Release|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x86.ActiveCfg = Release|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x86.Build.0 = Release|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Debug|x64.Build.0 = Debug|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Debug|x86.ActiveCfg = Debug|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Debug|x86.Build.0 = Debug|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.GPU|Any CPU.ActiveCfg = Debug|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.GPU|Any CPU.Build.0 = Debug|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.GPU|x64.ActiveCfg = Debug|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.GPU|x64.Build.0 = Debug|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.GPU|x86.ActiveCfg = Debug|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.GPU|x86.Build.0 = Debug|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Release|Any CPU.Build.0 = Release|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Release|x64.ActiveCfg = Release|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Release|x64.Build.0 = Release|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Release|x86.ActiveCfg = Release|Any CPU
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4}.Release|x86.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
EndGlobalSection
GlobalSection(NestedProjects) = preSolution
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144} = {01A1787F-A9BE-4221-84E8-6360DD010AB6}
- {3A6EB896-604F-4E25-B677-B8103BCF3D2E} = {E1A5D2B7-10AF-4876-85C0-7714EF274214}
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD}
- {03F06299-3F4B-4449-A709-3A647657BC0C} = {E1A5D2B7-10AF-4876-85C0-7714EF274214}
{49D71826-C03D-4FA7-9BAC-22C1327E65CF} = {01A1787F-A9BE-4221-84E8-6360DD010AB6}
{1AB8108D-4FFE-4A16-88E7-328EAF686370} = {01A1787F-A9BE-4221-84E8-6360DD010AB6}
{F17AAECB-960A-4E18-A270-BAD776F0E55B} = {01A1787F-A9BE-4221-84E8-6360DD010AB6}
@@ -299,7 +377,13 @@ Global
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD}
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18} = {01A1787F-A9BE-4221-84E8-6360DD010AB6}
{7DEA8760-E401-4872-81F3-405F185A13A0} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD}
- {62D543A2-8846-45A3-829B-5754B094A8E2} = {E1A5D2B7-10AF-4876-85C0-7714EF274214}
+ {3D92142F-EEDB-469B-B03C-4E38728BFE4C} = {E1A5D2B7-10AF-4876-85C0-7714EF274214}
+ {AB131FA7-B7C3-4ABF-ABDE-E059C72A613C} = {E1A5D2B7-10AF-4876-85C0-7714EF274214}
+ {D24FCAA5-548C-4251-B226-A1B6535D0845} = {E1A5D2B7-10AF-4876-85C0-7714EF274214}
+ {C23563DB-FE21-48E7-A411-87A109E4A899} = {E1A5D2B7-10AF-4876-85C0-7714EF274214}
+ {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0} = {E1A5D2B7-10AF-4876-85C0-7714EF274214}
+ {654A027D-1364-4729-880B-144DFE1FF5BB} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD}
+ {A73DF5A6-866E-4AED-9017-AA2EE86368C4} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {2DEAD3CC-486B-4918-A607-50B0DE7B114A}
diff --git a/data/img001.bmp b/data/img001.bmp
new file mode 100644
index 000000000..d149d76f1
Binary files /dev/null and b/data/img001.bmp differ
diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs
index 10f678e0a..a91b86827 100644
--- a/src/TensorFlowNET.Core/APIs/c_api.cs
+++ b/src/TensorFlowNET.Core/APIs/c_api.cs
@@ -16,6 +16,7 @@ limitations under the License.
using System;
using System.Runtime.InteropServices;
+using static Tensorflow.CppShapeInferenceResult.Types;
namespace Tensorflow
{
@@ -50,6 +51,35 @@ public static string StringPiece(IntPtr handle)
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
}
+ public unsafe static byte[] ByteStringPiece(Buffer? handle)
+ {
+ if (handle is null)
+ {
+ return new byte[0];
+ }
+ var data = handle.ToArray();
+ return data;
+ }
+
+ public unsafe static byte[] ByteStringPieceFromNativeString(IntPtr handle)
+ {
+ if (handle == IntPtr.Zero)
+ {
+ return new byte[0];
+ }
+
+ byte* str_data = (byte*)handle.ToPointer();
+ List bytes = new List();
+ byte current = 255;
+ while (current != ((byte)'\0'))
+ {
+ current = *(str_data++);
+ bytes.Add(current);
+ }
+ var data = bytes.ToArray();
+ return data;
+ }
+
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args);
diff --git a/src/TensorFlowNET.Core/APIs/c_api.customize.cs b/src/TensorFlowNET.Core/APIs/c_api.customize.cs
index d2aab9ac0..bee4897ee 100644
--- a/src/TensorFlowNET.Core/APIs/c_api.customize.cs
+++ b/src/TensorFlowNET.Core/APIs/c_api.customize.cs
@@ -8,10 +8,10 @@ namespace Tensorflow
public partial class c_api
{
[DllImport(TensorFlowLibName)]
- public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status);
+ public static extern void TF_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
+ public static extern SafeBufferHandle TF_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
[DllImport(TensorFlowLibName)]
- public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
+ public static extern void TF_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs
index a2c91983e..b529cd319 100644
--- a/src/TensorFlowNET.Core/APIs/tf.array.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.array.cs
@@ -44,7 +44,8 @@ public partial class tensorflow
///
///
public Tensor batch_to_space_nd(T input, int[] block_shape, int[,] crops, string name = null)
- => gen_array_ops.batch_to_space_nd(input, block_shape, crops, name: name);
+ => gen_array_ops.batch_to_space_nd(ops.convert_to_tensor(input), ops.convert_to_tensor(block_shape),
+ ops.convert_to_tensor(crops), name: name);
///
/// Apply boolean mask to tensor.
@@ -90,8 +91,7 @@ public Tensor concat(IEnumerable values, int axis, string name = "concat
return identity(values.First(), name: scope);
});
}
-
- return gen_array_ops.concat_v2(values.ToArray(), axis, name: name);
+ return array_ops.concat(values.ToArray(), axis, name: name);
}
///
@@ -115,7 +115,7 @@ public Tensor expand_dims(Tensor input, int axis = -1, string name = null)
///
///
public Tensor fill(Tensor dims, T value, string name = null)
- => gen_array_ops.fill(dims, value, name: name);
+ => gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name);
public Tensor fill(Shape dims, T value, string name = null)
=> array_ops.fill(dims, value, name: name);
@@ -138,7 +138,17 @@ public Tensor identity(Tensor input, string name = null)
///
///
public Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0)
- => array_ops.gather(@params, indices, name: name, axis: axis);
+ => array_ops.gather(@params, indices, name: name, axis: ops.convert_to_tensor(axis));
+
+ ///
+ /// Gather slices from `params` into a Tensor with shape specified by `indices`.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public Tensor gather_nd(Tensor @params, Tensor indices, string name = null)
+ => gen_array_ops.gather_nd(@params, indices, name: name);
///
/// Return the elements, either from `x` or `y`, depending on the `condition`.
@@ -162,14 +172,17 @@ public Tensor transpose(T1 a, Axis perm = null, string name = "transpose", b
/// Reverses specific dimensions of a tensor.
///
///
- ///
+ /// The indices of the dimensions to reverse. Must be in the range [-rank(tensor), rank(tensor)).
///
///
- public Tensor reverse(Tensor tensor, int[] axis, string name = null)
- => gen_array_ops.reverse(tensor, axis, name: name);
-
- public Tensor reverse(Tensor tensor, Tensor axis, string name = null)
- => gen_array_ops.reverse(tensor, axis, name: name);
+ public Tensor reverse(Tensor tensor, Axis axis, string name = null)
+ {
+ if (axis.IsScalar)
+ {
+ axis = new Axis(axis.axis);
+ }
+ return array_ops.reverse(tensor, axis, name: name);
+ }
///
/// Returns the rank of a tensor.
@@ -189,7 +202,8 @@ public Tensor rank(Tensor input, string name = null)
/// A name for the operation (optional).
/// A `Tensor` the same type as `input`.
public Tensor slice(Tensor input, Tb[] begin, Ts[] size, string name = null)
- => array_ops.slice(input, begin, size, name: name);
+ => array_ops.slice(input, begin.Select(x => ops.convert_to_tensor(x)).ToArray(),
+ size.Select(x => ops.convert_to_tensor(x)).ToArray(), name: name);
public Tensor squeeze(Tensor input, int axis, string name = null, int squeeze_dims = -1)
=> array_ops.squeeze(input, new[] { axis }, name);
@@ -255,7 +269,7 @@ public Tensor pad(Tensor tensor, Tensor paddings, string mode = "CONSTANT", stri
/// A name for the operation (optional).
/// A `Tensor`. Has the same type as `input`.
public Tensor placeholder_with_default(T input, int[] shape, string name = null)
- => gen_array_ops.placeholder_with_default(input, shape, name: name);
+ => gen_array_ops.placeholder_with_default(ops.convert_to_tensor(input), shape, name: name);
///
/// Returns the shape of a tensor.
diff --git a/src/TensorFlowNET.Core/APIs/tf.control_flow.cs b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs
index 239487e05..cd5a71e50 100644
--- a/src/TensorFlowNET.Core/APIs/tf.control_flow.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs
@@ -46,10 +46,10 @@ public Tensor while_loop(Func cond,
Tensor loop_vars,
int parallel_iterations = 10)
{
- Func cond1 = x
+ Func cond1 = x
=> cond(x[0]);
- Func body1 = x
+ Func body1 = x
=> new[] { body(x[0]) };
var results = control_flow_ops.while_loop(cond1,
@@ -58,9 +58,9 @@ public Tensor while_loop(Func cond,
return results[0];
}
- public Tensor[] while_loop(Func cond,
- Func body,
- Tensor[] loop_vars,
+ public Tensor[] while_loop(Func cond,
+ Func body,
+ Tensors loop_vars,
int parallel_iterations = 10,
string name = null)
=> control_flow_ops.while_loop(cond, body, loop_vars,
diff --git a/src/TensorFlowNET.Core/APIs/tf.image.cs b/src/TensorFlowNET.Core/APIs/tf.image.cs
index 9230b50dc..41ef52967 100644
--- a/src/TensorFlowNET.Core/APIs/tf.image.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.image.cs
@@ -14,6 +14,10 @@ You may obtain a copy of the License at
limitations under the License.
******************************************************************************/
+using OneOf.Types;
+using System;
+using System.Buffers.Text;
+using Tensorflow.Contexts;
using static Tensorflow.Binding;
namespace Tensorflow
@@ -162,17 +166,108 @@ public Tensor ssim_multiscale(Tensor img1, Tensor img2, float max_val, float[] p
public Tensor sobel_edges(Tensor image)
=> image_ops_impl.sobel_edges(image);
- public Tensor decode_jpeg(Tensor contents,
- int channels = 0,
- int ratio = 1,
- bool fancy_upscaling = true,
- bool try_recover_truncated = false,
- int acceptable_fraction = 1,
- string dct_method = "",
- string name = null)
- => gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio,
- fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated,
- acceptable_fraction: acceptable_fraction, dct_method: dct_method);
+ ///
+ /// Adjust contrast of RGB or grayscale images.
+ ///
+ /// Images to adjust. At least 3-D.
+ ///
+ /// A float multiplier for adjusting contrast.
+ /// The contrast-adjusted image or images.
+ public Tensor adjust_contrast(Tensor images, float contrast_factor, string name = null)
+ => gen_image_ops.adjust_contrastv2(images, contrast_factor, name);
+
+ ///
+ /// Adjust hue of RGB images.
+ ///
+ /// RGB image or images. The size of the last dimension must be 3.
+ /// float. How much to add to the hue channel.
+ /// A name for this operation (optional).
+ /// Adjusted image(s), same shape and DType as `image`.
+ /// if `delta` is not in the interval of `[-1, 1]`.
+ public Tensor adjust_hue(Tensor images, float delta, string name = null)
+ {
+ if (tf.Context.executing_eagerly())
+ {
+ if (delta < -1f || delta > 1f)
+ throw new ValueError("delta must be in the interval [-1, 1]");
+ }
+ return gen_image_ops.adjust_hue(images, delta, name: name);
+ }
+
+ ///
+ /// Adjust saturation of RGB images.
+ ///
+ /// RGB image or images. The size of the last dimension must be 3.
+ /// float. Factor to multiply the saturation by.
+ /// A name for this operation (optional).
+ /// Adjusted image(s), same shape and DType as `image`.
+ public Tensor adjust_saturation(Tensor image, float saturation_factor, string name = null)
+ => gen_image_ops.adjust_saturation(image, saturation_factor, name);
+
+ ///
+ /// Greedily selects a subset of bounding boxes in descending order of score.
+ ///
+ ///
+ /// A 4-D float `Tensor` of shape `[batch_size, num_boxes, q, 4]`. If `q`
+ /// is 1 then same boxes are used for all classes otherwise, if `q` is equal
+ /// to number of classes, class-specific boxes are used.
+ ///
+ ///
+ /// A 3-D float `Tensor` of shape `[batch_size, num_boxes, num_classes]`
+ /// representing a single score corresponding to each box(each row of boxes).
+ ///
+ ///
+ /// A scalar integer `Tensor` representing the
+ /// maximum number of boxes to be selected by non-max suppression per class
+ ///
+ ///
+ /// A int32 scalar representing maximum number of boxes retained
+ /// over all classes.Note that setting this value to a large number may
+ /// result in OOM error depending on the system workload.
+ ///
+ ///
+ /// A float representing the threshold for deciding whether boxes
+ /// overlap too much with respect to IOU.
+ ///
+ ///
+ /// A float representing the threshold for deciding when to
+ /// remove boxes based on score.
+ ///
+ ///
+ /// If false, the output nmsed boxes, scores and classes are
+ /// padded/clipped to `max_total_size`. If true, the output nmsed boxes, scores and classes are padded to be of length `max_size_per_class`*`num_classes`,
+ /// unless it exceeds `max_total_size` in which case it is clipped to `max_total_size`. Defaults to false.
+ ///
+ ///
+ /// If true, the coordinates of output nmsed boxes will be clipped
+ /// to[0, 1]. If false, output the box coordinates as it is. Defaults to true.
+ ///
+ ///
+ /// 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor containing the non-max suppressed boxes.
+ /// 'nmsed_scores': A [batch_size, max_detections] float32 tensor containing the scores for the boxes.
+ /// 'nmsed_classes': A [batch_size, max_detections] float32 tensor containing the class for boxes.
+ /// 'valid_detections': A [batch_size] int32 tensor indicating the number of
+ /// valid detections per batch item. Only the top valid_detections[i] entries
+ /// in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The rest of the
+ /// entries are zero paddings.
+ ///
+ public (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression(
+ Tensor boxes,
+ Tensor scores,
+ int max_output_size_per_class,
+ int max_total_size,
+ float iou_threshold,
+ float score_threshold,
+ bool pad_per_class = false,
+ bool clip_boxes = true)
+ {
+ var iou_threshold_t = ops.convert_to_tensor(iou_threshold, TF_DataType.TF_FLOAT, name: "iou_threshold");
+ var score_threshold_t = ops.convert_to_tensor(score_threshold, TF_DataType.TF_FLOAT, name: "score_threshold");
+ var max_total_size_t = ops.convert_to_tensor(max_total_size);
+ var max_output_size_per_class_t = ops.convert_to_tensor(max_output_size_per_class);
+ return gen_image_ops.combined_non_max_suppression(boxes, scores, max_output_size_per_class_t, max_total_size_t,
+ iou_threshold_t, score_threshold_t, pad_per_class, clip_boxes);
+ }
///
/// Extracts crops from the input image tensor and resizes them using bilinear sampling or nearest neighbor sampling (possibly with aspect ratio change) to a common output size specified by crop_size. This is more general than the crop_to_bounding_box op which extracts a fixed size slice from the input image and does not allow resizing or aspect ratio change.
@@ -187,7 +282,19 @@ public Tensor decode_jpeg(Tensor contents,
/// A name for the operation (optional).
/// A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth].
public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) =>
- image_ops_impl.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name);
+ gen_image_ops.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name);
+
+ public Tensor decode_jpeg(Tensor contents,
+ int channels = 0,
+ int ratio = 1,
+ bool fancy_upscaling = true,
+ bool try_recover_truncated = false,
+ int acceptable_fraction = 1,
+ string dct_method = "",
+ string name = null)
+ => gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio,
+ fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated,
+ acceptable_fraction: acceptable_fraction, dct_method: dct_method);
public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool centered = true, bool normalized = true,
bool uniform_noise = true, string name = null)
@@ -232,6 +339,13 @@ public Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype
=> image_ops_impl.decode_image(contents, channels: channels, dtype: dtype,
name: name, expand_animations: expand_animations);
+ public Tensor encode_png(Tensor contents, string name = null)
+ => image_ops_impl.encode_png(contents, name: name);
+
+ public Tensor encode_jpeg(Tensor contents, string name = null)
+ => image_ops_impl.encode_jpeg(contents, name: name);
+
+
///
/// Convenience function to check if the 'contents' encodes a JPEG image.
///
diff --git a/src/TensorFlowNET.Core/APIs/tf.io.cs b/src/TensorFlowNET.Core/APIs/tf.io.cs
index be1e86e6c..ea1e44b28 100644
--- a/src/TensorFlowNET.Core/APIs/tf.io.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.io.cs
@@ -16,6 +16,7 @@ limitations under the License.
using System.Collections.Generic;
using Tensorflow.IO;
+using Tensorflow.Operations;
namespace Tensorflow
{
@@ -46,6 +47,12 @@ public Operation save_v2(Tensor prefix, string[] tensor_names,
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names,
string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
=> ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name);
+
+ public Operation write_file(string filename, Tensor conentes, string name = null)
+ => write_file(Tensorflow.ops.convert_to_tensor(filename, TF_DataType.TF_STRING), conentes, name);
+
+ public Operation write_file(Tensor filename, Tensor conentes, string name = null)
+ => gen_ops.write_file(filename, conentes, name);
}
public GFile gfile = new GFile();
diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index 83653c8bb..da54a9dd7 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
limitations under the License.
******************************************************************************/
+using Tensorflow.NumPy;
using Tensorflow.Operations;
namespace Tensorflow
@@ -42,10 +43,20 @@ public Tensor erf(Tensor x, string name = null)
public Tensor multiply(Tensor x, Tensor y, string name = null)
=> math_ops.multiply(x, y, name: name);
-
public Tensor divide_no_nan(Tensor a, Tensor b, string name = null)
=> math_ops.div_no_nan(a, b);
+ ///
+ /// Computes the Euclidean norm of elements across dimensions of a tensor.
+ ///
+ /// The tensor to reduce. Should have numeric type.
+ /// The dimensions to reduce. If `None` (the default), reduces all dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`
+ /// If true, retains reduced dimensions with length 1.
+ /// A name for the operation (optional).
+ /// The reduced tensor, of the same dtype as the input_tensor.
+ public Tensor reduce_euclidean_norm(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
+ => math_ops.reduce_euclidean_norm(input_tensor, axis: axis, keepdims: keepdims, name);
+
public Tensor square(Tensor x, string name = null)
=> math_ops.square(x, name: name);
@@ -130,7 +141,7 @@ public Tensor add(Tensor a, Tensor b, string name = null)
=> gen_math_ops.add(a, b, name: name);
public Tensor add(Tx a, Ty b, string name = null)
- => gen_math_ops.add(a, b, name: name);
+ => gen_math_ops.add(ops.convert_to_tensor(a), ops.convert_to_tensor(b), name: name);
///
/// Adds all input tensors element-wise.
@@ -151,10 +162,10 @@ public Tensor atan(Tensor x, string name = null)
=> gen_math_ops.atan(x, name);
public Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null)
- => gen_math_ops.arg_max(input, dimension, output_type: output_type, name: name);
+ => gen_math_ops.arg_max(input, ops.convert_to_tensor(dimension), output_type: output_type, name: name);
public Tensor arg_min(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null)
- => gen_math_ops.arg_min(input, dimension, output_type: output_type, name: name);
+ => gen_math_ops.arg_min(input, ops.convert_to_tensor(dimension), output_type: output_type, name: name);
public Tensor is_finite(Tensor input, string name = null)
=> gen_math_ops.is_finite(input, name);
@@ -199,7 +210,7 @@ public Tensor cos(Tensor x, string name = null)
=> gen_math_ops.cos(x, name);
public Tensor cos(float x, string name = null)
- => gen_math_ops.cos(x, name);
+ => gen_math_ops.cos(ops.convert_to_tensor(x), name);
///
/// Computes hyperbolic cosine of x element-wise.
@@ -235,7 +246,7 @@ public Tensor floor(Tensor x, string name = null)
///
///
public Tensor greater(Tx x, Ty y, string name = null)
- => gen_math_ops.greater(x, y, name);
+ => gen_math_ops.greater(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name);
///
/// Returns the truth value of (x >= y) element-wise.
@@ -247,7 +258,7 @@ public Tensor greater(Tx x, Ty y, string name = null)
///
///
public Tensor greater_equal(Tx x, Ty y, string name = null)
- => gen_math_ops.greater_equal(x, y, name);
+ => gen_math_ops.greater_equal(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name);
///
/// Returns the truth value of (x < y) element-wise.
@@ -259,7 +270,7 @@ public Tensor greater_equal(Tx x, Ty y, string name = null)
///
///
public Tensor less(Tx x, Ty y, string name = null)
- => gen_math_ops.less(x, y, name);
+ => gen_math_ops.less(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name);
///
/// Computes the log of the absolute value of `Gamma(x)` element-wise.
@@ -280,7 +291,7 @@ public Tensor lgamma(Tensor x, string name = null)
///
///
public Tensor less_equal(Tx x, Ty y, string name = null)
- => gen_math_ops.less_equal(x, y, name);
+ => gen_math_ops.less_equal(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name);
///
/// Computes natural logarithm of (1 + x) element-wise.
@@ -292,7 +303,7 @@ public Tensor log1p(Tensor x, string name = null)
=> gen_math_ops.log1p(x, name);
public Tensor logical_and(T x, T y, string name = null)
- => gen_math_ops.logical_and(x, y, name);
+ => gen_math_ops.logical_and(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name);
public Tensor logical_not(Tensor x, string name = null)
=> gen_math_ops.logical_not(x, name);
@@ -301,7 +312,10 @@ public Tensor logical_or(Tensor x, Tensor y, string name = null)
=> gen_math_ops.logical_or(x, y, name);
public Tensor logical_xor(Tensor x, Tensor y, string name = "LogicalXor")
- => gen_math_ops.logical_xor(x, y, name);
+ {
+ return gen_math_ops.logical_and(gen_math_ops.logical_or(x, y),
+ gen_math_ops.logical_not(gen_math_ops.logical_and(x, y)), name);
+ }
///
/// Clips tensor values to a specified min and max.
@@ -312,7 +326,7 @@ public Tensor logical_xor(Tensor x, Tensor y, string name = "LogicalXor")
///
///
public Tensor _clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = null)
- => gen_math_ops._clip_by_value(t, clip_value_min, clip_value_max);
+ => gen_math_ops.clip_by_value(t, clip_value_min, clip_value_max);
///
/// Clips tensor values to a specified min and max.
@@ -345,13 +359,13 @@ public Tensor clip_by_value(Tensor t, T1 clip_value_min, T2 clip_value_m
=> clip_ops.clip_by_value(t, clip_value_min, clip_value_max, name);
public Tensor sub(Tx a, Ty b, string name = null)
- => gen_math_ops.sub(a, b, name: name);
+ => gen_math_ops.sub(ops.convert_to_tensor(a), ops.convert_to_tensor(b), name: name);
public Tensor divide(Tensor a, Tensor b)
=> a / b;
public Tensor sqrt(Tensor a, string name = null)
- => gen_math_ops.sqrt(a, name);
+ => math_ops.sqrt(a, name);
public Tensor sign(Tensor a, string name = null)
=> gen_math_ops.sign(a, name);
@@ -396,7 +410,7 @@ public Tensor atan2(Tensor y, Tensor x, string name = null)
///
///
public Tensor max(Tx input, Ty axis, bool keep_dims = false, string name = null)
- => gen_math_ops._max(input, axis, keep_dims: keep_dims, name: name);
+ => gen_math_ops.max(ops.convert_to_tensor(input), ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name);
///
/// Computes the minimum of elements across dimensions of a tensor.
@@ -409,7 +423,7 @@ public Tensor max(Tx input, Ty axis, bool keep_dims = false, string name
///
///
public Tensor min(Tx input, Ty axis, bool keep_dims = false, string name = null)
- => gen_math_ops._min(input, axis, keep_dims: keep_dims, name: name);
+ => gen_math_ops.min(ops.convert_to_tensor(input), ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name);
///
/// Returns the max of x and y (i.e. x > y ? x : y) element-wise.
@@ -421,7 +435,7 @@ public Tensor min(Tx input, Ty axis, bool keep_dims = false, string name
///
///
public Tensor maximum(T1 x, T2 y, string name = null)
- => gen_math_ops.maximum(x, y, name: name);
+ => gen_math_ops.maximum(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name);
///
/// Returns the min of x and y (i.e. x < y ? x : y) element-wise.
@@ -433,7 +447,7 @@ public Tensor maximum(T1 x, T2 y, string name = null)
///
///
public Tensor minimum(T1 x, T2 y, string name = null)
- => gen_math_ops.minimum(x, y, name: name);
+ => gen_math_ops.minimum(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name);
public Tensor multiply(Tensor x, Tensor y, string name = null)
=> gen_math_ops.mul(x, y, name: name);
@@ -448,8 +462,19 @@ public Tensor multiply(Tensor x, Tensor y, string name = null)
///
///
public Tensor multiply(Tx x, Ty y, string name = null)
- => gen_math_ops.mul(x, y, name: name);
-
+ => gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name);
+ ///
+ /// return scalar product
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public Tensor dot_prod(Tx x, Ty y, NDArray axes, string name = null)
+ => math_ops.tensordot(convert_to_tensor(x), convert_to_tensor(y), axes, name: name);
public Tensor negative(Tensor x, string name = null)
=> gen_math_ops.neg(x, name);
@@ -577,7 +602,7 @@ public Tensor sigmoid(T x, string name = null)
=> math_ops.sigmoid(x, name: name);
public Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null)
- => gen_math_ops._sum(input, axis, keep_dims: keep_dims, name: name);
+ => gen_math_ops.sum(input, ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name);
public Tensor reduce_mean(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices);
@@ -597,5 +622,7 @@ public Tensor squared_difference(Tensor x, Tensor y, string name = null)
=> gen_math_ops.squared_difference(x: x, y: y, name: name);
public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null,
string name = null) => gen_ops.complex(real, imag, dtype, name);
+ public Tensor exp(Tensor x,
+ string name = null) => gen_math_ops.exp(x, name);
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index 1595e52fc..112c48628 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
limitations under the License.
******************************************************************************/
+using System.Xml.Linq;
using Tensorflow.Operations;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;
@@ -29,21 +30,8 @@ public class nn_internal
public Tensor conv2d(Tensor input, Tensor filter, int[] strides, string padding, bool use_cudnn_on_gpu = true,
string data_format = "NHWC", int[] dilations = null, string name = null)
{
- var parameters = new Conv2dParams
- {
- Input = input,
- Filter = filter,
- Strides = strides,
- Padding = padding,
- UseCudnnOnGpu = use_cudnn_on_gpu,
- DataFormat = data_format,
- Name = name
- };
-
- if (dilations != null)
- parameters.Dilations = dilations;
-
- return gen_nn_ops.conv2d(parameters);
+ return gen_nn_ops.conv2d(input, filter, strides, padding, use_cudnn_on_gpu,
+ data_format: data_format, dilations: dilations, name: name);
}
public Tensor[] ctc_greedy_decoder(Tensor inputs, Tensor sequence_length, bool merge_repeated = true, string name = null)
@@ -113,16 +101,21 @@ public Tensor embedding_lookup(Tensor @params,
name: name);
public IActivation relu() => new relu();
+
+
public IActivation swish() => new swish();
public IActivation tanh() => new tanh();
public IActivation softmax() => new softmax();
public Tensor tanh(Tensor x, string name = null)
- => gen_nn_ops.tanh(x, name);
+ => gen_math_ops.tanh(x, name);
public Tensor relu(Tensor features, string name = null)
=> gen_nn_ops.relu(features, name);
+ public Tensor relu6(Tensor features, string name = null)
+ => gen_nn_ops.relu6(features, name);
+
public Tensor[] fused_batch_norm(Tensor x,
Tensor scale,
Tensor offset,
@@ -139,6 +132,26 @@ public Tensor[] fused_batch_norm(Tensor x,
name: name,
exponential_avg_factor: exponential_avg_factor);
+ ///
+ /// Normalizes a tensor by `mean` and `variance`, and applies (optionally) a`scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\).
+ ///
+ /// A floating point tensor.
+ /// A mean `Tensor`.
+ /// A variance `Tensor`.
+ /// An offset `Tensor`, often denoted \\(\beta\\) in equations, or NULL. If present, will be added to the normalized tensor.
+ /// A scale `Tensor`, often denoted \\(\gamma\\) in equations, or NULL. If present, the scale is applied to the normalized tensor.
+ /// A small float number to avoid dividing by 0.
+ /// A name for this operation.
+ /// the normalized, scaled, offset tensor.
+ public Tensor batch_normalization(Tensor x,
+ Tensor mean,
+ Tensor variance,
+ Tensor offset,
+ Tensor scale,
+ float variance_epsilon,
+ string name = null) => nn_impl.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name);
+
+
public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null)
=> nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name);
@@ -146,14 +159,14 @@ public Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name =
=> nn_ops.in_top_k(predictions, targets, k, name);
public Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null)
- => gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name);
+ => gen_nn_ops.top_kv2(input, k: ops.convert_to_tensor(k), sorted: sorted, name: name);
public Tensor bias_add(Tensor value, IVariableV1 bias, string data_format = null, string name = null)
{
return tf_with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope =>
{
name = scope;
- return gen_nn_ops.bias_add(value, bias, data_format: data_format, name: name);
+ return gen_nn_ops.bias_add(value, ops.convert_to_tensor(bias), data_format: data_format, name: name);
});
}
@@ -172,7 +185,7 @@ public Tensor l2_loss(Tensor t, string name = null)
///
public Tensor lrn(Tensor input, int depth_radius = 5, int bias = 1,
int alpha = 1, float beta = 0.5f, string name = null)
- => gen_nn_ops.local_response_normalization(input, depth_radius: depth_radius, bias: bias,
+ => gen_nn_ops.lrn(input, depth_radius: depth_radius, bias: bias,
alpha: alpha, beta: beta, name: name);
public Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null)
diff --git a/src/TensorFlowNET.Core/APIs/tf.reshape.cs b/src/TensorFlowNET.Core/APIs/tf.reshape.cs
index cdd5194a2..102a81323 100644
--- a/src/TensorFlowNET.Core/APIs/tf.reshape.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.reshape.cs
@@ -31,6 +31,6 @@ public Tensor reshape(Tensor tensor,
public Tensor reshape(Tensor tensor,
object[] shape,
string name = null)
- => gen_array_ops.reshape(tensor, shape, name);
+ => array_ops.reshape(tensor, shape, name);
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.tensor.cs b/src/TensorFlowNET.Core/APIs/tf.tensor.cs
index 35efde06b..b03168ab3 100644
--- a/src/TensorFlowNET.Core/APIs/tf.tensor.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.tensor.cs
@@ -46,10 +46,10 @@ public Tensor strided_slice(Tensor input, T[] begin, T[] end, T[] strides = n
int ellipsis_mask = 0,
int new_axis_mask = 0,
int shrink_axis_mask = 0,
- string name = null) => gen_array_ops.strided_slice(input: input,
- begin: begin,
- end: end,
- strides: strides,
+ string name = null) => array_ops.strided_slice(input,
+ begin: ops.convert_to_tensor(begin),
+ end: ops.convert_to_tensor(end),
+ strides: ops.convert_to_tensor(strides),
begin_mask: begin_mask,
end_mask: end_mask,
ellipsis_mask: ellipsis_mask,
@@ -68,20 +68,27 @@ public Tensor strided_slice(Tensor input, T[] begin, T[] end, T[] strides = n
/// A name for the operation (optional)
/// if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects;
/// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.
- public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null)
+ public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null)
=> array_ops.split(
value: value,
- num_split: num_split,
+ num_or_size_splits: num_split,
axis: axis,
name: name);
- public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
+ public Tensor[] split(Tensor value, int[] num_split, Axis axis, string name = null)
=> array_ops.split(
value: value,
- num_split: num_split,
+ num_or_size_splits: num_split,
axis: axis,
name: name);
+ //public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null)
+ // => array_ops.split(
+ // value: value,
+ // num_or_size_splits: num_split,
+ // axis: axis,
+ // name: name);
+
public Tensor ensure_shape(Tensor x, Shape shape, string name = null)
{
return gen_ops.ensure_shape(x, shape, name);
diff --git a/src/TensorFlowNET.Core/APIs/tf.tile.cs b/src/TensorFlowNET.Core/APIs/tf.tile.cs
index be03e453c..a3b497e8a 100644
--- a/src/TensorFlowNET.Core/APIs/tf.tile.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.tile.cs
@@ -23,7 +23,7 @@ public Tensor tile(Tensor input, Tensor multiples, string name = null)
=> gen_array_ops.tile(input, multiples, name);
public Tensor tile(Tensor input, object[] multiples, string name = null)
- => gen_array_ops.tile(input, multiples, name);
+ => array_ops.tile(input, constant_op.constant(shape_utils.from_object_array(multiples).dims), name);
public Tensor tile(Tensor input, Shape multiples, string name = null)
{
diff --git a/src/TensorFlowNET.Core/Attributes/c_api.ops.cs b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs
index 2a22413b0..ba6f653a1 100644
--- a/src/TensorFlowNET.Core/Attributes/c_api.ops.cs
+++ b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs
@@ -57,6 +57,21 @@ public partial class c_api
[DllImport(TensorFlowLibName)]
public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, SafeBufferHandle output_attr_value, SafeStatusHandle status);
+ [DllImport(TensorFlowLibName)]
+ public static extern void TF_OperationGetAttrType(IntPtr oper, string attr_name, IntPtr value, SafeStatusHandle status);
+
+ [DllImport(TensorFlowLibName)]
+ public static extern void TF_OperationGetAttrInt(IntPtr oper, string attr_name, IntPtr value, SafeStatusHandle status);
+
+ [DllImport(TensorFlowLibName)]
+ public static extern void TF_OperationGetAttrFloat(IntPtr oper, string attr_name, IntPtr value, SafeStatusHandle status);
+
+ [DllImport(TensorFlowLibName)]
+ public static extern void TF_OperationGetAttrBool(IntPtr oper, string attr_name, IntPtr value, SafeStatusHandle status);
+
+ [DllImport(TensorFlowLibName)]
+ public static extern void TF_OperationGetAttrShape(IntPtr oper, string attr_name, long[] value, int num_dims, SafeStatusHandle status);
+
[DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value);
diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs
index 8df39334a..99ed5c1f3 100644
--- a/src/TensorFlowNET.Core/Binding.Util.cs
+++ b/src/TensorFlowNET.Core/Binding.Util.cs
@@ -486,7 +486,28 @@ public static Shape GetShape(this object data)
throw new NotImplementedException("");
}
}
-
+ public static NDArray GetFlattenArray(NDArray x)
+ {
+ switch (x.GetDataType())
+ {
+ case TF_DataType.TF_FLOAT:
+ x = x.ToArray();
+ break;
+ case TF_DataType.TF_DOUBLE:
+ x = x.ToArray();
+ break;
+ case TF_DataType.TF_INT16:
+ case TF_DataType.TF_INT32:
+ x = x.ToArray();
+ break;
+ case TF_DataType.TF_INT64:
+ x = x.ToArray();
+ break;
+ default:
+ break;
+ }
+ return x;
+ }
public static TF_DataType GetDataType(this object data)
{
var type = data.GetType();
@@ -503,7 +524,7 @@ public static TF_DataType GetDataType(this object data)
case Tensors tensors:
return tensors.dtype;
case IEnumerable tensors:
- return tensors.First().dtype;
+ return tensors.Where(x => x is not null).First().dtype;
case RefVariable variable:
return variable.dtype;
case ResourceVariable variable:
diff --git a/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs b/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs
index adb26ef29..1b295fcfd 100644
--- a/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs
+++ b/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs
@@ -88,7 +88,7 @@ private Tensor _initialize()
public Tensor op()
{
- var x = control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0),
+ var x = control_flow_ops.cond(gen_math_ops.equal(_num_remaining, ops.convert_to_tensor(0)),
() =>
{
return check_ops.assert_equal(_cluster_centers_initialized, true);
diff --git a/src/TensorFlowNET.Core/Extensions/DictionaryExtension.cs b/src/TensorFlowNET.Core/Common/Extensions/DictionaryExtension.cs
similarity index 100%
rename from src/TensorFlowNET.Core/Extensions/DictionaryExtension.cs
rename to src/TensorFlowNET.Core/Common/Extensions/DictionaryExtension.cs
diff --git a/src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs b/src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs
similarity index 80%
rename from src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs
rename to src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs
index 2e758dbf1..6ceba445a 100644
--- a/src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs
+++ b/src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs
@@ -3,16 +3,16 @@
using System.Collections.Generic;
using System.Text;
-namespace Tensorflow.Extensions
+namespace Tensorflow.Common.Extensions
{
public static class JObjectExtensions
{
public static T? TryGetOrReturnNull(this JObject obj, string key)
{
var res = obj[key];
- if(res is null)
+ if (res is null)
{
- return default(T);
+ return default;
}
else
{
diff --git a/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs b/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs
new file mode 100644
index 000000000..287b48cc3
--- /dev/null
+++ b/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs
@@ -0,0 +1,38 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Tensorflow.Common.Extensions
+{
+ public static class LinqExtensions
+ {
+#if NETSTANDARD2_0
+ public static IEnumerable TakeLast(this IEnumerable sequence, int count)
+ {
+ return sequence.Skip(sequence.Count() - count);
+ }
+
+ public static IEnumerable SkipLast(this IEnumerable sequence, int count)
+ {
+ return sequence.Take(sequence.Count() - count);
+ }
+#endif
+ public static Tensors ToTensors(this Tensor[] tensors)
+ {
+ return new Tensors(tensors);
+ }
+
+ public static Tensors ToTensors(this IList tensors)
+ {
+ return new Tensors(tensors);
+ }
+
+ public static void Deconstruct(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third)
+ {
+ first = values.Item1;
+ second = values.Item2;
+ third = values.Item3;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs b/src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs
new file mode 100644
index 000000000..76bdd6133
--- /dev/null
+++ b/src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs
@@ -0,0 +1,33 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Common.Types;
+
+namespace Tensorflow.Common.Extensions
+{
+ public static class NestExtensions
+ {
+ public static Tensors ToTensors(this INestable tensors)
+ {
+ return new Tensors(tensors.AsNest());
+ }
+
+ public static Tensors? ToTensors(this Nest tensors)
+ {
+ return Tensors.FromNest(tensors);
+ }
+
+ ///
+ /// If the nested object is already a nested type, this function could reduce it.
+ /// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Nest ReduceTo(this INestStructure input) where TIn: INestStructure
+ {
+ return Nest.ReduceFrom(input);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Extensions/OneofExtension.cs b/src/TensorFlowNET.Core/Common/Extensions/OneofExtension.cs
similarity index 100%
rename from src/TensorFlowNET.Core/Extensions/OneofExtension.cs
rename to src/TensorFlowNET.Core/Common/Extensions/OneofExtension.cs
diff --git a/src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs b/src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs
new file mode 100644
index 000000000..d0c35ee70
--- /dev/null
+++ b/src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs
@@ -0,0 +1,20 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Common.Types
+{
+ ///
+ /// This is a temp solution, which should be removed after refactoring `Tensors`
+ ///
+ [Obsolete]
+ public class FakeTensorByTensorArray: Tensor
+ {
+ public TensorArray TensorArray { get; set; }
+
+ public FakeTensorByTensorArray(TensorArray array)
+ {
+ TensorArray = array;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
new file mode 100644
index 000000000..986136f4d
--- /dev/null
+++ b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
@@ -0,0 +1,69 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Text;
+
+namespace Tensorflow.Common.Types
+{
+ public class GeneralizedTensorShape: Nest
+ {
+ public GeneralizedTensorShape(Shape value, string? name = null)
+ {
+ NodeValue = value;
+ NestType = NestType.Node;
+ }
+
+ public GeneralizedTensorShape(IEnumerable values, string? name = null)
+ {
+ ListValue = values.Select(s => new Nest(s) as INestStructure).ToList();
+ Name = name;
+ NestType = NestType.List;
+ }
+
+ public GeneralizedTensorShape(Dictionary value, string? name = null)
+ {
+ DictValue = value.ToDictionary(x => x.Key, x => new Nest(x.Value) as INestStructure);
+ Name = name;
+ NestType = NestType.Dictionary;
+ }
+
+ public GeneralizedTensorShape(Nest other)
+ {
+ NestType = other.NestType;
+ NodeValue = other.NodeValue;
+ DictValue = other.DictValue;
+ ListValue = other.ListValue;
+ Name = other.Name;
+ }
+
+ public Shape ToSingleShape()
+ {
+ var shapes = Flatten().ToList();
+ if (shapes.Count != 1)
+ {
+ throw new ValueError("The generalized shape contains more than 1 dim.");
+ }
+ return shapes[0];
+ }
+
+ public long ToNumber()
+ {
+ var shapes = Flatten().ToList();
+ if (shapes.Count != 1 || shapes[0].ndim != 1)
+ {
+ throw new ValueError("The generalized shape contains more than 1 dim.");
+ }
+ return shapes[0].dims[0];
+ }
+
+ public INestStructure ToTensorShapeConfigs()
+ {
+ return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select(x => x == -1 ? null : x).ToArray() });
+ }
+
+ public static implicit operator GeneralizedTensorShape(Shape shape)
+ {
+ return new GeneralizedTensorShape(shape);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Common/Types/INestStructure.cs b/src/TensorFlowNET.Core/Common/Types/INestStructure.cs
new file mode 100644
index 000000000..32b662937
--- /dev/null
+++ b/src/TensorFlowNET.Core/Common/Types/INestStructure.cs
@@ -0,0 +1,40 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Common.Types
+{
+ ///
+ /// This interface indicates that a class may have a nested structure and provide
+ /// methods to manipulate with the structure.
+ ///
+ public interface INestStructure: INestable
+ {
+ NestType NestType { get; }
+
+ ///
+ /// The item count of depth 1 of the nested structure.
+ /// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3.
+ ///
+ int ShallowNestedCount { get; }
+ ///
+ /// The total item count of depth 1 of the nested structure.
+ /// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5.
+ ///
+ int TotalNestedCount { get; }
+
+ ///
+ /// Flatten the Nestable object. Node that if the object contains only one value,
+ /// it will be flattened to an enumerable with one element.
+ ///
+ ///
+ IEnumerable Flatten();
+ ///
+ /// Construct a new object with the same nested structure.
+ ///
+ ///
+ ///
+ ///
+ INestStructure MapStructure(Func func);
+ }
+}
diff --git a/src/TensorFlowNET.Core/Common/Types/INestable.cs b/src/TensorFlowNET.Core/Common/Types/INestable.cs
new file mode 100644
index 000000000..7ce49f85a
--- /dev/null
+++ b/src/TensorFlowNET.Core/Common/Types/INestable.cs
@@ -0,0 +1,11 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Common.Types
+{
+ public interface INestable
+ {
+ Nest AsNest();
+ }
+}
diff --git a/src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs b/src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs
new file mode 100644
index 000000000..427e71aaa
--- /dev/null
+++ b/src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs
@@ -0,0 +1,21 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Common.Types
+{
+ ///
+ /// This interface is used when some corresponding python methods have optional args.
+ /// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while
+ /// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs`
+ /// as the parameter of the method.
+ ///
+ public interface IOptionalArgs
+ {
+ ///
+ /// The identifier of the class. It is not an argument but only something to
+ /// separate different OptionalArgs.
+ ///
+ string Identifier { get; }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Extensions/NamedTuple.cs b/src/TensorFlowNET.Core/Common/Types/NamedTuple.cs
similarity index 100%
rename from src/TensorFlowNET.Core/Extensions/NamedTuple.cs
rename to src/TensorFlowNET.Core/Common/Types/NamedTuple.cs
diff --git a/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs b/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs
new file mode 100644
index 000000000..dc7fd3a1f
--- /dev/null
+++ b/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs
@@ -0,0 +1,62 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Common.Types
+{
+ public static class Nest
+ {
+ ///
+ /// Pack the flat items to a nested sequence by the template.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Nest PackSequenceAs(INestable template, TOut[] flatItems)
+ {
+ return template.AsNest().PackSequence(flatItems);
+ }
+
+ ///
+ /// Pack the flat items to a nested sequence by the template.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Nest PackSequenceAs(INestable template, List flatItems)
+ {
+ return template.AsNest().PackSequence(flatItems.ToArray());
+ }
+
+ ///
+ /// Flatten the nested object.
+ ///
+ ///
+ ///
+ ///
+ public static IEnumerable Flatten(INestable nestedObject)
+ {
+ return nestedObject.AsNest().Flatten();
+ }
+
+ ///
+ /// Map the structure with specified function.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static INestStructure MapStructure(Func func, INestable nestedObject)
+ {
+ return nestedObject.AsNest().MapStructure(func);
+ }
+
+ public static bool IsNested(INestable obj)
+ {
+ return obj.AsNest().IsNested();
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Common/Types/Nest.cs b/src/TensorFlowNET.Core/Common/Types/Nest.cs
new file mode 100644
index 000000000..89ce29f2f
--- /dev/null
+++ b/src/TensorFlowNET.Core/Common/Types/Nest.cs
@@ -0,0 +1,485 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Common.Extensions;
+
+namespace Tensorflow.Common.Types
+{
+ public enum NestType
+ {
+ Empty,
+ Node,
+ List,
+ Dictionary
+ }
+
+ ///
+ /// A nested structure which may inclulde value, list and dictionary.
+ /// Note that dictionary does not ensure the data order. When using it as IEnumerable,
+ /// its order is depth-first.
+ ///
+ ///
+ public class Nest : INestStructure, IEnumerable
+ {
+ private static readonly Nest _empty = new Nest()
+ {
+ NestType = NestType.Empty,
+ };
+ public static Nest Empty => _empty;
+ public NestType NestType { get; protected set; }
+ public string? Name { get; set; }
+ public T? NodeValue { get; protected set; }
+ public List>? ListValue { get; protected set; }
+ public Dictionary>? DictValue { get; protected set; }
+
+ public int ShallowNestedCount
+ {
+ get
+ {
+ if (NestType == NestType.Empty)
+ {
+ return 0;
+ }
+ else if (NestType == NestType.Node)
+ {
+ return 1;
+ }
+ else if (NestType == NestType.List)
+ {
+ return ListValue!.Count;
+ }
+ else // dict
+ {
+ return DictValue!.Count;
+ }
+ }
+ }
+
+ public int TotalNestedCount
+ {
+ get
+ {
+ return Flatten().Count();
+ }
+ }
+
+ protected Nest() { }
+
+ public Nest(T value, string? name = null)
+ {
+ NodeValue = value;
+ Name = name;
+ NestType = NestType.Node;
+ }
+
+ public Nest(IEnumerable> values, string? name = null)
+ {
+ ListValue = values.ToList();
+ Name = name;
+ NestType = NestType.List;
+ }
+
+ public Nest(Dictionary> value, string? name = null)
+ {
+ DictValue = value;
+ Name = name;
+ NestType = NestType.Dictionary;
+ }
+
+ public Nest(Nest other)
+ {
+ NestType = other.NestType;
+ NodeValue = other.NodeValue;
+ DictValue = other.DictValue;
+ ListValue = other.ListValue;
+ Name = other.Name;
+ }
+
+ public virtual IEnumerable Flatten()
+ {
+ return FlattenInternal(this);
+ }
+ public virtual INestStructure MapStructure(Func func)
+ {
+ return MapStructureInternal(func);
+ }
+
+ ///
+ /// Pack the flat items to a nested sequence by the template.
+ ///
+ ///
+ ///
+ public virtual Nest PackSequence(TOut[] flatItems)
+ {
+ if(flatItems.Length == 0)
+ {
+ return Nest.Empty;
+ }
+ int index = 0;
+ return PackSequenceInternal(this, flatItems, ref index);
+ }
+
+ private static Nest PackSequenceInternal(Nest template, TOut[] flatItems, ref int index)
+ {
+ if(template.NestType == NestType.Node)
+ {
+ if(index >= flatItems.Length)
+ {
+ throw new InvalidArgumentError("The template and flat items are not matched.");
+ }
+ return new Nest(flatItems[index++]);
+ }
+ else if(template.NestType == NestType.List)
+ {
+ List> nestedObjects = new List>();
+ for (int i = 0; i < template.ListValue!.Count; i++)
+ {
+ nestedObjects.Add(PackSequenceInternal(template.ListValue![i].AsNest(), flatItems, ref index));
+ }
+ return new Nest(nestedObjects);
+ }
+ else if(template.NestType == NestType.Node)
+ {
+ Dictionary> dict = new Dictionary>();
+ foreach(var (key, value) in template.DictValue!)
+ {
+ dict[key] = PackSequenceInternal(value.AsNest(), flatItems, ref index);
+ }
+ return new Nest(dict);
+ }
+ // Consider Empty as invalid type.
+ throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node.");
+ }
+
+ public virtual Nest AsNest()
+ {
+ return this;
+ }
+
+ public virtual Nest MergeWith(Nest? other)
+ {
+ if(other is null || other == Nest.Empty)
+ {
+ return this;
+ }
+ if(this == Nest.Empty)
+ {
+ return other;
+ }
+ if(NestType == NestType.Node && other.NestType == NestType.Node)
+ {
+ return new Nest(new Nest[] { this, other });
+ }
+ else if(NestType == NestType.List && other.NestType == NestType.List)
+ {
+ return new Nest(this.ListValue!.Concat(other.ListValue!));
+ }
+ else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary)
+ {
+ return new Nest(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value));
+ }
+ else
+ {
+ return new Nest(new Nest[] { this, other });
+ }
+ }
+
+ ///
+ /// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not
+ /// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested.
+ ///
+ ///
+ public bool IsNested()
+ {
+ if(NestType is NestType.Empty or NestType.Node)
+ {
+ return false;
+ }
+ else if(NestType is NestType.List)
+ {
+ return ListValue!.Count > 0;
+ }
+ else
+ {
+ return DictValue!.Count > 0;
+ }
+ }
+
+ [Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")]
+ public T this[int index]
+ {
+ get
+ {
+ bool success = FindInternal(this, index, out var result);
+ if (success)
+ {
+ return result;
+ }
+ else
+ {
+ throw new IndexOutOfRangeException();
+ }
+ }
+ set
+ {
+ bool success = SetInternal(this, index, value);
+ if (!success)
+ {
+ throw new IndexOutOfRangeException();
+ }
+ }
+ }
+
+ ///
+ /// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it
+ /// to `Nest[T]`.
+ ///
+ ///
+ ///
+ ///
+ public static Nest ReduceFrom(INestStructure input) where TOut: INestStructure
+ {
+ var nested = input.AsNest();
+ return ReduceInternal(nested).AsNest();
+ }
+
+ private static INestStructure ReduceInternal(Nest node) where TOut : INestStructure
+ {
+ if(node.NestType == NestType.Empty)
+ {
+ return Nest.Empty;
+ }
+ else if(node.NestType == NestType.Node)
+ {
+ return node.NodeValue!.AsNest();
+ }
+ else if(node.NestType == NestType.List)
+ {
+ return new Nest(node.ListValue!.Select(x => ReduceInternal(x.AsNest())));
+ }
+ else // Dictionary type
+ {
+ return new Nest(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value.AsNest())));
+ }
+ }
+
+ private static bool FindInternal(Nest node, int index, out T? result)
+ {
+ if (node.NestType == NestType.Node)
+ {
+ if(index == 0)
+ {
+ result = node.NodeValue!;
+ return true;
+ }
+ result = default(T);
+ return false;
+ }
+ else if (node.NestType == NestType.List)
+ {
+ foreach (var item in node.ListValue!)
+ {
+ if(index == 0)
+ {
+ return FindInternal(item.AsNest(), index, out result);
+ }
+ index--;
+ }
+ result = default(T);
+ return false;
+ }
+ else if(node.NestType == NestType.Dictionary)
+ {
+ foreach (var item in node.DictValue!.Values)
+ {
+ if (index == 0)
+ {
+ return FindInternal(item.AsNest(), index, out result);
+ }
+ index--;
+ }
+ result = default(T);
+ return false;
+ }
+ else
+ {
+ result = default(T);
+ return false;
+ }
+ }
+
+ private static bool SetInternal(Nest node, int index, T newValue)
+ {
+ if (node.NestType == NestType.Node)
+ {
+ if (index == 0)
+ {
+ node.NodeValue = newValue;
+ return true;
+ }
+ return false;
+ }
+ else if (node.NestType == NestType.List)
+ {
+ foreach (var item in node.ListValue!)
+ {
+ if (index == 0)
+ {
+ return SetInternal(item.AsNest(), index, newValue);
+ }
+ index--;
+ }
+ return false;
+ }
+ else if (node.NestType == NestType.Dictionary)
+ {
+ foreach (var item in node.DictValue!.Values)
+ {
+ if (index == 0)
+ {
+ return SetInternal(item.AsNest(), index, newValue);
+ }
+ index--;
+ }
+ return false;
+ }
+ else
+ {
+ return false;
+ }
+ }
+
+ private static IEnumerable FlattenInternal(Nest node)
+ {
+ if (node.NestType == NestType.Node)
+ {
+ yield return node.NodeValue!;
+ }
+ else if (node.NestType == NestType.List)
+ {
+ foreach (var item in node.ListValue!)
+ {
+ foreach(var val in FlattenInternal(item.AsNest()))
+ {
+ yield return val;
+ }
+ }
+ }
+ else if (node.NestType == NestType.Dictionary)
+ {
+ foreach (var item in node.DictValue!.Values)
+ {
+ foreach (var val in FlattenInternal(item.AsNest()))
+ {
+ yield return val;
+ }
+ }
+ }
+ }
+
+ private Nest MapStructureInternal(Func func)
+ {
+ if (NestType == NestType.Node)
+ {
+ return new Nest(func(NodeValue!));
+ }
+ else if (NestType == NestType.List)
+ {
+ List> outs = new List>();
+ foreach (var item in ListValue!)
+ {
+ outs.Add(item.AsNest().MapStructureInternal(func));
+ }
+ return new Nest(outs);
+ }
+ else if (NestType == NestType.Dictionary)
+ {
+ Dictionary> outs = new Dictionary>();
+ foreach (var (key, value) in DictValue!)
+ {
+ outs.Add(key, value.AsNest().MapStructureInternal(func));
+ }
+ return new Nest(outs);
+ }
+ else
+ {
+ return Nest.Empty;
+ }
+ }
+
+ public IEnumerator GetEnumerator()
+ {
+ return Flatten().GetEnumerator();
+ }
+
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return GetEnumerator();
+ }
+
+ public override string ToString()
+ {
+ StringBuilder sb = new StringBuilder();
+ sb.Append("(");
+ WriteString(this, sb);
+ sb.Append(")");
+ return sb.ToString();
+ }
+
+ private static void WriteString(Nest node, StringBuilder sb)
+ {
+ if (!string.IsNullOrEmpty(node.Name))
+ {
+ sb.Append($"{node.Name}: ");
+ }
+ if (node.NestType == NestType.Node)
+ {
+ sb.Append(node.NodeValue!.ToString());
+ }
+ else if (node.NestType == NestType.List)
+ {
+ sb.Append("[");
+ for(int i = 0; i < node.ListValue!.Count; i++)
+ {
+ WriteString(node.ListValue![i].AsNest(), sb);
+ if(i != node.ListValue!.Count - 1)
+ {
+ sb.Append(", ");
+ }
+ }
+ sb.Append("]");
+ }
+ else if (node.NestType == NestType.Dictionary)
+ {
+ sb.Append("{");
+ int count = node.DictValue!.Count;
+ int i = 0;
+ foreach (var (key, value) in node.DictValue!)
+ {
+ sb.Append($"{key}: ");
+ WriteString(value.AsNest(), sb);
+ if (i != count - 1)
+ {
+ sb.Append(", ");
+ }
+ i++;
+ }
+ sb.Append("}");
+ }
+ else
+ {
+ sb.Append("");
+ }
+ }
+
+ public static implicit operator Nest((INestStructure, INestStructure) inputs)
+ {
+ return new Nest(new INestStructure[] { inputs.Item1, inputs.Item2 });
+ }
+
+ public static implicit operator Nest((INestStructure, INestStructure, INestStructure) inputs)
+ {
+ return new Nest(new INestStructure[] { inputs.Item1, inputs.Item2, inputs.Item3 });
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs b/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs
new file mode 100644
index 000000000..cf1994554
--- /dev/null
+++ b/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs
@@ -0,0 +1,103 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Common.Types
+{
+ public class NestDictionary : INestStructure, IDictionary where TKey : notnull
+ {
+ public NestType NestType => NestType.Dictionary;
+ public IDictionary Value { get; set; }
+ public int ShallowNestedCount => Values.Count;
+
+ public int TotalNestedCount => Values.Count;
+ public NestDictionary(IDictionary dict)
+ {
+ Value = dict;
+ }
+ public IEnumerable Flatten()
+ {
+ return Value.Select(x => x.Value);
+ }
+ public INestStructure MapStructure(Func func)
+ {
+ return new NestList(Value.Select(x => func(x.Value)));
+ }
+
+ public Nest AsNest()
+ {
+ return new Nest(Value.Values.Select(x => new Nest(x)));
+ }
+
+ // Required IDictionary members
+ public int Count => Value.Count;
+
+ public bool IsReadOnly => Value.IsReadOnly;
+
+ public ICollection Keys => Value.Keys;
+
+ public ICollection Values => Value.Values;
+
+ public void Add(TKey key, TValue value)
+ {
+ Value.Add(key, value);
+ }
+
+ public void Add(KeyValuePair item)
+ {
+ Value.Add(item);
+ }
+
+ public void Clear()
+ {
+ Value.Clear();
+ }
+
+ public bool Contains(KeyValuePair item)
+ {
+ return Value.Contains(item);
+ }
+
+ public bool ContainsKey(TKey key)
+ {
+ return Value.ContainsKey(key);
+ }
+
+ public void CopyTo(KeyValuePair[] array, int arrayIndex)
+ {
+ Value.CopyTo(array, arrayIndex);
+ }
+
+ public IEnumerator> GetEnumerator()
+ {
+ return Value.GetEnumerator();
+ }
+
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return GetEnumerator();
+ }
+
+ public bool Remove(TKey key)
+ {
+ return Value.Remove(key);
+ }
+
+ public bool Remove(KeyValuePair item)
+ {
+ return Value.Remove(item);
+ }
+
+ public bool TryGetValue(TKey key, out TValue value)
+ {
+ return Value.TryGetValue(key, out value);
+ }
+
+ // Optional IDictionary members
+ public TValue this[TKey key]
+ {
+ get => Value[key];
+ set => Value[key] = value;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Common/Types/NestList.cs b/src/TensorFlowNET.Core/Common/Types/NestList.cs
new file mode 100644
index 000000000..1e0d272b7
--- /dev/null
+++ b/src/TensorFlowNET.Core/Common/Types/NestList.cs
@@ -0,0 +1,53 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Common.Types
+{
+ ///
+ /// The implementation of a list that support nest structure, in which the depth is 1.
+ ///
+ ///
+ public sealed class NestList : INestStructure, IEnumerable
+ {
+ public NestType NestType => NestType.List;
+ public List Values { get; set; }
+ public int ShallowNestedCount => Values.Count;
+
+ public int TotalNestedCount => Values.Count;
+
+ public NestList(params T[] values)
+ {
+ Values = new List(values);
+ }
+
+ public NestList(IEnumerable values)
+ {
+ Values = new List(values);
+ }
+ public IEnumerable Flatten()
+ {
+ return Values;
+ }
+ public INestStructure MapStructure(Func func)
+ {
+ return new NestList(Values.Select(x => func(x)));
+ }
+
+ public Nest AsNest()
+ {
+ return new Nest(Values.Select(x => new Nest(x)));
+ }
+
+ // Enumerator implementation
+ public IEnumerator GetEnumerator()
+ {
+ return Values.GetEnumerator();
+ }
+
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return GetEnumerator();
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Common/Types/NestNode.cs b/src/TensorFlowNET.Core/Common/Types/NestNode.cs
new file mode 100644
index 000000000..701aade9a
--- /dev/null
+++ b/src/TensorFlowNET.Core/Common/Types/NestNode.cs
@@ -0,0 +1,36 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Common.Types
+{
+ ///
+ /// A nested structure with only one element.
+ ///
+ ///
+ public class NestNode : INestStructure
+ {
+ public NestType NestType => NestType.Node;
+ public T Value { get; set; }
+ public int ShallowNestedCount => 1;
+
+ public int TotalNestedCount => 1;
+ public NestNode(T value)
+ {
+ Value = value;
+ }
+ public IEnumerable Flatten()
+ {
+ yield return Value;
+ }
+ public INestStructure MapStructure(Func func)
+ {
+ return new NestNode(func(Value));
+ }
+
+ public Nest AsNest()
+ {
+ return new Nest(Value);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs b/src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs
similarity index 95%
rename from src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs
rename to src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs
index 7abcfde26..a36930eca 100644
--- a/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs
+++ b/src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs
@@ -3,7 +3,7 @@
using System.Collections.Generic;
using System.Linq;
-namespace Tensorflow.Keras.Saving
+namespace Tensorflow.Common.Types
{
public class TensorShapeConfig
{
diff --git a/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs b/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs
index ac1cd8660..f6e0911ca 100644
--- a/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs
+++ b/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs
@@ -49,7 +49,7 @@ Tensors ExecGraphAction(string OpType, string Name, ExecuteOpArgs args)
Tensors ExecEagerAction(string OpType, string Name, ExecuteOpArgs args)
{
- var opExecInfo = new FastPathOpExecInfo(OpType, Name, args.OpInputArgs)
+ var opExecInfo = new FastPathOpExecInfo(tf.Context, OpType, Name, args.OpInputArgs)
{
attrs = args.OpAttrs
};
diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs
index 324d7e834..c1762d670 100644
--- a/src/TensorFlowNET.Core/Data/DatasetV2.cs
+++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs
@@ -161,8 +161,8 @@ public override string ToString()
break;
}
- yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ?
- null : new Tensors(results.Skip(FirstInputTensorCount)));
+ yield return (new Tensors(results.Take(FirstInputTensorCount).ToArray()), results.Length == FirstInputTensorCount ?
+ null : new Tensors(results.Skip(FirstInputTensorCount).ToArray()));
}
}
diff --git a/src/TensorFlowNET.Core/Device/DeviceSpec.cs b/src/TensorFlowNET.Core/Device/DeviceSpec.cs
index f4ea8cf05..255191cb5 100644
--- a/src/TensorFlowNET.Core/Device/DeviceSpec.cs
+++ b/src/TensorFlowNET.Core/Device/DeviceSpec.cs
@@ -1,4 +1,5 @@
using System;
+using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;
@@ -7,8 +8,8 @@ namespace Tensorflow.Device
{
public class DeviceSpec
{
- private static Dictionary _STRING_TO_COMPONENTS_CACHE = new();
- private static Dictionary _COMPONENTS_TO_STRING_CACHE = new();
+ private static ConcurrentDictionary _STRING_TO_COMPONENTS_CACHE = new();
+ private static ConcurrentDictionary _COMPONENTS_TO_STRING_CACHE = new();
private string _job;
private int _replica;
private int _task;
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
index 59d5fd030..2bdd65f5b 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
@@ -80,6 +80,11 @@ BackwardFunction GetGradientFunction(string op_name,
Tensor[] op_outputs)
=> (out_grads, unneeded_gradients) =>
{
+ if(!ops.gradientFunctions.ContainsKey(op_name))
+ {
+ throw new Exception($"gradientFunctions not find op_name: {op_name}");
+ }
+
if (ops.gradientFunctions[op_name] == null)
return new Tensor[op_inputs.Length];
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
index fedc02cb9..0ce55841b 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
@@ -68,7 +68,8 @@ public Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info)
var input_arg = op_def.InputArg[i];
if (!string.IsNullOrEmpty(input_arg.NumberAttr))
{
- int len = (input as object[]).Length;
+ var fast_input_array = input is Tensors tensors ? (object[])tensors : (object[])input;
+ int len = fast_input_array.Length;
c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len);
if (op_exec_info.run_callbacks)
{
@@ -79,7 +80,6 @@ public Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info)
if (len > 0)
{
- var fast_input_array = (object[])op_exec_info.args[i];
// First item adds the type attr.
if (!AddInputToOp(fast_input_array[i], true, input_arg, flattened_attrs, flattened_inputs, op, status))
return null;
@@ -352,13 +352,19 @@ bool SetOpAttrScalar(Context ctx, SafeEagerOpHandle op,
c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value));
break;
case TF_AttrType.TF_ATTR_SHAPE:
- var dims = (value as long[]).ToArray();
+ long[] dims;
+ if (value is Shape shape) dims = shape.dims.ToArray();
+ else if (value is long[] longs) dims = longs.ToArray();
+ else if (value is int[] ints) dims = ints.Select(x => (long)x).ToArray();
+ else dims = ((long[])value).ToArray();
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status);
status.Check(true);
break;
case TF_AttrType.TF_ATTR_FUNC:
if (value is ConcreteFunction func)
c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length);
+ else if(value is string str)
+ c_api.TFE_OpSetAttrFunctionName(op, key, str, str.Length);
else
throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC");
break;
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
index 1f7b3ae64..3515fed83 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
@@ -65,7 +65,7 @@ public Tensor[] TFE_TapeGradient(ITape tape,
{
outgrad_vec = output_gradients.ToList();
}
- var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false);
+ var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, true);
bool unconnected_gradients_zero = unconnected_gradients == "zero";
@@ -137,7 +137,6 @@ TapeTensor TapeTensorFromTensor(Tensor tensor)
{
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status);
}
- Shape tensor_shape = new(dims);
if(status.Code != TF_Code.TF_OK)
{
@@ -145,6 +144,7 @@ TapeTensor TapeTensorFromTensor(Tensor tensor)
}
else
{
+ Shape tensor_shape = new(dims);
return new TapeTensor(id, dtype, tensor_shape);
}
}
@@ -173,8 +173,12 @@ bool DTypeNeedsHandleData(TF_DataType dtype)
return dtype == dtypes.variant || dtype == dtypes.resource;
}
- bool ListContainNone(long[] list)
+ bool ListContainNone(long[]? list)
{
+ if(list is null)
+ {
+ return true;
+ }
int len = list.Length;
if(len == 0)
{
diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs
index ce3c983b5..71b3075aa 100644
--- a/src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs
@@ -10,6 +10,11 @@ public override string ToString()
var str = NDArrayRender.ToString(nd);
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
}
-
+ public string ToString(int maxLength)
+ {
+ var nd = new NDArray(this);
+ var str = NDArrayRender.ToString(nd, maxLength);
+ return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Eager/FastPathOpExecInfo.cs b/src/TensorFlowNET.Core/Eager/FastPathOpExecInfo.cs
index 2cdf025a1..307ca2ce4 100644
--- a/src/TensorFlowNET.Core/Eager/FastPathOpExecInfo.cs
+++ b/src/TensorFlowNET.Core/Eager/FastPathOpExecInfo.cs
@@ -17,8 +17,9 @@ public class FastPathOpExecInfo
public bool run_callbacks { get; set; }
public Action callbacks { get; set; }
- public FastPathOpExecInfo(string opName, string name, params object[] inputArgs)
+ public FastPathOpExecInfo(Context ctx, string opName, string name, params object[] inputArgs)
{
+ this.ctx = ctx;
this.op_name = opName;
this.name = name;
this.args = inputArgs;
diff --git a/src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs b/src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs
new file mode 100644
index 000000000..2c20cfe9b
--- /dev/null
+++ b/src/TensorFlowNET.Core/Eager/GraphOnlyOps.cs
@@ -0,0 +1,25 @@
+using Tensorflow;
+
+internal static class GraphOnlyOps
+{
+ ///
+ /// Graph-only version of tf.compat.v1.placeholder(), for internal use only.
+ ///
+ ///
+ ///
+ ///
+ ///
+ internal static Tensor graph_placeholder(TF_DataType dtype, Shape shape, string? name = null)
+ {
+ var dtype_value = new AttrValue() { Type = dtype.as_datatype_enum() };
+ var shape_value = new AttrValue() { Shape = shape.as_proto() };
+ var g = ops.get_default_graph();
+ Dictionary attrs = new();
+ attrs["dtype"] = dtype_value;
+ attrs["shape"] = shape_value;
+ var op = g.create_op("Placeholder", new Tensor[0], new TF_DataType[] { dtype },
+ new TF_DataType[0], attrs: attrs, name: name);
+ var result = op.outputs[0];
+ return result;
+ }
+}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Eager/execute.cs b/src/TensorFlowNET.Core/Eager/execute.cs
index 1804992ac..e981c6c51 100644
--- a/src/TensorFlowNET.Core/Eager/execute.cs
+++ b/src/TensorFlowNET.Core/Eager/execute.cs
@@ -7,10 +7,11 @@
using static Tensorflow.ApiDef.Types;
using static Tensorflow.CostGraphDef.Types;
using static Tensorflow.Binding;
+using Tensorflow.Gradients;
namespace Tensorflow.Eager
{
- internal static class execute
+ internal static class _execute
{
public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx)
{
@@ -18,7 +19,7 @@ public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] valu
var types = v.Select(t => t.dtype.as_datatype_enum());
return (types.ToArray(), v.ToArray());
}
- public static Tensor[] executes(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null)
+ public static Tensor[] execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null)
{
return quick_execute(op_name, num_outputs, inputs, attrs, ctx, name);
}
@@ -33,7 +34,12 @@ public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] i
}
public static bool must_record_gradient()
{
- return false;
+ return tf.GetTapeSet().Count != 0;
+ }
+
+ public static bool record_gradient(string op_name, Tensor[] inputs, object[] attrs, Tensor[] results)
+ {
+ return tf.Runner.RecordGradient(op_name, inputs, attrs, results);
}
}
}
diff --git a/src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs b/src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs
new file mode 100644
index 000000000..c283c1a45
--- /dev/null
+++ b/src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs
@@ -0,0 +1,19 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Exceptions
+{
+ public class NotOkStatusException : TensorflowException
+ {
+ public NotOkStatusException() : base()
+ {
+
+ }
+
+ public NotOkStatusException(string message) : base(message)
+ {
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Framework/IndexedSlices.cs b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs
index 24d356fbb..bac5e6fb1 100644
--- a/src/TensorFlowNET.Core/Framework/IndexedSlices.cs
+++ b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs
@@ -49,12 +49,25 @@ public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null)
public static implicit operator Tensor(IndexedSlices indexedSlices)
{
- return indexedSlices.values;
+ return _indexed_slices_to_tensor(indexedSlices);
}
public static implicit operator IndexedSlices(Tensor tensor)
{
return tensor.Tag as IndexedSlices;
}
+
+ ///
+ /// Converts an IndexedSlices object `value` to a Tensor.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor _indexed_slices_to_tensor(IndexedSlices indexedSlices, TF_DataType dtype = TF_DataType.DtInvalid, String name = "", bool as_ref = false)
+ {
+ return gen_math_ops.unsorted_segment_sum(indexedSlices.values, indexedSlices.indices, indexedSlices.dense_shape.slice(0));
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs b/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs
index 083d4813a..ac099ae2b 100644
--- a/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs
+++ b/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs
@@ -1,4 +1,5 @@
using System.Linq;
+using Tensorflow.Eager;
namespace Tensorflow.Framework.Models
{
@@ -24,5 +25,17 @@ public TensorSpec _batch(int dim = -1)
shapes.Insert(0, dim);
return new TensorSpec(shapes.ToArray(), _dtype);
}
+
+ public static TensorSpec FromTensor(Tensor tensor, string? name = null)
+ {
+ if(tensor is EagerTensor)
+ {
+ return new TensorSpec(tensor.shape, tensor.dtype, name);
+ }
+ else
+ {
+ return new TensorSpec(tensor.shape, tensor.dtype, name ?? tensor.name);
+ }
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs b/src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs
new file mode 100644
index 000000000..28d9e5008
--- /dev/null
+++ b/src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs
@@ -0,0 +1,89 @@
+using Tensorflow.Graphs;
+
+namespace Tensorflow.Framework
+{
+ internal static class auto_control_deps_utils
+ {
+ public static readonly string READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs";
+ public static List get_read_only_resource_input_indices_graph(FuncGraph func_graph)
+ {
+ List result = new List();
+ // A cache to store the read only resource inputs of an Op.
+ // Operation -> ObjectIdentitySet of resource handles.
+ Dictionary> opReadOnlyResourceInputs =
+ new Dictionary>();
+
+ for (int inputIndex = 0; inputIndex < func_graph.Inputs.Length; inputIndex++)
+ {
+ Tensor t = func_graph.Inputs[inputIndex];
+ if (t.dtype != dtypes.resource)
+ continue;
+
+ bool readOnly = true;
+ foreach (var op in t.consumers())
+ {
+ if (opReadOnlyResourceInputs.ContainsKey(op))
+ {
+ if (!opReadOnlyResourceInputs[op].Contains(t))
+ {
+ readOnly = false;
+ break;
+ }
+ }
+ else
+ {
+ List indices = _get_read_only_resource_input_indices_op(op);
+ opReadOnlyResourceInputs[op] = new HashSet(
+ indices.Select(i => op.inputs[i]));
+ if (!opReadOnlyResourceInputs[op].Contains(t))
+ {
+ readOnly = false;
+ break;
+ }
+ }
+ }
+
+ if (readOnly)
+ result.Add(inputIndex);
+ }
+
+ return result;
+ }
+
+ private static List _get_read_only_resource_input_indices_op(Operation op)
+ {
+ // ignore the RESOURCE_READ_OPS
+
+ int[] read_only_input_indices;
+
+ try
+ {
+ read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR);
+ }
+ catch (InvalidArgumentError)
+ {
+ return new List();
+ }
+
+ int read_only_index = 0;
+ List result = new();
+ for (int i = 0; i < op.inputs.Length; i++)
+ {
+ if (read_only_index >= read_only_input_indices.Length)
+ {
+ break;
+ }
+ if (op.inputs[i].dtype != dtypes.resource)
+ {
+ continue;
+ }
+ if (read_only_index < read_only_input_indices.Length && i == read_only_input_indices[read_only_index])
+ {
+ result.Add(i);
+ read_only_index++;
+ }
+ }
+ return result;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Framework/function_def_lib.cs b/src/TensorFlowNET.Core/Framework/function_def_lib.cs
index 67f8d324e..488c6b654 100644
--- a/src/TensorFlowNET.Core/Framework/function_def_lib.cs
+++ b/src/TensorFlowNET.Core/Framework/function_def_lib.cs
@@ -42,10 +42,10 @@ public static FuncGraph function_def_to_graph(FunctionDef fdef, object? structur
func_graph.as_default();
importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false);
var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]);
- func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)));
+ func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray());
var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]);
- func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)));
+ func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray());
// TODO(Rinne): func_graph.ControlOutputs
_set_handle_data(func_graph, fdef);
diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
index 88dce7d98..8742e4535 100644
--- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
+++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
@@ -8,6 +8,7 @@
using Tensorflow.Graphs;
using Tensorflow.Train;
using Tensorflow.Util;
+using Tensorflow.Common.Extensions;
using static Tensorflow.Binding;
namespace Tensorflow.Functions
@@ -40,6 +41,18 @@ public class ConcreteFunction: Trackable
public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs;
public IEnumerable Variables => func_graph.Variables;
public IEnumerable