8000 [mypyc] Add bytes concat op (#10926) · python/mypy@e7161ac · GitHub
[go: up one dir, main page]

Skip to content

Commit e7161ac

Browse files
[mypyc] Add bytes concat op (#10926)
1 parent daed963 commit e7161ac

File tree

8 files changed

+115
-46
lines changed

8 files changed

+115
-46
lines changed

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ Py_ssize_t CPyStr_Size_size_t(PyObject *str);
400400
// Bytes operations
401401

402402

403+
PyObject *CPyBytes_Concat(PyObject *a, PyObject *b);
403404
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
404405

405406

mypyc/lib-rt/bytes_ops.c

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,24 @@
55
#include <Python.h>
66
#include "CPy.h"
77

8+
PyObject *CPyBytes_Concat(PyObject *a, PyObject *b) {
9+
if (PyBytes_Check(a) && PyBytes_Check(b)) {
10+
Py_ssize_t a_len = ((PyVarObject *)a)->ob_size;
11+
Py_ssize_t b_len = ((PyVarObject *)b)->ob_size;
12+
PyBytesObject *ret = (PyBytesObject *)PyBytes_FromStringAndSize(NULL, a_len + b_len);
13+
if (ret != NULL) {
14+
memcpy(ret->ob_sval, ((PyBytesObject *)a)->ob_sval, a_len);
15+
memcpy(ret->ob_sval + a_len, ((PyBytesObject *)b)->ob_sval, b_len);
16+
}
17+
return (PyObject *)ret;
18+
} else if (PyByteArray_Check(a)) {
19+
return PyByteArray_Concat(a, b);
20+
} else {
21+
PyBytes_Concat(&a, b);
22+
return a;
23+
}
24+
}
25+
826
// Like _PyBytes_Join but fallback to dynamic call if 'sep' is not bytes
927
// (mostly commonly, for bytearrays)
1028
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter) {

mypyc/primitives/bytes_ops.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
object_rprimitive, bytes_rprimitive, list_rprimitive, dict_rprimitive,
66
str_rprimitive, RUnion
77
)
8-
from mypyc.primitives.registry import load_address_op, function_op, method_op
9-
8+
from mypyc.primitives.registry import (
9+
load_address_op, function_op, method_op, binary_op
10+
)
1011

1112
# Get the 'bytes' type object.
1213
load_address_op(
@@ -30,6 +31,16 @@
3031
c_function_name='PyByteArray_FromObject',
3132
error_kind=ERR_MAGIC)
3233

34+
# bytes + bytes
35+
# bytearray + bytearray
36+
binary_op(
37+
name='+',
38+
arg_types=[bytes_rprimitive, bytes_rprimitive],
39+
return_type=bytes_rprimitive,
40+
c_function_name='CPyBytes_Concat',
41+
error_kind=ERR_MAGIC,
42+
steals=[True, False])
43+
3344
# bytes.join(obj)
3445
method_op(
3546
name='join',

mypyc/primitives/list_ops.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,12 @@
179179
error_kind=ERR_MAGIC)
180180

181181
# int * list
182-
binary_op(name='*',
183-
arg_types=[int_rprimitive, list_rprimitive],
184-
return_type=list_rprimitive,
185-
c_function_name='CPySequence_RMultiply',
186-
error_kind=ERR_MAGIC)
182+
binary_op(
183+
name='*',
184+
arg_types=[int_rprimitive, list_rprimitive],
185+
return_type=list_rprimitive,
186+
c_function_name='CPySequence_RMultiply',
187+
error_kind=ERR_MAGIC)
187188

188189
# list[begin:end]
189190
list_slice_op = custom_op(

mypyc/primitives/str_ops.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,46 @@
2929
error_kind=ERR_MAGIC)
3030

3131
# str1 + str2
32-
binary_op(name='+',
33-
arg_types=[str_rprimitive, str_rprimitive],
34-
return_type=str_rprimitive,
35-
c_function_name='PyUnicode_Concat',
36-
error_kind=ERR_MAGIC)
32+
binary_op(
33+
name='+',
34+
arg_types=[str_rprimitive, str_rprimitive],
35+
return_type=str_rprimitive,
36+
c_function_name='PyUnicode_Concat',
37+
error_kind=ERR_MAGIC)
38+
39+
# str1 += str2
40+
#
41+
# PyUnicode_Append makes an effort to reuse the LHS when the refcount
42+
# is 1. This is super dodgy but oh well, the interpreter does it.
43+
binary_op(
44+
name='+=',
45+
arg_types=[str_rprimitive, str_rprimitive],
46+
return_type=str_rprimitive,
47+
c_function_name='CPyStr_Append',
48+
error_kind=ERR_MAGIC,
49+
steals=[True, False])
50+
51+
unicode_compare = custom_op(
52+
arg_types=[str_rprimitive, str_rprimitive],
53+
return_type=c_int_rprimitive,
54+
c_function_name='PyUnicode_Compare',
55+
error_kind=ERR_NEVER)
56+
57+
# str[index] (for an int index)
58+
method_op(
59+
name='__getitem__',
60+
arg_types=[str_rprimitive, int_rprimitive],
61+
return_type=str_rprimitive,
62+
c_function_name='CPyStr_GetItem',
63+
error_kind=ERR_MAGIC
64+
)
65+
66+
# str[begin:end]
67+
str_slice_op = custom_op(
68+
arg_types=[str_rprimitive, int_rprimitive, int_rprimitive],
69+
return_type=object_rprimitive,
70+
c_function_name='CPyStr_GetSlice',
71+
error_kind=ERR_MAGIC)
3772

3873
# str.join(obj)
3974
method_op(
@@ -70,15 +105,6 @@
70105
error_kind=ERR_NEVER
71106
)
72107

73-
# str[index] (for an int index)
74-
method_op(
75-
name='__getitem__',
76-
arg_types=[str_rprimitive, int_rprimitive],
77-
return_type=str_rprimitive,
78-
c_function_name='CPyStr_GetItem',
79-
error_kind=ERR_MAGIC
80-
)
81-
82108
# str.split(...)
83109
str_split_types: List[RType] = [str_rprimitive, str_rprimitive, int_rprimitive]
84110
str_split_functions = ["PyUnicode_Split", "PyUnicode_Split", "CPyStr_Split"]
@@ -96,30 +122,6 @@
96122
extra_int_constants=str_split_constants[i],
97123
error_kind=ERR_MAGIC)
98124

