Merge pull request #2277 from valentinewallace/2023-05-fix-big-oms
[rust-lightning] / lightning / src / routing / router.rs
index f547d9e4420ce27d14e34611e691aa66e79237ec..46010475cebfd56550ef06ede0cba4824265684a 100644 (file)
@@ -16,7 +16,7 @@ use bitcoin::hashes::sha256::Hash as Sha256;
 use crate::blinded_path::{BlindedHop, BlindedPath};
 use crate::ln::PaymentHash;
 use crate::ln::channelmanager::{ChannelDetails, PaymentId};
-use crate::ln::features::{ChannelFeatures, InvoiceFeatures, NodeFeatures};
+use crate::ln::features::{Bolt12InvoiceFeatures, ChannelFeatures, InvoiceFeatures, NodeFeatures};
 use crate::ln::msgs::{DecodeError, ErrorAction, LightningError, MAX_VALUE_MSAT};
 use crate::offers::invoice::BlindedPayInfo;
 use crate::routing::gossip::{DirectedChannelInfo, EffectiveCapacity, ReadOnlyNetworkGraph, NetworkGraph, NodeId, RoutingFees};
@@ -442,7 +442,7 @@ impl Writeable for RouteParameters {
                        (2, self.final_value_msat, required),
                        // LDK versions prior to 0.0.114 had the `final_cltv_expiry_delta` parameter in
                        // `RouteParameters` directly. For compatibility, we write it here.
-                       (4, self.payment_params.final_cltv_expiry_delta, required),
+                       (4, self.payment_params.payee.final_cltv_expiry_delta(), option),
                });
                Ok(())
        }
