Skip to content

Commit 5b061b2

Browse files
committed
Auto merge of #13475 - lowr:fix/lookup-impl-method-trait-ref, r=Veykril
fix: Test all type args for trait when finding matching impl Addresses #13463 (comment) When finding matching impl for a trait method, we've been testing the unifiability of self type. However, there can be multiple impl of a trait for the same type with different generic arguments for the trait. This patch takes it into account and tests the unifiability of all type arguments for the trait (the first being the self type) thus enables rust-analyzer to find the correct impl even in such cases.
2 parents 98aa678 + 61fbde0 commit 5b061b2

File tree

3 files changed

+139
-40
lines changed

3 files changed

+139
-40
lines changed

crates/hir-ty/src/method_resolution.rs

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::{
2525
static_lifetime,
2626
utils::all_super_traits,
2727
AdtId, Canonical, CanonicalVarKinds, DebruijnIndex, ForeignDefId, InEnvironment, Interner,
28-
Scalar, TraitEnvironment, TraitRefExt, Ty, TyBuilder, TyExt, TyKind,
28+
Scalar, Substitution, TraitEnvironment, TraitRefExt, Ty, TyBuilder, TyExt, TyKind,
2929
};
3030

3131
/// This is used as a key for indexing impls.
@@ -625,17 +625,22 @@ pub(crate) fn iterate_method_candidates<T>(
625625
}
626626

