8000 Ensure ufunc.at, reduce, reduceat, accumulate, outer work · astropy/astropy@69f42ce · GitHub
[go: up one dir, main page]

Skip to content

Commit 69f42ce

Browse files
committed
Ensure ufunc.at, reduce, reduceat, accumulate, outer work
1 parent c5bc1c7 commit 69f42ce

File tree

2 files changed

+251
-11
lines changed

2 files changed

+251
-11
lines changed

astropy/units/quantity_helper.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def scales_and_result_unit(function, method, *args):
366366
raise TypeError("Cannot use function '{0}' with quantities"
367367
.format(function.__name__))
368368

369-
if method == '__call__':
369+
if method == '__call__' or (method == 'outer' and function.nin == 2):
370370
# Find out the units of the arguments passed to the ufunc; usually,
371371
# at least one is a quantity, but for two-argument ufuncs, the second
372372
# could also be a Numpy array, etc. These are given unit=None.
@@ -397,7 +397,7 @@ def scales_and_result_unit(function, method, *args):
397397
raise UnitsError("Can only apply '{0}' function to "
398398
"dimensionless quantities when other "
399399
"argument is not a quantity (unless the "
400-
"latter is all zero/infinity/nan)"
400+
"latter is all zero/infinity/nan)."
401401
.format(function.__name__))
402402

403403
# In the case of np.power, the unit itself needs to be modified by an
@@ -413,16 +413,53 @@ def scales_and_result_unit(function, method, *args):
413413

414414
result_unit = result_unit ** validate_power(p)
415415

416-
else:
417-
# methods other than __call__, e.g., reduce, accumulate; these make
418-
# sense only for two-argument functions that leave the unit intact.
419-
if UFUNC_HELPERS[function] is helper_twoarg_invariant:
420-
scales = [1.]
421-
result_unit = getattr(args[0], 'unit', None)
416+
else: # methods for which the unit should stay the same
417+
if method == 'at':
418+
unit = getattr(args[0], 'unit', None)
419+
units = [unit]
420+
if function.nin == 2:
421+
units.append(getattr(args[2], 'unit', None))
422+
423+
scales, result_unit = UFUNC_HELPERS[function](function, *units)
424+
425+
# add 'scale' for indices (2nd argument)
426+
if function.nin == 1:
427+
scales += [None]
428+
else:
429+
scales = [scales[0], None, scales[1]]
430+
431+
elif (method in ('reduce', 'accumulate', 'reduceat') and
432+
function.nin == 2):
433+
unit = getattr(args[0], 'unit', None)
434+
scales, result_unit = UFUNC_HELPERS[function](function, unit, unit)
435+
if method == 'reduceat':
436+
# add 'scale' for indices (2nd argument)
437+
scales = [scales[0], None]
438+
else:
439+
scales = [scales[0]]
440+
422441
else:
423-
raise TypeError("Unknown ufunc {0} for method {1}. "
424-
"If this should work, please raise an issue on "
442+
if method in ('reduce', 'accumulate', 'reduceat',
443+
'outer') and function.nin != 2:
444+
raise ValueError("{0} only supported for binary functions"
445+
.format(method))
446+
447+
raise TypeError("Unexpected ufunc method {0}. If this should "
448+
"work, please raise an issue on"
425449
"https://github.com/astropy/astropy"
426-
.format(function.__name__, method))
450+
.format(method))
451+
452+
# for all but __call__ method, scaling is not allowed
453+
if unit is not None and result_unit is None:
454+
raise TypeError("Cannot use '{1}' method on ufunc {0} with a "
455+
"Quantity instance as the result is not a "
456+
"Quantity.".format(function.__name__, method))
457+
458+
if scales[0] != 1. or (unit is not None and
459+
(not result_unit.is_equivalent(unit) or
460+
result_unit.to(unit) != 1.)):
461+
raise UnitsError("Cannot use '{1}' method on ufunc {0} with a "
462+
"Quantity instance as it would change the unit."
463+
.format(function.__name__, method))
427464

428465
return scales, result_unit

