8000 Adding dimensional reductions to Array.java · wcork/arrayfire-java@045247e · GitHub
[go: up one dir, main page]

Skip to content

Commit 045247e

Browse files
committed
Adding dimensional reductions to Array.java
1 parent 7364976 commit 045247e

File tree

3 files changed

+192
-28
lines changed

3 files changed

+192
-28
lines changed

com/arrayfire/Array.java

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ public class Array implements AutoCloseable {
2323
private native static long createArrayFromInt(int[] dims, int[] elems);
2424
private native static long createArrayFromBoolean(int[] dims, boolean[] elems);
2525

26+
private native static long createRanduArray(int[] dims, int type);
27+
private native static long createRandnArray(int[] dims, int type);
28+
private native static long createConstantsArray(double val, int[] dims, int type);
29+
30+
2631
private native static void destroyArray(long ref);
2732
private native static int[] getDims(long ref);
2833
private native static int getType(long ref);
@@ -66,9 +71,13 @@ public class Array implements AutoCloseable {
6671
private native static long sqrt (long a);
6772

6873
// Scalar return operations
69-
private native static float sum(long a);
70-
private native static float max(long a);
71-
private native static float min(long a);
74+
private native static double sumAll(long a);
75+
private native static double maxAll(long a);
76+
private native static double minAll(long a);
77+
78+
private native static long sum(long a, int dim);
79+
private native static long max(long a, int dim);
80+
private native static long min(long a, int dim);
7281

7382
// Scalar operations
7483
private native static long addf(long a, float b);
@@ -115,7 +124,7 @@ public String typeName(int ty) throws Exception {
115124
throw new Exception("Unknown type");
116125
}
117126

118-
private int[] dim4(int[] dims) throws Exception {
127+
private static int[] dim4(int[] dims) throws Exception {
119128

120129
if( dims == null ) {
121130
throw new Exception("Null dimensions object provided");
@@ -280,6 +289,37 @@ public boolean[] getBooleanArray() throws Exception {
280289
}
281290

282291
// Binary operations
292+
293+
public static Array randu(int[] dims, int type) throws Exception {
294+
int[] adims = dim4(dims);
295+
long ref = createRanduArray(adims, type);
296+
if (ref == 0) throw new Exception("Failed to create Array");
297+
298+
Array ret_val = new Array();
299+
ret_val.ref = ref;
300+
return ret_val;
301+
}
302+
303+
public static Array randn(int[] dims, int type) throws Exception {
304+
int[] adims = dim4(dims);
305+
long ref = createRandnArray(adims, type);
306+
if (ref == 0) throw new Exception("Failed to create Array");
307+
308+
Array ret_val = new Array();
309+
ret_val.ref = ref;
310+
return ret_val;
311+
}
312+
313+
public static Array constant(double val, int[] dims, int type) throws Exception {
314+
int[] adims = dim4(dims);
315+
long ref = createConstantsArray(val, adims, type);
316+
if (ref == 0) throw new Exception("Failed to create Array");
317+
318+
Array ret_val = new Array();
319+
ret_val.ref = ref;
320+
return ret_val;
321+
}
322+
283323
public static Array add(Array a, Array b) throws Exception {
284324
Array ret_val = new Array();
285325
ret_val.ref = add(a.ref,b.ref);
@@ -432,11 +472,39 @@ public static Array sqrt(Array a) throws Exception {
432472
}
433473

434474
// Scalar return operations
435-
public static float sum(Array a) throws Exception { return sum(a.ref); }
475+
public static double sumAll(Array a) throws Exception { return sumAll(a.ref); }
476+
public static double maxAll(Array a) throws Exception { return maxAll(a.ref); }
477+
public static double minAll(Array a) throws Exception { return minAll(a.ref); }
478+
479+
public static Array sum(Array a, int dim) throws Exception {
480+
Array ret_val = new Array();
481+
ret_val.ref = sum(a.ref, dim);
482+
return ret_val;
483+
}
484+
485+
public static Array max(Array a, int dim) throws Exception {
486+
Array ret_val = new Array();
487+
ret_val.ref = max(a.ref, dim);
488+
return ret_val;
489+
}
490+
491+
public static Array min(Array a, int dim) throws Exception {
492+
Array ret_val = new Array();
493+
ret_val.ref = min(a.ref, dim);
494+
return ret_val;
495+
}
436496

437-
public static float max(Array a) throws Exception { return max(a.ref); }
497+
public static Array sum(Array a) throws Exception {
498+
return sum(a, -1);
499+
}
438500

439-
public static float min(Array a) throws Exception { return min(a.ref); }
501+
public static Array max(Array a) throws Exception {
502+
return max(a, -1);
503+
}
504+
505+
public static Array min(Array a) throws Exception {
506+
return min(a, -1);
507+
}
440508

441509
// Scalar operations
442510
public static Array add(Array a, float b) throws Exception {

src/java_wrapper.cpp

Lines changed: 100 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,61 @@ JNIEXPORT void JNICALL Java_com_arrayfire_Array_info(JNIEnv *env, jclass clazz)
2626
}
2727
}
2828

29+
JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_createRanduArray(JNIEnv *env, jclass clazz, jintArray dims, jint type)
30+
{
31+
jlong ret;
32+
try{
33+
jint* dimptr = env->GetIntArrayElements(dims,0);
34+
af::dtype ty = (af::dtype)(type);
35+
af::array *A = new af::array(dimptr[0],dimptr[1],dimptr[2],dimptr[3], ty);
36+
*A = af::randu(dimptr[0],dimptr[1],dimptr[2], ty);
37+
ret = (jlong)(A);
38+
env->ReleaseIntArrayElements(dims,dimptr,0);
39+
} catch(af::exception& e) {
40+
ret = 0;
41+
} catch(std::exception& e) {
42+
ret = 0;
43+
}
44+
return ret;
45+
}
46+
47+
JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_createRandnArray(JNIEnv *env, jclass clazz, jintArray dims, jint type)
48+
{
49+
jlong ret;
50+
try{
51+
jint* dimptr = env->GetIntArrayElements(dims,0);
52+
af::dtype ty = (af::dtype)(type);
53+
af::array *A = new af::array(dimptr[0],dimptr[1],dimptr[2],dimptr[3], ty);
54+
*A = af::randn(dimptr[0],dimptr[1],dimptr[2], ty);
55+
ret = (jlong)(A);
56+
env->ReleaseIntArrayElements(dims,dimptr,0);
57+
} catch(af::exception& e) {
58+
ret = 0;
59+
} catch(std::exception& e) {
60+
ret = 0;
61+
}
62+
return ret;
63+
}
64+
65+
66+
JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_createConstantsArray(JNIEnv *env, jclass clazz, jdouble val, jintArray dims, jint type)
67+
{
68+
jlong ret;
69+
try{
70+
jint* dimptr = env->GetIntArrayElements(dims,0);
71+
af::dtype ty = (af::dtype)(type);
72+
af::array *A = new af::array(dimptr[0],dimptr[1],dimptr[2],dimptr[3], ty);
73+
*A = af::constant(val, dimptr[0],dimptr[1],dimptr[2], ty);
74+
ret = (jlong)(A);
75+
env->ReleaseIntArrayElements(dims,dimptr,0);
76+
} catch(af::exception& e) {
77+
ret = 0;
78+
} catch(std::exception& e) {
79+
ret = 0;
80+
}
81+
return ret;
82+
}
83+
2984
JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_createEmptyArray(JNIEnv *env, jclass clazz, jintArray dims, jint type)
3085
{
3186
jlong ret;
@@ -380,24 +435,53 @@ UNARY_OP_DEF(log)
380435
UNARY_OP_DEF(abs)
381436
UNARY_OP_DEF(sqrt)
382437

383-
#define SCALAR_RET_OP_DEF(func) \
384-
JNIEXPORT jfloat JNICALL Java_com_arrayfire_Array_##func(JNIEnv *env, jclass clazz, jlong a) \
385-
{ \
386-
jfloat ret \
387-
try { \
388-
af::array *A = (af::array*)(a); \
389-
ret = af::func<float>( (*A) ); \
390-
} catch(af::exception& e) { \
391-
ret = 0; \
392-
} catch(std::exception& e) { \
393-
ret = 0; \
394-
} \
395-
return ret; \
438+
#define SCALAR_RET_OP_DEF(func) \
439+
JNIEXPORT jdouble JNICALL Java_com_arrayfire_Array_##func##All \
440+
(JNIEnv *env, jclass clazz, jlong a) \
441+
{ \
442+
try { \
443+
af::array *A = (af::array*)(a); \
444+
if (A->type() == af::f32) \
445+
return af::func<float>( (*A) ); \
446+
if (A->type() == af::s32) \
447+
return af::func<int>( (*A) ); \
448+
if (A->type() == af::f64) \
449+
return af::func<double>( (*A) ); \
450+
if (A->type() == af::b8) \
451+
return af::func<float>( (*A) ); \
452+
return af::NaN; \
453+
} catch(af::exception& e) { \
454+
return af::NaN; \
455+
} catch(std::exception& e) { \
456+
return af::NaN; \
457+
} \
458+
}
459+
460+
SCALAR_RET_OP_DEF(sum)
461+
SCALAR_RET_OP_DEF(max)
462+
SCALAR_RET_OP_DEF(min)
463+
464+
#define ARRAY_RET_OP_DEF(func) \
465+
JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_##func \
466+
(JNIEnv *env, jclass clazz, jlong a, jint dim) \
467+
{ \
468+
jlong ret = 0; \
469+
try { \
470+
af::array *A = (af::array*)(a); \
471+
af::array *res = new af::array(); \
472+
*res = af::func((*A), dim); \
473+
ret = (jlong)res; \
474+
} catch(af::exception& e) { \
475+
return 0; \
476+
} catch(std::exception& e) { \
477+
return 0; \
478+
} \
479+
return ret; \
396480
}
397481

398-
SCALAR_RET_OP(sum)
399-
SCALAR_RET_OP(max)
400-
SCALAR_RET_OP(min)
482+
ARRAY_RET_OP_DEF(sum)
483+
ARRAY_RET_OP_DEF(max)
484+
ARRAY_RET_OP_DEF(min)
401485

402486
#define SCALAR_OP1_DEF(func,operation) \
403487
JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_##func(JNIEnv *env, jclass clazz, jlong a, jfloat b) \

src/java_wrapper.h

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ JNIEXPORT void JNICALL Java_com_arrayfire_Array_info(JNIEnv *env, jclass clazz);
1212

1313
// Loader methods
1414
JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_createEmptyArray(JNIEnv *env, jclass clazz, jintArray dims, jint type);
15-
1615
JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_createArrayFromFloat(JNIEnv *env, jclass clazz, jintArray dims, jfloatArray elems);
1716
JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_createArrayFromDouble(JNIEnv *env, jclass clazz, jintArray dims, jdoubleArray elems);
1817
JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_createArrayFromInt(JNIEnv *env, jclass clazz, jintArray dims, jintArray elems);
@@ -22,6 +21,12 @@ JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_createArrayFromDoubleComplex(JN
2221
// Unloader methods
2322
JNIEXPORT void JNICALL Java_com_arrayfire_Array_destroyArray(JNIEnv *env, jclass clazz, jlong ref);
2423

24+
JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_createRanduArray(JNIEnv *env, jclass clazz, jintArray dims, jint type);
25+
JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_createRandnArray(JNIEnv *env, jclass clazz, jintArray dims, jint type);
26+
JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_createConstantsArray(JNIEnv *env, jclass clazz, jdouble val, jintArray dims, jint type);
27+
28+
29+
2530
// Data pull back methods
2631
JNIEXPORT jfloatArray JNICALL Java_com_arrayfire_Array_getFloatFromArray(JNIEnv *env, jclass clazz, jlong ref);
2732
JNIEXPORT jdoubleArray JNICALL Java_com_arrayfire_Array_getDoubleFromArray(JNIEnv *env, jclass clazz, jlong ref);
@@ -70,11 +75,18 @@ UNARY_OP(abs)
7075
UNARY_OP(sqrt)
7176

7277
#define SCALAR_RET_OP(func) \
73-
JNIEXPORT jfloat JNICALL Java_com_arrayfire_Array_##func(JNIEnv *env, jclass clazz, jlong a);
78+
JNIEXPORT jdouble JNICALL Java_com_arrayfire_Array_##func(JNIEnv *env, jclass clazz, jlong a);
79+
80+
SCALAR_RET_OP(sumAll)
81+
SCALAR_RET_OP(maxAll)
82+
SCALAR_RET_OP(minAll)
83+
84+
#define ARRAY_RET_OP(func) \
85+
JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_##func(JNIEnv *env, jclass clazz, jlong a, jint dim);
7486

75-
SCALAR_RET_OP(sum)
76-
SCALAR_RET_OP(max)
77-
SCALAR_RET_OP(min)
87+
ARRAY_RET_OP(sum)
88+
ARRAY_RET_OP(max)
89+
ARRAY_RET_OP(min)
7890

7991
#define SCALAR_OP1(func) \
8092
JNIEXPORT jlong JNICALL Java_com_arrayfire_Array_##func(JNIEnv *env, jclass clazz, jlong a, jfloat b);

0 commit comments

Comments
 (0)
0