627627
pub fn lookup_impl_method(
628-
self_ty: &Ty,
629628
db: &dyn HirDatabase,
630629
env: Arc<TraitEnvironment>,
631630
trait_: TraitId,
632631
name: &Name,
632+
fn_subst: Substitution,
633633
) -> Option<FunctionId> {
634+
let trait_params = db.generic_params(trait_.into()).type_or_consts.len();
635+
let fn_params = fn_subst.len(Interner) - trait_params;
636+
let trait_subst = Substitution::from_iter(Interner, fn_subst.iter(Interner).skip(fn_params));
637+
638+
let self_ty = trait_subst.at(Interner, 0).ty(Interner)?;
634639
let self_ty_fp = TyFingerprint::for_trait_impl(self_ty)?;
635640
let trait_impls = db.trait_impls_in_deps(env.krate);
636641
let impls = trait_impls.for_trait_and_self_ty(trait_, self_ty_fp);
637-
let mut table = InferenceTable::new(db, env.clone());
638-
find_matching_impl(impls, &mut table, &self_ty).and_then(|data| {
642+
let mut table = InferenceTable::new(db, env);
643+
find_matching_impl(impls, &mut table, trait_subst).and_then(|data| {
639644
data.items.iter().find_map(|it| match it {
640645
AssocItemId::FunctionId(f) => (db.function_data(*f).name == *name).then(|| *f),
641646
_ => None,
@@ -646,30 +651,41 @@ pub fn lookup_impl_method(
646651
fn find_matching_impl(
647652
mut impls: impl Iterator<Item = ImplId>,
648653
table: &mut InferenceTable<'_>,
649-
self_ty: &Ty,
654+
expected_subst: Substitution,
650655
) -> Option<Arc<ImplData>> {
651656
let db = table.db;
652657
loop {
653658
let impl_ = impls.next()?;
654659
let r = table.run_in_snapshot(|table| {
655660
let impl_data = db.impl_data(impl_);
656-
let substs =
661+
let impl_substs =
657662
TyBuilder::subst_for_def(db, impl_, None).fill_with_inference_vars(table).build();
658-
let impl_ty = db.impl_self_ty(impl_).substitute(Interner, &substs);
659-
660-
table
661-
.unify(self_ty, &impl_ty)
662-
.then(|| {
663-
let wh_goals =
664-
crate::chalk_db::convert_where_clauses(db, impl_.into(), &substs)
665-
.into_iter()
666-
.map(|b| b.cast(Interner));
667-
668-
let goal = crate::Goal::all(Interner, wh_goals);
663+
let trait_ref = db
664+
.impl_trait(impl_)
665+
.expect("non-trait method in find_matching_impl")
666+
.substitute(Interner, &impl_substs);
667+
assert_eq!(trait_ref.substitution.len(Interner), expected_subst.len(Interner));
668+
669+
for (actual, expected) in
670+
trait_ref.substitution.iter(Interner).zip(expected_subst.iter(Interner))
671+
{
672+
// FIXME: test unifiability of const args when supported.
673+
if let Some(actual) = actual.ty(Interner) {
674+
if !stdx::always!(expected.ty(Interner).is_some()) {
675+
return None;
676+
}
677+
let expected = expected.assert_ty_ref(Interner);
678+
if !table.unify(actual, expected) {
679+
return None;
680+
}
681+
}
682+
}
669683

670-
table.try_obligation(goal).map(|_| impl_data)
671-
})
672-
.flatten()
684+
let wcs = crate::chalk_db::convert_where_clauses(db, impl_.into(), &impl_substs)
685+
.into_iter()
686+
.map(|b| b.cast(Interner));
687+
let goal = crate::Goal::all(Interner, wcs);
688+
table.try_obligation(goal).map(|_| impl_data)
673689
});
674690
if r.is_some() {
675691
break r;

crates/hir/src/source_analyzer.rs

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ impl SourceAnalyzer {
270270
let expr_id = self.expr_id(db, &call.clone().into())?;
271271
let (f_in_trait, substs) = self.infer.as_ref()?.method_resolution(expr_id)?;
272272

273-
Some(self.resolve_impl_method_or_trait_def(db, f_in_trait, &substs))
273+
Some(self.resolve_impl_method_or_trait_def(db, f_in_trait, substs))
274274
}
275275

276276
pub(crate) fn resolve_await_to_poll(
@@ -311,7 +311,7 @@ impl SourceAnalyzer {
311311
// HACK: subst for `poll()` coincides with that for `Future` because `poll()` itself
312312
// doesn't have any generic parameters, so we skip building another subst for `poll()`.
313313
let substs = hir_ty::TyBuilder::subst_for_def(db, future_trait, None).push(ty).build();
314-
Some(self.resolve_impl_method_or_trait_def(db, poll_fn, &substs))
314+
Some(self.resolve_impl_method_or_trait_def(db, poll_fn, substs))
315315
}
316316

317317
pub(crate) fn resolve_prefix_expr(
@@ -331,7 +331,7 @@ impl SourceAnalyzer {
331331
// don't have any generic parameters, so we skip building another subst for the methods.
332332
let substs = hir_ty::TyBuilder::subst_for_def(db, op_trait, None).push(ty.clone()).build();
333333

334-
Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
334+
Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
335335
}
336336

337337
pub(crate) fn resolve_index_expr(
@@ -351,7 +351,7 @@ impl SourceAnalyzer {
351351
.push(base_ty.clone())
352352
.push(index_ty.clone())
353353
.build();
354-
Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
354+
Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
355355
}
356356

357357
pub(crate) fn resolve_bin_expr(
@@ -372,7 +372,7 @@ impl SourceAnalyzer {
372372
.push(rhs.clone())
373373
.build();
374374

375-
Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
375+
Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
376376
}
377377

378378
pub(crate) fn resolve_try_expr(
@@ -392,7 +392,7 @@ impl SourceAnalyzer {
392392
// doesn't have any generic parameters, so we skip building another subst for `branch()`.
393393
let substs = hir_ty::TyBuilder::subst_for_def(db, op_trait, None).push(ty.clone()).build();
394394

395-
Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
395+
Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs))
396396
}
397397

398398
pub(crate) fn resolve_field(
@@ -487,17 +487,17 @@ impl SourceAnalyzer {
487487

488488
let mut prefer_value_ns = false;
489489
let resolved = (|| {
490+
let infer = self.infer.as_deref()?;
490491
if let Some(path_expr) = parent().and_then(ast::PathExpr::cast) {
491492
let expr_id = self.expr_id(db, &path_expr.into())?;
492-
let infer = self.infer.as_ref()?;
493493
if let Some(assoc) = infer.assoc_resolutions_for_expr(expr_id) {
494494
let assoc = match assoc {
495495
AssocItemId::FunctionId(f_in_trait) => {
496496
match infer.type_of_expr.get(expr_id) {
497497
None => assoc,
498498
Some(func_ty) => {
499499
if let TyKind::FnDef(_fn_def, subs) = func_ty.kind(Interner) {
500-
self.resolve_impl_method(db, f_in_trait, subs)
500+
self.resolve_impl_method(db, f_in_trait, subs.clone())
501501
.map(AssocItemId::FunctionId)
502502
.unwrap_or(assoc)
503503
} else {
@@ -520,18 +520,18 @@ impl SourceAnalyzer {
520520
prefer_value_ns = true;
521521
} else if let Some(path_pat) = parent().and_then(ast::PathPat::cast) {
522522
let pat_id = self.pat_id(&path_pat.into())?;
523-
if let Some(assoc) = self.infer.as_ref()?.assoc_resolutions_for_pat(pat_id) {
523+
if let Some(assoc) = infer.assoc_resolutions_for_pat(pat_id) {
524524
return Some(PathResolution::Def(AssocItem::from(assoc).into()));
525525
}
526526
if let Some(VariantId::EnumVariantId(variant)) =
527-
self.infer.as_ref()?.variant_resolution_for_pat(pat_id)
527+
infer.variant_resolution_for_pat(pat_id)
528528
{
529529
return Some(PathResolution::Def(ModuleDef::Variant(variant.into())));
530530
}
531531
} else if let Some(rec_lit) = parent().and_then(ast::RecordExpr::cast) {
532532
let expr_id = self.expr_id(db, &rec_lit.into())?;
533533
if let Some(VariantId::EnumVariantId(variant)) =
534-
self.infer.as_ref()?.variant_resolution_for_expr(expr_id)
534+
infer.variant_resolution_for_expr(expr_id)
535535
{
536536
return Some(PathResolution::Def(ModuleDef::Variant(variant.into())));
537537
}
@@ -541,8 +541,7 @@ impl SourceAnalyzer {
541541
|| parent().and_then(ast::TupleStructPat::cast).map(ast::Pat::from);
542542
if let Some(pat) = record_pat.or_else(tuple_struct_pat) {
543543
let pat_id = self.pat_id(&pat)?;
544-
let variant_res_for_pat =
545-
self.infer.as_ref()?.variant_resolution_for_pat(pat_id);
544+
let variant_res_for_pat = infer.variant_resolution_for_pat(pat_id);
546545
if let Some(VariantId::EnumVariantId(variant)) = variant_res_for_pat {
547546
return Some(PathResolution::Def(ModuleDef::Variant(variant.into())));
548547
}
@@ -784,31 +783,33 @@ impl SourceAnalyzer {
784783
&self,
785784
db: &dyn HirDatabase,
786785
func: FunctionId,
787-
substs: &Substitution,
786+
fn_substs: Substitution,
788787
) -> Option<FunctionId> {
789788
let impled_trait = match func.lookup(db.upcast()).container {
790789
ItemContainerId::TraitId(trait_id) => trait_id,
791790
_ => return None,
792791
};
793-
if substs.is_empty(Interner) {
794-
return None;
795-
}
796-
let self_ty = substs.at(Interner, 0).ty(Interner)?;
797792
let krate = self.resolver.krate();
798793
let trait_env = self.resolver.body_owner()?.as_generic_def_id().map_or_else(
799794
|| Arc::new(hir_ty::TraitEnvironment::empty(krate)),
800795
|d| db.trait_environment(d),
801796
);
802797

803798
let fun_data = db.function_data(func);
804-
method_resolution::lookup_impl_method(self_ty, db, trait_env, impled_trait, &fun_data.name)
799+
method_resolution::lookup_impl_method(
800+
db,
801+
trait_env,
802+
impled_trait,
803+
&fun_data.name,
804+
fn_substs,
805+
)
805806
}
806807

807808
fn resolve_impl_method_or_trait_def(
808809
&self,
809810
db: &dyn HirDatabase,
810811
func: FunctionId,
811-
substs: &Substitution,
812+
substs: Substitution,
812813
) -> FunctionId {
813814
self.resolve_impl_method(db, func, substs).unwrap_or(func)
814815
}

crates/ide/src/goto_definition.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,4 +1834,86 @@ fn f() {
18341834
"#,
18351835
);
18361836
}
1837+
1838+
#[test]
1839+
fn goto_bin_op_multiple_impl() {
1840+
check(
1841+
r#"
1842+
//- minicore: add
1843+
struct S;
1844+
impl core::ops::Add for S {
1845+
fn add(
1846+
//^^^
1847+
) {}
1848+
}
1849+
impl core::ops::Add<usize> for S {
1850+
fn add(
1851+
) {}
1852+
}
1853+
1854+
fn f() {
1855+
S +$0 S
1856+
}
1857+
"#,
1858+
);
1859+
1860+
check(
1861+
r#"
1862+
//- minicore: add
1863+
struct S;
1864+
impl core::ops::Add for S {
1865+
fn add(
1866+
) {}
1867+
}
1868+
impl core::ops::Add<usize> for S {
1869+
fn add(
1870+
//^^^
1871+
) {}
1872+
}
1873+
1874+
fn f() {
1875+
S +$0 0usize
1876+
}
1877+
"#,
1878+
);
1879+
}
1880+
1881+
#[test]
1882+
fn path_call_multiple_trait_impl() {
1883+
check(
1884+
r#"
1885+
trait Trait<T> {
1886+
fn f(_: T);
1887+
}
1888+
impl Trait<i32> for usize {
1889+
fn f(_: i32) {}
1890+
//^
1891+
}
1892+
impl Trait<i64> for usize {
1893+
fn f(_: i64) {}
1894+
}
1895+
fn main() {
1896+
usize::f$0(0i32);
1897+
}
1898+
"#,
1899+
);
1900+
1901+
check(
1902+
r#"
1903+
trait Trait<T> {
1904+
fn f(_: T);
1905+
}
1906+
impl Trait<i32> for usize {
1907+
fn f(_: i32) {}
1908+
}
1909+
impl Trait<i64> for usize {
1910+
fn f(_: i64) {}
1911+
//^
1912+
}
1913+
fn main() {
1914+
usize::f$0(0i64);
1915+
}
1916+
"#,
1917+
)
1918+
}
18371919
}

0 commit comments

Comments
 (0)