8000 Add `raise_if_stop!` macro for `protocol/iter.rs` by ShaharNaveh · Pull Request #5885 · RustPython/RustPython · GitHub
[go: up one dir, main page]

Skip to content

Add raise_if_stop! macro for protocol/iter.rs #5885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 2, 2025
Next Next commit
Replace struct name with Self
  • Loading branch information
ShaharNaveh committed Jul 2, 2025
commit e7e7ddbfcc8aeafdec7942e50382fe0c84990002
57 changes: 30 additions & 27 deletions vm/src/stdlib/itertools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ mod decl {
#[pyslot]
fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
let args_list = PyList::from(args.args);
PyItertoolsChain {
Self {
source: PyRwLock::new(Some(args_list.to_pyobject(vm).get_iter(vm)?)),
active: PyRwLock::new(None),
}
Expand Down Expand Up @@ -91,25 +91,26 @@ mod decl {
fn __setstate__(zelf: PyRef<Self>, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
let args = state.as_slice();
if args.is_empty() {
let msg = String::from("function takes at least 1 arguments (0 given)");
return Err(vm.new_type_error(msg));
return Err(vm.new_type_error("function takes at least 1 arguments (0 given)"));
}
if args.len() > 2 {
let msg = format!("function takes at most 2 arguments ({} given)", args.len());
return Err(vm.new_type_error(msg));
return Err(vm.new_type_error(format!(
"function takes at most 2 arguments ({} given)",
args.len()
)));
}
let source = &args[0];
if args.len() == 1 {
if !PyIter::check(source.as_ref()) {
return Err(vm.new_type_error(String::from("Arguments must be iterators.")));
return Err(vm.new_type_error("Arguments must be iterators."));
}
*zelf.source.write() = source.to_owned().try_into_value(vm)?;
return Ok(());
}
let active = &args[1];

if !PyIter::check(source.as_ref()) || !PyIter::check(active.as_ref()) {
return Err(vm.new_type_error(String::from("Arguments must be iterators.")));
return Err(vm.new_type_error("Arguments must be iterators."));
}
let mut source_lock = zelf.source.write();
let mut active_lock = zelf.active.write();
Expand Down Expand Up @@ -387,7 +388,7 @@ mod decl {
}
None => None,
};
PyItertoolsRepeat { object, times }
Self { object, times }
.into_ref_with_type(vm, cls)
.map(Into::into)
}
Expand Down Expand Up @@ -466,7 +467,7 @@ mod decl {
Self::Args { function, iterable }: Self::Args,
vm: &VirtualMachine,
) -> PyResult {
PyItertoolsStarmap { function, iterable }
Self { function, iterable }
.into_ref_with_type(vm, cls)
.map(Into::into)
}
Expand Down Expand Up @@ -527,7 +528,7 @@ mod decl {
}: Self::Args,
vm: &VirtualMachine,
) -> PyResult {
PyItertoolsTakewhile {
Self {
predicate,
iterable,
stop_flag: AtomicCell::new(false),
Expand Down Expand Up @@ -614,7 +615,7 @@ mod decl {
}: Self::Args,
vm: &VirtualMachine,
) -> PyResult {
PyItertoolsDropwhile {
Self {
predicate,
iterable,
start_flag: AtomicCell::new(false),
Expand All @@ -634,6 +635,7 @@ mod decl {
(zelf.start_flag.load() as _),
)
}

#[pymethod]
fn __setstate__(
zelf: PyRef<Self>,
Expand Down Expand Up @@ -734,7 +736,7 @@ mod decl {
Self::Args { iterable, key }: Self::Args,
vm: &VirtualMachine,
) -> PyResult {
PyItertoolsGroupBy {
Self {
iterable,
key_func: key.flatten(),
state: PyMutex::new(GroupByState {
Expand Down Expand Up @@ -952,7 +954,7 @@ mod decl {

let iter = iter.get_iter(vm)?;

PyItertoolsIslice {
Self {
iterable: iter,
cur: AtomicCell::new(0),
next: AtomicCell::new(start),
Expand Down Expand Up @@ -980,14 +982,16 @@ mod decl {
fn __setstate__(zelf: PyRef<Self>, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
let args = state.as_slice();
if args.len() != 1 {
let msg = format!("function takes exactly 1 argument ({} given)", args.len());
return Err(vm.new_type_error(msg));
return Err(vm.new_type_error(format!(
"function takes exactly 1 argument ({} given)",
args.len()
)));
}
let cur = &args[0];
if let Ok(cur) = cur.try_to_value(vm) {
zelf.cur.store(cur);
} else {
return Err(vm.new_type_error(String::from("Argument must be usize.")));
return Err(vm.new_type_error("Argument must be usize."));
}
Ok(())
}
Expand Down Expand Up @@ -1049,7 +1053,7 @@ mod decl {
}: Self::Args,
vm: &VirtualMachine,
) -> PyResult {
PyItertoolsFilterFalse {
Self {
predicate,
iterable,
}
Expand Down Expand Up @@ -1117,7 +1121,7 @@ mod decl {
type Args = AccumulateArgs;

fn py_new(cls: PyTypeRef, args: AccumulateArgs, vm: &VirtualMachine) -> PyResult {
PyItertoolsAccumulate {
Self {
iterable: args.iterable,
bin_op: args.func.flatten(),
initial: args.initial.flatten(),
Expand Down Expand Up @@ -1226,8 +1230,8 @@ mod decl {
}

impl PyItertoolsTeeData {
fn new(iterable: PyIter, _vm: &VirtualMachine) -> PyResult<PyRc<PyItertoolsTeeData>> {
Ok(PyRc::new(PyItertoolsTeeData {
fn new(iterable: PyIter, _vm: &VirtualMachine) -> PyResult<PyRc<Self>> {
Ok(PyRc::new(Self {
iterable,
values: PyRwLock::new(vec![]),
}))
Expand Down Expand Up @@ -1295,7 +1299,7 @@ mod decl {
if iterator.class().is(PyItertoolsTee::class(&vm.ctx)) {
return vm.call_special_method(&iterator, identifier!(vm, __copy__), ());
}
Ok(PyItertoolsTee {
Ok(Self {
tee_data: PyItertoolsTeeData::new(iterator, vm)?,
index: AtomicCell::new(0),
}
Expand Down Expand Up @@ -1354,7 +1358,7 @@ mod decl {

let l = pools.len();

PyItertoolsProduct {
Self {
pools,
idxs: PyRwLock::new(vec![0; l]),
cur: AtomicCell::new(l.wrapping_sub(1)),
Expand Down Expand Up @@ -1394,8 +1398,7 @@ mod decl {
fn __setstate__(zelf: PyRef<Self>, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
let args = state.as_slice();
if args.len() != zelf.pools.len() {
let msg = "Invalid number of arguments".to_string();
return Err(vm.new_type_error(msg));
return Err(vm.new_type_error("Invalid number of arguments"));
}
let mut idxs: PyRwLockWriteGuard<'_, Vec<usize>> = zelf.idxs.write();
idxs.clear();
Expand Down Expand Up @@ -1642,7 +1645,7 @@ mod decl {

let n = pool.len();

PyItertoolsCombinationsWithReplacement {
Self {
pool,
indices: PyRwLock::new(vec![0; r]),
r: AtomicCell::new(r),
Expand Down Expand Up @@ -1860,7 +1863,7 @@ mod decl {
fn py_new(cls: PyTypeRef, (iterators, args): Self::Args, vm: &VirtualMachine) -> PyResult {
let fillvalue = args.fillvalue.unwrap_or_none(vm);
let iterators = iterators.into_vec();
PyItertoolsZipLongest {
Self {
iterators,
fillvalue: PyRwLock::new(fillvalue),
}
Expand Down Expand Up @@ -1943,7 +1946,7 @@ mod decl {
type Args = PyIter;

fn py_new(cls: PyTypeRef, iterator: Self::Args, vm: &VirtualMachine) -> PyResult {
PyItertoolsPairwise {
Self {
iterator,
old: PyRwLock::new(None),
}
Expand Down
0