astropy/units/tests/test_quantity_ufuncs.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,3 +640,206 @@ def test_ufunc_inplace_non_contiguous_data(self):
640640
s2 += 1. * u.cm
641641
assert np.all(s[::2] > s_copy[::2])
642642
assert np.all(s[1::2] == s_copy[1::2])
643+
644+
645+
@pytest.mark.xfail("NUMPY_LT_1P10")
646+
class TestUfuncAt(object):
647+
"""Test that 'at' method for ufuncs (calculates in-place at given indices)
648+
649+
For Quantities, since calculations are in-place, it makes sense only
650+
if the result is still a quantity, and if the unit does not have to change
651+
"""
652+
def test_one_argument_ufunc_at(self):
653+
q = np.arange(10.) * u.m
654+
i = np.array([1, 2])
655+
qv = q.value.copy()
656+
np.negative.at(q, i)
657+
np.negative.at(qv, i)
658+
assert np.all(q.value == qv)
659+
assert q.unit is u.m
660+
661+
# cannot change from quantity to bool array
662+
with pytest.raises(TypeError):
663+
np.isfinite.at(q, i)
664+
665+
# for selective in-place, cannot change the unit
666+
with pytest.raises(u.UnitsError):
667+
np.square.at(q, i)
668+
669+
# except if the unit does not change (i.e., dimensionless)
670+
d = np.arange(10.) * u.dimensionless_unscaled
671+
dv = d.value.copy()
672+
np.square.at(d, i)
673+
np.square.at(dv, i)
674+
assert np.all(d.value == dv)
675+
assert d.unit is u.dimensionless_unscaled
676+
677+
d = np.arange(10.) * u.dimensionless_unscaled
678+
dv = d.value.copy()
679+
np.log.at(d, i)
680+
np.log.at(dv, i)
681+
assert np.all(d.value == dv)
682+
assert d.unit is u.dimensionless_unscaled
683+
684+
# also for sine it doesn't work, even if given an angle
685+
a = np.arange(10.) * u.radian
686+
with pytest.raises(u.UnitsError):
687+
np.sin.at(a, i)
688+
689+
# except, for consistency, if we have made radian equivalent to
690+
# dimensionless (though hopefully it will never be needed)
691+
av = a.value.copy()
692+
with u.add_enabled_equivalencies(u.dimensionless_angles()):
693+
np.sin.at(a, i)
694+
np.sin.at(av, i)
695+
assert_allclose(a.value, av)
696+
697+
# but we won't do double conversion
698+
ad = np.arange(10.) * u.degree
699+
with pytest.raises(u.UnitsError):
700+
np.sin.at(ad, i)
701+
702+
def test_two_argument_ufunc_at(self):
703+
s = np.arange(10.) * u.m
704+
i = np.array([1, 2])
705+
check = s.value.copy()
706+
np.add.at(s, i, 1.*u.km)
707+
np.add.at(check, i, 1000.)
708+
assert np.all(s.value == check)
709+
assert s.unit is u.m
710+
711+
with pytest.raises(u.UnitsError):
712+
np.add.at(s, i, 1.*u.s)
713+
714+
# also raise UnitsError if unit would have to be changed
715+
with pytest.raises(u.UnitsError):
716+
np.multiply.at(s, i, 1*u.s)
717+
718+
# but be fine if it does not
719+
s = np.arange(10.) * u.m
720+
check = s.value.copy()
721+
np.multiply.at(s, i, 2.*u.dimensionless_unscaled)
722+
np.multiply.at(check, i, 2)
723+
assert np.all(s.value == check)
724+
s = np.arange(10.) * u.m
725+
np.multiply.at(s, i, 2.)
726+
assert np.all(s.value == check)
727+
728+
# of course cannot change class of data either
729+
with pytest.raises(TypeError):
730+
np.greater.at(s, i, 1.*u.km)
731+
732+
733+
@pytest.mark.xfail("NUMPY_LT_1P10")
734+
class TestUfuncReduceReduceatAccumulate(object):
735+
"""Test 'reduce', 'reduceat' and 'accumulate' methods for ufuncs
736+
737+
For Quantities, it makes sense only if the unit does not have to change
738+
"""
739+
def test_one_argument_ufunc_reduce_accumulate(self):
740+
# one argument cannot be used
741+
s = np.arange(10.) * u.radian
742+
i = np.array([0, 5, 1, 6])
743+
with pytest.raises(ValueError):
744+
np.sin.reduce(s)
745+
with pytest.raises(ValueError):
746+
np.sin.accumulate(s)
747+
with pytest.raises(ValueError):
748+
np.sin.reduceat(s, i)
749+
750+
def test_two_argument_ufunc_reduce_accumulate(self):
751+
s = np.arange(10.) * u.m
752+
i = np.array([0, 5, 1, 6])
753+
check = s.value.copy()
754+
s_add_reduce = np.add.reduce(s)
755+
check_add_reduce = np.add.reduce(check)
756+
assert s_add_reduce.value == check_add_reduce
757+
assert s_add_reduce.unit is u.m
758+
759+
s_add_accumulate = np.add.accumulate(s)
760+
check_add_accumulate = np.add.accumulate(check)
761+
assert np.all(s_add_accumulate.value == check_add_accumulate)
762+
assert s_add_accumulate.unit is u.m
763+
764+
s_add_reduceat = np.add.reduceat(s, i)
765+
check_add_reduceat = np.add.reduceat(check, i)
766+
assert np.all(s_add_reduceat.value == check_add_reduceat)
767+
assert s_add_reduceat.unit is u.m
768+
769+
# reduce(at) or accumulate on comparisons makes no sense,
770+
# as intermediate result is not even a Quantity
771+
with pytest.raises(TypeError):
772+
np.greater.reduce(s)
773+
774+
with pytest.raises(TypeError):
775+
np.greater.accumulate(s)
776+
777+
with pytest.raises(TypeError):
778+
np.greater.reduceat(s, i)
779+
780+
# raise UnitsError if unit would have to be changed
781+
with pytest.raises(u.UnitsError):
782+
np.multiply.reduce(s)
783+
784+
with pytest.raises(u.UnitsError):
785+
np.multiply.accumulate(s)
786+
787+
with pytest.raises(u.UnitsError):
788+
np.multiply.reduceat(s)
789+
790+
# but be fine if it does not
791+
s = np.arange(10.) * u.dimensionless_unscaled
792+
check = s.value.copy()
793+
s_multiply_reduce = np.multiply.reduce(s)
794+
check_multiply_reduce = np.multiply.reduce(check)
795+
assert s_multiply_reduce.value == check_multiply_reduce
796+
assert s_multiply_reduce.unit is u.dimensionless_unscaled
797+
s_multiply_accumulate = np.multiply.accumulate(s)
798+
check_multiply_accumulate = np.multiply.accumulate(check)
799+
assert np.all(s_multiply_accumulate.value == check_multiply_accumulate)
800+
assert s_multiply_accumulate.unit is u.dimensionless_unscaled
801+
s_multiply_reduceat = np.multiply.reduceat(s, i)
802+
check_multiply_reduceat = np.multiply.reduceat(check, i)
803+
assert np.all(s_multiply_reduceat.value == check_multiply_reduceat)
804+
assert s_multiply_reduceat.unit is u.dimensionless_unscaled
805+
806+
807+
@pytest.mark.xfail("NUMPY_LT_1P10")
808+
class TestUfuncOuter(object):
809+
"""Test 'outer' methods for ufuncs
810+
811+
Just a few spot checks, since it uses the same code as the regular
812+
ufunc call
813+
"""
814+
def test_one_argument_ufunc_outer(self):
815+
# one argument cannot be used
816+
s = np.arange(10.) * u.radian
817+
with pytest.raises(ValueError):
818+
np.sin.outer(s)
819+
820+
def test_two_argument_ufunc_outer(self):
821+
s1 = np.arange(10.) * u.m
822+
s2 = np.arange(2.) * u.s
823+
check1 = s1.value
824+
check2 = s2.value
825+
s12_multiply_outer = np.multiply.outer(s1, s2)
826+
check12_multiply_outer = np.multiply.outer(check1, check2)
827+
assert np.all(s12_multiply_outer.value == check12_multiply_outer)
828+
assert s12_multiply_outer.unit == s1.unit * s2.unit
829+
830+
# raise UnitsError if appropriate
831+
with pytest.raises(u.UnitsError):
832+
np.add.outer(s1, s2)
833+
834+
# but be fine if it does not
835+
s3 = np.arange(2.) * s1.unit
836+
check3 = s3.value
837+
s13_add_outer = np.add.outer(s1, s3)
838+
check13_add_outer = np.add.outer(check1, check3)
839+
assert np.all(s13_add_outer.value == check13_add_outer)
840+
assert s13_add_outer.unit is s1.unit
841+
842+
s13_greater_outer = np.greater.outer(s1, s3)
843+
check13_greater_outer = np.greater.outer(check1, check3)
844+
assert type(s13_greater_outer) is np.ndarray
845+
assert np.all(s13_greater_outer == check13_greater_outer)

0 commit comments

Comments
 (0)
0