@@ -783,21 +783,27 @@ class Grouper:
783783 """
784784
785785 def __init__ (self , init = ()):
786- self ._mapping = weakref .WeakKeyDictionary (
787- {x : weakref .WeakSet ([x ]) for x in init })
786+ self ._count = itertools .count ()
787+ # For each item, we store (order_in_which_item_was_seen, group_of_item), which
788+ # lets __iter__ and get_siblings return items in the order in which they have
789+ # been seen.
790+ self ._mapping = weakref .WeakKeyDictionary ()
791+ for x in init :
792+ if x not in self ._mapping :
793+ self ._mapping [x ] = (next (self ._count ), weakref .WeakSet ([x ]))
788794
789795 def __getstate__ (self ):
790796 return {
791797 ** vars (self ),
792798 # Convert weak refs to strong ones.
793- "_mapping" : {k : set (v ) for k , v in self ._mapping .items ()},
799+ "_mapping" : {k : ( i , set (v )) for k , ( i , v ) in self ._mapping .items ()},
794800 }
795801
796802 def __setstate__ (self , state ):
797803 vars (self ).update (state )
798804 # Convert strong refs to weak ones.
799805 self ._mapping = weakref .WeakKeyDictionary (
800- {k : weakref .WeakSet (v ) for k , v in self ._mapping .items ()})
806+ {k : ( i , weakref .WeakSet (v )) for k , ( i , v ) in self ._mapping .items ()})
801807
802808 def __contains__ (self , item ):
803809 return item in self ._mapping
@@ -810,25 +816,32 @@ def join(self, a, *args):
810816 """
811817 Join given arguments into the same set. Accepts one or more arguments.
812818 """
813- mapping = self ._mapping
814- set_a = mapping .setdefault (a , weakref .WeakSet ([a ]))
815-
816- for arg in args :
817- set_b = mapping .get (arg , weakref .WeakSet ([arg ]))
819+ m = self ._mapping
820+ try :
821+ _ , set_a = m [a ]
822+ except KeyError :
823+ _ , set_a = m [a ] = (next (self ._count ), weakref .WeakSet ([a ]))
824+ for b in args :
825+ try :
826+ _ , set_b = m [b ]
827+ except KeyError :
828+ _ , set_b = m [b ] = (next (self ._count ), weakref .WeakSet ([b ]))
818829 if set_b is not set_a :
819830 if len (set_b ) > len (set_a ):
820831 set_a , set_b = set_b , set_a
821832 set_a .update (set_b )
822833 for elem in set_b :
823- mapping [elem ] = set_a
834+ i , _ = m [elem ]
835+ m [elem ] = (i , set_a )
824836
825837 def joined (self , a , b ):
826838 """Return whether *a* and *b* are members of the same set."""
827- return (self ._mapping .get (a , object ()) is self ._mapping .get (b ))
839+ return (self ._mapping .get (a , (None , object ()))[1 ]
840+ is self ._mapping .get (b , (None , object ()))[1 ])
828841
829842 def remove (self , a ):
830843 """Remove *a* from the grouper, doing nothing if it is not there."""
831- set_a = self ._mapping .pop (a , None )
844+ _ , set_a = self ._mapping .pop (a , ( None , None ) )
832845 if set_a :
833846 set_a .remove (a )
834847
@@ -838,14 +851,14 @@ def __iter__(self):
838851
839852 The iterator is invalid if interleaved with calls to join().
840853 """
841- unique_groups = {id (group ): group for group in self ._mapping .values ()}
854+ unique_groups = {id (group ): group for _ , group in self ._mapping .values ()}
842855 for group in unique_groups .values ():
843- yield [ x for x in group ]
856+ yield sorted ( group , key = self . _mapping . __getitem__ )
844857
845858 def get_siblings (self , a ):
846859 """Return all of the items joined with *a*, including itself."""
847- siblings = self ._mapping .get (a , [a ])
848- return [ x for x in siblings ]
860+ _ , siblings = self ._mapping .get (a , ( None , [a ]) )
861+ return sorted ( siblings , key = self . _mapping . get )
849862
850863
851864class GrouperView :
0 commit comments