8000 Expose and wrap the ParallelSorter · meilisearch/grenad@97896cd · GitHub
[go: up one dir, main page]

Skip to content

Commit 97896cd

Browse files
committed
Expose and wrap the ParallelSorter
1 parent eafb6ae commit 97896cd

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

src/lib.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ pub use self::reader::{PrefixIter, RangeIter, Reader, ReaderCursor, RevPrefixIte
203203
#[cfg(feature = "tempfile")]
204204
pub use self::sorter::TempFileChunk;
205205
pub use self::sorter::{
206-
ChunkCreator, CursorVec, DefaultChunkCreator, SortAlgorithm, Sorter, SorterBuilder,
206+
ChunkCreator, CursorVec, DefaultChunkCreator, ParallelSorter, SortAlgorithm, Sorter,
207+
SorterBuilder,
207208
};
208209
pub use self::writer::{Writer, WriterBuilder};
209210

src/sorter.rs

+21-11
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ impl<MF, CC: ChunkCreator> SorterBuilder<MF, CC> {
218218
CC::Chunk: Send + 'static,
219219
{
220220
match number.get() {
221-
1 => ParallelSorter::Single(self.build()),
221+
1 | 2 => ParallelSorter(ParallelSorterInner::Single(self.build())),
222222
number => {
223223
let (senders, receivers): (Vec<Sender<(usize, Vec<u8>)>>, Vec<_>) =
224224
repeat_with(unbounded).take(number).unzip();
@@ -227,6 +227,7 @@ impl<MF, CC: ChunkCreator> SorterBuilder<MF, CC> {
227227
for receiver in receivers {
228228
let sorter_builder = self.clone();
229229
handles.push(thread::spawn(move || {
230+
// TODO make sure the max memory is divided by the number of threads
230231
let mut sorter = sorter_builder.build();
231232
for (key_length, data) in receiver {
232233
let (key, val) = data.split_at(key_length);
@@ -236,7 +237,11 @@ impl<MF, CC: ChunkCreator> SorterBuilder<MF, CC> {
236237
}));
237238
}
238239

239-
ParallelSorter::Multi { senders, handles, merge_function: self.merge }
240+
ParallelSorter(ParallelSorterInner::Multi {
241+
senders,
242+
handles,
243+
merge_function: self.merge,
244+
})
240245
}
241246
}
242247
}
@@ -712,14 +717,19 @@ where
712717
}
713718
}
714719

715-
// TODO Make this private by wrapping it
716-
pub enum ParallelSorter<MF, U, CC: ChunkCreator = DefaultChunkCreator>
720+
pub struct ParallelSorter<MF, U, CC: ChunkCreator = DefaultChunkCreator>(
721+
ParallelSorterInner<MF, U, CC>,
722+
)
723+
where
724+
MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>;
725+
726+
enum ParallelSorterInner<MF, U, CC: ChunkCreator = DefaultChunkCreator>
717727
where
718728
MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
719729
{
720730
Single(Sorter<MF, CC>),
721731
Multi {
722-
// Indicates the length of the key and the bytes assoicated to the key + the data.
732+
// Indicates the length of the key and the bytes associated to the key + the data.
723733
senders: Vec<Sender<(usize, Vec<u8>)>>,
724734
handles: Vec<JoinHandle<Result<Vec<ReaderCursor<CC::Chunk>>, Error<U>>>>,
725735
merge_function: MF,
@@ -740,9 +750,9 @@ where
740750
{
741751
let key = key.as_ref();
742752
let val = val.as_ref();
743-
match self {
744-
ParallelSorter::Single(sorter) => sorter.insert(key, val),
745-
ParallelSorter::Multi { senders, .. } => {
753+
match &mut self.0 {
754+
ParallelSorterInner::Single(sorter) => sorter.insert(key, val),
755+
ParallelSorterInner::Multi { senders, .. } => {
746756
let key_length = key.len();
747757
let key_hash = compute_hash(key);
748758

@@ -766,9 +776,9 @@ where
766776

767777
/// Consumes this [`Sorter`] and outputs a stream of the merged entries in key-order.
768778
pub fn into_stream_merger_iter(self) -> Result<MergerIter<CC::Chunk, MF>, Error<U>> {
769-
match self {
770-
ParallelSorter::Single(sorter) => sorter.into_stream_merger_iter(),
771-
ParallelSorter::Multi { senders, handles, merge_function } => {
779+
match self.0 {
780+
ParallelSorterInner::Single(sorter) => sorter.into_stream_merger_iter(),
781+
ParallelSorterInner::Multi { senders, handles, merge_function } => {
772782
drop(senders);
773783

774784
let mut sources = Vec::new();

0 commit comments

Comments
 (0)
0