Skip to content

Commit 888f87a

Browse files
committed
Rust: Improve type inference for closures
1 parent 5e187b4 commit 888f87a

File tree

5 files changed

+553
-155
lines changed

5 files changed

+553
-155
lines changed

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 104 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,14 @@ private predicate isPanicMacroCall(MacroExpr me) {
416416
me.getMacroCall().resolveMacro().(MacroRules).getName().getText() = "panic"
417417
}
418418

419+
// Due to "binding modes" the type of the pattern is not necessarily the
420+
// same as the type of the initializer. The pattern being an identifier
421+
// pattern is sufficient to ensure that this is not the case.
422+
private predicate identLetStmt(LetStmt let, IdentPat lhs, Expr rhs) {
423+
let.getPat() = lhs and
424+
let.getInitializer() = rhs
425+
}
426+
419427
/** Module for inferring certain type information. */
420428
module CertainTypeInference {
421429
pragma[nomagic]
@@ -493,11 +501,7 @@ module CertainTypeInference {
493501
// is not a certain type equality.
494502
exists(LetStmt let |
495503
not let.hasTypeRepr() and
496-
// Due to "binding modes" the type of the pattern is not necessarily the
497-
// same as the type of the initializer. The pattern being an identifier
498-
// pattern is sufficient to ensure that this is not the case.
499-
let.getPat().(IdentPat) = n1 and
500-
let.getInitializer() = n2
504+
identLetStmt(let, n1, n2)
501505
)
502506
or
503507
exists(LetExpr let |
@@ -521,6 +525,25 @@ module CertainTypeInference {
521525
)
522526
else prefix2.isEmpty()
523527
)
528+
or
529+
exists(CallExprImpl::DynamicCallExpr dce, TupleType tt, int i |
530+
n1 = dce.getArgList() and
531+
tt.getArity() = dce.getNumberOfSyntacticArguments() and
532+
n2 = dce.getSyntacticPositionalArgument(i) and
533+
prefix1 = TypePath::singleton(tt.getPositionalTypeParameter(i)) and
534+
prefix2.isEmpty()
535+
)
536+
or
537+
exists(ClosureExpr ce, int index |
538+
n1 = ce and
539+
n2 = ce.getParam(index).getPat() and
540+
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
541+
prefix2.isEmpty()
542+
)
543+
or
544+
n1 = any(ClosureExpr ce | not ce.hasRetType() and ce.getClosureBody() = n2) and
545+
prefix1 = closureReturnPath() and
546+
prefix2.isEmpty()
524547
}
525548

526549
pragma[nomagic]
@@ -783,17 +806,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
783806
n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and
784807
prefix1 = TypePath::singleton(getArrayTypeParameter()) and
785808
prefix2.isEmpty()
786-
or
787-
exists(ClosureExpr ce, int index |
788-
n1 = ce and
789-
n2 = ce.getParam(index).getPat() and
790-
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
791-
prefix2.isEmpty()
792-
)
793-
or
794-
n1.(ClosureExpr).getClosureBody() = n2 and
795-
prefix1 = closureReturnPath() and
796-
prefix2.isEmpty()
797809
}
798810

