Use `crate::prelude::*` rather than specific imports
[rust-lightning] / lightning / src / blinded_path / payment.rs
index 39f16a91692cb30fb1be087191b4c32a506de830..3d56020a626d68cd1e233f53a905cfdfa8acf21c 100644 (file)
@@ -8,13 +8,14 @@ use crate::blinded_path::BlindedHop;
 use crate::blinded_path::utils;
 use crate::io;
 use crate::ln::PaymentSecret;
+use crate::ln::channelmanager::CounterpartyForwardingInfo;
 use crate::ln::features::BlindedHopFeatures;
 use crate::ln::msgs::DecodeError;
 use crate::offers::invoice::BlindedPayInfo;
-use crate::prelude::*;
-use crate::util::ser::{Readable, Writeable, Writer};
+use crate::util::ser::{HighZeroBytesDroppedBigSize, Readable, Writeable, Writer};
 
-use core::convert::TryFrom;
+#[allow(unused_imports)]
+use crate::prelude::*;
 
 /// An intermediate node, its outbound channel, and relay parameters.
 #[derive(Clone, Debug)]
@@ -89,21 +90,44 @@ pub struct PaymentRelay {
 /// [`BlindedHop`]: crate::blinded_path::BlindedHop
 #[derive(Clone, Debug)]
 pub struct PaymentConstraints {
-       /// The maximum total CLTV delta that is acceptable when relaying a payment over this
-       /// [`BlindedHop`].
+       /// The maximum total CLTV that is acceptable when relaying a payment over this [`BlindedHop`].
        pub max_cltv_expiry: u32,
        /// The minimum value, in msat, that may be accepted by the node corresponding to this
        /// [`BlindedHop`].
        pub htlc_minimum_msat: u64,
 }
 
+impl TryFrom<CounterpartyForwardingInfo> for PaymentRelay {
+       type Error = ();
+
+       fn try_from(info: CounterpartyForwardingInfo) -> Result<Self, ()> {
+               let CounterpartyForwardingInfo {
+                       fee_base_msat, fee_proportional_millionths, cltv_expiry_delta
+               } = info;
+
+               // Avoid exposing esoteric CLTV expiry deltas
+               let cltv_expiry_delta = match cltv_expiry_delta {
+                       0..=40 => 40,
+                       41..=80 => 80,
+                       81..=144 => 144,
+                       145..=216 => 216,
+                       _ => return Err(()),
+               };
+
+               Ok(Self { cltv_expiry_delta, fee_proportional_millionths, fee_base_msat })
+       }
+}
+
 impl Writeable for ForwardTlvs {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               let features_opt =
+                       if self.features == BlindedHopFeatures::empty() { None }
+                       else { Some(&self.features) };
                encode_tlv_stream!(w, {
                        (2, self.short_channel_id, required),
                        (10, self.payment_relay, required),
                        (12, self.payment_constraints, required),
-                       (14, self.features, required)
+                       (14, features_opt, option)
                });
                Ok(())
        }
@@ -119,21 +143,6 @@ impl Writeable for ReceiveTlvs {
        }
 }
 
-// This will be removed once we support forwarding blinded HTLCs, because we'll always read a
-// `BlindedPaymentTlvs` instead.
-impl Readable for ReceiveTlvs {
-       fn read<R: io::Read>(r: &mut R) -> Result<Self, DecodeError> {
-               _init_and_read_tlv_stream!(r, {
-                       (12, payment_constraints, required),
-                       (65536, payment_secret, required),
-               });
-               Ok(Self {
-                       payment_secret: payment_secret.0.unwrap(),
-                       payment_constraints: payment_constraints.0.unwrap()
-               })
-       }
-}
-
 impl<'a> Writeable for BlindedPaymentTlvsRef<'a> {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
                // TODO: write padding
@@ -163,7 +172,7 @@ impl Readable for BlindedPaymentTlvs {
                                short_channel_id,
                                payment_relay: payment_relay.ok_or(DecodeError::InvalidValue)?,
                                payment_constraints: payment_constraints.0.unwrap(),
-                               features: features.ok_or(DecodeError::InvalidValue)?,
+                               features: features.unwrap_or_else(BlindedHopFeatures::empty),
                        }))
                } else {
                        if payment_relay.is_some() || features.is_some() { return Err(DecodeError::InvalidValue) }
@@ -188,7 +197,7 @@ pub(super) fn blinded_hops<T: secp256k1::Signing + secp256k1::Verification>(
 }
 
 /// `None` if underflow occurs.
-fn amt_to_forward_msat(inbound_amt_msat: u64, payment_relay: &PaymentRelay) -> Option<u64> {
+pub(crate) fn amt_to_forward_msat(inbound_amt_msat: u64, payment_relay: &PaymentRelay) -> Option<u64> {
        let inbound_amt = inbound_amt_msat as u128;
        let base = payment_relay.fee_base_msat as u128;
        let prop = payment_relay.fee_proportional_millionths as u128;
@@ -209,11 +218,12 @@ fn amt_to_forward_msat(inbound_amt_msat: u64, payment_relay: &PaymentRelay) -> O
 }
 
 pub(super) fn compute_payinfo(
-       intermediate_nodes: &[ForwardNode], payee_tlvs: &ReceiveTlvs, payee_htlc_maximum_msat: u64
+       intermediate_nodes: &[ForwardNode], payee_tlvs: &ReceiveTlvs, payee_htlc_maximum_msat: u64,
+       min_final_cltv_expiry_delta: u16
 ) -> Result<BlindedPayInfo, ()> {
        let mut curr_base_fee: u64 = 0;
        let mut curr_prop_mil: u64 = 0;
-       let mut cltv_expiry_delta: u16 = 0;
+       let mut cltv_expiry_delta: u16 = min_final_cltv_expiry_delta;
        for tlvs in intermediate_nodes.iter().rev().map(|n| &n.tlvs) {
                // In the future, we'll want to take the intersection of all supported features for the
                // `BlindedPayInfo`, but there are no features in that context right now.
@@ -269,16 +279,35 @@ pub(super) fn compute_payinfo(
        })
 }
 
-impl_writeable_msg!(PaymentRelay, {
-       cltv_expiry_delta,
-       fee_proportional_millionths,
-       fee_base_msat
-}, {});
+impl Writeable for PaymentRelay {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               self.cltv_expiry_delta.write(w)?;
+               self.fee_proportional_millionths.write(w)?;
+               HighZeroBytesDroppedBigSize(self.fee_base_msat).write(w)
+       }
+}
+impl Readable for PaymentRelay {
+       fn read<R: io::Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let cltv_expiry_delta: u16 = Readable::read(r)?;
+               let fee_proportional_millionths: u32 = Readable::read(r)?;
+               let fee_base_msat: HighZeroBytesDroppedBigSize<u32> = Readable::read(r)?;
+               Ok(Self { cltv_expiry_delta, fee_proportional_millionths, fee_base_msat: fee_base_msat.0 })
+       }
+}
 
