Skip to content

Commit 55c3674

Browse files
committed
Rust: Add type inference for dyn types
1 parent 9592185 commit 55c3674

File tree

10 files changed

+457
-199
lines changed

10 files changed

+457
-199
lines changed

rust/ql/.generated.list

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/ql/.gitattributes

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/ql/lib/codeql/rust/elements/internal/DynTraitTypeReprImpl.qll

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// generated by codegen, remove this comment if you wish to edit this file
21
/**
32
* This module provides a hand-modifiable wrapper around the generated class `DynTraitTypeRepr`.
43
*
@@ -12,6 +11,10 @@ private import codeql.rust.elements.internal.generated.DynTraitTypeRepr
1211
* be referenced directly.
1312
*/
1413
module Impl {
14+
private import rust
15+
private import codeql.rust.internal.PathResolution as PathResolution
16+
17+
// the following QLdoc is generated: if you need to edit it, do it in the schema file
1518
/**
1619
* A dynamic trait object type.
1720
*
@@ -21,5 +24,16 @@ module Impl {
2124
* // ^^^^^^^^^
2225
* ```
2326
*/
24-
class DynTraitTypeRepr extends Generated::DynTraitTypeRepr { }
27+
class DynTraitTypeRepr extends Generated::DynTraitTypeRepr {
28+
/** Gets the trait that this trait object refers to. */
29+
pragma[nomagic]
30+
Trait getTrait() {
31+
result =
32+
PathResolution::resolvePath(this.getTypeBoundList()
33+
.getBound(0)
34+
.getTypeRepr()
35+
.(PathTypeRepr)
36+
.getPath())
37+
}
38+
}
2539
}

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,15 @@ newtype TType =
1616
TArrayType() or // todo: add size?
1717
TRefType() or // todo: add mut?
1818
TImplTraitType(ImplTraitTypeRepr impl) or
19+
TDynTraitType(Trait t) { t = any(DynTraitTypeRepr dt).getTrait() } or
1920
TSliceType() or
2021
TTypeParamTypeParameter(TypeParam t) or
2122
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
2223
TArrayTypeParameter() or
24+
TDynTraitTypeParameter(Trait t, TypeParam tp) {
25+
t = any(DynTraitTypeRepr dt).getTrait() and
26+
tp = t.getGenericParamList().getAGenericParam()
27+
} or
2328
TRefTypeParameter() or
2429
TSelfTypeParameter(Trait t) or
2530
TSliceTypeParameter()
@@ -226,6 +231,26 @@ class ImplTraitType extends Type, TImplTraitType {
226231
override Location getLocation() { result = impl.getLocation() }
227232
}
228233

234+
class DynTraitType extends Type, TDynTraitType {
235+
Trait trait;
236+
237+
DynTraitType() { this = TDynTraitType(trait) }
238+
239+
override StructField getStructField(string name) { none() }
240+
241+
override TupleField getTupleField(int i) { none() }
242+
243+
override DynTraitTypeParameter getTypeParameter(int i) {
244+
result = TDynTraitTypeParameter(trait, trait.getGenericParamList().getTypeParam(i))
245+
}
246+
247+
Trait getTrait() { result = trait }
248+
249+
override string toString() { result = "dyn " + trait.getName().toString() }
250+
251+
override Location getLocation() { result = trait.getLocation() }
252+
}
253+
229254
/**
230255
* An [impl Trait in return position][1] type, for example:
231256
*
@@ -336,6 +361,23 @@ class ArrayTypeParameter extends TypeParameter, TArrayTypeParameter {
336361
override Location getLocation() { result instanceof EmptyLocation }
337362
}
338363

364+
class DynTraitTypeParameter extends TypeParameter, TDynTraitTypeParameter {
365+
private Trait trait;
366+
private TypeParam typeParam;
367+
368+
DynTraitTypeParameter() { this = TDynTraitTypeParameter(trait, typeParam) }
369+
370+
Trait getTrait() { result = trait }
371+
372+
TypeParam getTypeParam() { result = typeParam }
373+
374+
override string toString() {
375+
result = "dyn " + trait.getName().toString() + "<" + typeParam.toString() + ">"
376+
}
377+
378+
override Location getLocation() { result = typeParam.getLocation() }
379+
}
380+
339381
/** An implicit reference type parameter. */
340382
class RefTypeParameter extends TypeParameter, TRefTypeParameter {
341383
override string toString() { result = "&T" }
@@ -420,6 +462,13 @@ final class ImplTypeAbstraction extends TypeAbstraction, Impl {
420462
}
421463
}
422464

