@@ -23,6 +23,11 @@ public class Array implements AutoCloseable {
23
23
private native static long createArrayFromInt (int [] dims , int [] elems );
24
24
private native static long createArrayFromBoolean (int [] dims , boolean [] elems );
25
25
26
+ private native static long createRanduArray (int [] dims , int type );
27
+ private native static long createRandnArray (int [] dims , int type );
28
+ private native static long createConstantsArray (double val , int [] dims , int type );
29
+
30
+
26
31
private native static void destroyArray (long ref );
27
32
private native static int [] getDims (long ref );
28
33
private native static int getType (long ref );
@@ -66,9 +71,13 @@ public class Array implements AutoCloseable {
66
71
private native static long sqrt (long a );
67
72
68
73
// Scalar return operations
69
- private native static float sum (long a );
70
- private native static float max (long a );
71
- private native static float min (long a );
74
+ private native static double sumAll (long a );
75
+ private native static double maxAll (long a );
76
+ private native static double minAll (long a );
77
+
78
+ private native static long sum (long a , int dim );
79
+ private native static long max (long a , int dim );
80
+ private native static long min (long a , int dim );
72
81
73
82
// Scalar operations
74
83
private native static long addf (long a , float b );
@@ -115,7 +124,7 @@ public String typeName(int ty) throws Exception {
115
124
throw new Exception ("Unknown type" );
116
125
}
117
126
118
- private int [] dim4 (int [] dims ) throws Exception {
127
+ private static int [] dim4 (int [] dims ) throws Exception {
119
128
120
129
if ( dims == null ) {
121
130
throw new Exception ("Null dimensions object provided" );
@@ -280,6 +289,37 @@ public boolean[] getBooleanArray() throws Exception {
280
289
}
281
290
282
291
// Binary operations
292
+
293
+ public static Array randu (int [] dims , int type ) throws Exception {
294
+ int [] adims = dim4 (dims );
295
+ long ref = createRanduArray (adims , type );
296
+ if (ref == 0 ) throw new Exception ("Failed to create Array" );
297
+
298
+ Array ret_val = new Array ();
299
+ ret_val .ref = ref ;
300
+ return ret_val ;
301
+ }
302
+
303
+ public static Array randn (int [] dims , int type ) throws Exception {
304
+ int [] adims = dim4 (dims );
305
+ long ref = createRandnArray (adims , type );
306
+ if (ref == 0 ) throw new Exception ("Failed to create Array" );
307
+
308
+ Array ret_val = new Array ();
309
+ ret_val .ref = ref ;
310
+ return ret_val ;
311
+ }
312
+
313
+ public static Array constant (double val , int [] dims , int type ) throws Exception {
314
+ int [] adims = dim4 (dims );
315
+ long ref = createConstantsArray (val , adims , type );
316
+ if (ref == 0 ) throw new Exception ("Failed to create Array" );
317
+
318
+ Array ret_val = new Array ();
319
+ ret_val .ref = ref ;
320
+ return ret_val ;
321
+ }
322
+
283
323
public static Array add (Array a , Array b ) throws Exception {
284
324
Array ret_val = new Array ();
285
325
ret_val .ref = add (a .ref ,b .ref );
@@ -432,11 +472,39 @@ public static Array sqrt(Array a) throws Exception {
432
472
}
433
473
434
474
// Scalar return operations
435
- public static float sum (Array a ) throws Exception { return sum (a .ref ); }
475
+ public static double sumAll (Array a ) throws Exception { return sumAll (a .ref ); }
476
+ public static double maxAll (Array a ) throws Exception { return maxAll (a .ref ); }
477
+ public static double minAll (Array a ) throws Exception { return minAll (a .ref ); }
478
+
479
+ public static Array sum (Array a , int dim ) throws Exception {
480
+ Array ret_val = new Array ();
481
+ ret_val .ref = sum (a .ref , dim );
482
+ return ret_val ;
483
+ }
484
+
485
+ public static Array max (Array a , int dim ) throws Exception {
486
+ Array ret_val = new Array ();
487
+ ret_val .ref = max (a .ref , dim );
488
+ return ret_val ;
489
+ }
490
+
491
+ public static Array min (Array a , int dim ) throws Exception {
492
+ Array ret_val = new Array ();
493
+ ret_val .ref = min (a .ref , dim );
494
+ return ret_val ;
495
+ }
436
496
437
- public static float max (Array a ) throws Exception { return max (a .ref ); }
497
+ public static Array sum (Array a ) throws Exception {
498
+ return sum (a , -1 );
499
+ }
438
500
439
- public static float min (Array a ) throws Exception { return min (a .ref ); }
501
+ public static Array max (Array a ) throws Exception {
502
+ return max (a , -1 );
503
+ }
504
+
505
+ public static Array min (Array a ) throws Exception {
506
+ return min (a , -1 );
507
+ }
440
508
441
509
// Scalar operations
442
510
public static Array add (Array a , float b ) throws Exception {
0 commit comments