8000 Adding better indexing support · arrayfire/arrayfire-fortran@def9526 · GitHub
[go: up one dir, main page]

Skip to content

Commit def9526

Browse files
committed
Adding better indexing support
REF: #2
1 parent 01d17f0 commit def9526

File tree

3 files changed

+357
-30
lines changed

3 files changed

+357
-30
lines changed

examples/indexing.f95

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ program template
88
! 1D indexing
99
A1 = randu(5, 1)
1010
A2 = constant(0, 5, 1)
11-
tmp = get(A1, seq(3,5)) ! Get elements 3 through 5
12-
call set(A2, tmp, seq(1,3)) ! Set elements 1 through 3 with values from tmp
11+
tmp = get(A1, (/3,5/)) ! Get elements 3 through 5
12+
call set(A2, tmp, (/1,3/)) ! Set elements 1 through 3 with values from tmp
1313
call print(A1,"A1")
1414
call print(tmp, "tmp")
1515
call print(A2,"A2")
@@ -30,8 +30,8 @@ program template
3030
A2 = constant(1,3,3,2)
3131
I1 = (/ 1, 3 /)
3232
I2 = (/ 2, 3 /)
33-
tmp = get(A1, idx(I1), seq(1,3,2), 1) ! Get rows 1 and 3 for columns 1 and 3, tile 1
34-
call set(A2, tmp, idx(I2), seq(1,2), 2) ! Set rows 2 and 3 for columns 1 and 2, tile 2 with tmp
33+
tmp = get(A1, idx(I1), (/1,3,2/), (/1/)) ! Get rows 1 and 3 for columns 1 and 3, tile 1
34+
call set(A2, tmp, idx(I2), (/1,2/), (/2/)) ! Set rows 2 and 3 for columns 1 and 2, tile 2 with tmp
3535
call print(A1, "A1")
3636
call print(tmp, "tmp")
3737
call print(A2, "A2")

src/arrayfire.f95

Lines changed: 206 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ module arrayfire
154154
!> @param[in] d4 integer denoting the index of the 4th dimension. Optional.
155155
!> @returns subarry of in referenced by d1,d2,d3,d4
156156
interface get
157-
module procedure array_get
157+
module procedure array_get, array_get2, array_get_seq
158158
end interface get
159159
!> @}
160160

@@ -166,7 +166,7 @@ module arrayfire
166166
!> @param[in] d3 integer denoting the index of the 3rd dimension. Optional.
167167
!> @param[in] d4 integer denoting the index of the 4th dimension. Optional.
168168
interface set
169-
module procedure array_set
169+
module procedure array_set, array_set2, array_set_seq
170170
end interface set
171171
!> @}
172172

@@ -1216,6 +1216,38 @@ function elements(A) result(num)
12161216
num = product(A%shape)
12171217
end function elements
12181218

1219+
function safeidx(d) result(idx)
1220+
integer, dimension(:), intent(in) :: d
1221+
integer, dimension(3) :: idx
1222+
integer, allocatable, dimension(:) :: S
1223+
integer :: f
1224+
integer :: l
1225+
integer :: st
1226+
1227+
S = shape(d)
1228+
1229+
if (S(1) == 1) then
1230+
f = d(1)
1231+
l = d(1)
1232+
st = 1
1233+
end if
1234+
1235+
if (S(1) == 2) then
1236+
f = d(1)
1237+
l = d(2)
1238+
st = 1
1239+
end if
1240+
1241+
if (S(1) == 3) then
1242+
f = d(1)
1243+
l = d(2)
1244+
st = d(3)
1245+
end if
1246+
1247+
idx = (/ f-1, l-1, st /)
1248+
1249+
end function safeidx
1250+
12191251
subroutine init_1d(A, S)
12201252
type(array), intent(inout) :: A
12211253
integer, intent(in) :: S(1)
@@ -1302,43 +1334,199 @@ function array_get(in, d1, d2, d3, d4) result(R)
13021334
type(array), intent(in) :: in
13031335
type(array), intent(in) :: d1
13041336
type(array), intent(in), optional :: d2
1305-
integer, intent(in), optional :: d3
1306-
integer, intent(in), optional :: d4
1337+
integer, dimension(:), intent(in), optional :: d3
1338+
integer, dimension(:), intent(in), optional :: d4
1339+
integer :: dims
1340+
13071341
type(array) :: R
13081342
type(C_ptr) :: idx1 = C_NULL_ptr
13091343
type(C_ptr) :: idx2 = C_NULL_ptr
1310-
integer :: idx3 = 1
1311-
integer :: idx4 = 1
1344+
integer, dimension(3) :: idx3
1345+
integer, dimension(3) :: idx4
13121346

