diff --git a/lightning/src/routing/mod.rs b/lightning/src/routing/mod.rs index 1cf73c2689e..51ffd91b504 100644 --- a/lightning/src/routing/mod.rs +++ b/lightning/src/routing/mod.rs @@ -13,10 +13,13 @@ pub mod network_graph; pub mod router; pub mod scorer; +use routing::network_graph::NodeId; + /// An interface used to score payment channels for path finding. /// /// Scoring is in terms of fees willing to be paid in order to avoid routing through a channel. pub trait Score { - /// Returns the fee in msats willing to be paid to avoid routing through the given channel. - fn channel_penalty_msat(&self, short_channel_id: u64) -> u64; + /// Returns the fee in msats willing to be paid to avoid routing through the given channel + /// in the direction from `source` to `target`. + fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64; } diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index 4083114dcb9..b617eebd42d 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -748,7 +748,7 @@ where L::Target: Logger { } let path_penalty_msat = $next_hops_path_penalty_msat - .checked_add(scorer.channel_penalty_msat($chan_id.clone())) + .checked_add(scorer.channel_penalty_msat($chan_id.clone(), &$src_node_id, &$dest_node_id)) .unwrap_or_else(|| u64::max_value()); let new_graph_node = RouteGraphNode { node_id: $src_node_id, @@ -973,15 +973,17 @@ where L::Target: Logger { _ => aggregate_next_hops_fee_msat.checked_add(999).unwrap_or(u64::max_value()) }) { Some( val / 1000 ) } else { break; }; // converting from msat or breaking if max ~ infinity + let src_node_id = NodeId::from_pubkey(&hop.src_node_id); + let dest_node_id = NodeId::from_pubkey(&prev_hop_id); aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat - .checked_add(scorer.channel_penalty_msat(hop.short_channel_id)) + .checked_add(scorer.channel_penalty_msat(hop.short_channel_id, &src_node_id, &dest_node_id)) .unwrap_or_else(|| u64::max_value()); // We assume that the recipient only included route hints for routes which had // sufficient value to route `final_value_msat`. Note that in the case of "0-value" // invoices where the invoice does not specify value this may not be the case, but // better to include the hints than not. - if !add_entry!(hop.short_channel_id, NodeId::from_pubkey(&hop.src_node_id), NodeId::from_pubkey(&prev_hop_id), directional_info, reqd_channel_cap, &empty_channel_features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat) { + if !add_entry!(hop.short_channel_id, src_node_id, dest_node_id, directional_info, reqd_channel_cap, &empty_channel_features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat) { // If this hop was not used then there is no use checking the preceding hops // in the RouteHint. We can break by just searching for a direct channel between // last checked hop and first_hop_targets @@ -1322,7 +1324,8 @@ where L::Target: Logger { #[cfg(test)] mod tests { - use routing::network_graph::{NetworkGraph, NetGraphMsgHandler}; + use routing; + use routing::network_graph::{NetworkGraph, NetGraphMsgHandler, NodeId}; use routing::router::{get_route, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees}; use routing::scorer::Scorer; use chain::transaction::OutPoint; @@ -4351,42 +4354,92 @@ mod tests { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (_, our_id, _, nodes) = get_nodes(&secp_ctx); + // Without penalizing each hop 100 msats, a longer path with lower fees is chosen. + let scorer = Scorer::new(0); + let route = get_route( + &our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, + &last_hops(&nodes).iter().collect::>(), 100, 42, Arc::clone(&logger), &scorer + ).unwrap(); + let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::>(); + + assert_eq!(route.get_total_fees(), 100); + assert_eq!(route.get_total_amount(), 100); + assert_eq!(path, vec![2, 4, 6, 11, 8]); + // Applying a 100 msat penalty to each hop results in taking channels 7 and 10 to nodes[6] // from nodes[2] rather than channel 6, 11, and 8, even though the longer path is cheaper. let scorer = Scorer::new(100); - let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, &last_hops(&nodes).iter().collect::>(), 100, 42, Arc::clone(&logger), &scorer).unwrap(); - assert_eq!(route.paths[0].len(), 4); + let route = get_route( + &our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, + &last_hops(&nodes).iter().collect::>(), 100, 42, Arc::clone(&logger), &scorer + ).unwrap(); + let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::>(); - assert_eq!(route.paths[0][0].pubkey, nodes[1]); - assert_eq!(route.paths[0][0].short_channel_id, 2); - assert_eq!(route.paths[0][0].fee_msat, 200); - assert_eq!(route.paths[0][0].cltv_expiry_delta, (4 << 8) | 1); - assert_eq!(route.paths[0][0].node_features.le_flags(), &id_to_feature_flags(2)); - assert_eq!(route.paths[0][0].channel_features.le_flags(), &id_to_feature_flags(2)); + assert_eq!(route.get_total_fees(), 300); + assert_eq!(route.get_total_amount(), 100); + assert_eq!(path, vec![2, 4, 7, 10]); + } - assert_eq!(route.paths[0][1].pubkey, nodes[2]); - assert_eq!(route.paths[0][1].short_channel_id, 4); - assert_eq!(route.paths[0][1].fee_msat, 100); - assert_eq!(route.paths[0][1].cltv_expiry_delta, (7 << 8) | 1); - assert_eq!(route.paths[0][1].node_features.le_flags(), &id_to_feature_flags(3)); - assert_eq!(route.paths[0][1].channel_features.le_flags(), &id_to_feature_flags(4)); + struct BadChannelScorer { + short_channel_id: u64, + } - assert_eq!(route.paths[0][2].pubkey, nodes[5]); - assert_eq!(route.paths[0][2].short_channel_id, 7); - assert_eq!(route.paths[0][2].fee_msat, 0); - assert_eq!(route.paths[0][2].cltv_expiry_delta, (10 << 8) | 1); - assert_eq!(route.paths[0][2].node_features.le_flags(), &id_to_feature_flags(6)); - assert_eq!(route.paths[0][2].channel_features.le_flags(), &id_to_feature_flags(7)); + impl routing::Score for BadChannelScorer { + fn channel_penalty_msat(&self, short_channel_id: u64, _source: &NodeId, _target: &NodeId) -> u64 { + if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 } + } + } - assert_eq!(route.paths[0][3].pubkey, nodes[6]); - assert_eq!(route.paths[0][3].short_channel_id, 10); - assert_eq!(route.paths[0][3].fee_msat, 100); - assert_eq!(route.paths[0][3].cltv_expiry_delta, 42); - assert_eq!(route.paths[0][3].node_features.le_flags(), &Vec::::new()); // We don't pass flags in from invoices yet - assert_eq!(route.paths[0][3].channel_features.le_flags(), &Vec::::new()); // We can't learn any flags from invoices, sadly + struct BadNodeScorer { + node_id: NodeId, + } + + impl routing::Score for BadNodeScorer { + fn channel_penalty_msat(&self, _short_channel_id: u64, _source: &NodeId, target: &NodeId) -> u64 { + if *target == self.node_id { u64::max_value() } else { 0 } + } + } + + #[test] + fn avoids_routing_through_bad_channels_and_nodes() { + let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); + let (_, our_id, _, nodes) = get_nodes(&secp_ctx); + + // A path to nodes[6] exists when no penalties are applied to any channel. + let scorer = Scorer::new(0); + let route = get_route( + &our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, + &last_hops(&nodes).iter().collect::>(), 100, 42, Arc::clone(&logger), &scorer + ).unwrap(); + let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::>(); + + assert_eq!(route.get_total_fees(), 100); + assert_eq!(route.get_total_amount(), 100); + assert_eq!(path, vec![2, 4, 6, 11, 8]); + + // A different path to nodes[6] exists if channel 6 cannot be routed over. + let scorer = BadChannelScorer { short_channel_id: 6 }; + let route = get_route( + &our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, + &last_hops(&nodes).iter().collect::>(), 100, 42, Arc::clone(&logger), &scorer + ).unwrap(); + let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::>(); assert_eq!(route.get_total_fees(), 300); assert_eq!(route.get_total_amount(), 100); + assert_eq!(path, vec![2, 4, 7, 10]); + + // A path to nodes[6] does not exist if nodes[2] cannot be routed through. + let scorer = BadNodeScorer { node_id: NodeId::from_pubkey(&nodes[2]) }; + match get_route( + &our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, + &last_hops(&nodes).iter().collect::>(), 100, 42, Arc::clone(&logger), &scorer + ) { + Err(LightningError { err, .. } ) => { + assert_eq!(err, "Failed to find a path to the given destination"); + }, + Ok(_) => panic!("Expected error"), + } } #[test] diff --git a/lightning/src/routing/scorer.rs b/lightning/src/routing/scorer.rs index f58da652096..0f43c3d7928 100644 --- a/lightning/src/routing/scorer.rs +++ b/lightning/src/routing/scorer.rs @@ -44,6 +44,8 @@ use routing; +use routing::network_graph::NodeId; + /// [`routing::Score`] implementation that provides reasonable default behavior. /// /// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with @@ -71,5 +73,9 @@ impl Default for Scorer { } impl routing::Score for Scorer { - fn channel_penalty_msat(&self, _short_channel_id: u64) -> u64 { self.base_penalty_msat } + fn channel_penalty_msat( + &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId + ) -> u64 { + self.base_penalty_msat + } }