@@ -308,3 +308,68 @@ def test_multi_node_weighted_large_sample_size_with_prefetcher(self, midpoint, s
308308 stop_criteria ,
309309 )
310310 run_test_save_load_state (self , node , midpoint )
311+
312+ def test_multi_node_weighted_sampler_tag_output_dict_items (self ) -> None :
313+ """Test MultiNodeWeightedSampler with tag_output=True for dictionary items"""
314+ node = MultiNodeWeightedSampler (
315+ self .datasets ,
316+ self .weights ,
317+ tag_output = True ,
318+ )
319+
320+ results = list (node )
321+
322+ # Verify that each result has a 'dataset_key' key with the correct dataset name
323+ for result in results :
324+ self .assertIn ("dataset_key" , result )
325+
326+ dataset_name = result ["dataset_key" ]
327+ self .assertIn (dataset_name , [f"ds{ i } " for i in range (self ._num_datasets )])
328+
329+ self .assertIn ("name" , result )
330+ self .assertIn ("test_tensor" , result )
331+
332+ self .assertEqual (dataset_name , result ["name" ])
333+
334+ def test_multi_node_weighted_sampler_tag_output_non_dict_items (self ) -> None :
335+ """Test MultiNodeWeightedSampler with tag_output=True for non-dictionary items"""
336+ non_dict_datasets = {
337+ f"ds{ i } " : IterableWrapper (range (i * 10 , (i + 1 ) * 10 ))
338+ for i in range (self ._num_datasets )
339+ }
340+
341+ node = MultiNodeWeightedSampler (
342+ non_dict_datasets ,
343+ self .weights ,
344+ tag_output = True ,
345+ )
346+
347+ results = list (node )
348+
349+ # Verify that each result is now a dictionary with 'data' and 'dataset_key' keys
350+ for result in results :
351+ self .assertIsInstance (result , dict )
352+
353+ self .assertIn ("data" , result )
354+ self .assertIn ("dataset_key" , result )
355+
356+ dataset_name = result ["dataset_key" ]
357+ self .assertIn (dataset_name , [f"ds{ i } " for i in range (self ._num_datasets )])
358+
359+ def test_multi_node_weighted_sampler_tag_output_false (self ) -> None :
360+ """Test MultiNodeWeightedSampler with tag_output=False (default behavior)"""
361+ node = MultiNodeWeightedSampler (
362+ self .datasets ,
363+ self .weights ,
364+ tag_output = False ,
365+ )
366+
367+ results = list (node )
368+
369+ # Verify that none of the results have a 'dataset' key
370+ for result in results :
371+ self .assertNotIn ("dataset" , result )
372+
373+ # Check that the original data is preserved
374+ self .assertIn ("name" , result )
375+ self .assertIn ("test_tensor" , result )
0 commit comments