13131347
idx1 = d1%ptr
1314-
if (present(d2)) idx2 = d2%ptr
1315-
if (present(d3)) idx3 = d3
1316-
if (present(d4)) idx4 = d4
1348+
dims = 1
1349+
1350+
if (present(d2)) then
1351+
idx2 = d2%ptr
1352+
dims = 2
1353+
end if
1354+
1355+
if (present(d3)) then
1356+
idx3 = safeidx(d3)
F438
1357+
dims = 3
1358+
end if
1359+
1360+
if (present(d4)) then
1361+
idx4 = safeidx(d4)
1362+
dims = 4
1363+
end if
13171364

1318-
call af_arr_get(R%ptr, in%ptr, idx1, idx2, idx3, idx4, err)
1365+
call af_arr_get(R%ptr, in%ptr, idx1, idx2, idx3, idx4, dims, err)
13191366
call init_post(R%ptr, R%shape, R%rank)
13201367

13211368
end function array_get
13221369

1370+
function array_get2(in, d1, d2, d3) result(R)
1371+
type(array), intent(in) :: in
1372+
type(array), intent(in) :: d1
1373+
integer, dimension(:), intent(in) :: d2
1374+
integer, dimension(:), intent(in), optional :: d3
1375+
integer :: dims
1376+
1377+
type(array) :: R
1378+
type(C_ptr) :: idx1 = C_NULL_ptr
1379+
integer, dimension(3) :: idx2
1380+
integer, dimension(3) :: idx3
1381+
1382+
idx1 = d1%ptr
1383+
idx2 = safeidx(d2)
1384+
dims = 2
1385+
1386+
if (present(d3)) then
1387+
idx3 = safeidx(d3)
1388+
dims = 3
1389+
end if
1390+
1391+
call af_arr_get2(R%ptr, in%ptr, idx1, idx2, idx3, dims, err)
1392+
call init_post(R%ptr, R%shape, R%rank)
1393+
1394+
end function array_get2
1395+
1396+
function array_get_seq(in, d1, d2, d3, d4) result(R)
1397+
type(array), intent(in) :: in
1398+
integer, intent(in) :: d1(:)
1399+
integer, intent(in), optional :: d2(:)
1400+
integer, intent(in), optional :: d3(:)
1401+
integer, intent(in), optional :: d4(:)
1402+
type(array) :: R
1403+
1404+
integer, dimension(3) :: idx1
1405+
integer, dimension(3) :: idx2
1406+
integer, dimension(3) :: idx3
1407+
integer, dimension(3) :: idx4
1408+
integer :: dims = 1
1409+
1410+
idx1 = safeidx(d1)
1411+
idx2 = safeidx(d1)
1412+
idx3 = safeidx(d1)
1413+
idx4 = safeidx(d1)
1414+
1415+
if (present(d2)) then
1416+
idx2 = safeidx(d2)
1417+
dims = 2
1418+
end if
1419+
1420+
if (present(d3)) then
1421+
idx3 = safeidx(d3)
1422+
dims = 3
1423+
end if
1424+
1425+
if (present(d4)) then
1426+
idx4 = safeidx(d4)
1427+
dims = 4
1428+
end if
1429+
1430+
call af_arr_get_seq(R%ptr, in%ptr, idx1, idx2, idx3, idx4, dims, err)
1431+
call init_post(R%ptr, R%shape, R%rank)
1432+
end function array_get_seq
1433+
13231434
subroutine array_set(lhs, rhs, d1, d2, d3, d4)
13241435
type(array), intent(in) :: lhs
13251436
type(array), intent(inout) :: rhs
13261437
type(array), intent(in) :: d1
13271438
type(array), intent(in), optional :: d2
1328-
integer, intent(in), optional :: d3
1329-
integer, intent(in), optional :: d4
1439+
integer, dimension(:), intent(in), optional :: d3
1440+
integer, dimension(:), intent(in), optional :: d4
1441+
13301442
type(C_ptr) :: idx1 = C_NULL_ptr
13311443
type(C_ptr) :: idx2 = C_NULL_ptr
1332-
integer :: idx3 = 1
1333-
integer :: idx4 = 1
1444+
integer, dimension(3) :: idx3
1445+
integer, dimension(3) :: idx4
1446+
integer :: dims
13341447