465+
final class DynTypeAbstraction extends TypeAbstraction, DynTraitTypeRepr {
466+
override TypeParameter getATypeParameter() {
467+
result.(TypeParamTypeParameter).getTypeParam() =
468+
this.getTrait().getGenericParamList().getATypeParam()
469+
}
470+
}
471+
423472
final class TraitTypeAbstraction extends TypeAbstraction, Trait {
424473
override TypeParameter getATypeParameter() {
425474
result.(TypeParamTypeParameter).getTypeParam() = this.getGenericParamList().getATypeParam()

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ private module Input1 implements InputSig1<Location> {
9797
id = 2
9898
or
9999
kind = 1 and
100+
id = idOfTypeParameterAstNode(tp0.(DynTraitTypeParameter).getTypeParam())
101+
or
102+
kind = 2 and
100103
exists(AstNode node | id = idOfTypeParameterAstNode(node) |
101104
node = tp0.(TypeParamTypeParameter).getTypeParam() or
102105
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
@@ -182,6 +185,14 @@ private module Input2 implements InputSig2 {
182185
condition = impl and
183186
constraint = impl.getTypeBoundList().getABound().getTypeRepr()
184187
)
188+
or
189+
// a `dyn Trait` type implements `Trait`. See the comment on
190+
// `DynTypeBoundListMention` for further details.
191+
exists(DynTraitTypeRepr object |
192+
abs = object and
193+
condition = object.getTypeBoundList() and
194+
constraint = object.getTrait()
195+
)
185196
}
186197
}
187198

@@ -1655,10 +1666,16 @@ private Function getMethodFromImpl(MethodCall mc) {
16551666

16561667
bindingset[trait, name]
16571668
pragma[inline_late]
1658-
private Function getTraitMethod(ImplTraitReturnType trait, string name) {
1669+
private Function getImplTraitMethod(ImplTraitReturnType trait, string name) {
16591670
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
16601671
}
16611672

1673+
bindingset[traitObject, name]
1674+
pragma[inline_late]
1675+
private Function getDynTraitMethod(DynTraitType traitObject, string name) {
1676+
result = getMethodSuccessor(traitObject.getTrait(), name)
1677+
}
1678+
16621679
pragma[nomagic]
16631680
private Function resolveMethodCallTarget(MethodCall mc) {
16641681
// The method comes from an `impl` block targeting the type of the receiver.
@@ -1669,7 +1686,10 @@ private Function resolveMethodCallTarget(MethodCall mc) {
16691686
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
16701687
or
16711688
// The type of the receiver is an `impl Trait` type.
1672-
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1689+
result = getImplTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1690+
or
1691+
// The type of the receiver is a trait object `dyn Trait` type.
1692+
result = getDynTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
16731693
}
16741694

16751695
pragma[nomagic]
@@ -2006,6 +2026,13 @@ private module Debug {
20062026
result = resolveCallTarget(c)
20072027
}
20082028

2029+
predicate debugConditionSatisfiesConstraint(
2030+
TypeAbstraction abs, TypeMention condition, TypeMention constraint
2031+
) {
2032+
abs = getRelevantLocatable() and
2033+
Input2::conditionSatisfiesConstraint(abs, condition, constraint)
2034+
}
2035+
20092036
predicate debugInferImplicitSelfType(SelfParam self, TypePath path, Type t) {
20102037
self = getRelevantLocatable() and
20112038
t = inferImplicitSelfType(self, path)

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,57 @@ class SelfTypeParameterMention extends TypeMention instanceof Name {
268268
result = TSelfTypeParameter(trait)
269269
}
270270
}
271+
272+
class DynTraitTypeReprMention extends TypeMention instanceof DynTraitTypeRepr {
273+
override Type resolveTypeAt(TypePath path) {
274+
exists(DynTraitType rootTy | rootTy.getTrait() = super.getTrait() |
275+
path.isEmpty() and
276+
result = rootTy
277+
or
278+
exists(DynTraitTypeParameter tp, TypePath path0, TypePath suffix |
279+
tp = rootTy.getTypeParameter(_) and
280+
path = TypePath::cons(tp, suffix) and
281+
result =
282+
super.getTypeBoundList().getBound(0).getTypeRepr().(TypeMention).resolveTypeAt(path0) and
283+
path0.isCons(TTypeParamTypeParameter(tp.getTypeParam()), suffix)
284+
)
285+
)
286+
}
287+
}
288+
289+
// We want a type of the form `dyn Trait` to implement `Trait`. If `Trait` has
290+
// type parameters then `dyn Trait` has equivalent type parameters and the
291+
// implementation should be abstracted over them.
292+
//
293+
// Intuitively we want something to the effect of:
294+
// ```
295+
// impl<A, B, ..> Trait<A, B, ..> for (dyn Trait)<A, B, ..>
296+
// ```
297+
// To achieve this:
298+
// - `DynTypeAbstraction` is an abstraction over type parameters of the trait.
299+
// - `DynTypeBoundListMention` (this class) is a type mention which has `dyn
300+
// Trait` at the root and which for every type parameter of `dyn Trait` has the
301+
// corresponding type parameter of the trait.
302+
// - `TraitMention` (which is used for other things as well) is a type mention
303+
// for the trait applied to its own type parameters.
304+
//
305+
// We arbitrarily use the `TypeBoundList` inside `DynTraitTypeRepr` to encode
306+
// this type mention, since it doesn't syntactically appear in the AST. This
307+
// works because there is a one-to-one correspondence between a trait object and
308+
// its list of type bounds.
309+
class DynTypeBoundListMention extends TypeMention instanceof TypeBoundList {
310+
private DynTraitTypeRepr dyn;
311+
312+
DynTypeBoundListMention() { this = dyn.getTypeBoundList() }
313+
314+
override Type resolveTypeAt(TypePath path) {
315+
path.isEmpty() and
316+
result.(DynTraitType).getTrait() = dyn.getTrait()
317+
or
318+
exists(DynTraitTypeParameter tp |
319+
tp.getTrait() = dyn.getTrait() and
320+
path = TypePath::singleton(tp) and
321+
result = TTypeParamTypeParameter(tp.getTypeParam())
322+
)
323+
}
324+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
illFormedTypeMention
2+
| gen_dyn_trait_type_repr.rs:7:13:7:21 | DynTraitTypeRepr |
3+
| gen_dyn_trait_type_repr.rs:7:17:7:21 | TypeBoundList |

rust/ql/test/library-tests/type-inference/dyn_type.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,17 @@ fn get_box_trait<A: Clone + Debug + 'static>(a: A) -> Box<dyn GenericGet<A>> {
4545
}
4646

4747
fn test_basic_dyn_trait(obj: &dyn MyTrait1) {
48-
let _result = (*obj).m(); // $ target=deref MISSING: target=MyTrait1::m type=_result:String
48+
let _result = (*obj).m(); // $ target=deref target=MyTrait1::m type=_result:String
4949
}
5050

5151
fn test_generic_dyn_trait(obj: &dyn GenericGet<String>) {
52-
let _result1 = (*obj).get(); // $ target=deref MISSING: target=GenericGet::get type=_result1:String
53-
let _result2 = get_a(obj); // $ target=get_a MISSING: type=_result2:String
52+
let _result1 = (*obj).get(); // $ target=deref target=GenericGet::get type=_result1:String
53+
let _result2 = get_a(obj); // $ target=get_a type=_result2:String
5454
}
5555

5656
fn test_poly_dyn_trait() {
5757
let obj = get_box_trait(true); // $ target=get_box_trait
58-
let _result = (*obj).get(); // $ target=deref MISSING: target=GenericGet::get type=_result:bool
58+
let _result = (*obj).get(); // $ target=deref target=GenericGet::get type=_result:bool
5959
}
6060

6161
pub fn test() {

0 commit comments

Comments
 (0)