Fix ser for PaymentRelay and PaymentConstraints.
[rust-lightning] / lightning / src / blinded_path / payment.rs
index 32181f7889c350fa3f23f2bdfe37a437d99808d4..7f938d6617f71a49ca65b5ddc9637ad9296b5665 100644 (file)
@@ -8,11 +8,12 @@ use crate::blinded_path::BlindedHop;
 use crate::blinded_path::utils;
 use crate::io;
 use crate::ln::PaymentSecret;
+use crate::ln::channelmanager::CounterpartyForwardingInfo;
 use crate::ln::features::BlindedHopFeatures;
 use crate::ln::msgs::DecodeError;
 use crate::offers::invoice::BlindedPayInfo;
 use crate::prelude::*;
-use crate::util::ser::{Readable, Writeable, Writer};
+use crate::util::ser::{HighZeroBytesDroppedBigSize, Readable, Writeable, Writer};
 
 use core::convert::TryFrom;
 
@@ -89,14 +90,34 @@ pub struct PaymentRelay {
 /// [`BlindedHop`]: crate::blinded_path::BlindedHop
 #[derive(Clone, Debug)]
 pub struct PaymentConstraints {
-       /// The maximum total CLTV delta that is acceptable when relaying a payment over this
-       /// [`BlindedHop`].
+       /// The maximum total CLTV that is acceptable when relaying a payment over this [`BlindedHop`].
        pub max_cltv_expiry: u32,
        /// The minimum value, in msat, that may be accepted by the node corresponding to this
        /// [`BlindedHop`].
        pub htlc_minimum_msat: u64,
 }
 
+impl TryFrom<CounterpartyForwardingInfo> for PaymentRelay {
+       type Error = ();
+
+       fn try_from(info: CounterpartyForwardingInfo) -> Result<Self, ()> {
+               let CounterpartyForwardingInfo {
+                       fee_base_msat, fee_proportional_millionths, cltv_expiry_delta
+               } = info;
+
+               // Avoid exposing esoteric CLTV expiry deltas
+               let cltv_expiry_delta = match cltv_expiry_delta {
+                       0..=40 => 40,
+                       41..=80 => 80,
+                       81..=144 => 144,
+                       145..=216 => 216,
+                       _ => return Err(()),
+               };
+
+               Ok(Self { cltv_expiry_delta, fee_proportional_millionths, fee_base_msat })
+       }
+}
+
 impl Writeable for ForwardTlvs {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
                encode_tlv_stream!(w, {
@@ -173,7 +194,7 @@ pub(super) fn blinded_hops<T: secp256k1::Signing + secp256k1::Verification>(
 }
 
 /// `None` if underflow occurs.
-fn amt_to_forward_msat(inbound_amt_msat: u64, payment_relay: &PaymentRelay) -> Option<u64> {
+pub(crate) fn amt_to_forward_msat(inbound_amt_msat: u64, payment_relay: &PaymentRelay) -> Option<u64> {
        let inbound_amt = inbound_amt_msat as u128;
        let base = payment_relay.fee_base_msat as u128;
        let prop = payment_relay.fee_proportional_millionths as u128;
@@ -254,16 +275,35 @@ pub(super) fn compute_payinfo(
        })
 }
 
-impl_writeable_msg!(PaymentRelay, {
-       cltv_expiry_delta,
-       fee_proportional_millionths,
-       fee_base_msat
-}, {});
+impl Writeable for PaymentRelay {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               self.cltv_expiry_delta.write(w)?;
+               self.fee_proportional_millionths.write(w)?;
+               HighZeroBytesDroppedBigSize(self.fee_base_msat).write(w)
+       }
+}
+impl Readable for PaymentRelay {
+       fn read<R: io::Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let cltv_expiry_delta: u16 = Readable::read(r)?;
+               let fee_proportional_millionths: u32 = Readable::read(r)?;
+               let fee_base_msat: HighZeroBytesDroppedBigSize<u32> = Readable::read(r)?;
+               Ok(Self { cltv_expiry_delta, fee_proportional_millionths, fee_base_msat: fee_base_msat.0 })
+       }
+}
 
-impl_writeable_msg!(PaymentConstraints, {
-       max_cltv_expiry,
-       htlc_minimum_msat
-}, {});
+impl Writeable for PaymentConstraints {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               self.max_cltv_expiry.write(w)?;
+               HighZeroBytesDroppedBigSize(self.htlc_minimum_msat).write(w)
+       }
+}
+impl Readable for PaymentConstraints {
+       fn read<R: io::Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let max_cltv_expiry: u32 = Readable::read(r)?;
+               let htlc_minimum_msat: HighZeroBytesDroppedBigSize<u64> = Readable::read(r)?;
+               Ok(Self { max_cltv_expiry, htlc_minimum_msat: htlc_minimum_msat.0 })
+       }
+}
 
 #[cfg(test)]
 mod tests {