@@ -379,83 +379,11 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
379
379
"""
380
380
(This docstring should be overwritten)
381
381
"""
382
- arr = array (arr , copy = False , subok = True )
383
- nd = arr .ndim
384
- axis = normalize_axis_index (axis , nd )
385
- ind = [0 ] * (nd - 1 )
386
- i = np .zeros (nd , 'O' )
387
- indlist = list (range (nd ))
388
- indlist .remove (axis )
389
- i [axis ] = slice (None , None )
390
- outshape = np .asarray (arr .shape ).take (indlist )
391
- i .put (indlist , ind )
392
- j = i .copy ()
393
- res = func1d (arr [tuple (i .tolist ())], * args , ** kwargs )
394
- # if res is a number, then we have a smaller output array
395
- asscalar = np .isscalar (res )
396
- if not asscalar :
397
- try :
398
- len (res )
399
- except TypeError :
400
- asscalar = True
401
- # Note: we shouldn't set the dtype of the output from the first result
402
- # so we force the type to object, and build a list of dtypes. We'll
403
- # just take the largest, to avoid some downcasting
404
- dtypes = []
405
- if asscalar :
406
- dtypes .append (np .asarray (res ).dtype )
407
- outarr = zeros (outshape , object )
408
- outarr [tuple (ind )] = res
409
- Ntot = np .product (outshape )
410
- k = 1
411
- while k < Ntot :
412
- # increment the index
413
- ind [- 1 ] += 1
414
- n = - 1
415
- while (ind [n ] >= outshape [n ]) and (n > (1 - nd )):
416
- ind [n - 1 ] += 1
417
- ind [n ] = 0
418
- n -= 1
419
- i .put (indlist , ind )
420
- res = func1d (arr [tuple (i .tolist ())], * args , ** kwargs )
421
- outarr [tuple (ind )] = res
422
- dtypes .append (asarray (res ).dtype )
423
- k += 1
424
- else :
425
- res = array (res , copy = False , subok = True )
426
- j = i .copy ()
427
- j [axis ] = ([slice (None , None )] * res .ndim )
428
- j .put (indlist , ind )
429
- Ntot = np .product (outshape )
430
- holdshape = outshape
431
- outshape = list (arr .shape )
432
- outshape [axis ] = res .shape
433
- dtypes .append (asarray (res ).dtype )
434
- outshape = flatten_inplace (outshape )
435
- outarr = zeros (outshape , object )
436
- outarr [tuple (flatten_inplace (j .tolist ()))] = res
437
- k = 1
438
- while k < Ntot :
439
- # increment the index
440
- ind [- 1 ] += 1
441
- n = - 1
442
- while (ind [n ] >= holdshape [n ]) and (n > (1 - nd )):
443
- ind [n - 1 ] += 1
444
- ind [n ] = 0
445
- n -= 1
446
- i .put (indlist , ind )
447
- j .put (indlist , ind )
448
- res = func1d (arr [tuple (i .tolist ())], * args , ** kwargs )
449
- outarr [tuple (flatten_inplace (j .tolist ()))] = res
450
- dtypes .append (asarray (res ).dtype )
451
- k += 1
452
- max_dtypes = np .dtype (np .asarray (dtypes ).max ())
453
- if not hasattr (arr , '_mask' ):
454
- result = np .asarray (outarr , dtype = max_dtypes )
455
- else :
456
- result = asarray (outarr , dtype = max_dtypes )
457
- result .fill_value = ma .default_fill_value (result )
458
- return result
382
+ def wrapped_func (a , * args , ** kwargs ):
383
+ res = func1d (a , * args , ** kwargs )
384
+ return np .asanyarray (res ).view (masked_array )
385
+
386
+ return np .apply_along_axis (wrapped_func , axis , arr , * args , ** kwargs )
459
387
apply_along_axis .__doc__ = np .apply_along_axis .__doc__
460
388
461
389
0 commit comments