8000 Merge pull request #1172 from novikov-alexander/alnovi/cached_session · SciSharp/TensorFlow.NET@9ddff69 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9ddff69

Browse files
authored
Merge pull request #1172 from novikov-alexander/alnovi/cached_session
cached_session for graph tests
2 parents a91e358 + 9d71cad commit 9ddff69

File tree

3 files changed

+156
-16
lines changed

3 files changed

+156
-16
lines changed

test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
3+
using System.Linq;
34
using Tensorflow;
45
using static Tensorflow.Binding;
56

@@ -23,7 +24,7 @@ public void SimpleWhileLoop()
2324
private void _testWhileContextHelper(int maximum_iterations)
2425
{
2526
// TODO: implement missing code dependencies
26-
var sess = this.cached_session();
27+
using var sess = this.cached_session();
2728
var i = constant_op.constant(0, name: "i");
2829
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, ops.convert_to_tensor(10), name: "c"));
2930
var b = new Func<Tensor, Tensor>(x => math_ops.add(x, 1, name: "c"));

test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -388,22 +388,21 @@ public void testBoundaryStop()
388388

389389
}
390390

391-
[Ignore("TODO")]
392391
[TestMethod]
393392
public void testBoundaryContinue()
394393
{
395-
//@test_util.run_v1_only("b/120545219")
396-
//def testBoundaryContinue(self):
397-
// # Test that we differentiate both 'x' and 'y' correctly when x is a
398-
// # predecessor of y.
399-
// with self.cached_session():
400-
// x = constant(1.0)
401-
// y = x * 2.0
402-
// z = y * 3.0
403-
// grads = gradients.gradients(z, [x, y])
404-
// self.assertTrue(all(x is not None for x in grads))
405-
// self.assertEqual(6.0, grads[0].eval())
394+
// Test that we differentiate both 'x' and 'y' correctly when x is a
395+
// predecessor of y.
406396

397+
using (self.cached_session())
398+
{
399+
var x = tf.constant(1.0);
400+
var y = x * 2.0;
401+
var z = y * 3.0;
402+
var grads = tf.gradients(z, new[] { x, y });
403+
self.assertTrue(all(grads.Select(x => x != null)));
404+
self.assertEqual(6.0, grads[0].eval());
405+
}
407406
}
408407

409408
[Ignore("TODO")]

test/TensorFlowNET.Graph.UnitTest/PythonTest.cs

Lines changed: 143 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
using System.Linq;
77
using Tensorflow;
88
using static Tensorflow.Binding;
9+
using OneOf.Types;
10+
using System.Collections.Generic;
911

1012
namespace TensorFlowNET.UnitTest
1113
{
@@ -139,6 +141,21 @@ public void assertProtoEquals(object toProto, object o)
139141

140142
#region tensor evaluation and test session
141143

144+
private Session _cached_session = null;
145+
private Graph _cached_graph = null;
146+
private object _cached_config = null;
147+
private bool _cached_force_gpu = false;
148+
149+
private void _ClearCachedSession()
150+
{
151+
if (self._cached_session != null)
152+
{
153+
self._cached_session.Dispose();
154+
self._cached_session = null;
155+
}
156+
}
157+
158+
142159
//protected object _eval_helper(Tensor[] tensors)
143160
//{
144161
// if (tensors == null)
@@ -203,10 +220,56 @@ public T evaluate<T>(Tensor tensor)
203220
}
204221
}
205222

