@@ -1200,20 +1200,41 @@ def test_collate_datapipe(self):
1200
1200
arrs = [[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]]
1201
1201
input_dp = dp .iter .IterableWrapper (arrs )
1202
1202
1203
- def _collate_fn (batch ):
1204
- return torch .tensor (sum (batch ), dtype = torch . float )
1203
+ def _collate_fn (batch , default_type = torch . float ):
1204
+ return torch .tensor (sum (batch ), dtype = default_type )
1205
1205
1206
+ # Functional Test: defaults to the default collate function when a custom one is not specified
1207
+ collate_dp = input_dp .collate ()
1208
+ for x , y in zip (input_dp , collate_dp ):
1209
+ self .assertEqual (torch .tensor (x ), y )
1210
+
1211
+ # Functional Test: custom collate function
1206
1212
collate_dp = input_dp .collate (collate_fn = _collate_fn )
1213
+ for x , y in zip (input_dp , collate_dp ):
1214
+ self .assertEqual (torch .tensor (sum (x ), dtype = torch .float ), y )
1215
+
1216
+ # Functional Test: custom, partial collate function
1217
+ collate_dp = input_dp .collate (partial (_collate_fn , default_type = torch .int ))
1218
+ for x , y in zip (input_dp , collate_dp ):
1219
+ self .assertEqual (torch .tensor (sum (x ), dtype = torch .int ), y )
1220
+
1221
+ # Reset Test: reset the DataPipe and results are still correct
1222
+ n_elements_before_reset = 1
1223
+ res_before_reset , res_after_reset = reset_after_n_next_calls (collate_dp , n_elements_before_reset )
1224
+ self .assertEqual ([torch .tensor (6 , dtype = torch .int )], res_before_reset )
1225
+ for x , y in zip (input_dp , res_after_reset ):
1226
+ self .assertEqual (torch .tensor (sum (x ), dtype = torch .int ), y )
1227
+
1228
+ # __len__ Test: __len__ is inherited
1207
1229
self .assertEqual (len (input_dp ), len (collate_dp ))
1208
- for x , y in zip (collate_dp , input_dp ):
1209
- self .assertEqual (x , torch .tensor (sum (y ), dtype = torch .float ))
1210
1230
1231
+ # __len__ Test: verify that it has no valid __len__ when the source doesn't have it
1211
1232
input_dp_nl = IDP_NoLen (arrs )
1212
1233
collate_dp_nl = input_dp_nl .collate ()
1213
1234
with self .assertRaisesRegex (TypeError , r"instance doesn't have valid length$" ):
1214
1235
len (collate_dp_nl )
1215
- for x , y in zip (collate_dp_nl , input_dp_nl ):
1216
- self .assertEqual (x , torch .tensor (y ) )
1236
+ for x , y in zip (input_dp_nl , collate_dp_nl ):
1237
+ self .assertEqual (torch .tensor (x ), y )
1217
1238
1218
1239
def test_batch_datapipe (self ):
1219
1240
arrs = list (range (10 ))
0 commit comments