8000 Merge pull request #1116 from lingbai-kong/imdbfix · SciSharp/TensorFlow.NET@3de7b8e · GitHub
[go: up one dir, main page]

Skip to content

Commit 3de7b8e

Browse files
authored
Merge pull request #1116 from lingbai-kong/imdbfix
fix: type converting errors when loading imdb dataset
2 parents 8630438 + c23b246 commit 3de7b8e

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,15 @@ public unsafe static implicit operator double(NDArray nd)
107107
public static implicit operator NDArray(bool value)
108108
=> new NDArray(value);
109109

110+
public static implicit operator NDArray(byte value)
111+
=> new NDArray(value);
112+
110113
public static implicit operator NDArray(int value)
111114
=> new NDArray(value);
112115

116+
public static implicit operator NDArray(long value)
117+
=> new NDArray(value);
118+
113119
public static implicit operator NDArray(float value)
114120
=> new NDArray(value);
115121

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,13 @@ public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT
8484
// var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape);
8585
Tensor zeros = dtype switch
8686
{
87+
TF_DataType.TF_BOOL => constant(false),
8788
TF_DataType.TF_DOUBLE => constant(0d),
8889
TF_DataType.TF_FLOAT => constant(0f),
90+
TF_DataType.TF_INT64 => constant(0L),
91+
TF_DataType.TF_UINT64 => constant((ulong)0),
92+
TF_DataType.TF_INT32 => constant(0),
93+
TF_DataType.TF_UINT32 => constant((uint)0),
8994
TF_DataType.TF_INT8 => constant((sbyte)0),
9095
TF_DataType.TF_UINT8 => constant((byte)0),
9196
_ => constant(0)
@@ -108,9 +113,15 @@ public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT
108113
return _constant_if_small(0.0F, shape, dtype, name);
109114
case TF_DataType.TF_INT64:
110115
return _constant_if_small(0L, shape, dtype, name);
116+
case TF_DataType.TF_UINT64:
117+
return _constant_if_small<ulong>(0, shape, dtype, name);
111118
case TF_DataType.TF_INT32:
112119
return _constant_if_small(0, shape, dtype, name);
120+
case TF_DataType.TF_UINT32:
121+
return _constant_if_small<uint>(0, shape, dtype, name);
113122
case TF_DataType.TF_INT8:
123+
return _constant_if_small<sbyte>(0, shape, dtype, name);
124+
case TF_DataType.TF_UINT8:
114125
return _constant_if_small<byte>(0, shape, dtype, name);
115126
default:
116127
throw new TypeError("can't find type for zeros");

0 commit comments

Comments
 (0)
0