206-
207-
public Session cached_session()
223+
///Returns a TensorFlow Session for use in executing tests.
224+
public Session cached_session(
225+
Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
208226
{
209-
throw new NotImplementedException();
227+
// This method behaves differently than self.session(): for performance reasons
228+
// `cached_session` will by default reuse the same session within the same
229+
// test.The session returned by this function will only be closed at the end
230+
// of the test(in the TearDown function).
231+
232+
// Use the `use_gpu` and `force_gpu` options to control where ops are run.If
233+
// `force_gpu` is True, all ops are pinned to `/ device:GPU:0`. Otherwise, if
234+
// `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
235+
// possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to
236+
// the CPU.
237+
238+
// Example:
239+
// python
240+
// class MyOperatorTest(test_util.TensorFlowTestCase) :
241+
// def testMyOperator(self):
242+
// with self.cached_session() as sess:
243+
// valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
244+
// result = MyOperator(valid_input).eval()
245+
// self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
246+
// invalid_input = [-1.0, 2.0, 7.0]
247+
// with self.assertRaisesOpError("negative input not supported"):
248+
// MyOperator(invalid_input).eval()
249+
250+
251+
// Args:
252+
// graph: Optional graph to use during the returned session.
253+
// config: An optional config_pb2.ConfigProto to use to configure the
254+
// session.
255+
// use_gpu: If True, attempt to run as many ops as possible on GPU.
256+
// force_gpu: If True, pin all ops to `/device:GPU:0`.
257+
258+
// Yields:
259+
// A Session object that should be used as a context manager to surround
260+
// the graph building and execution code in a test case.
261+
262+
263+
// TODO:
264+
// if context.executing_eagerly():
265+
// return self._eval_helper(tensors)
266+
// else:
267+
{
268+
var sess = self._get_cached_session(
269+
graph, config, force_gpu, crash_if_inconsistent_args: true);
270+
using var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu);
271+
return cached;
272+
}
210273
}
211274

212275
//Returns a TensorFlow Session for use in executing tests.
@@ -254,6 +317,39 @@ public Session session(Graph graph = null, object config = null, bool use_gpu =
254317
return s.as_default();
255318
}
256319

320+
private Session _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu)
321+
{
322+
// Set the session and its graph to global default and constrain devices."""
323+
if (tf.executing_eagerly())
324+
return null;
325+
else {
326+
sess.graph.as_default();
327+
sess.as_default();
328+
{
329+
if (force_gpu)
330+
{
331+
// TODO:
332+
333+
// Use the name of an actual device if one is detected, or
334+
// '/device:GPU:0' otherwise
335+
/* var gpu_name = gpu_device_name();
336+
if (!gpu_name)
337+
gpu_name = "/device:GPU:0"
338+
using (sess.graph.device(gpu_name)) {
339+
yield return sess;
340+
}*/
341+
return sess;
342+
}
343+
else if (use_gpu)
344+
return sess;
345+
else
346+
using (sess.graph.device("/device:CPU:0"))
347+
return sess;
348+
}
349+
350+
}
351+
}
352+
257353
// See session() for details.
258354
private Session _create_session(Graph graph, object cfg, bool forceGpu)
259355
{
@@ -298,6 +394,50 @@ private Session _create_session(Graph graph, object cfg, bool forceGpu)
298394
return new Session(graph);//, config = prepare_config(config))
299395
}
300396

397+
private Session _get_cached_session(
398+
Graph graph = null,
399+
object config = null,
400+
bool force_gpu = false,
401+
bool crash_if_inconsistent_args = true)
402+
{
403+
// See cached_session() for documentation.
404+
if (self._cached_session == null)
405+
{
406+
var sess = self._create_session(graph, config, force_gpu);
407+
self._cached_session = sess;
408+
self._cached_graph = graph;
409+
self._cached_config = config;
410+
self._cached_force_gpu = force_gpu;
411+
return sess;
412+
} else {
413+
414+
if (crash_if_inconsistent_args && !self._cached_graph.Equals(graph))
415+
throw new ValueError(@"The graph used to get the cached session is
416+
different than the one that was used to create the
417+
session. Maybe create a new session with
418+
self.session()");
419+
if (crash_if_inconsistent_args && !self._cached_config.Equals(config)) {
420+
throw new ValueError(@"The config used to get the cached session is
421+
different than the one that was used to create the
422+
session. Maybe create a new session with
423+
self.session()");
424+
}
425+
if (crash_if_inconsistent_args && !self._cached_force_gpu.Equals(force_gpu)) {
426+
throw new ValueError(@"The force_gpu value used to get the cached session is
427+
different than the one that was used to create the
428+
session. Maybe create a new session with
429+
self.session()");
430+
}
431+
return _cached_session;
432+
}
433+
}
434+
435+
[TestCleanup]
436+
public void Cleanup()
437+
{
438+
_ClearCachedSession();
439+
}
440+
301441
#endregion
302442

303443
public void AssetSequenceEqual<T>(T[] a, T[] b)

0 commit comments

Comments
 (0)
0