@@ -2065,6 +2065,69 @@ def test_roll_empty(self):
2065
2065
x = np .array ([])
2066
2066
assert_equal (np .roll (x , 1 ), np .array ([]))
2067
2067
2068
+
2069
+ class TestRollaxis (TestCase ):
2070
+
2071
+ # expected shape indexed by (axis, start) for array of
2072
+ # shape (1, 2, 3, 4)
2073
+ tgtshape = {(0 , 0 ): (1 , 2 , 3 , 4 ), (0 , 1 ): (1 , 2 , 3 , 4 ),
2074
+ (0 , 2 ): (2 , 1 , 3 , 4 ), (0 , 3 ): (2 , 3 , 1 , 4 ),
2075
+ (0 , 4 ): (2 , 3 , 4 , 1 ),
2076
+ (1 , 0 ): (2 , 1 , 3 , 4 ), (1 , 1 ): (1 , 2 , 3 , 4 ),
2077
+ (1 , 2 ): (1 , 2 , 3 , 4 ), (1 , 3 ): (1 , 3 , 2 , 4 ),
2078
+ (1 , 4 ): (1 , 3 , 4 , 2 ),
2079
+ (2 , 0 ): (3 , 1 , 2 , 4 ), (2 , 1 ): (1 , 3 , 2 , 4 ),
2080
+ (2 , 2 ): (1 , 2 , 3 , 4 ), (2 , 3 ): (1 , 2 , 3 , 4 ),
2081
+ (2 , 4 ): (1 , 2 , 4 , 3 ),
2082
+ (3 , 0 ): (4 , 1 , 2 , 3 ), (3 , 1 ): (1 , 4 , 2 , 3 ),
2083
+ (3 , 2 ): (1 , 2 , 4 , 3 ), (3 , 3 ): (1 , 2 , 3 , 4 ),
2084
+ (3 , 4 ): (1 , 2 , 3 , 4 )}
2085
+
2086
+ def test_exceptions (self ):
2087
+ a = arange (1 * 2 * 3 * 4 ).reshape (1 , 2 , 3 , 4 )
2088
+ assert_raises (ValueError , rollaxis , a , - 5 , 0 )
2089
+ assert_raises (ValueError , rollaxis , a , 0 , - 5 )
2090
+ assert_raises (ValueError , rollaxis , a , 4 , 0 )
2091
+ assert_raises (ValueError , rollaxis , a , 0 , 5 )
2092
+
2093
+ def test_results (self ):
2094
+ a = arange (1 * 2 * 3 * 4 ).reshape (1 , 2 , 3 , 4 ).copy ()
2095
+ aind = np .indices (a .shape )
2096
+ assert_ (a .flags ['OWNDATA' ])
2097
+ for (i , j ) in self .tgtshape :
2098
+ # positive axis, positive start
2099
+ res = rollaxis (a , axis = i , start = j )
2100
+ i0 , i1 , i2 , i3 = aind [np .array (res .shape ) - 1 ]
2101
+ assert_ (np .all (res [i0 , i1 , i2 , i3 ] == a ))
2102
+ assert_ (res .shape == self .tgtshape [(i , j )], str ((i ,j )))
2103
+ assert_ (not res .flags ['OWNDATA' ])
2104
+
2105
+ # negative axis, positive start
2106
+ ip = i + 1
2107
+ res = rollaxis (a , axis = - ip , start = j )
2108
+ i0 , i1 , i2 , i3 = aind [np .array (res .shape ) - 1 ]
2109
+ assert_ (np .all (res [i0 , i1 , i2 , i3 ] == a ))
2110
+ assert_ (res .shape == self .tgtshape [(4 - ip , j )])
2111
+ assert_ (not res .flags ['OWNDATA' ])
2112
+
2113
+ # positive axis, negative start
2114
+ jp = j + 1 if j < 4 else j
2115
+ res = rollaxis (a , axis = i , start = - jp )
2116
+ i0 , i1 , i2 , i3 = aind [np .array (res .shape ) - 1 ]
2117
+ assert_ (np .all (res [i0 , i1 , i2 , i3 ] == a ))
2118
+ assert_ (res .shape == self .tgtshape [(i , 4 - jp )])
2119
+ assert_ (not res .flags ['OWNDATA' ])
2120
+
2121
+ # negative axis, negative start
2122
+ ip = i + 1
2123
+ jp = j + 1 if j < 4 else j
2124
+ res = rollaxis (a , axis = - ip , start = - jp )
2125
+ i0 , i1 , i2 , i3 = aind [np .array (res .shape ) - 1 ]
2126
+ assert_ (np .all (res [i0 , i1 , i2 , i3 ] == a ))
2127
+ assert_ (res .shape == self .tgtshape [(4 - ip , 4 - jp )])
2128
+ assert_ (not res .flags ['OWNDATA' ])
2129
+
2130
+
2068
2131
class TestCross (TestCase ):
2069
2132
def test_2x2 (self ):
2070
2133
u = [1 , 2 ]
0 commit comments