8000 Merge pull request #306 from nouiz/fill_diagonal · githubmlai/numpy@a8f1612 · GitHub
[go: up one dir, main page]

Skip to content

Commit a8f1612

Browse files
committed
Merge pull request numpy#306 from nouiz/fill_diagonal
fix the wrapping problem of fill_diagonal with tall matrix.
2 parents 637fa62 + e909e4e commit a8f1612

File tree

2 files changed

+74
-4
lines changed

2 files changed

+74
-4
lines changed

numpy/lib/index_tricks.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -658,9 +658,8 @@ def __getitem__(self, item):
658658
# The following functions complement those in twodim_base, but are
659659
# applicable to N-dimensions.
660660

661-
def fill_diagonal(a, val):
662-
"""
663-
Fill the main diagonal of the given array of any dimensionality.
661+
def fill_diagonal(a, val, wrap=False):
662+
"""Fill the main diagonal of the given array of any dimensionality.
664663
665664
For an array `a` with ``a.ndim > 2``, the diagonal is the list of
666665
locations with indices ``a[i, i, ..., i]`` all identical. This function
@@ -675,6 +674,10 @@ def fill_diagonal(a, val):
675674
Value to be written on the diagonal, its type must be compatible with
676675
that of the array a.
677676
677+
wrap: bool For tall matrices in NumPy version up to 1.6.2, the
678+
diagonal "wrapped" after N columns. You can have this behavior
679+
with this option. This affect only tall matrices.
680+
678681
See also
679682
--------
680683
diag_indices, diag_indices_from
@@ -716,13 +719,42 @@ def fill_diagonal(a, val):
716719
[0, 0, 0],
717720
[0, 0, 4]])
718721
722+
# tall matrices no wrap
723+
>>> a = np.zeros((5, 3),int)
724+
>>> fill_diagonal(a, 4)
725+
array([[4, 0, 0],
726+
[0, 4, 0],
727+
[0, 0, 4],
728+
[0, 0, 0],
729+
[0, 0, 0]])
730+
731+
# tall matrices wrap
732+
>>> a = np.zeros((5, 3),int)
733+
>>> fill_diagonal(a, 4)
734+
array([[4, 0, 0],
735+
[0, 4, 0],
736+
[0, 0, 4],
737+
[0, 0, 0],
738+
[4, 0, 0]])
739+
740+
# wide matrices
741+
>>> a = np.zeros((3, 5),int)
742+
>>> fill_diagonal(a, 4)
743+
array([[4, 0, 0, 0, 0],
744+
[0, 4, 0, 0, 0],
745+
[0, 0, 4, 0, 0]])
746+
719747
"""
720748
if a.ndim < 2:
721749
raise ValueError("array must be at least 2-d")
750+
end = None
722751
if a.ndim == 2:
723752
# Explicit, fast formula for the common case. For 2-d arrays, we
724753
# accept rectangular ones.
725754
step = a.shape[1] + 1
755+
#This is needed to don't have tall matrix have the diagonal wrap.
756+
if not wrap:
757+
end = a.shape[1] * a.shape[1]
726758
else:
727759
# For more than d=2, the strided formula is only valid for arrays with
728760
# all dimensions equal, so we check first.
@@ -731,7 +763,7 @@ def fill_diagonal(a, val):
731763
step = 1 + (cumprod(a.shape[:-1])).sum()
732764

733765
# Write the value out into the diagonal.
734-
a.flat[::step] = val
766+
a.flat[:end:step] = val
735767

736768

737769
def diag_indices(n, ndim=2):

numpy/lib/tests/test_index_tricks.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,44 @@ def test_fill_diagonal():
159159
[0, 5, 0],
160160
[0, 0, 5]]))
161161

162+
#Test tall matrix
163+
a = zeros((10, 3),int)
164+
fill_diagonal(a, 5)
165+
yield (assert_array_equal, a,
166+
array([[5, 0, 0],
167+
[0, 5, 0],
168+
[0, 0, 5],
169+
[0, 0, 0],
170+
[0, 0, 0],
171+
[0, 0, 0],
172+
[0, 0, 0],
173+
[0, 0, 0],
174+
[0, 0, 0],
175+
[0, 0, 0]]))
176+
177+
#Test tall matrix wrap
178+
a = zeros((10, 3),int)
179+
fill_diagonal(a, 5, True)
180+
yield (assert_array_equal, a,
181+
array([[5, 0, 0],
182+
[0, 5, 0],
183+
[0, 0, 5],
184+
[0, 0, 0],
185+
[5, 0, 0],
186+
[0, 5, 0],
187+
[0, 0, 5],
188+
[0, 0, 0],
189+
[5, 0, 0],
190+
[0, 5, 0]]))
191+
192+
#Test wide matrix
193+
a = zeros((3, 10),int)
194+
fill_diagonal(a, 5)
195+
yield (assert_array_equal, a,
196+
array([[5, 0, 0, 0, 0, 0, 0, 0, 0, 0],
197+
[0, 5, 0, 0, 0, 0, 0, 0, 0, 0],
198+
[0, 0, 5, 0, 0, 0, 0, 0, 0, 0]]))
199+
162200
# The same function can operate on a 4-d array:
163201
a = zeros((3, 3, 3, 3), int)
164202
fill_diagonal(a, 4)

0 commit comments

Comments
 (0)
0