@@ -453,11 +453,13 @@ impl Readable for RouteParameters {
                _init_and_read_tlv_fields!(reader, {
                        (0, payment_params, (required: ReadableArgs, 0)),
                        (2, final_value_msat, required),
-                       (4, final_cltv_expiry_delta, required),
+                       (4, final_cltv_delta, option),
                });
                let mut payment_params: PaymentParameters = payment_params.0.unwrap();
-               if payment_params.final_cltv_expiry_delta == 0 {
-                       payment_params.final_cltv_expiry_delta = final_cltv_expiry_delta.0.unwrap();
+               if let Payee::Clear { ref mut final_cltv_expiry_delta, .. } = payment_params.payee {
+                       if final_cltv_expiry_delta == &0 {
+                               *final_cltv_expiry_delta = final_cltv_delta.ok_or(DecodeError::InvalidValue)?;
+                       }
                }
                Ok(Self {
                        payment_params,
@@ -526,9 +528,6 @@ pub struct PaymentParameters {
        /// payment to fail. Future attempts for the same payment shouldn't be relayed through any of
        /// these SCIDs.
        pub previously_failed_channels: Vec<u64>,
-
-       /// The minimum CLTV delta at the end of the route. This value must not be zero.
-       pub final_cltv_expiry_delta: u32,
 }
 
 impl Writeable for PaymentParameters {
@@ -537,19 +536,19 @@ impl Writeable for PaymentParameters {
                let mut blinded_hints = &vec![];
                match &self.payee {
                        Payee::Clear { route_hints, .. } => clear_hints = route_hints,
-                       Payee::Blinded { route_hints } => blinded_hints = route_hints,
+                       Payee::Blinded { route_hints, .. } => blinded_hints = route_hints,
                }
                write_tlv_fields!(writer, {
                        (0, self.payee.node_id(), option),
                        (1, self.max_total_cltv_expiry_delta, required),
-                       (2, if let Payee::Clear { features, .. } = &self.payee { features } else { &None }, option),
+                       (2, self.payee.features(), option),
                        (3, self.max_path_count, required),
                        (4, *clear_hints, vec_type),
                        (5, self.max_channel_saturation_power_of_half, required),
                        (6, self.expiry_time, option),
                        (7, self.previously_failed_channels, vec_type),
                        (8, *blinded_hints, optional_vec),
-                       (9, self.final_cltv_expiry_delta, required),
+                       (9, self.payee.final_cltv_expiry_delta(), option),
                });
                Ok(())
        }
@@ -560,7 +559,7 @@ impl ReadableArgs<u32> for PaymentParameters {
                _init_and_read_tlv_fields!(reader, {
                        (0, payee_pubkey, option),
                        (1, max_total_cltv_expiry_delta, (default_value, DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA)),
-                       (2, features, option),
+                       (2, features, (option: ReadableArgs, payee_pubkey.is_some())),
                        (3, max_path_count, (default_value, DEFAULT_MAX_PATH_COUNT)),
                        (4, route_hints, vec_type),
                        (5, max_channel_saturation_power_of_half, (default_value, 2)),
@@ -573,12 +572,16 @@ impl ReadableArgs<u32> for PaymentParameters {
                let blinded_route_hints = blinded_route_hints.unwrap_or(vec![]);
                let payee = if blinded_route_hints.len() != 0 {
                        if clear_route_hints.len() != 0 || payee_pubkey.is_some() { return Err(DecodeError::InvalidValue) }
-                       Payee::Blinded { route_hints: blinded_route_hints }
+                       Payee::Blinded {
+                               route_hints: blinded_route_hints,
+                               features: features.and_then(|f: Features| f.bolt12()),
+                       }
                } else {
                        Payee::Clear {
                                route_hints: clear_route_hints,
                                node_id: payee_pubkey.ok_or(DecodeError::InvalidValue)?,
-                               features,
+                               features: features.and_then(|f| f.bolt11()),
+                               final_cltv_expiry_delta: final_cltv_expiry_delta.0.unwrap(),
                        }
                };
                Ok(Self {
@@ -588,7 +591,6 @@ impl ReadableArgs<u32> for PaymentParameters {
                        max_channel_saturation_power_of_half: _init_tlv_based_struct_field!(max_channel_saturation_power_of_half, (default_value, unused)),
                        expiry_time,
                        previously_failed_channels: previously_failed_channels.unwrap_or(Vec::new()),
-                       final_cltv_expiry_delta: _init_tlv_based_struct_field!(final_cltv_expiry_delta, (default_value, unused)),
                })
        }
 }
@@ -601,13 +603,12 @@ impl PaymentParameters {
        /// provided.
        pub fn from_node_id(payee_pubkey: PublicKey, final_cltv_expiry_delta: u32) -> Self {
                Self {
-                       payee: Payee::Clear { node_id: payee_pubkey, route_hints: vec![], features: None },
+                       payee: Payee::Clear { node_id: payee_pubkey, route_hints: vec![], features: None, final_cltv_expiry_delta },
                        expiry_time: None,
                        max_total_cltv_expiry_delta: DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA,
                        max_path_count: DEFAULT_MAX_PATH_COUNT,
                        max_channel_saturation_power_of_half: 2,
                        previously_failed_channels: Vec::new(),
-                       final_cltv_expiry_delta,
                }
        }
 
@@ -626,8 +627,12 @@ impl PaymentParameters {
        pub fn with_bolt11_features(self, features: InvoiceFeatures) -> Result<Self, ()> {
                match self.payee {
                        Payee::Blinded { .. } => Err(()),
-                       Payee::Clear { route_hints, node_id, .. } =>
-                               Ok(Self { payee: Payee::Clear { route_hints, node_id, features: Some(features) }, ..self })
+                       Payee::Clear { route_hints, node_id, final_cltv_expiry_delta, .. } =>
+                               Ok(Self {
+                                       payee: Payee::Clear {
+                                               route_hints, node_id, features: Some(features), final_cltv_expiry_delta
+                                       }, ..self
+                               })
                }
        }
 
@@ -638,8 +643,12 @@ impl PaymentParameters {
        pub fn with_route_hints(self, route_hints: Vec<RouteHint>) -> Result<Self, ()> {
                match self.payee {
                        Payee::Blinded { .. } => Err(()),
-                       Payee::Clear { node_id, features, .. } =>
-                               Ok(Self { payee: Payee::Clear { route_hints, node_id, features }, ..self })
+                       Payee::Clear { node_id, features, final_cltv_expiry_delta, .. } =>
+                               Ok(Self {
+                                       payee: Payee::Clear {
+                                               route_hints, node_id, features, final_cltv_expiry_delta,
+                                       }, ..self
+                               })
                }
        }
 
@@ -682,6 +691,11 @@ pub enum Payee {
                /// Aggregated routing info and blinded paths, for routing to the payee without knowing their
                /// node id.
                route_hints: Vec<(BlindedPayInfo, BlindedPath)>,
+               /// Features supported by the payee.
+               ///
+               /// May be set from the payee's invoice. May be `None` if the invoice does not contain any
+               /// features.
+               features: Option<Bolt12InvoiceFeatures>,
        },
        /// The recipient included these route hints in their BOLT11 invoice.
        Clear {
@@ -696,6 +710,8 @@ pub enum Payee {
                ///
                /// [`for_keysend`]: PaymentParameters::for_keysend
                features: Option<InvoiceFeatures>,
+               /// The minimum CLTV delta at the end of the route. This value must not be zero.
+               final_cltv_expiry_delta: u32,
        },
 }
 
@@ -709,17 +725,69 @@ impl Payee {
        fn node_features(&self) -> Option<NodeFeatures> {
                match self {
                        Self::Clear { features, .. } => features.as_ref().map(|f| f.to_context()),
-                       _ => None,
+                       Self::Blinded { features, .. } => features.as_ref().map(|f| f.to_context()),
                }
        }
        fn supports_basic_mpp(&self) -> bool {
                match self {
                        Self::Clear { features, .. } => features.as_ref().map_or(false, |f| f.supports_basic_mpp()),
-                       _ => false,
+                       Self::Blinded { features, .. } => features.as_ref().map_or(false, |f| f.supports_basic_mpp()),
+               }
+       }
+       fn features(&self) -> Option<FeaturesRef> {
+               match self {
+                       Self::Clear { features, .. } => features.as_ref().map(|f| FeaturesRef::Bolt11(f)),
+                       Self::Blinded { features, .. } => features.as_ref().map(|f| FeaturesRef::Bolt12(f)),
+               }
+       }
+       fn final_cltv_expiry_delta(&self) -> Option<u32> {
+               match self {
+                       Self::Clear { final_cltv_expiry_delta, .. } => Some(*final_cltv_expiry_delta),
+                       _ => None,
                }
        }
 }
 
+enum FeaturesRef<'a> {
+       Bolt11(&'a InvoiceFeatures),
+       Bolt12(&'a Bolt12InvoiceFeatures),
+}
+enum Features {
+       Bolt11(InvoiceFeatures),
+       Bolt12(Bolt12InvoiceFeatures),
+}
+
+impl Features {
+       fn bolt12(self) -> Option<Bolt12InvoiceFeatures> {
+               match self {
+                       Self::Bolt12(f) => Some(f),
+                       _ => None,
+               }
+       }
+       fn bolt11(self) -> Option<InvoiceFeatures> {
+               match self {
+                       Self::Bolt11(f) => Some(f),
+                       _ => None,
+               }
+       }
+}
+
+impl<'a> Writeable for FeaturesRef<'a> {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               match self {
+                       Self::Bolt11(f) => Ok(f.write(w)?),
+                       Self::Bolt12(f) => Ok(f.write(w)?),
+               }
+       }
+}
+
+impl ReadableArgs<bool> for Features {
+       fn read<R: io::Read>(reader: &mut R, bolt11: bool) -> Result<Self, DecodeError> {
+               if bolt11 { return Ok(Self::Bolt11(Readable::read(reader)?)) }
+               Ok(Self::Bolt12(Readable::read(reader)?))
+       }
+}
+
 /// A list of hops along a payment path terminating with a channel to the recipient.
 #[derive(Clone, Debug, Hash, Eq, PartialEq)]
 pub struct RouteHint(pub Vec<RouteHintHop>);
@@ -1202,7 +1270,8 @@ where L::Target: Logger {
                _ => return Err(LightningError{err: "Routing to blinded paths isn't supported yet".to_owned(), action: ErrorAction::IgnoreError}),
 
        }
-       if payment_params.max_total_cltv_expiry_delta <= payment_params.final_cltv_expiry_delta {
+       let final_cltv_expiry_delta = payment_params.payee.final_cltv_expiry_delta().unwrap_or(0);
+       if payment_params.max_total_cltv_expiry_delta <= final_cltv_expiry_delta {
                return Err(LightningError{err: "Can't find a route where the maximum total CLTV expiry delta is below the final CLTV expiry.".to_owned(), action: ErrorAction::IgnoreError});
        }
 
@@ -1433,9 +1502,9 @@ where L::Target: Logger {
                                        // In order to already account for some of the privacy enhancing random CLTV
                                        // expiry delta offset we add on top later, we subtract a rough estimate
                                        // (2*MEDIAN_HOP_CLTV_EXPIRY_DELTA) here.
-                                       let max_total_cltv_expiry_delta = (payment_params.max_total_cltv_expiry_delta - payment_params.final_cltv_expiry_delta)
+                                       let max_total_cltv_expiry_delta = (payment_params.max_total_cltv_expiry_delta - final_cltv_expiry_delta)
                                                .checked_sub(2*MEDIAN_HOP_CLTV_EXPIRY_DELTA)
-                                               .unwrap_or(payment_params.max_total_cltv_expiry_delta - payment_params.final_cltv_expiry_delta);
+                                               .unwrap_or(payment_params.max_total_cltv_expiry_delta - final_cltv_expiry_delta);
                                        let hop_total_cltv_delta = ($next_hops_cltv_delta as u32)
                                                .saturating_add($candidate.cltv_expiry_delta());
                                        let exceeds_cltv_delta_limit = hop_total_cltv_delta > max_total_cltv_expiry_delta;
@@ -2125,7 +2194,7 @@ where L::Target: Logger {
                }).collect::<Vec<_>>();
                // Propagate the cltv_expiry_delta one hop backwards since the delta from the current hop is
                // applicable for the previous hop.
-               path.iter_mut().rev().fold(payment_params.final_cltv_expiry_delta, |prev_cltv_expiry_delta, hop| {
+               path.iter_mut().rev().fold(final_cltv_expiry_delta, |prev_cltv_expiry_delta, hop| {
                        core::mem::replace(&mut hop.as_mut().unwrap().cltv_expiry_delta, prev_cltv_expiry_delta)
                });
                selected_paths.push(path);
@@ -2334,7 +2403,7 @@ mod tests {
        use crate::routing::scoring::{ChannelUsage, FixedPenaltyScorer, Score, ProbabilisticScorer, ProbabilisticScoringParameters};
        use crate::routing::test_utils::{add_channel, add_or_update_node, build_graph, build_line_graph, id_to_feature_flags, get_nodes, update_channel};
        use crate::chain::transaction::OutPoint;
-       use crate::chain::keysinterface::EntropySource;
+       use crate::sign::EntropySource;
        use crate::ln::features::{ChannelFeatures, InitFeatures, NodeFeatures};
        use crate::ln::msgs::{ErrorAction, LightningError, UnsignedChannelUpdate, MAX_VALUE_MSAT};
        use crate::ln::channelmanager;
@@ -6008,7 +6077,7 @@ mod benches {
        use bitcoin::hashes::Hash;
        use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
        use crate::chain::transaction::OutPoint;
-       use crate::chain::keysinterface::{EntropySource, KeysManager};
+       use crate::sign::{EntropySource, KeysManager};
        use crate::ln::channelmanager::{self, ChannelCounterparty, ChannelDetails};
        use crate::ln::features::InvoiceFeatures;
        use crate::routing::gossip::NetworkGraph;