13351448
idx1 = d1%ptr
1336-
if (present(d2)) idx2 = d2%ptr
1337-
if (present(d3)) idx3 = d3
1338-
if (present(d4)) idx4 = d4
1449+
dims = 1
1450+
1451+
if (present(d2)) then
1452+
idx2 = d2%ptr
1453+
dims = 2
1454+
end if
13391455

1340-
call af_arr_set(lhs%ptr, rhs%ptr, idx1, idx2, idx3, idx4, err)
1456+
if (present(d3)) then
1457+
idx3 = safeidx(d3)
1458+
dims = 3
1459+
end if
1460+
1461+
if (present(d4)) then
1462+
idx4 = safeidx(d4)
1463+
dims = 4
1464+
end if
1465+
1466+
call af_arr_set(lhs%ptr, rhs%ptr, idx1, idx2, idx3, idx4, dims, err)
13411467
end subroutine array_set
1468+
1469+
subroutine array_set2(lhs, rhs, d1, d2, d3)
1470+
type(array), intent(in) :: lhs
1471+
type(array), intent(inout) :: rhs
1472+
type(array), intent(in) :: d1
1473+
integer, dimension(:), intent(in) :: d2
1474+
integer, dimension(:), intent(in), optional :: d3
1475+
1476+
type(C_ptr) :: idx1 = C_NULL_ptr
1477+
integer, dimension(3) :: idx2
1478+
integer, dimension(3) :: idx3
1479+
integer :: dims
1480+
1481+
idx1 = d1%ptr
1482+
idx2 = safeidx(d2)
1483+
dims = 2
1484+
1485+
if (present(d3)) then
1486+
idx3 = safeidx(d3)
1487+
dims = 3
1488+
end if
1489+
1490+
call af_arr_set2(lhs%ptr, rhs%ptr, idx1, idx2, idx3, dims, err)
1491+
end subroutine array_set2
1492+
1493+
subroutine array_set_seq(R, in, d1, d2, d3, d4)
1494+
type(array), intent(in) :: in
1495+
integer, intent(in) :: d1(:)
1496+
integer, intent(in), optional :: d2(:)
1497+
integer, intent(in), optional :: d3(:)
1498+
integer, intent(in), optional :: d4(:)
1499+
type(array), intent(inout) :: R
1500+
1501+
integer, dimension(3) :: idx1
1502+
integer, dimension(3) :: idx2
1503+
integer, dimension(3) :: idx3
1504+
integer, dimension(3) :: idx4
1505+
integer :: dims = 1
1506+
1507+
idx1 = safeidx(d1)
1508+
idx2 = safeidx(d1)
1509+
idx3 = safeidx(d1)
1510+
idx4 = safeidx(d1)
1511+
1512+
if (present(d2)) then
1513+
idx2 = safeidx(d2)
1514+
dims = 2
1515+
end if
1516+
1517+
if (present(d3)) then
1518+
idx3 = safeidx(d3)
1519+
dims = 3
1520+
end if
1521+
1522+
if (present(d4)) then
1523+
idx4 = safeidx(d4)
1524+
dims = 4
1525+
end if
1526+
1527+
call af_arr_set_seq(R%ptr, in%ptr, idx1, idx2, idx3, idx4, dims, err)
1528+
call init_post(R%ptr, R%shape, R%rank)
1529+
end subroutine array_set_seq
13421530

13431531
!> Assigns data to array
13441532
subroutine assign(L, R)

0 commit comments

Comments
 (0)
0