diff --git a/tests/snippets/membership.py b/tests/snippets/membership.py index a944c45398..2987c3c0fe 100644 --- a/tests/snippets/membership.py +++ b/tests/snippets/membership.py @@ -12,9 +12,8 @@ assert "whatever" not in "foobar" # test bytes -# TODO: uncomment this when bytes are implemented -# assert b"foo" in b"foobar" -# assert b"whatever" not in b"foobar" +assert b"foo" in b"foobar" +assert b"whatever" not in b"foobar" assert b"1" < b"2" assert b"1" <= b"2" assert b"5" <= b"5" @@ -32,18 +31,20 @@ assert 3 not in set([1, 2]) # test dicts -# TODO: test dicts when keys other than strings will be allowed assert "a" in {"a": 0, "b": 0} assert "c" not in {"a": 0, "b": 0} +assert 1 in {1: 5, 7: 12} +assert 5 not in {9: 10, 50: 100} +assert True in {True: 5} +assert False not in {True: 5} # test iter assert 3 in iter([1, 2, 3]) assert 3 not in iter([1, 2]) # test sequence -# TODO: uncomment this when ranges are usable -# assert 1 in range(0, 2) -# assert 3 not in range(0, 2) +assert 1 in range(0, 2) +assert 3 not in range(0, 2) # test __contains__ in user objects class MyNotContainingClass(): diff --git a/tests/snippets/strings.py b/tests/snippets/strings.py index 0baa35d23c..28df3f47f4 100644 --- a/tests/snippets/strings.py +++ b/tests/snippets/strings.py @@ -29,6 +29,14 @@ assert str(["a", "b", "can't"]) == "['a', 'b', \"can't\"]" +assert "xy" * 3 == "xyxyxy" +assert "x" * 0 == "" +assert "x" * -1 == "" + +assert 3 * "xy" == "xyxyxy" +assert 0 * "x" == "" +assert -1 * "x" == "" + a = 'Hallo' assert a.lower() == 'hallo' assert a.upper() == 'HALLO' @@ -94,6 +102,7 @@ ] +# requires CPython 3.7, and the CI currently runs with 3.6 # assert c.isascii() assert c.index('a') == 1 assert c.rindex('l') == 3 diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 22eda1f34b..e5e2afe370 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -183,7 +183,12 @@ impl PyString { if objtype::isinstance(&val, &vm.ctx.int_type()) { let value = &self.value; let multiplier = objint::get_value(&val).to_i32().unwrap(); - let mut result = String::new(); + let capacity = if multiplier > 0 { + multiplier.to_usize().unwrap() * value.len() + } else { + 0 + }; + let mut result = String::with_capacity(capacity); for _x in 0..multiplier { result.push_str(value.as_str()); } @@ -193,6 +198,11 @@ impl PyString { } } + #[pymethod(name = "__rmul__")] + fn rmul(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.mul(val, vm) + } + #[pymethod(name = "__str__")] fn str(zelf: PyRef, _vm: &VirtualMachine) -> PyStringRef { zelf @@ -206,7 +216,7 @@ impl PyString { } else { '\'' }; - let mut formatted = String::new(); + let mut formatted = String::with_capacity(value.len()); formatted.push(quote_char); for c in value.chars() { if c == quote_char || c == '\\' { @@ -799,7 +809,7 @@ impl PyString { #[pymethod] fn expandtabs(&self, tab_stop: OptionalArg, _vm: &VirtualMachine) -> String { let tab_stop = tab_stop.into_option().unwrap_or(8 as usize); - let mut expanded_str = String::new(); + let mut expanded_str = String::with_capacity(self.value.len()); let mut tab_size = tab_stop; let mut col_count = 0 as usize; for ch in self.value.chars() {