Merge pull request #3127 from wvanlint/non_strict_forwarding
[rust-lightning] / lightning / src / blinded_path / payment.rs
1 // This file is Copyright its original authors, visible in version control
2 // history.
3 //
4 // This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
5 // or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
7 // You may not use this file except in accordance with one or both of these
8 // licenses.
9
10 //! Data structures and methods for constructing [`BlindedPath`]s to send a payment over.
11 //!
12 //! [`BlindedPath`]: crate::blinded_path::BlindedPath
13
14 use bitcoin::secp256k1::{self, PublicKey, Secp256k1, SecretKey};
15
16 use crate::blinded_path::{BlindedHop, BlindedPath, IntroductionNode, NodeIdLookUp};
17 use crate::blinded_path::utils;
18 use crate::crypto::streams::ChaChaPolyReadAdapter;
19 use crate::io;
20 use crate::io::Cursor;
21 use crate::ln::types::PaymentSecret;
22 use crate::ln::channel_state::CounterpartyForwardingInfo;
23 use crate::ln::features::BlindedHopFeatures;
24 use crate::ln::msgs::DecodeError;
25 use crate::ln::onion_utils;
26 use crate::offers::invoice::BlindedPayInfo;
27 use crate::offers::invoice_request::InvoiceRequestFields;
28 use crate::offers::offer::OfferId;
29 use crate::sign::{NodeSigner, Recipient};
30 use crate::util::ser::{FixedLengthReader, LengthReadableArgs, HighZeroBytesDroppedBigSize, Readable, Writeable, Writer};
31
32 use core::mem;
33 use core::ops::Deref;
34
35 #[allow(unused_imports)]
36 use crate::prelude::*;
37
38 /// An intermediate node, its outbound channel, and relay parameters.
39 #[derive(Clone, Debug)]
40 pub struct ForwardNode {
41         /// The TLVs for this node's [`BlindedHop`], where the fee parameters contained within are also
42         /// used for [`BlindedPayInfo`] construction.
43         pub tlvs: ForwardTlvs,
44         /// This node's pubkey.
45         pub node_id: PublicKey,
46         /// The maximum value, in msat, that may be accepted by this node.
47         pub htlc_maximum_msat: u64,
48 }
49
50 /// Data to construct a [`BlindedHop`] for forwarding a payment.
51 #[derive(Clone, Debug)]
52 pub struct ForwardTlvs {
53         /// The short channel id this payment should be forwarded out over.
54         pub short_channel_id: u64,
55         /// Payment parameters for relaying over [`Self::short_channel_id`].
56         pub payment_relay: PaymentRelay,
57         /// Payment constraints for relaying over [`Self::short_channel_id`].
58         pub payment_constraints: PaymentConstraints,
59         /// Supported and required features when relaying a payment onion containing this object's
60         /// corresponding [`BlindedHop::encrypted_payload`].
61         ///
62         /// [`BlindedHop::encrypted_payload`]: crate::blinded_path::BlindedHop::encrypted_payload
63         pub features: BlindedHopFeatures,
64 }
65
66 /// Data to construct a [`BlindedHop`] for receiving a payment. This payload is custom to LDK and
67 /// may not be valid if received by another lightning implementation.
68 #[derive(Clone, Debug)]
69 pub struct ReceiveTlvs {
70         /// Used to authenticate the sender of a payment to the receiver and tie MPP HTLCs together.
71         pub payment_secret: PaymentSecret,
72         /// Constraints for the receiver of this payment.
73         pub payment_constraints: PaymentConstraints,
74         /// Context for the receiver of this payment.
75         pub payment_context: PaymentContext,
76 }
77
78 /// Data to construct a [`BlindedHop`] for sending a payment over.
79 ///
80 /// [`BlindedHop`]: crate::blinded_path::BlindedHop
81 pub(crate) enum BlindedPaymentTlvs {
82         /// This blinded payment data is for a forwarding node.
83         Forward(ForwardTlvs),
84         /// This blinded payment data is for the receiving node.
85         Receive(ReceiveTlvs),
86 }
87
88 // Used to include forward and receive TLVs in the same iterator for encoding.
89 enum BlindedPaymentTlvsRef<'a> {
90         Forward(&'a ForwardTlvs),
91         Receive(&'a ReceiveTlvs),
92 }
93
94 /// Parameters for relaying over a given [`BlindedHop`].
95 ///
96 /// [`BlindedHop`]: crate::blinded_path::BlindedHop
97 #[derive(Clone, Debug)]
98 pub struct PaymentRelay {
99         /// Number of blocks subtracted from an incoming HTLC's `cltv_expiry` for this [`BlindedHop`].
100         pub cltv_expiry_delta: u16,
101         /// Liquidity fee charged (in millionths of the amount transferred) for relaying a payment over
102         /// this [`BlindedHop`], (i.e., 10,000 is 1%).
103         pub fee_proportional_millionths: u32,
104         /// Base fee charged (in millisatoshi) for relaying a payment over this [`BlindedHop`].
105         pub fee_base_msat: u32,
106 }
107
108 /// Constraints for relaying over a given [`BlindedHop`].
109 ///
110 /// [`BlindedHop`]: crate::blinded_path::BlindedHop
111 #[derive(Clone, Debug)]
112 pub struct PaymentConstraints {
113         /// The maximum total CLTV that is acceptable when relaying a payment over this [`BlindedHop`].
114         pub max_cltv_expiry: u32,
115         /// The minimum value, in msat, that may be accepted by the node corresponding to this
116         /// [`BlindedHop`].
117         pub htlc_minimum_msat: u64,
118 }
119
120 /// The context of an inbound payment, which is included in a [`BlindedPath`] via [`ReceiveTlvs`]
121 /// and surfaced in [`PaymentPurpose`].
122 ///
123 /// [`BlindedPath`]: crate::blinded_path::BlindedPath
124 /// [`PaymentPurpose`]: crate::events::PaymentPurpose
125 #[derive(Clone, Debug, Eq, PartialEq)]
126 pub enum PaymentContext {
127         /// The payment context was unknown.
128         Unknown(UnknownPaymentContext),
129
130         /// The payment was made for an invoice requested from a BOLT 12 [`Offer`].
131         ///
132         /// [`Offer`]: crate::offers::offer::Offer
133         Bolt12Offer(Bolt12OfferContext),
134
135         /// The payment was made for an invoice sent for a BOLT 12 [`Refund`].
136         ///
137         /// [`Refund`]: crate::offers::refund::Refund
138         Bolt12Refund(Bolt12RefundContext),
139 }
140
141 // Used when writing PaymentContext in Event::PaymentClaimable to avoid cloning.
142 pub(crate) enum PaymentContextRef<'a> {
143         Bolt12Offer(&'a Bolt12OfferContext),
144         Bolt12Refund(&'a Bolt12RefundContext),
145 }
146
147 /// An unknown payment context.
148 #[derive(Clone, Debug, Eq, PartialEq)]
149 pub struct UnknownPaymentContext(());
150
151 /// The context of a payment made for an invoice requested from a BOLT 12 [`Offer`].
152 ///
153 /// [`Offer`]: crate::offers::offer::Offer
154 #[derive(Clone, Debug, Eq, PartialEq)]
155 pub struct Bolt12OfferContext {
156         /// The identifier of the [`Offer`].
157         ///
158         /// [`Offer`]: crate::offers::offer::Offer
159         pub offer_id: OfferId,
160
161         /// Fields from an [`InvoiceRequest`] sent for a [`Bolt12Invoice`].
162         ///
163         /// [`InvoiceRequest`]: crate::offers::invoice_request::InvoiceRequest
164         /// [`Bolt12Invoice`]: crate::offers::invoice::Bolt12Invoice
165         pub invoice_request: InvoiceRequestFields,
166 }
167
168 /// The context of a payment made for an invoice sent for a BOLT 12 [`Refund`].
169 ///
170 /// [`Refund`]: crate::offers::refund::Refund
171 #[derive(Clone, Debug, Eq, PartialEq)]
172 pub struct Bolt12RefundContext {}
173
174 impl PaymentContext {
175         pub(crate) fn unknown() -> Self {
176                 PaymentContext::Unknown(UnknownPaymentContext(()))
177         }
178 }
179
180 impl TryFrom<CounterpartyForwardingInfo> for PaymentRelay {
181         type Error = ();
182
183         fn try_from(info: CounterpartyForwardingInfo) -> Result<Self, ()> {
184                 let CounterpartyForwardingInfo {
185                         fee_base_msat, fee_proportional_millionths, cltv_expiry_delta
186                 } = info;
187
188                 // Avoid exposing esoteric CLTV expiry deltas
189                 let cltv_expiry_delta = match cltv_expiry_delta {
190                         0..=40 => 40,
191                         41..=80 => 80,
192                         81..=144 => 144,
193                         145..=216 => 216,
194                         _ => return Err(()),
195                 };
196
197                 Ok(Self { cltv_expiry_delta, fee_proportional_millionths, fee_base_msat })
198         }
199 }
200
201 impl Writeable for ForwardTlvs {
202         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
203                 let features_opt =
204                         if self.features == BlindedHopFeatures::empty() { None }
205                         else { Some(&self.features) };
206                 encode_tlv_stream!(w, {
207                         (2, self.short_channel_id, required),
208                         (10, self.payment_relay, required),
209                         (12, self.payment_constraints, required),
210                         (14, features_opt, option)
211                 });
212                 Ok(())
213         }
214 }
215
216 impl Writeable for ReceiveTlvs {
217         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
218                 encode_tlv_stream!(w, {
219                         (12, self.payment_constraints, required),
220                         (65536, self.payment_secret, required),
221                         (65537, self.payment_context, required)
222                 });
223                 Ok(())
224         }
225 }
226
227 impl<'a> Writeable for BlindedPaymentTlvsRef<'a> {
228         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
229                 // TODO: write padding
230                 match self {
231                         Self::Forward(tlvs) => tlvs.write(w)?,
232                         Self::Receive(tlvs) => tlvs.write(w)?,
233                 }
234                 Ok(())
235         }
236 }
237
238 impl Readable for BlindedPaymentTlvs {
239         fn read<R: io::Read>(r: &mut R) -> Result<Self, DecodeError> {
240                 _init_and_read_tlv_stream!(r, {
241                         (1, _padding, option),
242                         (2, scid, option),
243                         (10, payment_relay, option),
244                         (12, payment_constraints, required),
245                         (14, features, option),
246                         (65536, payment_secret, option),
247                         (65537, payment_context, (default_value, PaymentContext::unknown())),
248                 });
249                 let _padding: Option<utils::Padding> = _padding;
250
251                 if let Some(short_channel_id) = scid {
252                         if payment_secret.is_some() {
253                                 return Err(DecodeError::InvalidValue)
254                         }
255                         Ok(BlindedPaymentTlvs::Forward(ForwardTlvs {
256                                 short_channel_id,
257                                 payment_relay: payment_relay.ok_or(DecodeError::InvalidValue)?,
258                                 payment_constraints: payment_constraints.0.unwrap(),
259                                 features: features.unwrap_or_else(BlindedHopFeatures::empty),
260                         }))
261                 } else {
262                         if payment_relay.is_some() || features.is_some() { return Err(DecodeError::InvalidValue) }
263                         Ok(BlindedPaymentTlvs::Receive(ReceiveTlvs {
264                                 payment_secret: payment_secret.ok_or(DecodeError::InvalidValue)?,
265                                 payment_constraints: payment_constraints.0.unwrap(),
266                                 payment_context: payment_context.0.unwrap(),
267                         }))
268                 }
269         }
270 }
271
272 /// Construct blinded payment hops for the given `intermediate_nodes` and payee info.
273 pub(super) fn blinded_hops<T: secp256k1::Signing + secp256k1::Verification>(
274         secp_ctx: &Secp256k1<T>, intermediate_nodes: &[ForwardNode],
275         payee_node_id: PublicKey, payee_tlvs: ReceiveTlvs, session_priv: &SecretKey
276 ) -> Result<Vec<BlindedHop>, secp256k1::Error> {
277         let pks = intermediate_nodes.iter().map(|node| &node.node_id)
278                 .chain(core::iter::once(&payee_node_id));
279         let tlvs = intermediate_nodes.iter().map(|node| BlindedPaymentTlvsRef::Forward(&node.tlvs))
280                 .chain(core::iter::once(BlindedPaymentTlvsRef::Receive(&payee_tlvs)));
281         utils::construct_blinded_hops(secp_ctx, pks, tlvs, session_priv)
282 }
283
284 // Advance the blinded onion payment path by one hop, so make the second hop into the new
285 // introduction node.
286 //
287 // Will only modify `path` when returning `Ok`.
288 pub(crate) fn advance_path_by_one<NS: Deref, NL: Deref, T>(
289         path: &mut BlindedPath, node_signer: &NS, node_id_lookup: &NL, secp_ctx: &Secp256k1<T>
290 ) -> Result<(), ()>
291 where
292         NS::Target: NodeSigner,
293         NL::Target: NodeIdLookUp,
294         T: secp256k1::Signing + secp256k1::Verification,
295 {
296         let control_tlvs_ss = node_signer.ecdh(Recipient::Node, &path.blinding_point, None)?;
297         let rho = onion_utils::gen_rho_from_shared_secret(&control_tlvs_ss.secret_bytes());
298         let encrypted_control_tlvs = &path.blinded_hops.get(0).ok_or(())?.encrypted_payload;
299         let mut s = Cursor::new(encrypted_control_tlvs);
300         let mut reader = FixedLengthReader::new(&mut s, encrypted_control_tlvs.len() as u64);
301         match ChaChaPolyReadAdapter::read(&mut reader, rho) {
302                 Ok(ChaChaPolyReadAdapter {
303                         readable: BlindedPaymentTlvs::Forward(ForwardTlvs { short_channel_id, .. })
304                 }) => {
305                         let next_node_id = match node_id_lookup.next_node_id(short_channel_id) {
306                                 Some(node_id) => node_id,
307                                 None => return Err(()),
308                         };
309                         let mut new_blinding_point = onion_utils::next_hop_pubkey(
310                                 secp_ctx, path.blinding_point, control_tlvs_ss.as_ref()
311                         ).map_err(|_| ())?;
312                         mem::swap(&mut path.blinding_point, &mut new_blinding_point);
313                         path.introduction_node = IntroductionNode::NodeId(next_node_id);
314                         path.blinded_hops.remove(0);
315                         Ok(())
316                 },
317                 _ => Err(())
318         }
319 }
320
321 /// `None` if underflow occurs.
322 pub(crate) fn amt_to_forward_msat(inbound_amt_msat: u64, payment_relay: &PaymentRelay) -> Option<u64> {
323         let inbound_amt = inbound_amt_msat as u128;
324         let base = payment_relay.fee_base_msat as u128;
325         let prop = payment_relay.fee_proportional_millionths as u128;
326
327         let post_base_fee_inbound_amt =
328                 if let Some(amt) = inbound_amt.checked_sub(base) { amt } else { return None };
329         let mut amt_to_forward =
330                 (post_base_fee_inbound_amt * 1_000_000 + 1_000_000 + prop - 1) / (prop + 1_000_000);
331
332         let fee = ((amt_to_forward * prop) / 1_000_000) + base;
333         if inbound_amt - fee < amt_to_forward {
334                 // Rounding up the forwarded amount resulted in underpaying this node, so take an extra 1 msat
335                 // in fee to compensate.
336                 amt_to_forward -= 1;
337         }
338         debug_assert_eq!(amt_to_forward + fee, inbound_amt);
339         u64::try_from(amt_to_forward).ok()
340 }
341
342 pub(super) fn compute_payinfo(
343         intermediate_nodes: &[ForwardNode], payee_tlvs: &ReceiveTlvs, payee_htlc_maximum_msat: u64,
344         min_final_cltv_expiry_delta: u16
345 ) -> Result<BlindedPayInfo, ()> {
346         let mut curr_base_fee: u64 = 0;
347         let mut curr_prop_mil: u64 = 0;
348         let mut cltv_expiry_delta: u16 = min_final_cltv_expiry_delta;
349         for tlvs in intermediate_nodes.iter().rev().map(|n| &n.tlvs) {
350                 // In the future, we'll want to take the intersection of all supported features for the
351                 // `BlindedPayInfo`, but there are no features in that context right now.
352                 if tlvs.features.requires_unknown_bits_from(&BlindedHopFeatures::empty()) { return Err(()) }
353
354                 let next_base_fee = tlvs.payment_relay.fee_base_msat as u64;
355                 let next_prop_mil = tlvs.payment_relay.fee_proportional_millionths as u64;
356                 // Use integer arithmetic to compute `ceil(a/b)` as `(a+b-1)/b`
357                 // ((curr_base_fee * (1_000_000 + next_prop_mil)) / 1_000_000) + next_base_fee
358                 curr_base_fee = curr_base_fee.checked_mul(1_000_000 + next_prop_mil)
359                         .and_then(|f| f.checked_add(1_000_000 - 1))
360                         .map(|f| f / 1_000_000)
361                         .and_then(|f| f.checked_add(next_base_fee))
362                         .ok_or(())?;
363                 // ceil(((curr_prop_mil + 1_000_000) * (next_prop_mil + 1_000_000)) / 1_000_000) - 1_000_000
364                 curr_prop_mil = curr_prop_mil.checked_add(1_000_000)
365                         .and_then(|f1| next_prop_mil.checked_add(1_000_000).and_then(|f2| f2.checked_mul(f1)))
366                         .and_then(|f| f.checked_add(1_000_000 - 1))
367                         .map(|f| f / 1_000_000)
368                         .and_then(|f| f.checked_sub(1_000_000))
369                         .ok_or(())?;
370
371                 cltv_expiry_delta = cltv_expiry_delta.checked_add(tlvs.payment_relay.cltv_expiry_delta).ok_or(())?;
372         }
373
374         let mut htlc_minimum_msat: u64 = 1;
375         let mut htlc_maximum_msat: u64 = 21_000_000 * 100_000_000 * 1_000; // Total bitcoin supply
376         for node in intermediate_nodes.iter() {
377                 // The min htlc for an intermediate node is that node's min minus the fees charged by all of the
378                 // following hops for forwarding that min, since that fee amount will automatically be included
379                 // in the amount that this node receives and contribute towards reaching its min.
380                 htlc_minimum_msat = amt_to_forward_msat(
381                         core::cmp::max(node.tlvs.payment_constraints.htlc_minimum_msat, htlc_minimum_msat),
382                         &node.tlvs.payment_relay
383                 ).unwrap_or(1); // If underflow occurs, we definitely reached this node's min
384                 htlc_maximum_msat = amt_to_forward_msat(
385                         core::cmp::min(node.htlc_maximum_msat, htlc_maximum_msat), &node.tlvs.payment_relay
386                 ).ok_or(())?; // If underflow occurs, we cannot send to this hop without exceeding their max
387         }
388         htlc_minimum_msat = core::cmp::max(
389                 payee_tlvs.payment_constraints.htlc_minimum_msat, htlc_minimum_msat
390         );
391         htlc_maximum_msat = core::cmp::min(payee_htlc_maximum_msat, htlc_maximum_msat);
392
393         if htlc_maximum_msat < htlc_minimum_msat { return Err(()) }
394         Ok(BlindedPayInfo {
395                 fee_base_msat: u32::try_from(curr_base_fee).map_err(|_| ())?,
396                 fee_proportional_millionths: u32::try_from(curr_prop_mil).map_err(|_| ())?,
397                 cltv_expiry_delta,
398                 htlc_minimum_msat,
399                 htlc_maximum_msat,
400                 features: BlindedHopFeatures::empty(),
401         })
402 }
403
404 impl Writeable for PaymentRelay {
405         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
406                 self.cltv_expiry_delta.write(w)?;
407                 self.fee_proportional_millionths.write(w)?;
408                 HighZeroBytesDroppedBigSize(self.fee_base_msat).write(w)
409         }
410 }
411 impl Readable for PaymentRelay {
412         fn read<R: io::Read>(r: &mut R) -> Result<Self, DecodeError> {
413                 let cltv_expiry_delta: u16 = Readable::read(r)?;
414                 let fee_proportional_millionths: u32 = Readable::read(r)?;
415                 let fee_base_msat: HighZeroBytesDroppedBigSize<u32> = Readable::read(r)?;
416                 Ok(Self { cltv_expiry_delta, fee_proportional_millionths, fee_base_msat: fee_base_msat.0 })
417         }
418 }
419
420 impl Writeable for PaymentConstraints {
421         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
422                 self.max_cltv_expiry.write(w)?;
423                 HighZeroBytesDroppedBigSize(self.htlc_minimum_msat).write(w)
424         }
425 }
426 impl Readable for PaymentConstraints {
427         fn read<R: io::Read>(r: &mut R) -> Result<Self, DecodeError> {
428                 let max_cltv_expiry: u32 = Readable::read(r)?;
429                 let htlc_minimum_msat: HighZeroBytesDroppedBigSize<u64> = Readable::read(r)?;
430                 Ok(Self { max_cltv_expiry, htlc_minimum_msat: htlc_minimum_msat.0 })
431         }
432 }
433
434 impl_writeable_tlv_based_enum!(PaymentContext,
435         ;
436         (0, Unknown),
437         (1, Bolt12Offer),
438         (2, Bolt12Refund),
439 );
440
441 impl<'a> Writeable for PaymentContextRef<'a> {
442         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
443                 match self {
444                         PaymentContextRef::Bolt12Offer(context) => {
445                                 1u8.write(w)?;
446                                 context.write(w)?;
447                         },
448                         PaymentContextRef::Bolt12Refund(context) => {
449                                 2u8.write(w)?;
450                                 context.write(w)?;
451                         },
452                 }
453
454                 Ok(())
455         }
456 }
457
458 impl Writeable for UnknownPaymentContext {
459         fn write<W: Writer>(&self, _w: &mut W) -> Result<(), io::Error> {
460                 Ok(())
461         }
462 }
463
464 impl Readable for UnknownPaymentContext {
465         fn read<R: io::Read>(_r: &mut R) -> Result<Self, DecodeError> {
466                 Ok(UnknownPaymentContext(()))
467         }
468 }
469
470 impl_writeable_tlv_based!(Bolt12OfferContext, {
471         (0, offer_id, required),
472         (2, invoice_request, required),
473 });
474
475 impl_writeable_tlv_based!(Bolt12RefundContext, {});
476
477 #[cfg(test)]
478 mod tests {
479         use bitcoin::secp256k1::PublicKey;
480         use crate::blinded_path::payment::{ForwardNode, ForwardTlvs, ReceiveTlvs, PaymentConstraints, PaymentContext, PaymentRelay};
481         use crate::ln::types::PaymentSecret;
482         use crate::ln::features::BlindedHopFeatures;
483         use crate::ln::functional_test_utils::TEST_FINAL_CLTV;
484
485         #[test]
486         fn compute_payinfo() {
487                 // Taken from the spec example for aggregating blinded payment info. See
488                 // https://github.com/lightning/bolts/blob/master/proposals/route-blinding.md#blinded-payments
489                 let dummy_pk = PublicKey::from_slice(&[2; 33]).unwrap();
490                 let intermediate_nodes = vec![ForwardNode {
491                         node_id: dummy_pk,
492                         tlvs: ForwardTlvs {
493                                 short_channel_id: 0,
494                                 payment_relay: PaymentRelay {
495                                         cltv_expiry_delta: 144,
496                                         fee_proportional_millionths: 500,
497                                         fee_base_msat: 100,
498                                 },
499                                 payment_constraints: PaymentConstraints {
500                                         max_cltv_expiry: 0,
501                                         htlc_minimum_msat: 100,
502                                 },
503                                 features: BlindedHopFeatures::empty(),
504                         },
505                         htlc_maximum_msat: u64::max_value(),
506                 }, ForwardNode {
507                         node_id: dummy_pk,
508                         tlvs: ForwardTlvs {
509                                 short_channel_id: 0,
510                                 payment_relay: PaymentRelay {
511                                         cltv_expiry_delta: 144,
512                                         fee_proportional_millionths: 500,
513                                         fee_base_msat: 100,
514                                 },
515                                 payment_constraints: PaymentConstraints {
516                                         max_cltv_expiry: 0,
517                                         htlc_minimum_msat: 1_000,
518                                 },
519                                 features: BlindedHopFeatures::empty(),
520                         },
521                         htlc_maximum_msat: u64::max_value(),
522                 }];
523                 let recv_tlvs = ReceiveTlvs {
524                         payment_secret: PaymentSecret([0; 32]),
525                         payment_constraints: PaymentConstraints {
526                                 max_cltv_expiry: 0,
527                                 htlc_minimum_msat: 1,
528                         },
529                         payment_context: PaymentContext::unknown(),
530                 };
531                 let htlc_maximum_msat = 100_000;
532                 let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat, 12).unwrap();
533                 assert_eq!(blinded_payinfo.fee_base_msat, 201);
534                 assert_eq!(blinded_payinfo.fee_proportional_millionths, 1001);
535                 assert_eq!(blinded_payinfo.cltv_expiry_delta, 300);
536                 assert_eq!(blinded_payinfo.htlc_minimum_msat, 900);
537                 assert_eq!(blinded_payinfo.htlc_maximum_msat, htlc_maximum_msat);
538         }
539
540         #[test]
541         fn compute_payinfo_1_hop() {
542                 let recv_tlvs = ReceiveTlvs {
543                         payment_secret: PaymentSecret([0; 32]),
544                         payment_constraints: PaymentConstraints {
545                                 max_cltv_expiry: 0,
546                                 htlc_minimum_msat: 1,
547                         },
548                         payment_context: PaymentContext::unknown(),
549                 };
550                 let blinded_payinfo = super::compute_payinfo(&[], &recv_tlvs, 4242, TEST_FINAL_CLTV as u16).unwrap();
551                 assert_eq!(blinded_payinfo.fee_base_msat, 0);
552                 assert_eq!(blinded_payinfo.fee_proportional_millionths, 0);
553                 assert_eq!(blinded_payinfo.cltv_expiry_delta, TEST_FINAL_CLTV as u16);
554                 assert_eq!(blinded_payinfo.htlc_minimum_msat, 1);
555                 assert_eq!(blinded_payinfo.htlc_maximum_msat, 4242);
556         }
557
558         #[test]
559         fn simple_aggregated_htlc_min() {
560                 // If no hops charge fees, the htlc_minimum_msat should just be the maximum htlc_minimum_msat
561                 // along the path.
562                 let dummy_pk = PublicKey::from_slice(&[2; 33]).unwrap();
563                 let intermediate_nodes = vec![ForwardNode {
564                         node_id: dummy_pk,
565                         tlvs: ForwardTlvs {
566                                 short_channel_id: 0,
567                                 payment_relay: PaymentRelay {
568                                         cltv_expiry_delta: 0,
569                                         fee_proportional_millionths: 0,
570                                         fee_base_msat: 0,
571                                 },
572                                 payment_constraints: PaymentConstraints {
573                                         max_cltv_expiry: 0,
574                                         htlc_minimum_msat: 1,
575                                 },
576                                 features: BlindedHopFeatures::empty(),
577                         },
578                         htlc_maximum_msat: u64::max_value()
579                 }, ForwardNode {
580                         node_id: dummy_pk,
581                         tlvs: ForwardTlvs {
582                                 short_channel_id: 0,
583                                 payment_relay: PaymentRelay {
584                                         cltv_expiry_delta: 0,
585                                         fee_proportional_millionths: 0,
586                                         fee_base_msat: 0,
587                                 },
588                                 payment_constraints: PaymentConstraints {
589                                         max_cltv_expiry: 0,
590                                         htlc_minimum_msat: 2_000,
591                                 },
592                                 features: BlindedHopFeatures::empty(),
593                         },
594                         htlc_maximum_msat: u64::max_value()
595                 }];
596                 let recv_tlvs = ReceiveTlvs {
597                         payment_secret: PaymentSecret([0; 32]),
598                         payment_constraints: PaymentConstraints {
599                                 max_cltv_expiry: 0,
600                                 htlc_minimum_msat: 3,
601                         },
602                         payment_context: PaymentContext::unknown(),
603                 };
604                 let htlc_maximum_msat = 100_000;
605                 let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat, TEST_FINAL_CLTV as u16).unwrap();
606                 assert_eq!(blinded_payinfo.htlc_minimum_msat, 2_000);
607         }
608
609         #[test]
610         fn aggregated_htlc_min() {
611                 // Create a path with varying fees and htlc_mins, and make sure htlc_minimum_msat ends up as the
612                 // max (htlc_min - following_fees) along the path.
613                 let dummy_pk = PublicKey::from_slice(&[2; 33]).unwrap();
614                 let intermediate_nodes = vec![ForwardNode {
615                         node_id: dummy_pk,
616                         tlvs: ForwardTlvs {
617                                 short_channel_id: 0,
618                                 payment_relay: PaymentRelay {
619                                         cltv_expiry_delta: 0,
620                                         fee_proportional_millionths: 500,
621                                         fee_base_msat: 1_000,
622                                 },
623                                 payment_constraints: PaymentConstraints {
624                                         max_cltv_expiry: 0,
625                                         htlc_minimum_msat: 5_000,
626                                 },
627                                 features: BlindedHopFeatures::empty(),
628                         },
629                         htlc_maximum_msat: u64::max_value()
630                 }, ForwardNode {
631                         node_id: dummy_pk,
632                         tlvs: ForwardTlvs {
633                                 short_channel_id: 0,
634                                 payment_relay: PaymentRelay {
635                                         cltv_expiry_delta: 0,
636                                         fee_proportional_millionths: 500,
637                                         fee_base_msat: 200,
638                                 },
639                                 payment_constraints: PaymentConstraints {
640                                         max_cltv_expiry: 0,
641                                         htlc_minimum_msat: 2_000,
642                                 },
643                                 features: BlindedHopFeatures::empty(),
644                         },
645                         htlc_maximum_msat: u64::max_value()
646                 }];
647                 let recv_tlvs = ReceiveTlvs {
648                         payment_secret: PaymentSecret([0; 32]),
649                         payment_constraints: PaymentConstraints {
650                                 max_cltv_expiry: 0,
651                                 htlc_minimum_msat: 1,
652                         },
653                         payment_context: PaymentContext::unknown(),
654                 };
655                 let htlc_minimum_msat = 3798;
656                 assert!(super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_minimum_msat - 1, TEST_FINAL_CLTV as u16).is_err());
657
658                 let htlc_maximum_msat = htlc_minimum_msat + 1;
659                 let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat, TEST_FINAL_CLTV as u16).unwrap();
660                 assert_eq!(blinded_payinfo.htlc_minimum_msat, htlc_minimum_msat);
661                 assert_eq!(blinded_payinfo.htlc_maximum_msat, htlc_maximum_msat);
662         }
663
664         #[test]
665         fn aggregated_htlc_max() {
666                 // Create a path with varying fees and `htlc_maximum_msat`s, and make sure the aggregated max
667                 // htlc ends up as the min (htlc_max - following_fees) along the path.
668                 let dummy_pk = PublicKey::from_slice(&[2; 33]).unwrap();
669                 let intermediate_nodes = vec![ForwardNode {
670                         node_id: dummy_pk,
671                         tlvs: ForwardTlvs {
672                                 short_channel_id: 0,
673                                 payment_relay: PaymentRelay {
674                                         cltv_expiry_delta: 0,
675                                         fee_proportional_millionths: 500,
676                                         fee_base_msat: 1_000,
677                                 },
678                                 payment_constraints: PaymentConstraints {
679                                         max_cltv_expiry: 0,
680                                         htlc_minimum_msat: 1,
681                                 },
682                                 features: BlindedHopFeatures::empty(),
683                         },
684                         htlc_maximum_msat: 5_000,
685                 }, ForwardNode {
686                         node_id: dummy_pk,
687                         tlvs: ForwardTlvs {
688                                 short_channel_id: 0,
689                                 payment_relay: PaymentRelay {
690                                         cltv_expiry_delta: 0,
691                                         fee_proportional_millionths: 500,
692                                         fee_base_msat: 1,
693                                 },
694                                 payment_constraints: PaymentConstraints {
695                                         max_cltv_expiry: 0,
696                                         htlc_minimum_msat: 1,
697                                 },
698                                 features: BlindedHopFeatures::empty(),
699                         },
700                         htlc_maximum_msat: 10_000
701                 }];
702                 let recv_tlvs = ReceiveTlvs {
703                         payment_secret: PaymentSecret([0; 32]),
704                         payment_constraints: PaymentConstraints {
705                                 max_cltv_expiry: 0,
706                                 htlc_minimum_msat: 1,
707                         },
708                         payment_context: PaymentContext::unknown(),
709                 };
710
711                 let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, 10_000, TEST_FINAL_CLTV as u16).unwrap();
712                 assert_eq!(blinded_payinfo.htlc_maximum_msat, 3997);
713         }
714 }