8000 Merge pull request #1666 from RustPython/coolreader18/length_hint · RustPython/RustPython@896f3c4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 896f3c4

Browse files
authored
Merge pull request #1666 from RustPython/coolreader18/length_hint
Add __length_hint__ support for iterators
2 parents 7d40f45 + a26e5f2 commit 896f3c4

File tree

13 files changed

+194
-40
lines changed

13 files changed

+194
-40
lines changed

tests/snippets/dict.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,16 @@ def __eq__(self, other):
254254
assert not {}.__ne__({})
255255
assert {}.__ne__({'a':'b'})
256256
assert {}.__ne__(1) == NotImplemented
257+
258+
it = iter({0: 1, 2: 3, 4:5, 6:7})
259+
assert it.__length_hint__() == 4
260+
next(it)
261+
assert it.__length_hint__() == 3
262+
next(it)
263+
assert it.__length_hint__() == 2
264+
next(it)
265+
assert it.__length_hint__() == 1
266+
next(it)
267+
assert it.__length_hint__() == 0
268+
assert_raises(StopIteration, next, it)
269+
assert it.__length_hint__() == 0

vm/src/builtins.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use crate::obj::objdict::PyDictRef;
2222
use crate::obj::objfunction::PyFunctionRef;
2323
use crate::obj::objint::{self, PyIntRef};
2424
use crate::obj::objiter;
25+
use crate::obj::objsequence;
2526
use crate::obj::objstr::{PyString, PyStringRef};
2627
use crate::obj::objtype::{self, PyClassRef};
2728
use crate::pyhash;
@@ -384,11 +385,8 @@ fn builtin_iter(iter_target: PyObjectRef, vm: &VirtualMachine) -> PyResult {
384385
objiter::get_iter(vm, &iter_target)
385386
}
386387

387-
fn builtin_len(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult {
388-
let method = vm.get_method_or_type_error(obj.clone(), "__len__", || {
389-
format!("object of type '{}' has no len()", obj.class().name)
390-
})?;
391-
vm.invoke(&method, PyFuncArgs::default())
388+
fn builtin_len(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
389+
objsequence::len(&obj, vm)
392390
}
393391

394392
fn builtin_locals(vm: &VirtualMachine) -> PyDictRef {

vm/src/dictdatatype.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -204,26 +204,26 @@ impl<T: Clone> Dict<T> {
204204
}
205205

206206
pub fn next_entry(&self, position: &mut EntryIndex) -> Option<(&PyObjectRef, &T)> {
207-
while *position < self.entries.len() {
208-
if let Some(DictEntry { key, value, .. }) = &self.entries[*position] {
209-
*position += 1;
210-
return Some((key, value));
211-
}
207+
self.entries[*position..].iter().find_map(|entry| {
212208
*position += 1;
213-
}
214-
None
209+
entry
210+
.as_ref()
211+
.map(|DictEntry { key, value, .. }| (key, value))
212+
})
213+
}
214+
215+
pub fn len_from_entry_index(&self, position: EntryIndex) -> usize {
216+
self.entries[position..].iter().flatten().count()
215217
}
216218

217219
pub fn has_changed_size(&self, position: &DictSize) -> bool {
218220
position.size != self.size || self.entries.len() != position.entries_size
219221
}
220222

221-
pub fn keys<'a>(&'a self) -> Box<dyn Iterator<Item = PyObjectRef> + 'a> {
222-
Box::new(
223-
self.entries
224-
.iter()
225-
.filter_map(|v| v.as_ref().map(|v| v.key.clone())),
226-
)
223+
pub fn keys<'a>(&'a self) -> impl Iterator<Item = PyObjectRef> + 'a {
224+
self.entries
225+
.iter()
226+
.filter_map(|v| v.as_ref().map(|v| v.key.clone()))
227227
}
228228

229229
/// Lookup the index for the given key.

vm/src/exceptions.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,6 @@ pub fn write_exception_inner<W: Write>(
250250
write_traceback_entry(output, traceback)?;
251251
tb = traceback.next.as_ref();
252252
}
253-
} else {
254-
writeln!(output, "No traceback set on exception")?;
255253
}
256254

