diff --git a/chalk-engine/src/slg/aggregate.rs b/chalk-engine/src/slg/aggregate.rs index cce4e14dd61..24da5bdb127 100644 --- a/chalk-engine/src/slg/aggregate.rs +++ b/chalk-engine/src/slg/aggregate.rs @@ -17,7 +17,7 @@ pub trait AggregateOps { &self, root_goal: &UCanonical>>, answers: impl context::AnswerStream, - should_continue: impl std::ops::Fn() -> bool, + should_continue: impl std::ops::Fn() -> bool + Clone, ) -> Option>; } @@ -28,7 +28,7 @@ impl AggregateOps for SlgContextOps<'_, I> { &self, root_goal: &UCanonical>>, mut answers: impl context::AnswerStream, - should_continue: impl std::ops::Fn() -> bool, + should_continue: impl std::ops::Fn() -> bool + Clone, ) -> Option> { let interner = self.program.interner(); let CompleteAnswer { subst, ambiguous } = match answers.next_answer(&should_continue) { diff --git a/chalk-recursive/src/fixed_point.rs b/chalk-recursive/src/fixed_point.rs index b948c1c827b..517ca253b4d 100644 --- a/chalk-recursive/src/fixed_point.rs +++ b/chalk-recursive/src/fixed_point.rs @@ -43,6 +43,7 @@ where context: &mut RecursiveContext, goal: &K, minimums: &mut Minimums, + should_continue: impl std::ops::Fn() -> bool + Clone, ) -> V; fn reached_fixed_point(self, old_value: &V, new_value: &V) -> bool; fn error_value(self) -> V; @@ -104,22 +105,24 @@ where &mut self, canonical_goal: &K, solver_stuff: impl SolverStuff, + should_continue: impl std::ops::Fn() -> bool + Clone, ) -> V { debug!("solve_root_goal(canonical_goal={:?})", canonical_goal); assert!(self.stack.is_empty()); let minimums = &mut Minimums::new(); - self.solve_goal(canonical_goal, minimums, solver_stuff) + self.solve_goal(canonical_goal, minimums, solver_stuff, should_continue) } /// Attempt to solve a goal that has been fully broken down into leaf form /// and canonicalized. This is where the action really happens, and is the /// place where we would perform caching in rustc (and may eventually do in Chalk). - #[instrument(level = "info", skip(self, minimums, solver_stuff,))] + #[instrument(level = "info", skip(self, minimums, solver_stuff, should_continue))] pub fn solve_goal( &mut self, goal: &K, minimums: &mut Minimums, solver_stuff: impl SolverStuff, + should_continue: impl std::ops::Fn() -> bool + Clone, ) -> V { // First check the cache. if let Some(cache) = &self.cache { @@ -159,7 +162,8 @@ where let depth = self.stack.push(coinductive_goal); let dfn = self.search_graph.insert(goal, depth, initial_solution); - let subgoal_minimums = self.solve_new_subgoal(goal, depth, dfn, solver_stuff); + let subgoal_minimums = + self.solve_new_subgoal(goal, depth, dfn, solver_stuff, should_continue); self.search_graph[dfn].links = subgoal_minimums; self.search_graph[dfn].stack_depth = None; @@ -190,13 +194,14 @@ where } } - #[instrument(level = "debug", skip(self, solver_stuff))] + #[instrument(level = "debug", skip(self, solver_stuff, should_continue))] fn solve_new_subgoal( &mut self, canonical_goal: &K, depth: StackDepth, dfn: DepthFirstNumber, solver_stuff: impl SolverStuff, + should_continue: impl std::ops::Fn() -> bool + Clone, ) -> Minimums { // We start with `answer = None` and try to solve the goal. At the end of the iteration, // `answer` will be updated with the result of the solving process. If we detect a cycle @@ -209,7 +214,12 @@ where // so this function will eventually be constant and the loop terminates. loop { let minimums = &mut Minimums::new(); - let current_answer = solver_stuff.solve_iteration(self, canonical_goal, minimums); + let current_answer = solver_stuff.solve_iteration( + self, + canonical_goal, + minimums, + should_continue.clone(), // Note: cloning required as workaround for https://github.com/rust-lang/rust/issues/95734 + ); debug!( "solve_new_subgoal: loop iteration result = {:?} with minimums {:?}", diff --git a/chalk-recursive/src/fulfill.rs b/chalk-recursive/src/fulfill.rs index 6dbfb1f4086..2524ff795b8 100644 --- a/chalk-recursive/src/fulfill.rs +++ b/chalk-recursive/src/fulfill.rs @@ -342,16 +342,19 @@ impl<'s, I: Interner, Solver: SolveDatabase> Fulfill<'s, I, Solver> { Ok(()) } - #[instrument(level = "debug", skip(self, minimums))] + #[instrument(level = "debug", skip(self, minimums, should_continue))] fn prove( &mut self, wc: InEnvironment>, minimums: &mut Minimums, + should_continue: impl std::ops::Fn() -> bool + Clone, ) -> Fallible> { let interner = self.solver.interner(); let (quantified, free_vars) = canonicalize(&mut self.infer, interner, wc); let (quantified, universes) = u_canonicalize(&mut self.infer, interner, &quantified); - let result = self.solver.solve_goal(quantified, minimums); + let result = self + .solver + .solve_goal(quantified, minimums, should_continue); Ok(PositiveSolution { free_vars, universes, @@ -359,7 +362,11 @@ impl<'s, I: Interner, Solver: SolveDatabase> Fulfill<'s, I, Solver> { }) } - fn refute(&mut self, goal: InEnvironment>) -> Fallible { + fn refute( + &mut self, + goal: InEnvironment>, + should_continue: impl std::ops::Fn() -> bool + Clone, + ) -> Fallible { let canonicalized = match self .infer .invert_then_canonicalize(self.solver.interner(), goal) @@ -376,7 +383,10 @@ impl<'s, I: Interner, Solver: SolveDatabase> Fulfill<'s, I, Solver> { let (quantified, _) = u_canonicalize(&mut self.infer, self.solver.interner(), &canonicalized); let mut minimums = Minimums::new(); // FIXME -- minimums here seems wrong - if let Ok(solution) = self.solver.solve_goal(quantified, &mut minimums) { + if let Ok(solution) = self + .solver + .solve_goal(quantified, &mut minimums, should_continue) + { if solution.is_unique() { Err(NoSolution) } else { @@ -431,7 +441,11 @@ impl<'s, I: Interner, Solver: SolveDatabase> Fulfill<'s, I, Solver> { } } - fn fulfill(&mut self, minimums: &mut Minimums) -> Fallible { + fn fulfill( + &mut self, + minimums: &mut Minimums, + should_continue: impl std::ops::Fn() -> bool + Clone, + ) -> Fallible { debug_span!("fulfill", obligations=?self.obligations); // Try to solve all the obligations. We do this via a fixed-point @@ -460,7 +474,7 @@ impl<'s, I: Interner, Solver: SolveDatabase> Fulfill<'s, I, Solver> { free_vars, universes, solution, - } = self.prove(wc.clone(), minimums)?; + } = self.prove(wc.clone(), minimums, should_continue.clone())?; if let Some(constrained_subst) = solution.definite_subst(self.interner()) { // If the substitution is trivial, we won't actually make any progress by applying it! @@ -484,7 +498,7 @@ impl<'s, I: Interner, Solver: SolveDatabase> Fulfill<'s, I, Solver> { solution.is_ambig() } Obligation::Refute(goal) => { - let answer = self.refute(goal.clone())?; + let answer = self.refute(goal.clone(), should_continue.clone())?; answer == NegativeSolution::Ambiguous } }; @@ -514,8 +528,12 @@ impl<'s, I: Interner, Solver: SolveDatabase> Fulfill<'s, I, Solver> { /// Try to fulfill all pending obligations and build the resulting /// solution. The returned solution will transform `subst` substitution with /// the outcome of type inference by updating the replacements it provides. - pub(super) fn solve(mut self, minimums: &mut Minimums) -> Fallible> { - let outcome = match self.fulfill(minimums) { + pub(super) fn solve( + mut self, + minimums: &mut Minimums, + should_continue: impl std::ops::Fn() -> bool + Clone, + ) -> Fallible> { + let outcome = match self.fulfill(minimums, should_continue.clone()) { Ok(o) => o, Err(e) => return Err(e), }; @@ -567,7 +585,7 @@ impl<'s, I: Interner, Solver: SolveDatabase> Fulfill<'s, I, Solver> { free_vars, universes, solution, - } = self.prove(goal, minimums).unwrap(); + } = self.prove(goal, minimums, should_continue.clone()).unwrap(); if let Some(constrained_subst) = solution.constrained_subst(self.solver.interner()) { diff --git a/chalk-recursive/src/recursive.rs b/chalk-recursive/src/recursive.rs index 7f5203f01f7..61680c75269 100644 --- a/chalk-recursive/src/recursive.rs +++ b/chalk-recursive/src/recursive.rs @@ -76,8 +76,9 @@ impl SolverStuff, Fallible>> for &dyn context: &mut RecursiveContext, Fallible>>, goal: &UCanonicalGoal, minimums: &mut Minimums, + should_continue: impl std::ops::Fn() -> bool + Clone, ) -> Fallible> { - Solver::new(context, self).solve_iteration(goal, minimums) + Solver::new(context, self).solve_iteration(goal, minimums, should_continue) } fn reached_fixed_point( @@ -108,8 +109,10 @@ impl<'me, I: Interner> SolveDatabase for Solver<'me, I> { &mut self, goal: UCanonicalGoal, minimums: &mut Minimums, + should_continue: impl std::ops::Fn() -> bool + Clone, ) -> Fallible> { - self.context.solve_goal(&goal, minimums, self.program) + self.context + .solve_goal(&goal, minimums, self.program, should_continue) } fn interner(&self) -> I { @@ -131,17 +134,18 @@ impl chalk_solve::Solver for RecursiveSolver { program: &dyn RustIrDatabase, goal: &UCanonical>>, ) -> Option> { - self.ctx.solve_root_goal(goal, program).ok() + self.ctx.solve_root_goal(goal, program, || true).ok() } fn solve_limited( &mut self, program: &dyn RustIrDatabase, goal: &UCanonical>>, - _should_continue: &dyn std::ops::Fn() -> bool, + should_continue: &dyn std::ops::Fn() -> bool, ) -> Option> { - // TODO support should_continue in recursive solver - self.ctx.solve_root_goal(goal, program).ok() + self.ctx + .solve_root_goal(goal, program, should_continue) + .ok() } fn solve_multiple( diff --git a/chalk-recursive/src/solve.rs b/chalk-recursive/src/solve.rs index 62b029d33ee..ba2525467b9 100644 --- a/chalk-recursive/src/solve.rs +++ b/chalk-recursive/src/solve.rs @@ -20,6 +20,7 @@ pub(super) trait SolveDatabase: Sized { &mut self, goal: UCanonical>>, minimums: &mut Minimums, + should_continue: impl std::ops::Fn() -> bool + Clone, ) -> Fallible>; fn max_size(&self) -> usize; @@ -35,12 +36,17 @@ pub(super) trait SolveIteration: SolveDatabase { /// Executes one iteration of the recursive solver, computing the current /// solution to the given canonical goal. This is used as part of a loop in /// the case of cyclic goals. - #[instrument(level = "debug", skip(self))] + #[instrument(level = "debug", skip(self, should_continue))] fn solve_iteration( &mut self, canonical_goal: &UCanonicalGoal, minimums: &mut Minimums, + should_continue: impl std::ops::Fn() -> bool + Clone, ) -> Fallible> { + if !should_continue() { + return Ok(Solution::Ambig(Guidance::Unknown)); + } + let UCanonical { universes, canonical: @@ -72,7 +78,7 @@ pub(super) trait SolveIteration: SolveDatabase { let prog_solution = { debug_span!("prog_clauses"); - self.solve_from_clauses(&canonical_goal, minimums) + self.solve_from_clauses(&canonical_goal, minimums, should_continue) }; debug!(?prog_solution); @@ -88,7 +94,7 @@ pub(super) trait SolveIteration: SolveDatabase { }, }; - self.solve_via_simplification(&canonical_goal, minimums) + self.solve_via_simplification(&canonical_goal, minimums, should_continue) } } } @@ -103,15 +109,16 @@ where /// Helper methods for `solve_iteration`, private to this module. trait SolveIterationHelpers: SolveDatabase { - #[instrument(level = "debug", skip(self, minimums))] + #[instrument(level = "debug", skip(self, minimums, should_continue))] fn solve_via_simplification( &mut self, canonical_goal: &UCanonicalGoal, minimums: &mut Minimums, + should_continue: impl std::ops::Fn() -> bool + Clone, ) -> Fallible> { let (infer, subst, goal) = self.new_inference_table(canonical_goal); match Fulfill::new_with_simplification(self, infer, subst, goal) { - Ok(fulfill) => fulfill.solve(minimums), + Ok(fulfill) => fulfill.solve(minimums, should_continue), Err(e) => Err(e), } } @@ -123,6 +130,7 @@ trait SolveIterationHelpers: SolveDatabase { &mut self, canonical_goal: &UCanonical>>, minimums: &mut Minimums, + should_continue: impl std::ops::Fn() -> bool + Clone, ) -> Fallible> { let mut clauses = vec![]; @@ -159,7 +167,10 @@ trait SolveIterationHelpers: SolveDatabase { let subst = subst.clone(); let goal = goal.clone(); let res = match Fulfill::new_with_clause(self, infer, subst, goal, implication) { - Ok(fulfill) => (fulfill.solve(minimums), implication.skip_binders().priority), + Ok(fulfill) => ( + fulfill.solve(minimums, should_continue.clone()), + implication.skip_binders().priority, + ), Err(e) => (Err(e), ClausePriority::High), };