8000 Merge pull request #24701 from charris/backport-23282 · numpy/numpy@2e84b15 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2e84b15

Browse files
authored
Merge pull request #24701 from charris/backport-23282
BUG: Fix data stmt handling for complex values in f2py
2 parents 5a8d648 + cb3ffca commit 2e84b15

File tree

4 files changed

+79
-13
lines changed

numpy/f2py/auxfuncs.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
'applyrules', 'debugcapi', 'dictappend', 'errmess', 'gentitle',
2929
'getargs2', 'getcallprotoargument', 'getcallstatement',
3030
'getfortranname', 'getpymethoddef', 'getrestdoc', 'getusercode',
31-
'getusercode1', 'hasbody', 'hascallstatement', 'hascommon',
31+
'getusercode1', 'getdimension', 'hasbody', 'hascallstatement', 'hascommon',
3232
'hasexternals', 'hasinitvalue', 'hasnote', 'hasresultnote',
3333
'isallocatable', 'isarray', 'isarrayofstrings',
3434
'ischaracter', 'ischaracterarray', 'ischaracter_or_characterarray',
@@ -420,6 +420,13 @@ def isexternal(var):
420420
return 'attrspec' in var and 'external' in var['attrspec']
421421

422422

423+
def getdimension(var):
424+
dimpattern = r"\((.*?)\)"
425+
if 'attrspec' in var.keys():
426+
if any('dimension' in s for s in var['attrspec']):
427+
return [re.findall(dimpattern, v) for v in var['attrspec']][0]
428+
429+
423430
def isrequired(var):
424431
return not isoptional(var) and isintent_nothide(var)
425432

numpy/f2py/crackfortran.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,10 +1437,10 @@ def analyzeline(m, case, line):
14371437
outmess(
14381438
'analyzeline: implied-DO list "%s" is not supported. Skipping.\n' % l[0])
14391439
continue
1440-
i = 0
1441-
j = 0
14421440
llen = len(l[1])
1443-
for v in rmbadname([x.strip() for x in markoutercomma(l[0]).split('@,@')]):
1441+
for idx, v in enumerate(rmbadname(
1442+
[x.strip() for x in markoutercomma(l[0]).split('@,@')])
1443+
):
14441444
if v[0] == '(':
14451445
outmess(
14461446
'analyzeline: implied-DO list "%s" is not supported. Skipping.\n' % v)
@@ -1449,18 +1449,26 @@ def analyzeline(m, case, line):
14491449
# wrapping.
14501450
continue
14511451
fc = 0
1452-
while (i < llen) and (fc or not l[1][i] == ','):
1453-
if l[1][i] == "'":
1454-
fc = not fc
1455-
i = i + 1
1456-
i = i + 1
1452+
vtype = vars[v].get('typespec')
1453+
vdim = getdimension(vars[v])
1454+
1455+
if (vtype == 'complex'):
1456+
cmplxpat = r"\(.*?\)"
1457+
matches = re.findall(cmplxpat, l[1])
1458+
else:
1459+
matches = l[1].split(',')
1460+
14571461
if v not in vars:
14581462
vars[v] = {}
1459-
if '=' in vars[v] and not vars[v]['='] == l[1][j:i - 1]:
1463+
if '=' in vars[v] and not vars[v]['='] == matches[idx]:
14601464
outmess('analyzeline: changing init expression of "%s" ("%s") to "%s"\n' % (
1461-
v, vars[v]['='], l[1][j:i - 1]))
1462-
vars[v]['='] = l[1][j:i - 1]
1463-
j = i
1465+
v, vars[v]['='], matches[idx]))
1466+
1467+
if vdim is not None:
1468+
# Need to assign multiple values to one variable
1469+
vars[v]['='] = "(/{}/)".format(", ".join(matches))
1470+
else:
1471+
vars[v]['='] = matches[idx]
14641472
last_name = v
14651473
groupcache[groupcounter]['vars'] = vars
14661474
if last_name is not None:
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
! gh-23276
2+
module cmplxdat
3+
implicit none
4+
integer :: i, j
5+
real :: x, y
6+
real, dimension(2) :: z
7+
complex(kind=8), target :: medium_ref_index
8+
complex(kind=8), target :: ref_index_one, ref_index_two
9+
complex(kind=8), dimension(2) :: my_array
10+
real(kind=8), dimension(3) :: my_real_array = (/1.0d0, 2.0d0, 3.0d0/)
11+
12+
data i, j / 2, 3 /
13+
data x, y / 1.5, 2.0 /
14+
data z / 3.5, 7.0 /
15+
data medium_ref_index / (1.d0, 0.d0) /
16+
data ref_index_one, ref_index_two / (13.0d0, 21.0d0), (-30.0d0, 43.0d0) /
17+
data my_array / (1.0d0, 2.0d0), (-3.0d0, 4.0d0) /
18+
end module cmplxdat

numpy/f2py/tests/test_data.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
import pytest
3+
import numpy as np
4+
5+
from . import util
6+
from numpy.f2py.crackfortran import crackfortran
7+
8+
9+
class TestData(util.F2PyTest):
10+
sources = [util.getpath("tests", "src", "crackfortran", "data_stmts.f90")]
11+
12+
# For gh-23276
13+
def test_data_stmts(self):
14+
assert self.module.cmplxdat.i == 2
15+
assert self.module.cmplxdat.j == 3
16+
assert self.module.cmplxdat.x == 1.5
17+
assert self.module.cmplxdat.y == 2.0
18+
assert self.module.cmplxdat.medium_ref_index == np.array(1.+0.j)
19+
assert np.all(self.module.cmplxdat.z == np.array([3.5, 7.0]))
20+
assert np.all(self.module.cmplxdat.my_array == np.array([ 1.+2.j, -3.+4.j]))
21+
assert np.all(self.module.cmplxdat.my_real_array == np.array([ 1., 2., 3.]))
22+
assert np.all(self.module.cmplxdat.ref_index_one == np.array([13.0 + 21.0j]))
23+
assert np.all(self.module.cmplxdat.ref_index_two == np.array([-30.0 + 43.0j]))
24+
25+
def test_crackedlines(self):
26+
mod = 6118 crackfortran(self.sources)
27+
assert mod[0]['vars']['x']['='] == '1.5'
28+
assert mod[0]['vars']['y']['='] == '2.0'
29+
assert mod[0]['vars']['my_real_array']['='] == '(/1.0d0, 2.0d0, 3.0d0/)'
30+
assert mod[0]['vars']['ref_index_one']['='] == '(13.0d0, 21.0d0)'
31+
assert mod[0]['vars']['ref_index_two']['='] == '(-30.0d0, 43.0d0)'
32+
assert mod[0]['vars']['my_array']['='] == '(/(1.0d0, 2.0d0), (-3.0d0, 4.0d0)/)'
33+
assert mod[0]['vars']['z']['='] == '(/3.5, 7.0/)'

0 commit comments

Comments
 (0)
0