-impl_writeable_msg!(PaymentConstraints, {
-       max_cltv_expiry,
-       htlc_minimum_msat
-}, {});
+impl Writeable for PaymentConstraints {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               self.max_cltv_expiry.write(w)?;
+               HighZeroBytesDroppedBigSize(self.htlc_minimum_msat).write(w)
+       }
+}
+impl Readable for PaymentConstraints {
+       fn read<R: io::Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let max_cltv_expiry: u32 = Readable::read(r)?;
+               let htlc_minimum_msat: HighZeroBytesDroppedBigSize<u64> = Readable::read(r)?;
+               Ok(Self { max_cltv_expiry, htlc_minimum_msat: htlc_minimum_msat.0 })
+       }
+}
 
 #[cfg(test)]
 mod tests {
@@ -286,6 +315,7 @@ mod tests {
        use crate::blinded_path::payment::{ForwardNode, ForwardTlvs, ReceiveTlvs, PaymentConstraints, PaymentRelay};
        use crate::ln::PaymentSecret;
        use crate::ln::features::BlindedHopFeatures;
+       use crate::ln::functional_test_utils::TEST_FINAL_CLTV;
 
        #[test]
        fn compute_payinfo() {
@@ -333,10 +363,10 @@ mod tests {
                        },
                };
                let htlc_maximum_msat = 100_000;
-               let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat).unwrap();
+               let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat, 12).unwrap();
                assert_eq!(blinded_payinfo.fee_base_msat, 201);
                assert_eq!(blinded_payinfo.fee_proportional_millionths, 1001);
-               assert_eq!(blinded_payinfo.cltv_expiry_delta, 288);
+               assert_eq!(blinded_payinfo.cltv_expiry_delta, 300);
                assert_eq!(blinded_payinfo.htlc_minimum_msat, 900);
                assert_eq!(blinded_payinfo.htlc_maximum_msat, htlc_maximum_msat);
        }
@@ -350,10 +380,10 @@ mod tests {
                                htlc_minimum_msat: 1,
                        },
                };
-               let blinded_payinfo = super::compute_payinfo(&[], &recv_tlvs, 4242).unwrap();
+               let blinded_payinfo = super::compute_payinfo(&[], &recv_tlvs, 4242, TEST_FINAL_CLTV as u16).unwrap();
                assert_eq!(blinded_payinfo.fee_base_msat, 0);
                assert_eq!(blinded_payinfo.fee_proportional_millionths, 0);
-               assert_eq!(blinded_payinfo.cltv_expiry_delta, 0);
+               assert_eq!(blinded_payinfo.cltv_expiry_delta, TEST_FINAL_CLTV as u16);
                assert_eq!(blinded_payinfo.htlc_minimum_msat, 1);
                assert_eq!(blinded_payinfo.htlc_maximum_msat, 4242);
        }
@@ -404,7 +434,7 @@ mod tests {
                        },
                };
                let htlc_maximum_msat = 100_000;
-               let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat).unwrap();
+               let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat, TEST_FINAL_CLTV as u16).unwrap();
                assert_eq!(blinded_payinfo.htlc_minimum_msat, 2_000);
        }
 
@@ -454,10 +484,10 @@ mod tests {
                        },
                };
                let htlc_minimum_msat = 3798;
-               assert!(super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_minimum_msat - 1).is_err());
+               assert!(super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_minimum_msat - 1, TEST_FINAL_CLTV as u16).is_err());
 
                let htlc_maximum_msat = htlc_minimum_msat + 1;
-               let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat).unwrap();
+               let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat, TEST_FINAL_CLTV as u16).unwrap();
                assert_eq!(blinded_payinfo.htlc_minimum_msat, htlc_minimum_msat);
                assert_eq!(blinded_payinfo.htlc_maximum_msat, htlc_maximum_msat);
        }
@@ -508,7 +538,7 @@ mod tests {
                        },
                };
 
-               let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, 10_000).unwrap();
+               let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, 10_000, TEST_FINAL_CLTV as u16).unwrap();
                assert_eq!(blinded_payinfo.htlc_maximum_msat, 3997);
        }
 }