8000 Fixed a bug that led to a false positive when a generic function was … · sourcegraph/scip-python@15b47ae · GitHub
[go: up one dir, main page]

Skip to content

Commit 15b47ae

Browse files
committed
Fixed a bug that led to a false positive when a generic function was passed as an argument to another generic function multiple times. In such a case, the second (and subsequent) instances of the function must be given unique type parameters so they are distinguished from other instances of the same function. This addresses microsoft/pyright#4852.
1 parent 523762d commit 15b47ae

File tree

4 files changed

+168
-7
lines changed

4 files changed

+168
-7
lines changed

packages/pyright-internal/src/analyzer/typeEvaluator.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ import {
258258
convertTypeToParamSpecValue,
259259
derivesFromClassRecursive,
260260
doForEachSubtype,
261+
ensureFunctionSignaturesAreUnique,
261262
explodeGenericClass,
262263
getContainerDepth,
263264
getDeclaredGeneratorReturnType,
@@ -305,6 +306,7 @@ import {
305306
specializeTupleClass,
306307
synthesizeTypeVarForSelfCls,
307308
transformPossibleRecursiveTypeAlias,
309+
UniqueSignatureTracker,
308310
validateTypeVarDefault,
309311
} from './typeUtils';
310312
import { TypeVarContext, TypeVarSignatureContext } from './typeVarContext';
@@ -10877,6 +10879,8 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
1087710879
// where more than two passes are needed.
1087810880
let passCount = Math.min(typeVarMatchingCount, 2);
1087910881
for (let i = 0; i < passCount; i++) {
10882+
const signatureTracker = new UniqueSignatureTracker();
10883+
1088010884
useSpeculativeMode(errorNode, () => {
1088110885
matchResults.argParams.forEach((argParam) => {
1088210886
if (!argParam.requiresTypeVarMatching) {
@@ -10893,6 +10897,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
1089310897
const argResult = validateArgType(
1089410898
argParam,
1089510899
typeVarContext,
10900+
signatureTracker,
1089610901
{ type, isIncomplete: matchResults.isTypeIncomplete },
1089710902
skipUnknownArgCheck,
1089810903
/* skipOverloadArg */ i === 0,
@@ -10925,10 +10930,12 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
1092510930
let condition: TypeCondition[] = [];
1092610931
const argResults: ArgResult[] = [];
1092710932

10933+
const signatureTracker = new UniqueSignatureTracker();
1092810934
matchResults.argParams.forEach((argParam) => {
1092910935
const argResult = validateArgType(
1093010936
argParam,
1093110937
typeVarContext,
10938+
signatureTracker,
1093210939
{ type, isIncomplete: matchResults.isTypeIncomplete },
1093310940
skipUnknownArgCheck,
1093410941
/* skipOverloadArg */ false,
@@ -11275,6 +11282,8 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
1127511282
(paramInfo) => paramInfo.category === ParameterCategory.VarArgDictionary
1127611283
);
1127711284

11285+
const signatureTracker = new UniqueSignatureTracker();
11286+
1127811287
argList.forEach((arg) => {
1127911288
if (arg.argumentCategory === ArgumentCategory.Simple) {
1128011289
let paramType: Type | undefined;
@@ -11331,6 +11340,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
1133111340
errorNode: arg.valueExpression || errorNode,
1133211341
},
1133311342
srcTypeVarContext,
11343+
signatureTracker,
1133411344
/* functionType */ undefined,
1133511345
/* skipUnknownArgCheck */ false,
1133611346
/* skipOverloadArg */ false,
@@ -11382,6 +11392,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
1138211392
function validateArgType(
1138311393
argParam: ValidateArgTypeParams,
1138411394
typeVarContext: TypeVarContext,
11395+
signatureTracker: UniqueSignatureTracker,
1138511396
typeResult: TypeResult<FunctionType> | undefined,
1138611397
skipUnknownCheck: boolean,
1138711398
skipOverloadArg: boolean,
@@ -11481,6 +11492,11 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
1148111492
}
1148211493
}
1148311494

11495+
// If the type includes multiple instances of a generic function
11496+
// signature, force the type arguments for the duplicates to have
11497+
// unique names.
11498+
argType = ensureFunctionSignaturesAreUnique(argType, signatureTracker);
11499+
1148411500
// If we're assigning to a var arg dictionary with a TypeVar type,
1148511501
// strip literals before performing the assignment. This is used in
1148611502
// places like a dict constructor.
@@ -17170,6 +17186,8 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
1717017186
}
1717117187

1717217188
argList.forEach((arg) => {
17189+
const signatureTracker = new UniqueSignatureTracker();
17190+
1717317191
if (arg.argumentCategory === ArgumentCategory.Simple && arg.name) {
1717417192
const paramIndex = paramMap.get(arg.name.value) ?? paramListDetails.kwargsIndex;
1717517193

@@ -17189,6 +17207,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
1718917207
validateArgType(
1719017208
argParam,
1719117209
new TypeVarContext(),
17210+
signatureTracker,
1719217211
{ type: newMethodType },
1719317212
/* skipUnknownCheck */ true,
1719417213
/* skipOverloadArg */ true,

packages/pyright-internal/src/analyzer/typeUtils.ts

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,34 @@ export interface InferenceContext {
203203
typeVarContext?: TypeVarContext;
204204
}
205205

206+
export interface SignatureWithCount {
207+
type: FunctionType;
208+
count: number;
209+
}
210+
211+
export class UniqueSignatureTracker {
212+
public signaturesSeen: SignatureWithCount[];
213+
214+
constructor() {
215+
this.signaturesSeen = [];
216+
}
217+
218+
findSignature(signature: FunctionType): SignatureWithCount | undefined {
219+
return this.signaturesSeen.find((s) => {
220+
return isTypeSame(signature, s.type);
221+
});
222+
}
223+
224+
addSignature(signature: FunctionType) {
225+
const existingSignature = this.findSignature(signature);
226+
if (existingSignature) {
227+
existingSignature.count++;
228+
} else {
229+
this.signaturesSeen.push({ type: signature, count: 1 });
230+
}
231+
}
232+
}
233+
206234
export function isOptionalType(type: Type): boolean {
207235
if (isUnion(type)) {
208236
return findSubtype(type, (subtype) => isNoneInstance(subtype)) !== undefined;
@@ -936,6 +964,13 @@ export function populateTypeVarContextForSelfType(
936964
typeVarContext.setTypeVarType(synthesizedSelfTypeVar, convertToInstance(selfClass));
937965
}
938966

967+
// Looks for duplicate function types within the type and ensures that
968+
// if they are generic, they have unique type variables.
969+
export function ensureFunctionSignaturesAreUnique(type: Type, signatureTracker: UniqueSignatureTracker): Type {
970+
const transformer = new UniqueFunctionSignatureTransformer(signatureTracker);
971+
return transformer.apply(type, 0);
972+
}
973+
939974
// Specializes a (potentially generic) type by substituting
940975
// type variables from a type var map.
941976
export function applySolvedTypeVars(
@@ -2766,7 +2801,7 @@ class TypeVarTransformer {
27662801
}
27672802
recursionCount++;
27682803

2769-
type = this._transformGenericTypeAlias(type, recursionCount);
2804+
type = this.transformGenericTypeAlias(type, recursionCount);
27702805

27712806
// Shortcut the operation if possible.
27722807
if (!requiresSpecialization(type)) {
@@ -2873,7 +2908,7 @@ class TypeVarTransformer {
28732908
}
28742909

28752910
if (isClass(type)) {
2876-
return this._transformTypeVarsInClassType(type, recursionCount);
2911+
return this.transformTypeVarsInClassType(type, recursionCount);
28772912
}
28782913

28792914
if (isFunction(type)) {
@@ -2883,7 +2918,7 @@ class TypeVarTransformer {
28832918
}
28842919

28852920
this._pendingFunctionTransformations.push(type);
2886-
const result = this._transformTypeVarsInFunctionType(type, recursionCount);
2921+
const result = this.transformTypeVarsInFunctionType(type, recursionCount);
28872922
this._pendingFunctionTransformations.pop();
28882923

28892924
return result;
@@ -2902,7 +2937,7 @@ class TypeVarTransformer {
29022937
// Specialize each of the functions in the overload.
29032938
const newOverloads: FunctionType[] = [];
29042939
type.overloads.forEach((entry) => {
2905-
const replacementType = this._transformTypeVarsInFunctionType(entry, recursionCount);
2940+
const replacementType = this.transformTypeVarsInFunctionType(entry, recursionCount);
29062941

29072942
if (isFunction(replacementType)) {
29082943
newOverloads.push(replacementType);
@@ -2946,7 +2981,7 @@ class TypeVarTransformer {
29462981
return callback();
29472982
}
29482983

2949-
private _transformGenericTypeAlias(type: Type, recursionCount: number) {
2984+
transformGenericTypeAlias(type: Type, recursionCount: number) {
29502985
if (!type.typeAliasInfo || !type.typeAliasInfo.typeParameters || !type.typeAliasInfo.typeArguments) {
29512986
return type;
29522987
}
@@ -2972,7 +3007,7 @@ class TypeVarTransformer {
29723007
: type;
29733008
}
29743009

2975-
private _transformTypeVarsInClassType(classType: ClassType, recursionCount: number): ClassType {
3010+
transformTypeVarsInClassType(classType: ClassType, recursionCount: number): ClassType {
29763011
// Handle the common case where the class has no type parameters.
29773012
if (ClassType.getTypeParameters(classType).length === 0 && !ClassType.isSpecialBuiltIn(classType)) {
29783013
return classType;
@@ -3091,7 +3126,7 @@ class TypeVarTransformer {
30913126
);
30923127
}
30933128

3094-
private _transformTypeVarsInFunctionType(
3129+
transformTypeVarsInFunctionType(
30953130
sourceType: FunctionType,
30963131
recursionCount: number
30973132
): FunctionType | OverloadedFunctionType {
@@ -3317,6 +3352,45 @@ class TypeVarDefaultValidator extends TypeVarTransformer {
33173352
}
33183353
}
33193354

3355+
class UniqueFunctionSignatureTransformer extends TypeVarTransformer {
3356+
constructor(private _signatureTracker: UniqueSignatureTracker) {
3357+
super();
3358+
}
3359+
3360+
override transformTypeVarsInFunctionType(
3361+
sourceType: FunctionType,
3362+
recursionCount: number
3363+
): FunctionType | OverloadedFunctionType {
3364+
// If this function has already been specialized or is not generic,
3365+
// there's no need to check for uniqueness.
3366+
if (sourceType.specializedTypes || sourceType.details.typeParameters.length === 0) {
3367+
return super.transformTypeVarsInFunctionType(sourceType, recursionCount);
3368+
}
3369+
3370+
let updatedSourceType: Type = sourceType;
3371+
const existingSignature = this._signatureTracker.findSignature(sourceType);
3372+
if (existingSignature) {
3373+
const typeVarContext = new TypeVarContext(getTypeVarScopeId(sourceType));
3374+
3375+
// Create new type variables with the same scope but with
3376+
// different (unique) names.
3377+
sourceType.details.typeParameters.forEach((typeParam) => {
3378+
const replacement = TypeVarType.cloneForNewName(
3379+
typeParam,
3380+
`${typeParam.details.name}(${existingSignature.count})`
3381+
);
3382+
typeVarContext.setTypeVarType(typeParam, replacement);
3383+
3384+
updatedSourceType = applySolvedTypeVars(sourceType, typeVarContext);
3385+
});
3386+
}
3387+
3388+
this._signatureTracker.addSignature(sourceType);
3389+
3390+
return updatedSourceType;
3391+
}
3392+
}
3393+
33203394
// Specializes a (potentially generic) type by substituting
33213395
// type variables from a type var map.
33223396
class ApplySolvedTypeVarsTransformer extends TypeVarTransformer {

packages/pyright-internal/src/analyzer/types.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2236,6 +2236,18 @@ export namespace TypeVarType {
22362236
return newInstance;
22372237
}
22382238

2239+
export function cloneForNewName(type: TypeVarType, name: string): TypeVarType {
2240+
const newInstance = TypeBase.cloneType(type);
2241+
newInstance.details = { ...type.details };
2242+
newInstance.details.name = name;
2243+
2244+
if (newInstance.scopeId) {
2245+
newInstance.nameWithScope = makeNameWithScope(name, newInstance.scopeId);
2246+
}
2247+
2248+
return newInstance;
2249+
}
2250+
22392251
export function cloneForScopeId(
22402252
type: TypeVarType,
22412253
scopeId: string,
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# This sample tests the case where a generic function is passed
2+
# as an argument to another generic function multiple times.
3+
4+
from typing import TypeVar, Callable
5+
6+
T = TypeVar("T")
7+
A = TypeVar("A")
8+
B = TypeVar("B")
9+
C = TypeVar("C")
10+
X = TypeVar("X")
11+
Y = TypeVar("Y")
12+
Z = TypeVar("Z")
13+
14+
15+
def identity(x: T) -> T:
16+
return x
17+
18+
19+
def triple_1(
20+
f: Callable[[A], X], g: Callable[[B], Y], h: Callable[[C], Z]
21+
) -> Callable[[A, B, C], tuple[X, Y, Z]]:
22+
def wrapped(a: A, b: B, c: C) -> tuple[X, Y, Z]:
23+
return f(a), g(b), h(c)
24+
25+
return wrapped
26+
27+
28+
def triple_2(
29+
f: tuple[Callable[[A], X], Callable[[B], Y], Callable[[C], Z]]
30+
) -> Callable[[A, B, C], tuple[X, Y, Z]]:
31+
def wrapped(a: A, b: B, c: C) -> tuple[X, Y, Z]:
32+
return f[0](a), f[1](b), f[2](c)
33+
34+
return wrapped
35+
36+
37+
def test_1(f: Callable[[A], X]) -> Callable[[A, B, C], tuple[X, B, C]]:
38+
val = triple_1(f, identity, identity)
39+
40+
reveal_type(
41+
val,
42+
expected_text="(A@test_1, T@identity, T(1)@identity) -> tuple[X@test_1, T@identity, T(1)@identity]",
43+
)
44+
45+
return val
46+
47+
48+
def test_2(f: Callable[[A], X]) -> Callable[[A, B, C], tuple[X, B, C]]:
49+
val = triple_2((f, identity, identity))
50+
51+
reveal_type(
52+
val,
53+
expected_text="(A@test_2, T(1)@identity, T(2)@identity) -> tuple[X@test_2, T(1)@identity, T(2)@identity]",
54+
)
55+
56+
return val

0 commit comments

Comments
 (0)
0