99-
# str1 += str2
100-
#
101-
# PyUnicode_Append makes an effort to reuse the LHS when the refcount
102-
# is 1. This is super dodgy but oh well, the interpreter does it.
103-
binary_op(name='+=',
104-
arg_types=[str_rprimitive, str_rprimitive],
105-
return_type=str_rprimitive,
106-
c_function_name='CPyStr_Append',
107-
error_kind=ERR_MAGIC,
108-
steals=[True, False])
109-
110-
unicode_compare = custom_op(
111-
arg_types=[str_rprimitive, str_rprimitive],
112-
return_type=c_int_rprimitive,
113-
c_function_name='PyUnicode_Compare',
114-
error_kind=ERR_NEVER)
115-
116-
# str[begin:end]
117-
str_slice_op = custom_op(
118-
arg_types=[str_rprimitive, int_rprimitive, int_rprimitive],
119-
return_type=object_rprimitive,
120-
c_function_name='CPyStr_GetSlice',
121-
error_kind=ERR_MAGIC)
122-
123125
# str.replace(old, new)
124126
method_op(
125127
name='replace',

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(self) -> None: pass
115115
def __init__(self, x: object) -> None: pass
116116
@overload
117117
def __init__(self, string: str, encoding: str, err: str = ...) -> None: pass
118+
def __add__(self, s: bytes) -> bytearray: ...
118119

119120
class bool(int):
120121
def __init__(self, o: object = ...) -> None: ...

mypyc/test-data/irbuild-bytes.test

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@ L0:
6262
c = r6
6363
return 1
6464

65+
[case testBytesConcat]
66+
def f1(a: bytes, b: bytes) -> bytes:
67+
return a + b
68+
[out]
69+
def f1(a, b):
70+
a, b, r0 :: bytes
71+
L0:
72+
r0 = CPyBytes_Concat(a, b)
73+
return r0
74+
6575
[case testBytesJoin]
6676
from typing import List
6777
def f(b: List[bytes]) -> bytes:
@@ -90,4 +100,3 @@ L0:
90100
keep_alive b
91101
r2 = r1 << 1
92102
return r2
93-

mypyc/test-data/run-bytes.test

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,32 @@ def test_concat() -> None:
5858
b1 = b'123' + bytes()
5959
b2 = b'456' + bytes()
6060
assert b1 + b2 == b'123456'
61+
b3 = b1 + b2
62+
b3 = b3 + b1
63+
assert b3 == b'123456123'
64+
assert b1 == b'123'
65+
assert b2 == b'456'
66+
assert type(b1) == bytes
67+
assert type(b2) == bytes
68+
assert type(b3) == bytes
69+
brr1: bytes = bytearray(3)
70+
brr2: bytes = bytearray(range(5))
71+
b4 = b1 + brr1
72+
assert b4 == b'123\x00\x00\x00'
73+
assert type(brr1) == bytearray
74+
assert type(b4) == bytes
75+
brr3 = brr1 + brr2
76+
assert brr3 == bytearray(b'\x00\x00\x00\x00\x01\x02\x03\x04')
77+
assert len(brr3) == 8
78+
assert type(brr3) == bytearray
79+
brr3 = brr3 + bytearray([10])
80+
assert brr3 == bytearray(b'\x00\x00\x00\x00\x01\x02\x03\x04\n')
81+
b5 = brr2 + b2
82+
assert b5 == bytearray(b'\x00\x01\x02\x03\x04456')
83+
assert type(b5) == bytearray
84+
b5 = b2 + brr2
85+
assert b5 == b'456\x00\x01\x02\x03\x04'
86+
assert type(b5) == bytes
6187

6288
def test_join() -> None:
6389
seq = (b'1', b'"', b'\xf0')

0 commit comments

Comments
 (0)
0