8000 Fix crash on invalid call extension operator use (#1386) · TypeScriptToLua/TypeScriptToLua@7e8c2fb · GitHub
[go: up one dir, main page]

Skip to content
65EC

Commit 7e8c2fb

Browse files
authored
Fix crash on invalid call extension operator use (#1386)
* Fix crash on invalid call extension operator use * Fix prettier * Update snapshots
1 parent 3d0e98d commit 7e8c2fb

File tree

9 files changed

+225
-81
lines changed

9 files changed

+225
-81
lines changed

src/transformation/utils/diagnostics.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,11 @@ export const unsupportedOptionalCompileMembersOnly = createErrorDiagnosticFactor
155155
export const undefinedInArrayLiteral = createErrorDiagnosticFactory(
156156
"Array literals may not contain undefined or null."
157157
);
158+
159+
export const invalidMethodCallExtensionUse = createErrorDiagnosticFactory(
160+
"This language extension must be called as a method."
161+
);
162+
163+
export const invalidSpreadInCallExtension = createErrorDiagnosticFactory(
164+
"Spread elements are not supported in call extensions."
165+
);

src/transformation/utils/language-extensions.ts

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import * as ts from "typescript";
22
import { TransformationContext } from "../context";
3+
import { invalidMethodCallExtensionUse, invalidSpreadInCallExtension } from "./diagnostics";
34

45
export enum ExtensionKind {
56
MultiFunction = "MultiFunction",
@@ -53,6 +54,7 @@ export enum ExtensionKind {
5354
TableAddKeyType = "TableAddKey",
5455
TableAddKeyMethodType = "TableAddKeyMethod",
5556
}
57+
5658
const extensionValues: Set<string> = new Set(Object.values(ExtensionKind));
5759

5860
export function getExtensionKindForType(context: TransformationContext, type: ts.Type): ExtensionKind | undefined {
@@ -119,3 +121,78 @@ export function getIterableExtensionKindForNode(
119121
const type = context.checker.getTypeAtLocation(node);
120122
return getIterableExtensionTypeForType(context, type);
121123
}
124+
125+
export const methodExtensionKinds: ReadonlySet<ExtensionKind> = new Set<ExtensionKind>([
126+
ExtensionKind.AdditionOperatorMethodType,
127+
ExtensionKind.SubtractionOperatorMethodType,
128+
ExtensionKind.MultiplicationOperatorMethodType,
129+
ExtensionKind.DivisionOperatorMethodType,
130+
ExtensionKind.ModuloOperatorMethodType,
131+
ExtensionKind.PowerOperatorMethodType,
132+
ExtensionKind.FloorDivisionOperatorMethodType,
133+
ExtensionKind.BitwiseAndOperatorMethodType,
134+
ExtensionKind.BitwiseOrOperatorMethodType,
135+
ExtensionKind.BitwiseExclusiveOrOperatorMethodType,
136+
ExtensionKind.BitwiseLeftShiftOperatorMethodType,
137+
ExtensionKind.BitwiseRightShiftOperatorMethodType,
138+
ExtensionKind.ConcatOperatorMethodType,
139+
ExtensionKind.LessThanOperatorMethodType,
140+
ExtensionKind.GreaterThanOperatorMethodType,
141+
ExtensionKind.NegationOperatorMethodType,
142+
ExtensionKind.BitwiseNotOperatorMethodType,
143+
ExtensionKind.LengthOperatorMethodType,
144+
ExtensionKind.TableDeleteMethodType,
145+
ExtensionKind.TableGetMethodType,
146+
ExtensionKind.TableHasMethodType,
147+
ExtensionKind.TableSetMethodType,
148+
ExtensionKind.TableAddKeyMethodType,
149+
]);
150+
151+
export function getNaryCallExtensionArgs(
152+
context: TransformationContext,
153+
node: ts.CallExpression,
154+
kind: ExtensionKind,
155+
numArgs: number
156+
): readonly ts.Expression[] | undefined {
157+
let expressions: readonly ts.Expression[];
158+
if (node.arguments.some(ts.isSpreadElement)) {
159+
context.diagnostics.push(invalidSpreadInCallExtension(node));
160+
return undefined;
161+
}
162+
if (methodExtensionKinds.has(kind)) {
163+
if (!(ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression))) {
164+
context.diagnostics.push(invalidMethodCallExtensionUse(node));
165+
return undefined;
166+
}
167+
if (node.arguments.length < numArgs - 1) {
168+
// assumed to be TS error
169+
return undefined;
170+
}
171+
expressions = [node.expression.expression, ...node.arguments];
172+
} else {
173+
if (node.arguments.length < numArgs) {
174+
// assumed to be TS error
175+
return undefined;
176+
}
177+
expressions = node.arguments;
178+
}
179+
return expressions;
180+
}
181+
182+
export function getUnaryCallExtensionArg(
183+
context: TransformationContext,
184+
node: ts.CallExpression,
185+
kind: ExtensionKind
186+
): ts.Expression | undefined {
187+
return getNaryCallExtensionArgs(context, node, kind, 1)?.[0];
188+
}
189+
190+
export function getBinaryCallExtensionArgs(
191+
context: TransformationContext,
192+
node: ts.CallExpression,
193+
kind: ExtensionKind
194+
): readonly [ts.Expression, ts.Expression] | undefined {
195+
const expressions = getNaryCallExtensionArgs(context, node, kind, 2);
196+
if (expressions === undefined) return undefined;
197+
return [expressions[0], expressions[1]];
198+
}

src/transformation/visitors/language-extensions/operators.ts

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ import { TransformationContext } from "../../context";
44
import { assert } from "../../../utils";
55
import { LuaTarget } from "../../../CompilerOptions";
66
import { unsupportedForTarget } from "../../utils/diagnostics";
7-
import { ExtensionKind } from "../../utils/language-extensions";
7+
import { ExtensionKind, getBinaryCallExtensionArgs, getUnaryCallExtensionArg } from "../../utils/language-extensions";
88
import { LanguageExtensionCallTransformerMap } from "./call-extension";
9+
import { transformOrderedExpressions } from "../expression-list";
910

1011
const binaryOperatorMappings = new Map<ExtensionKind, lua.BinaryOperator>([
1112
[ExtensionKind.AdditionOperatorType, lua.SyntaxKind.AdditionOperator],
@@ -81,35 +82,21 @@ for (const kind of unaryOperatorMappings.keys()) {
8182
function transformBinaryOperator(context: TransformationContext, node: ts.CallExpression, kind: ExtensionKind) {
8283
if (requiresLua53.has(kind)) checkHasLua53(context, node, kind);
8384

84-
let args: readonly ts.Expression[] = node.arguments;
85-
if (
86-
args.length === 1 &&
87-
(ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression))
88-
) {
89-
args = [node.expression.expression, ...args];
90-
}
85+
const args = getBinaryCallExtensionArgs(context, node, kind);
86+
if (!args) return lua.createNilLiteral();
87+
88+
const [left, right] = transformOrderedExpressions(context, args);
9189

9290
const luaOperator = binaryOperatorMappings.get(kind);
9391
assert(luaOperator);
94-
return lua.createBinaryExpression(
95-
context.transformExpression(args[0]),
96-
context.transformExpression(args[1]),
97-
luaOperator
98-
);
92+
return lua.createBinaryExpression(left, right, luaOperator);
9993
}
10094

10195
function transformUnaryOperator(context: TransformationContext, node: ts.CallExpression, kind: ExtensionKind) {
10296
if (requiresLua53.has(kind)) checkHasLua53(context, node, kind);
10397

104-
let arg: ts.Expression;
105-
if (
106-
node.arguments.length === 0 &&
107-
(ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression))
108-
) {
109-
arg = node.expression.expression;
110-
} else {
111-
arg = node.arguments[0];
112-
}
98+
const arg = getUnaryCallExtensionArg(context, node, kind);
99+
if (!arg) return lua.createNilLiteral();
113100

114101
const luaOperator = unaryOperatorMappings.get(kind);
115102
assert(luaOperator);
Lines changed: 42 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
import * as ts from "typescript";
22
import * as lua from "../../../LuaAST";
33
import { TransformationContext } from "../../context";
4-
import { ExtensionKind, getExtensionKindForNode } from "../../utils/language-extensions";
5-
import { transformExpressionList } from "../expression-list";
6-
import { LanguageExtensionCallTransformer } from "./call-extension";
4+
import {
5+
ExtensionKind,
6+
getBinaryCallExtensionArgs,
7+
getExtensionKindForNode,
8+
getNaryCallExtensionArgs,
9+
} from "../../utils/language-extensions";
10+
import { transformOrderedExpressions } from "../expression-list";
11+
import { LanguageExtensionCallTransformerMap } from "./call-extension";
712

813
export function isTableNewCall(context: TransformationContext, node: ts.NewExpression) {
914
return getExtensionKindForNode(context, node.expression) === ExtensionKind.TableNewType;
1015
}
16+
1117
export const tableNewExtensions = [ExtensionKind.TableNewType];
1218

13-
export const tableExtensionTransformers: { [P in ExtensionKind]?: LanguageExtensionCallTransformer } = {
19+
export const tableExtensionTransformers: LanguageExtensionCallTransformerMap = {
1420
[ExtensionKind.TableDeleteType]: transformTableDeleteExpression,
1521
[ExtensionKind.TableDeleteMethodType]: transformTableDeleteExpression,
1622
[ExtensionKind.TableGetType]: transformTableGetExpression,
@@ -19,72 +25,56 @@ export const tableExtensionTransformers: { [P in ExtensionKind]?: LanguageExtens
1925
[ExtensionKind.TableHasMethodType]: transformTableHasExpression,
2026
[ExtensionKind.TableSetType]: transformTableSetExpression,
2127
[ExtensionKind.TableSetMethodType]: transformTableSetExpression,
22-
[ExtensionKind.TableAddKeyType]: transformTableAddExpression,
23-
[ExtensionKind.TableAddKeyMethodType]: transformTableAddExpression,
28+
[ExtensionKind.TableAddKeyType]: transformTableAddKeyExpression,
29+
[ExtensionKind.TableAddKeyMethodType]: transformTableAddKeyExpression,
2430
};
2531

2632
function transformTableDeleteExpression(
2733
context: TransformationContext,
2834
node: ts.CallExpression,
2935
extensionKind: ExtensionKind
3036
): lua.Expression {
31-
const args = node.arguments.slice();
32-
if (
33-
extensionKind === ExtensionKind.TableDeleteMethodType &&
34-
(ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression))
35-
) {
36-
// In case of method (no table argument), push method owner to front of args list
37-
args.unshift(node.expression.expression);
37+
const args = getBinaryCallExtensionArgs(context, node, extensionKind);
38+
if (!args) {
39+
return lua.createNilLiteral();
3840
}
3941

40-
const [table, accessExpression] = transformExpressionList(context, args);
42+
const [table, key] = transformOrderedExpressions(context, args);
4143
// arg0[arg1] = nil
4244
context.addPrecedingStatements(
43-
lua.createAssignmentStatement(
44-
lua.createTableIndexExpression(table, accessExpression),
45-
lua.createNilLiteral(),
46-
node
47-
)
45+
lua.createAssignmentStatement(lua.createTableIndexExpression(table, key), lua.createNilLiteral(), node)
4846
);
4947
return lua.createBooleanLiteral(true);
5048
}
5149

52-
function transformWithTableArgument(context: TransformationContext, node: ts.CallExpression): lua.Expression[] {
53-
if (ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression)) {
54-
return transformExpressionList(context, [node.expression.expression, ...node.arguments]);
55-
}
56-
// todo: report diagnostic?
57-
return [lua.createNilLiteral(), ...transformExpressionList(context, node.arguments)];
58-
}
59-
6050
function transformTableGetExpression(
6151
context: TransformationContext,
6252
node: ts.CallExpression,
6353
extensionKind: ExtensionKind
6454
): lua.Expression {
65-
const args =
66-
extensionKind === ExtensionKind.TableGetMethodType
67-
? transformWithTableArgument(context, node)
68-
: transformExpressionList(context, node.arguments);
55+
const args = getBinaryCallExtensionArgs(context, node, extensionKind);
56+
if (!args) {
57+
return lua.createNilLiteral();
58+
}
6959

70-
const [table, accessExpression] = args;
60+
const [table, key] = transformOrderedExpressions(context, args);
7161
// arg0[arg1]
72-
return lua.createTableIndexExpression(table, accessExpression, node);
62+
return lua.createTableIndexExpression(table, key, node);
7363
}
7464

7565
function transformTableHasExpression(
7666
context: TransformationContext,
7767
node: ts.CallExpression,
7868
extensionKind: ExtensionKind
7969
): lua.Expression {
80-
const args =
81-
extensionKind === ExtensionKind.TableHasMethodType
82-
? transformWithTableArgument(context, node)
83-
: transformExpressionList(context, node.arguments);
70+
const args = getBinaryCallExtensionArgs(context, node, extensionKind);
71+
if (!args) {
72+
return lua.createNilLiteral();
73+
}
8474

85-
const [table, accessExpression] = args;
75+
const [table, key] = transformOrderedExpressions(context, args);
8676
// arg0[arg1]
87-
const tableIndexExpression = lua.createTableIndexExpression(table, accessExpression);
77+
const tableIndexExpression = lua.createTableIndexExpression(table, key);
8878

8979
// arg0[arg1] ~= nil
9080
return lua.createBinaryExpression(
@@ -100,37 +90,33 @@ function transformTableSetExpression(
10090
node: ts.CallExpression,
10191
extensionKind: ExtensionKind
10292
): lua.Expression {
103-
const args =
104-
extensionKind === ExtensionKind.TableSetMethodType
105-
? transformWithTableArgument(context, node)
106-
: transformExpressionList(context, node.arguments);
93+
const args = getNaryCallExtensionArgs(context, node, extensionKind, 3);
94+
if (!args) {
95+
return lua.createNilLiteral();
96+
}
10797

108-
const [table, accessExpression, value] = args;
98+
const [table, key, value] = transformOrderedExpressions(context, args);
10999
// arg0[arg1] = arg2
110100
context.addPrecedingStatements(
111-
lua.createAssignmentStatement(lua.createTableIndexExpression(table, accessExpression), value, node)
101+
lua.createAssignmentStatement(lua.createTableIndexExpression(table, key), value, node)
112102
);
113103
return lua.createNilLiteral();
114104
}
115105

116-
function transformTableAddExpression(
106+
function transformTableAddKeyExpression(
117107
context: TransformationContext,
118108
node: ts.CallExpression,
119109
extensionKind: ExtensionKind
120110
): lua.Expression {
121-
const args =
122-
extensionKind === ExtensionKind.TableAddKeyMethodType
123-
? transformWithTableArgument(context, node)
124-
: transformExpressionList(context, node.arguments);
111+
const args = getNaryCallExtensionArgs(context, node, extensionKind, 2);
112+
if (!args) {
113+
return lua.createNilLiteral();
114+
}
125115

126-
const [table, value] = args;
116+
const [table, key] = transformOrderedExpressions(context, args);
127117
// arg0[arg1] = true
128118
context.addPrecedingStatements(
129-
lua.createAssignmentStatement(
130-
lua.createTableIndexExpression(table, value),
131-
lua.createBooleanLiteral(true),
132-
node
133-
)
119+
lua.createAssignmentStatement(lua.createTableIndexExpression(table, key), lua.createBooleanLiteral(true), node)
134120
);
135121
return lua.createNilLiteral();
136122
}

test/unit/__snapshots__/optionalChaining.spec.ts.snap

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,13 @@ exports[`Unsupported optional chains Compile members only: diagnostics 1`] = `"m
6363
exports[`Unsupported optional chains Language extensions: code 1`] = `
6464
"local ____opt_0 = ({}).has
6565
if ____opt_0 ~= nil then
66-
local ____ = nil[3] ~= nil
6766
end"
6867
`;
6968

70-
exports[`Unsupported optional chains Language extensions: diagnostics 1`] = `"main.ts(2,17): error TSTL: Optional calls are not supported for builtin or language extension functions."`;
69+
exports[`Unsupported optional chains Language extensions: diagnostics 1`] = `
70+
"main.ts(2,17): error TSTL: Optional calls are not supported for builtin or language extension functions.
71+
main.ts(2,17): error TSTL: This language extension must be called as a method."
72+
`;
7173

7274
exports[`long optional chain 1`] = `
7375
"local ____exports = {}

test/unit/language-extensions/__snapshots__/operators.spec.ts.snap

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
// Jest Snapshot v1, https://goo.gl/fbAQLP
22

3+
exports[`does not crash on invalid operator use global function: code 1`] = `""`;
4+
5+
exports[`does not crash on invalid operator use global function: diagnostics 1`] = `"main.ts(3,13): error TS2554: Expected 2 arguments, but got 1."`;
6+
7+
exports[`does not crash on invalid operator use method: code 1`] = `"left = {}"`;
8+
9+
exports[`does not crash on invalid operator use method: diagnostics 1`] = `"main.ts(5,18): error TS2554: Expected 1 arguments, but got 0."`;
10+
11+
exports[`does not crash on invalid operator use unary operator: code 1`] = `"op(_G)"`;
12+
13+
exports[`does not crash on invalid operator use unary operator: diagnostics 1`] = `"main.ts(2,31): error TS2304: Cannot find name 'LuaUnaryMinus'."`;
14+
315
exports[`operator mapping - invalid use (const foo = (op as any)(1, 2);): code 1`] = `"foo = op(_G, 1, 2)"`;
416

517
exports[`operator mapping - invalid use (const foo = (op as any)(1, 2);): diagnostics 1`] = `"main.ts(3,22): error TSTL: This function must be called directly and cannot be referred to."`;

test/unit/language-extensions/__snapshots__/table.spec.ts.snap

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,11 @@ __TS__ArrayMap({\\"a\\", \\"b\\", \\"c\\"}, ____table.has)"
133133
`;
134134
135135
exports[`LuaTableHas extension invalid use method expression ("LuaTable<string, number>"): diagnostics 1`] = `"main.ts(3,37): error TSTL: This function must be called directly and cannot be referred to."`;
136+
137+
exports[`does not crash on invalid extension use global function: code 1`] = `""`;
138+
139+
exports[`does not crash on invalid extension use global function: diagnostics 1`] = `"main.ts(3,9): error TS2554: Expected 2 arguments, but got 1."`;
140+
141+
exports[`does not crash on invalid extension use method: code 1`] = `"left = {}"`;
142+
143+
exports[`does not crash on invalid extension use method: diagnostics 1`] = `"main.ts(5,14): error TS2554: Expected 2 arguments, but got 0."`;

0 commit comments

Comments
 (0)
0