8000 Added support for type guards of the form `x[I] is B` and `x[I] is no… · codemuse-app/scip-python@f1db884 · GitHub
[go: up one dir, main page]

Skip to content

Commit f1db884

Browse files
committed
Added support for type guards of the form x[I] is B and x[I] is not B where x is a tuple and B is a boolean literal True or False. This addresses part of microsoft/pyright#4875.
1 parent ee851cf commit f1db884

File tree

3 files changed

+53
-10
lines changed

3 files changed

+53
-10
lines changed

docs/type-concepts-advanced.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t
6868
* `x.y == L` and `x.y != L` (where L is a literal expression and x is a type that is distinguished by a field or property with a literal type)
6969
* `x[K] == V`, `x[K] != V`, `x[K] is V`, and `x[K] is not V` (where K and V are literal expressions and x is a type that is distinguished by a TypedDict field with a literal type)
7070
* `x[I] == V` and `x[I] != V` (where I and V are literal expressions and x is a known-length tuple that is distinguished by the index indicated by I)
71+
* `x[I] is B` and `x[I] is not B` (where I is a literal expression, B is a `bool` literal, and x is a known-length tuple that is distinguished by the index indicated by I)
7172
* `x[I] is None` and `x[I] is not None` (where I is a literal expression and x is a known-length tuple that is distinguished by the index indicated by I)
7273
* `len(x) == L` and `len(x) != L` (where x is tuple and L is a literal integer)
7374
* `x in y` or `x not in y` (where y is instance of list, set, frozenset, deque, tuple, dict, defaultdict, or OrderedDict)

packages/pyright-internal/src/analyzer/typeGuards.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,28 @@ export function getTypeNarrowingCallback(
277277
};
278278
};
279279
}
280+
} else if (ClassType.isBuiltIn(indexType, 'int')) {
281+
const rightTypeResult = evaluator.getTypeOfExpression(testExpression.rightExpression);
282+
const rightType = rightTypeResult.type;
283+
284+
if (
285+
isClassInstance(rightType) &&
286+
ClassType.isBuiltIn(rightType, 'bool') &&
287+
rightType.literalValue !== undefined
288+
) {
289+
return (type: Type) => {
290+
return {
291+
type: narrowTypeForDiscriminatedTupleComparison(
292+
evaluator,
293+
type,
294+
indexType,
295+
rightType,
296+
adjIsPositiveTest
297+
),
298+
isIncomplete: !!rightTypeResult.isIncomplete,
299+
};
300+
};
301+
}
280302
}
281303
}
282304
}
Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,43 @@
11
# This sample tests the type narrowing for known-length tuples
22
# that have an entry with a declared literal type.
33

4-
from typing import Tuple, Union, Literal
4+
from typing import Literal
55

6-
MsgA = Tuple[Literal[1], str]
7-
MsgB = Tuple[Literal[2], float]
6+
MsgA = tuple[Literal[1], str]
7+
MsgB = tuple[Literal[2], float]
88

9-
Msg = Union[MsgA, MsgB]
9+
MsgAOrB = MsgA | MsgB
1010

1111

12-
def func1(m: Msg):
12+
def func1(m: MsgAOrB):
1313
if m[0] == 1:
14-
reveal_type(m, expected_text="Tuple[Literal[1], str]")
14+
reveal_type(m, expected_text="tuple[Literal[1], str]")
1515
else:
16-
reveal_type(m, expected_text="Tuple[Literal[2], float]")
16+
reveal_type(m, expected_text="tuple[Literal[2], float]")
1717

1818

19-
def func2(m: Msg):
19+
def func2(m: MsgAOrB):
2020
if m[0] != 1:
21-
reveal_type(m, expected_text="Tuple[Literal[2], float]")
21+
reveal_type(m, expected_text="tuple[Literal[2], float]")
2222
else:
23-
reveal_type(m, expected_text="Tuple[Literal[1], str]")
23+
reveal_type(m, expected_text="tuple[Literal[1], str]")
24+
25+
26+
MsgC = tuple[Literal[True], str]
27+
MsgD = tuple[Literal[False], float]
28+
29+
MsgCOrD = MsgC | MsgD
30+
31+
32+
def func3(m: MsgCOrD):
33+
if m[0] is True:
34+
reveal_type(m, expected_text="tuple[Literal[True], str]")
35+
else:
36+
reveal_type(m, expected_text="tuple[Literal[False], float]")
37+
38+
39+
def func4(m: MsgCOrD):
40+
if m[0] is not True:
41+
reveal_type(m, expected_text="tuple[Literal[False], float]")
42+
else:
43+
reveal_type(m, expected_text="tuple[Literal[True], str]")

0 commit comments

Comments
 (0)
0