257255
let varargs = exc.args();

vm/src/obj/objdict.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,14 @@ macro_rules! dict_iterator {
545545
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
546546
zelf
547547
}
548+
549+
#[pymethod(name = "__length_hint__")]
550+
fn length_hint(&self, _vm: &VirtualMachine) -> usize {
551+
self.dict
552+
.entries
553+
.borrow()
554+
.len_from_entry_index(self.position.get())
555+
}
548556
}
549557

550558
impl PyValue for $iter_name {

vm/src/obj/objiter.rs

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
* Various types to support iteration.
33
*/
44

5+
use num_traits::{Signed, ToPrimitive};
56
use std::cell::Cell;
67

8+
use super::objint::PyInt;
9+
use super::objsequence;
710
use super::objtype::{self, PyClassRef};
811
use crate::exceptions::PyBaseExceptionRef;
912
use crate::pyobject::{
@@ -61,10 +64,12 @@ pub fn get_next_object(
6164

6265
/* Retrieve all elements from an iterator */
6366
pub fn get_all<T: TryFromObject>(vm: &VirtualMachine, iter_obj: &PyObjectRef) -> PyResult<Vec<T>> {
64-
let mut elements = vec![];
67+
let cap = length_hint(vm, iter_obj.clone())?.unwrap_or(0);
68+
let mut elements = Vec::with_capacity(cap);
6569
while let Some(element) = get_next_object(vm, iter_obj)? {
6670
elements.push(T::try_from_object(vm, element)?);
6771
}
72+
elements.shrink_to_fit();
6873
Ok(elements)
6974
}
7075

@@ -83,6 +88,49 @@ pub fn stop_iter_value(vm: &VirtualMachine, exc: &PyBaseExceptionRef) -> PyResul
8388
Ok(val)
8489
}
8590

91+
pub fn length_hint(vm: &VirtualMachine, iter: PyObjectRef) -> PyResult<Option<usize>> {
92+
if let Some(len) = objsequence::opt_len(&iter, vm) {
93+
match len {
94+
Ok(len) => return Ok(Some(len)),
95+
Err(e) => {
96+
if !objtype::isinstance(&e, &vm.ctx.exceptions.type_error) {
97+
return Err(e);
98+
}
99+
}
100+
}
101+
}
102+
let hint = match vm.get_method(iter, "__length_hint__") {
103+
Some(hint) => hint?,
104+
None => return Ok(None),
105+
};
106+
let result = match vm.invoke(&hint, vec![]) {
107+
Ok(res) => res,
108+
Err(e) => {
109+
if objtype::isinstance(&e, &vm.ctx.exceptions.type_error) {
110+
return Ok(None);
111+
} else {
112+
return Err(e);
113+
}
114+
}
115+
};
116+
let result = result
117+
.payload_if_subclass::<PyInt>(vm)
118+
.ok_or_else(|| {
119+
vm.new_type_error(format!(
120+
"'{}' object cannot be interpreted as an integer",
121+
result.class().name
122+
))
123+
})?
124+
.as_bigint();
125+
if result.is_negative() {
126+
return Err(vm.new_value_error("__length_hint__() should return >= 0".to_string()));
127+
}
128+
let hint = result.to_usize().ok_or_else(|| {
129+
vm.new_value_error("Python int too large to convert to Rust usize".to_string())
130+
})?;
131+
Ok(Some(hint))
132+
}
133+
86134
#[pyclass]
87135
#[derive(Debug)]
88136
pub struct PySequenceIterator {
@@ -124,6 +172,20 @@ impl PySequenceIterator {
124172
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
125173
zelf
126174
}
175+
176+
#[pymethod(name = "__length_hint__")]
177+
fn length_hint(&self, vm: &VirtualMachine) -> PyResult<isize> {
178+
let pos = self.position.get();
179+
let hint = if self.reversed {
180+
pos + 1
181+
} else {
182+
let len = objsequence::opt_len(&self.obj, vm).unwrap_or_else(|| {
183+
Err(vm.new_type_error("sequence has no __len__ method".to_string()))
184+
})?;
185+
len as isize - pos
186+
};
187+
Ok(hint)
188+
}
127189
}
128190

129191
pub fn seq_iter_method(obj: PyObjectRef, _vm: &VirtualMachine) -> PySequenceIterator {

vm/src/obj/objlist.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,11 @@ impl PyListIterator {
873873
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
874874
zelf
875875
}
876+
877+
#[pymethod(name = "__length_hint__")]
878+
fn length_hint(&self, _vm: &VirtualMachine) -> usize {
879+
self.list.elements.borrow().len() - self.position.get()
880+
}
876881
}
877882

87888 D7AE 3
#[pyclass]
@@ -906,6 +911,11 @@ impl PyListReverseIterator {
906911
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
907912
zelf
908913
}
914+
915+
#[pymethod(name = "__length_hint__")]
916+
fn length_hint(&self, _vm: &VirtualMachine) -> usize {
917+
self.position.get()
918+
}
909919
}
910920

911921
pub fn init(context: &PyContext) {

vm/src/obj/objmap.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ impl PyMap {
5858
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
5959
zelf
6060
}
61+
62+
#[pymethod(name = "__length_hint__")]
63+
fn length_hint(&self, vm: &VirtualMachine) -> PyResult<usize> {
64+
self.iterators.iter().try_fold(0, |prev, cur| {
65+
let cur = objiter::length_hint(vm, cur.clone())?.unwrap_or(0);
66+
let max = std::cmp::max(prev, cur);
67+
Ok(max)
68+
})
69+
}
6170
}
6271

6372
pub fn init(context: &PyContext) {

vm/src/obj/objrange.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ impl PyRange {
210210
}
211211

212212
#[pymethod(name = "__len__")]
213-
fn len(&self, _vm: &VirtualMachine) -> PyInt {
214-
PyInt::new(self.length())
213+
fn len(&self, _vm: &VirtualMachine) -> BigInt {
214+
self.length()
215215
}
216216

217217
#[pymethod(name = "__repr__")]

vm/src/obj/objsequence.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,33 @@ pub fn is_valid_slice_arg(
263263
Ok(None)
264264
}
265265
}
266+
267+
pub fn opt_len(obj: &PyObjectRef, vm: &VirtualMachine) -> Option<PyResult<usize>> {
268+
vm.get_method(obj.clone(), "__len__").map(|len| {
269+
let len = vm.invoke(&len?, vec![])?;
270+
let len = len
271+
.payload_if_subclass::<PyInt>(vm)
272+
.ok_or_else(|| {
273+
vm.new_type_error(format!(
274+
"'{}' object cannot be interpreted as an integer",
275+
len.class().name
276+
))
277+
})?
278+
.as_bigint();
279+
if len.is_negative() {
280+
return Err(vm.new_value_error("__len__() should return >= 0".to_string()));
281+
}
282+
len.to_usize().ok_or_else(|| {
283+
vm.new_overflow_error("cannot fit __len__() result into usize".to_string())
284+
})
285+
})
286+
}
287+
288+
pub fn len(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
289+
opt_len(obj, vm).unwrap_or_else(|| {
290+
Err(vm.new_type_error(format!(
291+
"object of type '{}' has no len()",
292+
obj.class().name
293+
)))
294+
})
295+
}

vm/src/stdlib/itertools.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
use std::cell::{Cell, RefCell};
2-
use std::cmp::Ordering;
32
use std::iter;
4-
use std::ops::{AddAssign, SubAssign};
53
use std::rc::Rc;
64

75
use num_bigint::BigInt;
8-
use num_traits::sign::Signed;
9-
use num_traits::ToPrimitive;
6+
use num_traits::{One, Signed, ToPrimitive, Zero};
107

118
use crate::function::{Args, OptionalArg, OptionalOption, PyFuncArgs};
129
use crate::obj::objbool;
@@ -166,11 +163,11 @@ impl PyItertoolsCount {
166163
) -> PyResult<PyRef<Self>> {
167164
let start = match start.into_option() {
168165
Some(int) => int.as_bigint().clone(),
169-
None => BigInt::from(0),
166+
None => BigInt::zero(),
170167
};
171168
let step = match step.into_option() {
172169
Some(int) => int.as_bigint().clone(),
173-
None => BigInt::from(1),
170+
None => BigInt::one(),
174171
};
175172

176173
PyItertoolsCount {
@@ -183,7 +180,7 @@ impl PyItertoolsCount {
183180
#[pymethod(name = "__next__")]
184181
fn next(&self, _vm: &VirtualMachine) -> PyResult<PyInt> {
185182
let result = self.cur.borrow().clone();
186-
AddAssign::add_assign(&mut self.cur.borrow_mut() as &mut BigInt, &self.step);
183+
*self.cur.borrow_mut() += &self.step;
187184
Ok(PyInt::new(result))
188185
}
189186

@@ -296,16 +293,11 @@ impl PyItertoolsRepeat {
296293

297294
#[pymethod(name = "__next__")]
298295
fn next(&self, vm: &VirtualMachine) -> PyResult {
299-
if self.times.is_some() {
300-
match self.times.as_ref().unwrap().borrow().cmp(&BigInt::from(0)) {
301-
Ordering::Less | Ordering::Equal => return Err(new_stop_iteration(vm)),
302-
_ => (),
303-
};
304-
305-
SubAssign::sub_assign(
306-
&mut self.times.as_ref().unwrap().borrow_mut() as &mut BigInt,
307-
&BigInt::from(1),
308-
);
296+
if let Some(ref times) = self.times {
297+
if *times.borrow() <= BigInt::zero() {
298+
return Err(new_stop_iteration(vm));
299+
}
300+
*times.borrow_mut() -= 1;
309301
}
310302

311303
Ok(self.object.clone())
@@ -315,6 +307,14 @@ impl PyItertoolsRepeat {
315307
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
316308
zelf
317309
}
310+
311+
#[pymethod(name = "__length_hint__")]
312+
fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef {
313+
match self.times {
314+
Some(ref times) => vm.new_int(times.borrow().clone()),
315+
None => vm.new_int(0),
316+
}
317+
}
318318
}
319319

320320
#[pyclass(name = "starmap")]

vm/src/stdlib/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ mod json;
1616
mod keyword;
1717
mod marshal;
1818
mod math;
19+
mod operator;
1920
mod platform;
2021
mod pystruct;
2122
mod random;
@@ -74,6 +75,7 @@ pub fn get_module_inits() -> HashMap<String, StdlibInitFunc> {
7475
"json".to_string() => Box::new(json::make_module),
7576
"marshal".to_string() => Box::new(marshal::make_module),
7677
"math".to_string() => Box::new(math::make_module),
78+
"_operator".to_string() => Box::new(operator::make_module),
7779
"platform".to_string() => Box::new(platform::make_module),
7880
"regex_crate".to_string() => Box::new(re::make_module),
7981
"random".to_string() => Box::new(random::make_module),

vm/src/stdlib/operator.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
use crate::function::OptionalArg;
2+
use crate::obj::{objiter, objtype};
3+
use crate::pyobject::{PyObjectRef, PyResult, TypeProtocol};
4+
use crate::VirtualMachine;
5+
6+
fn operator_length_hint(obj: PyObjectRef, default: OptionalArg, vm: &VirtualMachine) -> PyResult {
7+
let default = default.unwrap_or_else(|| vm.new_int(0));
8+
if !objtype::isinstance(&default, &vm.ctx.types.int_type) {
9+
return Err(vm.new_type_error(format!(
10+
"'{}' type cannot be interpreted as an integer",
11+
default.class().name
12+
)));
13+
}
14+
let hint = objiter::length_hint(vm, obj)?
15+
.map(|i| vm.new_int(i))
16+
.unwrap_or(default);
17+
Ok(hint)
18+
}
19+
20+
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
21+
py_module!(vm, "_operator", {
22+
"length_hint" => vm.ctx.new_rustfunc(operator_length_hint),
23+
})
24+
}

0 commit comments

Comments
 (0)
0