@@ -84,8 +84,13 @@ public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT
84
84
// var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape);
85
85
Tensor zeros = dtype switch
86
86
{
87
+ TF_DataType . TF_BOOL => constant ( false ) ,
87
88
TF_DataType . TF_DOUBLE => constant ( 0d ) ,
88
89
TF_DataType . TF_FLOAT => constant ( 0f ) ,
90
+ TF_DataType . TF_INT64 => constant ( 0L ) ,
91
+ TF_DataType . TF_UINT64 => constant ( ( ulong ) 0 ) ,
92
+ TF_DataType . TF_INT32 => constant ( 0 ) ,
93
+ TF_DataType . TF_UINT32 => constant ( ( uint ) 0 ) ,
89
94
TF_DataType . TF_INT8 => constant ( ( sbyte ) 0 ) ,
90
95
TF_DataType . TF_UINT8 => constant ( ( byte ) 0 ) ,
91
96
_ => constant ( 0 )
@@ -108,9 +113,15 @@ public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT
108
113
return _constant_if_small ( 0.0F , shape , dtype , name ) ;
109
114
case TF_DataType . TF_INT64 :
110
115
return _constant_if_small ( 0L , shape , dtype , name ) ;
116
+ case TF_DataType . TF_UINT64 :
117
+ return _constant_if_small < ulong > ( 0 , shape , dtype , name ) ;
111
118
case TF_DataType . TF_INT32 :
112
119
return _constant_if_small ( 0 , shape , dtype , name ) ;
120
+ case TF_DataType . TF_UINT32 :
121
+ return _constant_if_small < uint > ( 0 , shape , dtype , name ) ;
113
122
case TF_DataType . TF_INT8 :
123
+ return _constant_if_small < sbyte > ( 0 , shape , dtype , name ) ;
124
+ case TF_DataType . TF_UINT8 :
114
125
return _constant_if_small < byte > ( 0 , shape , dtype , name ) ;
115
126
default :
116
127
throw new TypeError ( "can't find type for zeros" ) ;
0 commit comments