799811
/**
@@ -836,6 +848,19 @@ private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
836848
)
837849
}
838850

851+
private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) {
852+
inferType(n, path) = TUnknownType() and
853+
// Normally, these are coercion sites, but in case a type is unknown we
854+
// allow for type information to flow from the type annotation.
855+
exists(TypeMention tm | result = tm.getTypeAt(path) |
856+
tm = any(LetStmt let | identLetStmt(let, _, n)).getTypeRepr()
857+
or
858+
tm = any(ClosureExpr ce | n = ce.getBody()).getRetType().getTypeRepr()
859+
or
860+
tm = getReturnTypeMention(any(Function f | n = f.getBody()))
861+
)
862+
}
863+
839864
/**
840865
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
841866
* of `n2` at `prefix2`, but type information should only propagate from `n1` to
@@ -1685,6 +1710,8 @@ private module AssocFunctionResolution {
16851710
* 3. `AssocFunctionCallCallExpr`: a qualified function call, `Q::f(x)`; or
16861711
* 4. `AssocFunctionCallOperation`: an operation expression, `x + y`, which is syntactic sugar
16871712
* for `Add::add(x, y)`.
1713+
* 5. `ClosureMethodCall`: a call to a closure, `c(x)`, which is syntactic sugar for
1714+
* `c.call_once(x)`, `c.call_mut(x)`, or `c.call(x)`.
16881715
*
16891716
* Note that only in case 1 and 2 is auto-dereferencing and borrowing allowed.
16901717
*
@@ -1701,7 +1728,7 @@ private module AssocFunctionResolution {
17011728
pragma[nomagic]
17021729
abstract predicate hasNameAndArity(string name, int arity);
17031730

1704-
abstract Expr getNonReturnNodeAt(FunctionPosition pos);
1731+
abstract AstNode getNonReturnNodeAt(FunctionPosition pos);
17051732

17061733
AstNode getNodeAt(FunctionPosition pos) {
17071734
result = this.getNonReturnNodeAt(pos)
@@ -2235,7 +2262,7 @@ private module AssocFunctionResolution {
22352262
}
22362263
}
22372264

2238-
private class AssocFunctionCallMethodCallExpr extends AssocFunctionCall instanceof MethodCallExpr {
2265+
private class MethodCallExprAssocFunctionCall extends AssocFunctionCall instanceof MethodCallExpr {
22392266
override predicate hasNameAndArity(string name, int arity) {
22402267
name = super.getIdentifier().getText() and
22412268
arity = super.getNumberOfSyntacticArguments()
@@ -2255,7 +2282,7 @@ private module AssocFunctionResolution {
22552282
override Trait getTrait() { none() }
22562283
}
22572284

2258-
private class AssocFunctionCallIndexExpr extends AssocFunctionCall, IndexExpr {
2285+
private class IndexExprAssocFunctionCall extends AssocFunctionCall, IndexExpr {
22592286
private predicate isInMutableContext() {
22602287
// todo: does not handle all cases yet
22612288
VariableImpl::assignmentOperationDescendant(_, this)
@@ -2285,8 +2312,8 @@ private module AssocFunctionResolution {
22852312
}
22862313
}
22872314

2288-
private class AssocFunctionCallCallExpr extends AssocFunctionCall, CallExpr {
2289-
AssocFunctionCallCallExpr() {
2315+
private class CallExprAssocFunctionCall extends AssocFunctionCall, CallExpr {
2316+
CallExprAssocFunctionCall() {
22902317
exists(getCallExprPathQualifier(this)) and
22912318
// even if a target cannot be resolved by path resolution, it may still
22922319
// be possible to resolve a blanket implementation (so not `forex`)
@@ -2318,7 +2345,7 @@ private module AssocFunctionResolution {
23182345
override Trait getTrait() { result = getCallExprTraitQualifier(this) }
23192346
}
23202347

2321-
final class AssocFunctionCallOperation extends AssocFunctionCall, Operation {
2348+
final class OperationAssocFunctionCall extends AssocFunctionCall, Operation {
23222349
override predicate hasNameAndArity(string name, int arity) {
23232350
this.isOverloaded(_, name, _) and
23242351
arity = this.getNumberOfOperands()
@@ -2376,6 +2403,29 @@ private module AssocFunctionResolution {
23762403
override Trait getTrait() { this.isOverloaded(result, _, _) }
23772404
}
23782405

2406+
private class DynamicAssocFunctionCall extends AssocFunctionCall instanceof CallExprImpl::DynamicCallExpr
2407+
{
2408+
pragma[nomagic]
2409+
override predicate hasNameAndArity(string name, int arity) {
2410+
name = "call_once" and // todo: handle call_mut and call
2411+
arity = 2 // args are passed in a tuple
2412+
}
2413+
2414+
override predicate hasReceiver() { any() }
2415+
2416+
override AstNode getNonReturnNodeAt(FunctionPosition pos) {
2417+
pos.asPosition() = 0 and
2418+
result = super.getFunction()
2419+
or
2420+
pos.asPosition() = 1 and
2421+
result = super.getArgList()
2422+
}
2423+
2424+
override predicate supportsAutoDerefAndBorrow() { any() }
2425+
2426+
override Trait getTrait() { result instanceof AnyFnTrait }
2427+
}
2428+
23792429
pragma[nomagic]
23802430
private AssocFunctionDeclaration getAssocFunctionSuccessor(
23812431
ImplOrTraitItemNode i, string name, int arity
@@ -3239,7 +3289,7 @@ private module OperationMatchingInput implements MatchingInputSig {
32393289
}
32403290
}
32413291

3242-
class Access extends AssocFunctionResolution::AssocFunctionCallOperation {
3292+
class Access extends AssocFunctionResolution::OperationAssocFunctionCall {
32433293
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { none() }
32443294

32453295
pragma[nomagic]
@@ -3844,14 +3894,6 @@ private module InvokedClosureSatisfiesTypeInput implements SatisfiesTypeInputSig
38443894
}
38453895
}
38463896

3847-
private module InvokedClosureSatisfiesType =
3848-
SatisfiesType<InvokedClosureExpr, InvokedClosureSatisfiesTypeInput>;
3849-
3850-
/** Gets the type of `ce` when viewed as an implementation of `FnOnce`. */
3851-
private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
3852-
InvokedClosureSatisfiesType::satisfiesConstraintType(ce, _, path, result)
3853-
}
3854-
38553897
/**
38563898
* Gets the root type of a closure.
38573899
*
@@ -3878,73 +3920,39 @@ private TypePath closureParameterPath(int arity, int index) {
38783920
TypePath::singleton(getTupleTypeParameter(arity, index)))
38793921
}
38803922

3881-
/** Gets the path to the return type of the `FnOnce` trait. */
3882-
private TypePath fnReturnPath() {
3883-
result = TypePath::singleton(getAssociatedTypeTypeParameter(any(FnOnceTrait t).getOutputType()))
3884-
}
3885-
3886-
/**
3887-
* Gets the path to the parameter type of the `FnOnce` trait with arity `arity`
3888-
* and index `index`.
3889-
*/
38903923
pragma[nomagic]
3891-
private TypePath fnParameterPath(int arity, int index) {
3892-
result =
3893-
TypePath::cons(TTypeParamTypeParameter(any(FnOnceTrait t).getTypeParam()),
3894-
TypePath::singleton(getTupleTypeParameter(arity, index)))
3895-
}
3896-
3897-
pragma[nomagic]
3898-
private Type inferDynamicCallExprType(Expr n, TypePath path) {
3899-
exists(InvokedClosureExpr ce |
3900-
// Propagate the function's return type to the call expression
3901-
exists(TypePath path0 | result = invokedClosureFnTypeAt(ce, path0) |
3902-
n = ce.getCall() and
3903-
path = path0.stripPrefix(fnReturnPath())
3924+
private Type inferClosureExprType(AstNode n, TypePath path) {
3925+
exists(ClosureExpr ce |
3926+
n = ce and
3927+
(
3928+
path.isEmpty() and
3929+
result = closureRootType()
3930+
or
3931+
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
3932+
result.(TupleType).getArity() = ce.getNumberOfParams()
39043933
or
3905-
// Propagate the function's parameter type to the arguments
3906-
exists(int index |
3907-
n = ce.getCall().getSyntacticPositionalArgument(index) and
3908-
path =
3909-
path0.stripPrefix(fnParameterPath(ce.getCall().getArgList().getNumberOfArgs(), index))
3934+
exists(TypePath path0 |
3935+
result = ce.getRetType().getTypeRepr().(TypeMention).getTypeAt(path0) and
3936+
path = closureReturnPath().append(path0)
39103937
)
39113938
)
39123939
or
3913-
// _If_ the invoked expression has the type of a closure, then we propagate
3914-
// the surrounding types into the closure.
3915-
exists(int arity, TypePath path0 | ce.getTypeAt(TypePath::nil()) = closureRootType() |
3916-
// Propagate the type of arguments to the parameter types of closure
3917-
exists(int index, ArgList args |
3918-
n = ce and
3919-
args = ce.getCall().getArgList() and
3920-
arity = args.getNumberOfArgs() and
3921-
result = inferType(args.getArg(index), path0) and
3922-
path = closureParameterPath(arity, index).append(path0)
3923-
)
3924-
or
3925-
// Propagate the type of the call expression to the return type of the closure
3926-
n = ce and
3927-
arity = ce.getCall().getArgList().getNumberOfArgs() and
3928-
result = inferType(ce.getCall(), path0) and
3929-
path = closureReturnPath().append(path0)
3940+
exists(Param p |
3941+
p = ce.getAParam() and
3942+
not p.hasTypeRepr() and
3943+
n = p.getPat() and
3944+
result = TUnknownType() and
3945+
path.isEmpty()
39303946
)
39313947
)
39323948
}
39333949

39343950
pragma[nomagic]
3935-
private Type inferClosureExprType(AstNode n, TypePath path) {
3936-
exists(ClosureExpr ce |
3937-
n = ce and
3938-
path.isEmpty() and
3939-
result = closureRootType()
3940-
or
3941-
n = ce and
3942-
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
3943-
result.(TupleType).getArity() = ce.getNumberOfParams()
3944-
or
3945-
// Propagate return type annotation to body
3946-
n = ce.getClosureBody() and
3947-
result = ce.getRetType().getTypeRepr().(TypeMention).getTypeAt(path)
3951+
private TupleType inferArgList(ArgList args, TypePath path) {
3952+
exists(CallExprImpl::DynamicCallExpr dce |
3953+
args = dce.getArgList() and
3954+
result.getArity() = dce.getNumberOfSyntacticArguments() and
3955+
path.isEmpty()
39483956
)
39493957
}
39503958

@@ -3992,7 +4000,8 @@ private module Cached {
39924000
or
39934001
i instanceof ImplItemNode and dispatch = false
39944002
|
3995-
result = call.(AssocFunctionResolution::AssocFunctionCall).resolveCallTarget(i, _, _, _)
4003+
result = call.(AssocFunctionResolution::AssocFunctionCall).resolveCallTarget(i, _, _, _) and
4004+
not call instanceof CallExprImpl::DynamicCallExpr // todo
39964005
)
39974006
}
39984007

@@ -4101,13 +4110,15 @@ private module Cached {
41014110
or
41024111
result = inferForLoopExprType(n, path)
41034112
or
4104-
result = inferDynamicCallExprType(n, path)
4105-
or
41064113
result = inferClosureExprType(n, path)
41074114
or
4115+
result = inferArgList(n, path)
4116+
or
41084117
result = inferStructPatType(n, path)
41094118
or
41104119
result = inferTupleStructPatType(n, path)
4120+
or
4121+
result = inferUnknownTypeFromAnnotation(n, path)
41114122
)
41124123
}
41134124
}
@@ -4124,8 +4135,8 @@ private module Debug {
41244135
Locatable getRelevantLocatable() {
41254136
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
41264137
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
4127-
filepath.matches("%/main.rs") and
4128-
startline = 103
4138+
filepath.matches("%/regressions.rs") and
4139+
startline = 24
41294140
)
41304141
}
41314142

0 commit comments

Comments
 (0)