Merge pull request #2964 from jbesraa/prune-stale-chanmonitor
authorMatt Corallo <649246+TheBlueMatt@users.noreply.github.com>
Thu, 18 Apr 2024 21:39:43 +0000 (14:39 -0700)
committerGitHub <noreply@github.com>
Thu, 18 Apr 2024 21:39:43 +0000 (14:39 -0700)
Add `archive_fully_resolved_monitors` to `ChainMonitor`

46 files changed:
ci/check-cfg-flags.py
ci/ci-tests.sh
fuzz/src/chanmon_consistency.rs
fuzz/src/onion_message.rs
fuzz/src/router.rs
lightning-background-processor/src/lib.rs
lightning-net-tokio/src/lib.rs
lightning/src/blinded_path/message.rs
lightning/src/blinded_path/mod.rs
lightning/src/blinded_path/payment.rs
lightning/src/chain/chaininterface.rs
lightning/src/chain/mod.rs
lightning/src/events/mod.rs
lightning/src/ln/blinded_payment_tests.rs
lightning/src/ln/chanmon_update_fail_tests.rs
lightning/src/ln/channel.rs
lightning/src/ln/channelmanager.rs
lightning/src/ln/functional_test_utils.rs
lightning/src/ln/functional_tests.rs
lightning/src/ln/monitor_tests.rs
lightning/src/ln/msgs.rs
lightning/src/ln/offers_tests.rs
lightning/src/ln/onion_payment.rs
lightning/src/ln/payment_tests.rs
lightning/src/ln/peer_handler.rs
lightning/src/ln/reorg_tests.rs
lightning/src/ln/wire.rs
lightning/src/offers/invoice.rs
lightning/src/offers/invoice_error.rs
lightning/src/offers/invoice_request.rs
lightning/src/offers/merkle.rs
lightning/src/offers/offer.rs
lightning/src/offers/refund.rs
lightning/src/offers/test_utils.rs
lightning/src/onion_message/functional_tests.rs
lightning/src/onion_message/messenger.rs
lightning/src/onion_message/packet.rs
lightning/src/routing/router.rs
lightning/src/routing/scoring.rs
lightning/src/sign/mod.rs
lightning/src/util/indexed_map.rs
lightning/src/util/mod.rs
lightning/src/util/persist.rs
lightning/src/util/ser_macros.rs
lightning/src/util/sweep.rs [new file with mode: 0644]
lightning/src/util/test_utils.rs

index 0cfa2023ee2b9844a07601548584d7c687f43e04..fd514da657bfd8a12801eea16a9cf8fedd13aac6 100755 (executable)
@@ -98,6 +98,8 @@ def check_cfg_tag(cfg):
         pass
     elif cfg == "dual_funding":
         pass
+    elif cfg == "splicing":
+        pass
     else:
         print("Bad cfg tag: " + cfg)
         assert False
index 5cae6d45de5f56a778fc95b2e78b5c5fd9f5ffb6..0dc654d8bedca04688e5cf5e8afcd248b4cc3587 100755 (executable)
@@ -177,3 +177,5 @@ RUSTFLAGS="--cfg=taproot" cargo test --verbose --color always -p lightning
 RUSTFLAGS="--cfg=async_signing" cargo test --verbose --color always -p lightning
 [ "$CI_MINIMIZE_DISK_USAGE" != "" ] && cargo clean
 RUSTFLAGS="--cfg=dual_funding" cargo test --verbose --color always -p lightning
+[ "$CI_MINIMIZE_DISK_USAGE" != "" ] && cargo clean
+RUSTFLAGS="--cfg=splicing" cargo test --verbose --color always -p lightning
index 36e7cea8a2215e1414af4ef0f46a082401732c03..b3cf867d6e330d5e684c78f2406c815d0a284a04 100644 (file)
@@ -84,7 +84,7 @@ impl FeeEstimator for FuzzEstimator {
                // Background feerate which is <= the minimum Normal feerate.
                match conf_target {
                        ConfirmationTarget::OnChainSweep => MAX_FEE,
-                       ConfirmationTarget::ChannelCloseMinimum|ConfirmationTarget::AnchorChannelFee|ConfirmationTarget::MinAllowedAnchorChannelRemoteFee|ConfirmationTarget::MinAllowedNonAnchorChannelRemoteFee => 253,
+                       ConfirmationTarget::ChannelCloseMinimum|ConfirmationTarget::AnchorChannelFee|ConfirmationTarget::MinAllowedAnchorChannelRemoteFee|ConfirmationTarget::MinAllowedNonAnchorChannelRemoteFee|ConfirmationTarget::OutputSpendingFee => 253,
                        ConfirmationTarget::NonAnchorChannelFee => cmp::min(self.ret_val.load(atomic::Ordering::Acquire), MAX_FEE),
                }
        }
index f2bae246fabeb09f6a2ab827c078b6046c2cd8e5..91fcb9bf2d406484ba1e364cdfc994971781f00e 100644 (file)
@@ -6,7 +6,7 @@ use bitcoin::secp256k1::ecdh::SharedSecret;
 use bitcoin::secp256k1::ecdsa::RecoverableSignature;
 use bitcoin::secp256k1::schnorr;
 
-use lightning::blinded_path::BlindedPath;
+use lightning::blinded_path::{BlindedPath, EmptyNodeIdLookUp};
 use lightning::ln::features::InitFeatures;
 use lightning::ln::msgs::{self, DecodeError, OnionMessageHandler};
 use lightning::ln::script::ShutdownScript;
@@ -36,12 +36,13 @@ pub fn do_test<L: Logger>(data: &[u8], logger: &L) {
                        node_secret: secret,
                        counter: AtomicU64::new(0),
                };
+               let node_id_lookup = EmptyNodeIdLookUp {};
                let message_router = TestMessageRouter {};
                let offers_msg_handler = TestOffersMessageHandler {};
                let custom_msg_handler = TestCustomMessageHandler {};
                let onion_messenger = OnionMessenger::new(
-                       &keys_manager, &keys_manager, logger, &message_router, &offers_msg_handler,
-                       &custom_msg_handler
+                       &keys_manager, &keys_manager, logger, &node_id_lookup, &message_router,
+                       &offers_msg_handler, &custom_msg_handler
                );
 
                let peer_node_id = {
index ad4373c4793bc704f66d6cc095b83a5e82013cf7..afe028131c1de0a3a79eb30e975a9cc6e72097a9 100644 (file)
@@ -11,7 +11,7 @@ use bitcoin::blockdata::constants::ChainHash;
 use bitcoin::blockdata::script::Builder;
 use bitcoin::blockdata::transaction::TxOut;
 
-use lightning::blinded_path::{BlindedHop, BlindedPath};
+use lightning::blinded_path::{BlindedHop, BlindedPath, IntroductionNode};
 use lightning::chain::transaction::OutPoint;
 use lightning::ln::ChannelId;
 use lightning::ln::channelmanager::{self, ChannelDetails, ChannelCounterparty};
@@ -363,7 +363,7 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
                                                });
                                        }
                                        (payinfo, BlindedPath {
-                                               introduction_node_id: hop.src_node_id,
+                                               introduction_node: IntroductionNode::NodeId(hop.src_node_id),
                                                blinding_point: dummy_pk,
                                                blinded_hops,
                                        })
index 3736bd603e5bd65977bfc874919b5f9c61ee48bc..3fe4da00d4bcc85a07a14e56969d3ba76c03cf32 100644 (file)
@@ -919,14 +919,16 @@ impl Drop for BackgroundProcessor {
 
 #[cfg(all(feature = "std", test))]
 mod tests {
+       use bitcoin::{ScriptBuf, Txid};
        use bitcoin::blockdata::constants::{genesis_block, ChainHash};
        use bitcoin::blockdata::locktime::absolute::LockTime;
        use bitcoin::blockdata::transaction::{Transaction, TxOut};
+       use bitcoin::hashes::Hash;
        use bitcoin::network::constants::Network;
        use bitcoin::secp256k1::{SecretKey, PublicKey, Secp256k1};
-       use lightning::chain::{BestBlock, Confirm, chainmonitor};
+       use lightning::chain::{BestBlock, Confirm, chainmonitor, Filter};
        use lightning::chain::channelmonitor::ANTI_REORG_DELAY;
-       use lightning::sign::{InMemorySigner, KeysManager};
+       use lightning::sign::{InMemorySigner, KeysManager, ChangeDestinationSource};
        use lightning::chain::transaction::OutPoint;
        use lightning::events::{Event, PathFailure, MessageSendEventsProvider, MessageSendEvent};
        use lightning::{get_event_msg, get_event};
@@ -947,6 +949,7 @@ mod tests {
                CHANNEL_MANAGER_PERSISTENCE_PRIMARY_NAMESPACE, CHANNEL_MANAGER_PERSISTENCE_SECONDARY_NAMESPACE, CHANNEL_MANAGER_PERSISTENCE_KEY,
                NETWORK_GRAPH_PERSISTENCE_PRIMARY_NAMESPACE, NETWORK_GRAPH_PERSISTENCE_SECONDARY_NAMESPACE, NETWORK_GRAPH_PERSISTENCE_KEY,
                SCORER_PERSISTENCE_PRIMARY_NAMESPACE, SCORER_PERSISTENCE_SECONDARY_NAMESPACE, SCORER_PERSISTENCE_KEY};
+       use lightning::util::sweep::{OutputSweeper, OutputSpendStatus};
        use lightning_persister::fs_store::FilesystemStore;
        use std::collections::VecDeque;
        use std::{fs, env};
@@ -1009,6 +1012,9 @@ mod tests {
                logger: Arc<test_utils::TestLogger>,
                best_block: BestBlock,
                scorer: Arc<LockingWrapper<TestScorer>>,
+               sweeper: Arc<OutputSweeper<Arc<test_utils::TestBroadcaster>, Arc<TestWallet>,
+                       Arc<test_utils::TestFeeEstimator>, Arc<dyn Filter + Sync + Send>, Arc<FilesystemStore>,
+                       Arc<test_utils::TestLogger>, Arc<KeysManager>>>,
        }
 
        impl Node {
@@ -1247,6 +1253,14 @@ mod tests {
                }
        }
 
+       struct TestWallet {}
+
+       impl ChangeDestinationSource for TestWallet {
+               fn get_change_destination_script(&self) -> Result<ScriptBuf, ()> {
+                       Ok(ScriptBuf::new())
+               }
+       }
+
        fn get_full_filepath(filepath: String, filename: String) -> String {
                let mut path = PathBuf::from(filepath);
                path.push(filename);
@@ -1271,10 +1285,15 @@ mod tests {
                        let router = Arc::new(DefaultRouter::new(network_graph.clone(), logger.clone(), Arc::clone(&keys_manager), scorer.clone(), Default::default()));
                        let chain_source = Arc::new(test_utils::TestChainSource::new(Network::Bitcoin));
                        let kv_store = Arc::new(FilesystemStore::new(format!("{}_persister_{}", &persist_dir, i).into()));
+                       let now = Duration::from_secs(genesis_block.header.time as u64);
+                       let keys_manager = Arc::new(KeysManager::new(&seed, now.as_secs(), now.subsec_nanos()));
                        let chain_monitor = Arc::new(chainmonitor::ChainMonitor::new(Some(chain_source.clone()), tx_broadcaster.clone(), logger.clone(), fee_estimator.clone(), kv_store.clone()));
                        let best_block = BestBlock::from_network(network);
                        let params = ChainParameters { network, best_block };
                        let manager = Arc::new(ChannelManager::new(fee_estimator.clone(), chain_monitor.clone(), tx_broadcaster.clone(), router.clone(), logger.clone(), keys_manager.clone(), keys_manager.clone(), keys_manager.clone(), UserConfig::default(), params, genesis_block.header.time));
+                       let wallet = Arc::new(TestWallet {});
+                       let sweeper = Arc::new(OutputSweeper::new(best_block, Arc::clone(&tx_broadcaster), Arc::clone(&fee_estimator),
+                               None::<Arc<dyn Filter + Sync + Send>>, Arc::clone(&keys_manager), wallet, Arc::clone(&kv_store), Arc::clone(&logger)));
                        let p2p_gossip_sync = Arc::new(P2PGossipSync::new(network_graph.clone(), Some(chain_source.clone()), logger.clone()));
                        let rapid_gossip_sync = Arc::new(RapidGossipSync::new(network_graph.clone(), logger.clone()));
                        let msg_handler = MessageHandler {
@@ -1283,7 +1302,7 @@ mod tests {
                                onion_message_handler: IgnoringMessageHandler{}, custom_message_handler: IgnoringMessageHandler{}
                        };
                        let peer_manager = Arc::new(PeerManager::new(msg_handler, 0, &seed, logger.clone(), keys_manager.clone()));
-                       let node = Node { node: manager, p2p_gossip_sync, rapid_gossip_sync, peer_manager, chain_monitor, kv_store, tx_broadcaster, network_graph, logger, best_block, scorer };
+                       let node = Node { node: manager, p2p_gossip_sync, rapid_gossip_sync, peer_manager, chain_monitor, kv_store, tx_broadcaster, network_graph, logger, best_block, scorer, sweeper };
                        nodes.push(node);
                }
 
@@ -1352,15 +1371,40 @@ mod tests {
                                1 => {
                                        node.node.transactions_confirmed(&header, &txdata, height);
                                        node.chain_monitor.transactions_confirmed(&header, &txdata, height);
+                                       node.sweeper.transactions_confirmed(&header, &txdata, height);
                                },
                                x if x == depth => {
+                                       // We need the TestBroadcaster to know about the new height so that it doesn't think
+                                       // we're violating the time lock requirements of transactions broadcasted at that
+                                       // point.
+                                       node.tx_broadcaster.blocks.lock().unwrap().push((genesis_block(Network::Bitcoin), height));
                                        node.node.best_block_updated(&header, height);
                                        node.chain_monitor.best_block_updated(&header, height);
+                                       node.sweeper.best_block_updated(&header, height);
                                },
                                _ => {},
                        }
                }
        }
+
+       fn advance_chain(node: &mut Node, num_blocks: u32) {
+               for i in 1..=num_blocks {
+                       let prev_blockhash = node.best_block.block_hash;
+                       let height = node.best_block.height + 1;
+                       let header = create_dummy_header(prev_blockhash, height);
+                       node.best_block = BestBlock::new(header.block_hash(), height);
+                       if i == num_blocks {
+                               // We need the TestBroadcaster to know about the new height so that it doesn't think
+                               // we're violating the time lock requirements of transactions broadcasted at that
+                               // point.
+                               node.tx_broadcaster.blocks.lock().unwrap().push((genesis_block(Network::Bitcoin), height));
+                               node.node.best_block_updated(&header, height);
+                               node.chain_monitor.best_block_updated(&header, height);
+                               node.sweeper.best_block_updated(&header, height);
+                       }
+               }
+       }
+
        fn confirm_transaction(node: &mut Node, tx: &Transaction) {
                confirm_transaction_depth(node, tx, ANTI_REORG_DELAY);
        }
@@ -1592,6 +1636,9 @@ mod tests {
                let _as_channel_update = get_event_msg!(nodes[0], MessageSendEvent::SendChannelUpdate, nodes[1].node.get_our_node_id());
                nodes[1].node.handle_channel_ready(&nodes[0].node.get_our_node_id(), &as_funding);
                let _bs_channel_update = get_event_msg!(nodes[1], MessageSendEvent::SendChannelUpdate, nodes[0].node.get_our_node_id());
+               let broadcast_funding = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap().pop().unwrap();
+               assert_eq!(broadcast_funding.txid(), funding_tx.txid());
+               assert!(nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap().is_empty());
 
                if !std::thread::panicking() {
                        bg_processor.stop().unwrap();
@@ -1617,10 +1664,95 @@ mod tests {
                        .recv_timeout(Duration::from_secs(EVENT_DEADLINE))
                        .expect("Events not handled within deadline");
                match event {
-                       Event::SpendableOutputs { .. } => {},
+                       Event::SpendableOutputs { outputs, channel_id } => {
+                               nodes[0].sweeper.track_spendable_outputs(outputs, channel_id, false, Some(153));
+                       },
                        _ => panic!("Unexpected event: {:?}", event),
                }
 
+               // Check we don't generate an initial sweeping tx until we reach the required height.
+               assert_eq!(nodes[0].sweeper.tracked_spendable_outputs().len(), 1);
+               let tracked_output = nodes[0].sweeper.tracked_spendable_outputs().first().unwrap().clone();
+               if let Some(sweep_tx_0) = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap().pop() {
+                       assert!(!tracked_output.is_spent_in(&sweep_tx_0));
+                       match tracked_output.status {
+                               OutputSpendStatus::PendingInitialBroadcast { delayed_until_height } => {
+                                       assert_eq!(delayed_until_height, Some(153));
+                               }
+                               _ => panic!("Unexpected status"),
+                       }
+               }
+
+               advance_chain(&mut nodes[0], 3);
+
+               // Check we generate an initial sweeping tx.
+               assert_eq!(nodes[0].sweeper.tracked_spendable_outputs().len(), 1);
+               let tracked_output = nodes[0].sweeper.tracked_spendable_outputs().first().unwrap().clone();
+               let sweep_tx_0 = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap().pop().unwrap();
+               match tracked_output.status {
+                       OutputSpendStatus::PendingFirstConfirmation { latest_spending_tx, .. } => {
+                               assert_eq!(sweep_tx_0.txid(), latest_spending_tx.txid());
+                       }
+                       _ => panic!("Unexpected status"),
+               }
+
+               // Check we regenerate and rebroadcast the sweeping tx each block.
+               advance_chain(&mut nodes[0], 1);
+               assert_eq!(nodes[0].sweeper.tracked_spendable_outputs().len(), 1);
+               let tracked_output = nodes[0].sweeper.tracked_spendable_outputs().first().unwrap().clone();
+               let sweep_tx_1 = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap().pop().unwrap();
+               match tracked_output.status {
+                       OutputSpendStatus::PendingFirstConfirmation { latest_spending_tx, .. } => {
+                               assert_eq!(sweep_tx_1.txid(), latest_spending_tx.txid());
+                       }
+                       _ => panic!("Unexpected status"),
+               }
+               assert_ne!(sweep_tx_0, sweep_tx_1);
+
+               advance_chain(&mut nodes[0], 1);
+               assert_eq!(nodes[0].sweeper.tracked_spendable_outputs().len(), 1);
+               let tracked_output = nodes[0].sweeper.tracked_spendable_outputs().first().unwrap().clone();
+               let sweep_tx_2 = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap().pop().unwrap();
+               match tracked_output.status {
+                       OutputSpendStatus::PendingFirstConfirmation { latest_spending_tx, .. } => {
+                               assert_eq!(sweep_tx_2.txid(), latest_spending_tx.txid());
+                       }
+                       _ => panic!("Unexpected status"),
+               }
+               assert_ne!(sweep_tx_0, sweep_tx_2);
+               assert_ne!(sweep_tx_1, sweep_tx_2);
+
+               // Check we still track the spendable outputs up to ANTI_REORG_DELAY confirmations.
+               confirm_transaction_depth(&mut nodes[0], &sweep_tx_2, 5);
+               assert_eq!(nodes[0].sweeper.tracked_spendable_outputs().len(), 1);
+               let tracked_output = nodes[0].sweeper.tracked_spendable_outputs().first().unwrap().clone();
+               match tracked_output.status {
+                       OutputSpendStatus::PendingThresholdConfirmations { latest_spending_tx, .. } => {
+                               assert_eq!(sweep_tx_2.txid(), latest_spending_tx.txid());
+                       }
+                       _ => panic!("Unexpected status"),
+               }
+
+               // Check we still see the transaction as confirmed if we unconfirm any untracked
+               // transaction. (We previously had a bug that would mark tracked transactions as
+               // unconfirmed if any transaction at an unknown block height would be unconfirmed.)
+               let unconf_txid = Txid::from_slice(&[0; 32]).unwrap();
+               nodes[0].sweeper.transaction_unconfirmed(&unconf_txid);
+
+               assert_eq!(nodes[0].sweeper.tracked_spendable_outputs().len(), 1);
+               let tracked_output = nodes[0].sweeper.tracked_spendable_outputs().first().unwrap().clone();
+               match tracked_output.status {
+                       OutputSpendStatus::PendingThresholdConfirmations { latest_spending_tx, .. } => {
+                               assert_eq!(sweep_tx_2.txid(), latest_spending_tx.txid());
+                       }
+                       _ => panic!("Unexpected status"),
+               }
+
+               // Check we stop tracking the spendable outputs when one of the txs reaches
+               // ANTI_REORG_DELAY confirmations.
+               confirm_transaction_depth(&mut nodes[0], &sweep_tx_0, ANTI_REORG_DELAY);
+               assert_eq!(nodes[0].sweeper.tracked_spendable_outputs().len(), 0);
+
                if !std::thread::panicking() {
                        bg_processor.stop().unwrap();
                }
index 71d63ecadcfa96e3b7fbd4a2c7c1076af2470e25..6d001ca67fd5e1ec44d0307a198509d83cf2538a 100644 (file)
@@ -624,8 +624,11 @@ mod tests {
                fn handle_open_channel_v2(&self, _their_node_id: &PublicKey, _msg: &OpenChannelV2) {}
                fn handle_accept_channel_v2(&self, _their_node_id: &PublicKey, _msg: &AcceptChannelV2) {}
                fn handle_stfu(&self, _their_node_id: &PublicKey, _msg: &Stfu) {}
+               #[cfg(splicing)]
                fn handle_splice(&self, _their_node_id: &PublicKey, _msg: &Splice) {}
+               #[cfg(splicing)]
                fn handle_splice_ack(&self, _their_node_id: &PublicKey, _msg: &SpliceAck) {}
+               #[cfg(splicing)]
                fn handle_splice_locked(&self, _their_node_id: &PublicKey, _msg: &SpliceLocked) {}
                fn handle_tx_add_input(&self, _their_node_id: &PublicKey, _msg: &TxAddInput) {}
                fn handle_tx_add_output(&self, _their_node_id: &PublicKey, _msg: &TxAddOutput) {}
index bdcbd7726f71189175407bc3dd125aad2c1db199..df7f8e7ad6128ee90b1d6048c3cdef7a5875b28c 100644 (file)
@@ -3,7 +3,7 @@ use bitcoin::secp256k1::{self, PublicKey, Secp256k1, SecretKey};
 #[allow(unused_imports)]
 use crate::prelude::*;
 
-use crate::blinded_path::{BlindedHop, BlindedPath};
+use crate::blinded_path::{BlindedHop, BlindedPath, IntroductionNode, NodeIdLookUp};
 use crate::blinded_path::utils;
 use crate::io;
 use crate::io::Cursor;
@@ -19,8 +19,8 @@ use core::ops::Deref;
 /// TLVs to encode in an intermediate onion message packet's hop data. When provided in a blinded
 /// route, they are encoded into [`BlindedHop::encrypted_payload`].
 pub(crate) struct ForwardTlvs {
-       /// The node id of the next hop in the onion message's path.
-       pub(crate) next_node_id: PublicKey,
+       /// The next hop in the onion message's path.
+       pub(crate) next_hop: NextHop,
        /// Senders to a blinded path use this value to concatenate the route they find to the
        /// introduction node with the blinded path.
        pub(crate) next_blinding_override: Option<PublicKey>,
@@ -34,11 +34,25 @@ pub(crate) struct ReceiveTlvs {
        pub(crate) path_id: Option<[u8; 32]>,
 }
 
+/// The next hop to forward the onion message along its path.
+#[derive(Debug)]
+pub enum NextHop {
+       /// The node id of the next hop.
+       NodeId(PublicKey),
+       /// The short channel id leading to the next hop.
+       ShortChannelId(u64),
+}
+
 impl Writeable for ForwardTlvs {
        fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
+               let (next_node_id, short_channel_id) = match self.next_hop {
+                       NextHop::NodeId(pubkey) => (Some(pubkey), None),
+                       NextHop::ShortChannelId(scid) => (None, Some(scid)),
+               };
                // TODO: write padding
                encode_tlv_stream!(writer, {
-                       (4, self.next_node_id, required),
+                       (2, short_channel_id, option),
+                       (4, next_node_id, option),
                        (8, self.next_blinding_override, option)
                });
                Ok(())
@@ -61,9 +75,8 @@ pub(super) fn blinded_hops<T: secp256k1::Signing + secp256k1::Verification>(
 ) -> Result<Vec<BlindedHop>, secp256k1::Error> {
        let blinded_tlvs = unblinded_path.iter()
                .skip(1) // The first node's TLVs contains the next node's pubkey
-               .map(|pk| {
-                       ControlTlvs::Forward(ForwardTlvs { next_node_id: *pk, next_blinding_override: None })
-               })
+               .map(|pk| ForwardTlvs { next_hop: NextHop::NodeId(*pk), next_blinding_override: None })
+               .map(|tlvs| ControlTlvs::Forward(tlvs))
                .chain(core::iter::once(ControlTlvs::Receive(ReceiveTlvs { path_id: None })));
 
        utils::construct_blinded_hops(secp_ctx, unblinded_path.iter(), blinded_tlvs, session_priv)
@@ -71,18 +84,30 @@ pub(super) fn blinded_hops<T: secp256k1::Signing + secp256k1::Verification>(
 
 // Advance the blinded onion message path by one hop, so make the second hop into the new
 // introduction node.
-pub(crate) fn advance_path_by_one<NS: Deref, T: secp256k1::Signing + secp256k1::Verification>(
-       path: &mut BlindedPath, node_signer: &NS, secp_ctx: &Secp256k1<T>
-) -> Result<(), ()> where NS::Target: NodeSigner {
+pub(crate) fn advance_path_by_one<NS: Deref, NL: Deref, T>(
+       path: &mut BlindedPath, node_signer: &NS, node_id_lookup: &NL, secp_ctx: &Secp256k1<T>
+) -> Result<(), ()>
+where
+       NS::Target: NodeSigner,
+       NL::Target: NodeIdLookUp,
+       T: secp256k1::Signing + secp256k1::Verification,
+{
        let control_tlvs_ss = node_signer.ecdh(Recipient::Node, &path.blinding_point, None)?;
        let rho = onion_utils::gen_rho_from_shared_secret(&control_tlvs_ss.secret_bytes());
        let encrypted_control_tlvs = path.blinded_hops.remove(0).encrypted_payload;
        let mut s = Cursor::new(&encrypted_control_tlvs);
        let mut reader = FixedLengthReader::new(&mut s, encrypted_control_tlvs.len() as u64);
        match ChaChaPolyReadAdapter::read(&mut reader, rho) {
-               Ok(ChaChaPolyReadAdapter { readable: ControlTlvs::Forward(ForwardTlvs {
-                       mut next_node_id, next_blinding_override,
-               })}) => {
+               Ok(ChaChaPolyReadAdapter {
+                       readable: ControlTlvs::Forward(ForwardTlvs { next_hop, next_blinding_override })
+               }) => {
+                       let next_node_id = match next_hop {
+                               NextHop::NodeId(pubkey) => pubkey,
+                               NextHop::ShortChannelId(scid) => match node_id_lookup.next_node_id(scid) {
+                                       Some(pubkey) => pubkey,
+                                       None => return Err(()),
+                               },
+                       };
                        let mut new_blinding_point = match next_blinding_override {
                                Some(blinding_point) => blinding_point,
                                None => {
@@ -91,7 +116,7 @@ pub(crate) fn advance_path_by_one<NS: Deref, T: secp256k1::Signing + secp256k1::
                                }
                        };
                        mem::swap(&mut path.blinding_point, &mut new_blinding_point);
-                       mem::swap(&mut path.introduction_node_id, &mut next_node_id);
+                       path.introduction_node = IntroductionNode::NodeId(next_node_id);
                        Ok(())
                },
                _ => Err(())
index e70f310f5e1d8b00875ff58e069f432bd69526cf..07fa7b770249cae94e1ea18269b63805dc933f93 100644 (file)
@@ -17,6 +17,7 @@ use bitcoin::secp256k1::{self, PublicKey, Secp256k1, SecretKey};
 
 use crate::ln::msgs::DecodeError;
 use crate::offers::invoice::BlindedPayInfo;
+use crate::routing::gossip::{NodeId, ReadOnlyNetworkGraph};
 use crate::sign::EntropySource;
 use crate::util::ser::{Readable, Writeable, Writer};
 
@@ -28,11 +29,11 @@ use crate::prelude::*;
 #[derive(Clone, Debug, Hash, PartialEq, Eq)]
 pub struct BlindedPath {
        /// To send to a blinded path, the sender first finds a route to the unblinded
-       /// `introduction_node_id`, which can unblind its [`encrypted_payload`] to find out the onion
+       /// `introduction_node`, which can unblind its [`encrypted_payload`] to find out the onion
        /// message or payment's next hop and forward it along.
        ///
        /// [`encrypted_payload`]: BlindedHop::encrypted_payload
-       pub introduction_node_id: PublicKey,
+       pub introduction_node: IntroductionNode,
        /// Used by the introduction node to decrypt its [`encrypted_payload`] to forward the onion
        /// message or payment.
        ///
@@ -42,6 +43,52 @@ pub struct BlindedPath {
        pub blinded_hops: Vec<BlindedHop>,
 }
 
+/// The unblinded node in a [`BlindedPath`].
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+pub enum IntroductionNode {
+       /// The node id of the introduction node.
+       NodeId(PublicKey),
+       /// The short channel id of the channel leading to the introduction node. The [`Direction`]
+       /// identifies which side of the channel is the introduction node.
+       DirectedShortChannelId(Direction, u64),
+}
+
+/// The side of a channel that is the [`IntroductionNode`] in a [`BlindedPath`]. [BOLT 7] defines
+/// which nodes is which in the [`ChannelAnnouncement`] message.
+///
+/// [BOLT 7]: https://github.com/lightning/bolts/blob/master/07-routing-gossip.md#the-channel_announcement-message
+/// [`ChannelAnnouncement`]: crate::ln::msgs::ChannelAnnouncement
+#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
+pub enum Direction {
+       /// The lesser node id when compared lexicographically in ascending order.
+       NodeOne,
+       /// The greater node id when compared lexicographically in ascending order.
+       NodeTwo,
+}
+
+/// An interface for looking up the node id of a channel counterparty for the purpose of forwarding
+/// an [`OnionMessage`].
+///
+/// [`OnionMessage`]: crate::ln::msgs::OnionMessage
+pub trait NodeIdLookUp {
+       /// Returns the node id of the forwarding node's channel counterparty with `short_channel_id`.
+       ///
+       /// Here, the forwarding node is referring to the node of the [`OnionMessenger`] parameterized
+       /// by the [`NodeIdLookUp`] and the counterparty to one of that node's peers.
+       ///
+       /// [`OnionMessenger`]: crate::onion_message::messenger::OnionMessenger
+       fn next_node_id(&self, short_channel_id: u64) -> Option<PublicKey>;
+}
+
+/// A [`NodeIdLookUp`] that always returns `None`.
+pub struct EmptyNodeIdLookUp {}
+
+impl NodeIdLookUp for EmptyNodeIdLookUp {
+       fn next_node_id(&self, _short_channel_id: u64) -> Option<PublicKey> {
+               None
+       }
+}
+
 /// An encrypted payload and node id corresponding to a hop in a payment or onion message path, to
 /// be encoded in the sender's onion packet. These hops cannot be identified by outside observers
 /// and thus can be used to hide the identity of the recipient.
@@ -74,10 +121,10 @@ impl BlindedPath {
                if node_pks.is_empty() { return Err(()) }
                let blinding_secret_bytes = entropy_source.get_secure_random_bytes();
                let blinding_secret = SecretKey::from_slice(&blinding_secret_bytes[..]).expect("RNG is busted");
-               let introduction_node_id = node_pks[0];
+               let introduction_node = IntroductionNode::NodeId(node_pks[0]);
 
                Ok(BlindedPath {
-                       introduction_node_id,
+                       introduction_node,
                        blinding_point: PublicKey::from_secret_key(secp_ctx, &blinding_secret),
                        blinded_hops: message::blinded_hops(secp_ctx, node_pks, &blinding_secret).map_err(|_| ())?,
                })
@@ -111,6 +158,9 @@ impl BlindedPath {
                payee_tlvs: payment::ReceiveTlvs, htlc_maximum_msat: u64, min_final_cltv_expiry_delta: u16,
                entropy_source: &ES, secp_ctx: &Secp256k1<T>
        ) -> Result<(BlindedPayInfo, Self), ()> {
+               let introduction_node = IntroductionNode::NodeId(
+                       intermediate_nodes.first().map_or(payee_node_id, |n| n.node_id)
+               );
                let blinding_secret_bytes = entropy_source.get_secure_random_bytes();
                let blinding_secret = SecretKey::from_slice(&blinding_secret_bytes[..]).expect("RNG is busted");
 
@@ -118,18 +168,49 @@ impl BlindedPath {
                        intermediate_nodes, &payee_tlvs, htlc_maximum_msat, min_final_cltv_expiry_delta
                )?;
                Ok((blinded_payinfo, BlindedPath {
-                       introduction_node_id: intermediate_nodes.first().map_or(payee_node_id, |n| n.node_id),
+                       introduction_node,
                        blinding_point: PublicKey::from_secret_key(secp_ctx, &blinding_secret),
                        blinded_hops: payment::blinded_hops(
                                secp_ctx, intermediate_nodes, payee_node_id, payee_tlvs, &blinding_secret
                        ).map_err(|_| ())?,
                }))
        }
+
+       /// Returns the introduction [`NodeId`] of the blinded path, if it is publicly reachable (i.e.,
+       /// it is found in the network graph).
+       pub fn public_introduction_node_id<'a>(
+               &self, network_graph: &'a ReadOnlyNetworkGraph
+       ) -> Option<&'a NodeId> {
+               match &self.introduction_node {
+                       IntroductionNode::NodeId(pubkey) => {
+                               let node_id = NodeId::from_pubkey(pubkey);
+                               network_graph.nodes().get_key_value(&node_id).map(|(key, _)| key)
+                       },
+                       IntroductionNode::DirectedShortChannelId(direction, scid) => {
+                               network_graph
+                                       .channel(*scid)
+                                       .map(|c| match direction {
+                                               Direction::NodeOne => &c.node_one,
+                                               Direction::NodeTwo => &c.node_two,
+                                       })
+                       },
+               }
+       }
 }
 
 impl Writeable for BlindedPath {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-               self.introduction_node_id.write(w)?;
+               match &self.introduction_node {
+                       IntroductionNode::NodeId(pubkey) => pubkey.write(w)?,
+                       IntroductionNode::DirectedShortChannelId(direction, scid) => {
+                               match direction {
+                                       Direction::NodeOne => 0u8.write(w)?,
+                                       Direction::NodeTwo => 1u8.write(w)?,
+                               }
+                               scid.write(w)?;
+                       },
+               }
+
                self.blinding_point.write(w)?;
                (self.blinded_hops.len() as u8).write(w)?;
                for hop in &self.blinded_hops {
@@ -141,7 +222,17 @@ impl Writeable for BlindedPath {
 
 impl Readable for BlindedPath {
        fn read<R: io::Read>(r: &mut R) -> Result<Self, DecodeError> {
-               let introduction_node_id = Readable::read(r)?;
+               let mut first_byte: u8 = Readable::read(r)?;
+               let introduction_node = match first_byte {
+                       0 => IntroductionNode::DirectedShortChannelId(Direction::NodeOne, Readable::read(r)?),
+                       1 => IntroductionNode::DirectedShortChannelId(Direction::NodeTwo, Readable::read(r)?),
+                       2|3 => {
+                               use io::Read;
+                               let mut pubkey_read = core::slice::from_mut(&mut first_byte).chain(r.by_ref());
+                               IntroductionNode::NodeId(Readable::read(&mut pubkey_read)?)
+                       },
+                       _ => return Err(DecodeError::InvalidValue),
+               };
                let blinding_point = Readable::read(r)?;
                let num_hops: u8 = Readable::read(r)?;
                if num_hops == 0 { return Err(DecodeError::InvalidValue) }
@@ -150,7 +241,7 @@ impl Readable for BlindedPath {
                        blinded_hops.push(Readable::read(r)?);
                }
                Ok(BlindedPath {
-                       introduction_node_id,
+                       introduction_node,
                        blinding_point,
                        blinded_hops,
                })
@@ -162,3 +253,25 @@ impl_writeable!(BlindedHop, {
        encrypted_payload
 });
 
+impl Direction {
+       /// Returns the [`NodeId`] from the inputs corresponding to the direction.
+       pub fn select_node_id<'a>(&self, node_a: &'a NodeId, node_b: &'a NodeId) -> &'a NodeId {
+               match self {
+                       Direction::NodeOne => core::cmp::min(node_a, node_b),
+                       Direction::NodeTwo => core::cmp::max(node_a, node_b),
+               }
+       }
+
+       /// Returns the [`PublicKey`] from the inputs corresponding to the direction.
+       pub fn select_pubkey<'a>(&self, node_a: &'a PublicKey, node_b: &'a PublicKey) -> &'a PublicKey {
+               let (node_one, node_two) = if NodeId::from_pubkey(node_a) < NodeId::from_pubkey(node_b) {
+                       (node_a, node_b)
+               } else {
+                       (node_b, node_a)
+               };
+               match self {
+                       Direction::NodeOne => node_one,
+                       Direction::NodeTwo => node_two,
+               }
+       }
+}
index 3d56020a626d68cd1e233f53a905cfdfa8acf21c..ec441c18c986ed4788948dcf09465e07233878d8 100644 (file)
@@ -12,6 +12,8 @@ use crate::ln::channelmanager::CounterpartyForwardingInfo;
 use crate::ln::features::BlindedHopFeatures;
 use crate::ln::msgs::DecodeError;
 use crate::offers::invoice::BlindedPayInfo;
+use crate::offers::invoice_request::InvoiceRequestFields;
+use crate::offers::offer::OfferId;
 use crate::util::ser::{HighZeroBytesDroppedBigSize, Readable, Writeable, Writer};
 
 #[allow(unused_imports)]
@@ -53,6 +55,8 @@ pub struct ReceiveTlvs {
        pub payment_secret: PaymentSecret,
        /// Constraints for the receiver of this payment.
        pub payment_constraints: PaymentConstraints,
+       /// Context for the receiver of this payment.
+       pub payment_context: PaymentContext,
 }
 
 /// Data to construct a [`BlindedHop`] for sending a payment over.
@@ -97,6 +101,66 @@ pub struct PaymentConstraints {
        pub htlc_minimum_msat: u64,
 }
 
+/// The context of an inbound payment, which is included in a [`BlindedPath`] via [`ReceiveTlvs`]
+/// and surfaced in [`PaymentPurpose`].
+///
+/// [`BlindedPath`]: crate::blinded_path::BlindedPath
+/// [`PaymentPurpose`]: crate::events::PaymentPurpose
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub enum PaymentContext {
+       /// The payment context was unknown.
+       Unknown(UnknownPaymentContext),
+
+       /// The payment was made for an invoice requested from a BOLT 12 [`Offer`].
+       ///
+       /// [`Offer`]: crate::offers::offer::Offer
+       Bolt12Offer(Bolt12OfferContext),
+
+       /// The payment was made for an invoice sent for a BOLT 12 [`Refund`].
+       ///
+       /// [`Refund`]: crate::offers::refund::Refund
+       Bolt12Refund(Bolt12RefundContext),
+}
+
+// Used when writing PaymentContext in Event::PaymentClaimable to avoid cloning.
+pub(crate) enum PaymentContextRef<'a> {
+       Bolt12Offer(&'a Bolt12OfferContext),
+       Bolt12Refund(&'a Bolt12RefundContext),
+}
+
+/// An unknown payment context.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct UnknownPaymentContext(());
+
+/// The context of a payment made for an invoice requested from a BOLT 12 [`Offer`].
+///
+/// [`Offer`]: crate::offers::offer::Offer
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct Bolt12OfferContext {
+       /// The identifier of the [`Offer`].
+       ///
+       /// [`Offer`]: crate::offers::offer::Offer
+       pub offer_id: OfferId,
+
+       /// Fields from an [`InvoiceRequest`] sent for a [`Bolt12Invoice`].
+       ///
+       /// [`InvoiceRequest`]: crate::offers::invoice_request::InvoiceRequest
+       /// [`Bolt12Invoice`]: crate::offers::invoice::Bolt12Invoice
+       pub invoice_request: InvoiceRequestFields,
+}
+
+/// The context of a payment made for an invoice sent for a BOLT 12 [`Refund`].
+///
+/// [`Refund`]: crate::offers::refund::Refund
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct Bolt12RefundContext {}
+
+impl PaymentContext {
+       pub(crate) fn unknown() -> Self {
+               PaymentContext::Unknown(UnknownPaymentContext(()))
+       }
+}
+
 impl TryFrom<CounterpartyForwardingInfo> for PaymentRelay {
        type Error = ();
 
@@ -137,7 +201,8 @@ impl Writeable for ReceiveTlvs {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
                encode_tlv_stream!(w, {
                        (12, self.payment_constraints, required),
-                       (65536, self.payment_secret, required)
+                       (65536, self.payment_secret, required),
+                       (65537, self.payment_context, required)
                });
                Ok(())
        }
@@ -163,11 +228,14 @@ impl Readable for BlindedPaymentTlvs {
                        (12, payment_constraints, required),
                        (14, features, option),
                        (65536, payment_secret, option),
+                       (65537, payment_context, (default_value, PaymentContext::unknown())),
                });
                let _padding: Option<utils::Padding> = _padding;
 
                if let Some(short_channel_id) = scid {
-                       if payment_secret.is_some() { return Err(DecodeError::InvalidValue) }
+                       if payment_secret.is_some() {
+                               return Err(DecodeError::InvalidValue)
+                       }
                        Ok(BlindedPaymentTlvs::Forward(ForwardTlvs {
                                short_channel_id,
                                payment_relay: payment_relay.ok_or(DecodeError::InvalidValue)?,
@@ -179,6 +247,7 @@ impl Readable for BlindedPaymentTlvs {
                        Ok(BlindedPaymentTlvs::Receive(ReceiveTlvs {
                                payment_secret: payment_secret.ok_or(DecodeError::InvalidValue)?,
                                payment_constraints: payment_constraints.0.unwrap(),
+                               payment_context: payment_context.0.unwrap(),
                        }))
                }
        }
@@ -309,10 +378,53 @@ impl Readable for PaymentConstraints {
        }
 }
 
+impl_writeable_tlv_based_enum!(PaymentContext,
+       ;
+       (0, Unknown),
+       (1, Bolt12Offer),
+       (2, Bolt12Refund),
+);
+
+impl<'a> Writeable for PaymentContextRef<'a> {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               match self {
+                       PaymentContextRef::Bolt12Offer(context) => {
+                               1u8.write(w)?;
+                               context.write(w)?;
+                       },
+                       PaymentContextRef::Bolt12Refund(context) => {
+                               2u8.write(w)?;
+                               context.write(w)?;
+                       },
+               }
+
+               Ok(())
+       }
+}
+
+impl Writeable for UnknownPaymentContext {
+       fn write<W: Writer>(&self, _w: &mut W) -> Result<(), io::Error> {
+               Ok(())
+       }
+}
+
+impl Readable for UnknownPaymentContext {
+       fn read<R: io::Read>(_r: &mut R) -> Result<Self, DecodeError> {
+               Ok(UnknownPaymentContext(()))
+       }
+}
+
+impl_writeable_tlv_based!(Bolt12OfferContext, {
+       (0, offer_id, required),
+       (2, invoice_request, required),
+});
+
+impl_writeable_tlv_based!(Bolt12RefundContext, {});
+
 #[cfg(test)]
 mod tests {
        use bitcoin::secp256k1::PublicKey;
-       use crate::blinded_path::payment::{ForwardNode, ForwardTlvs, ReceiveTlvs, PaymentConstraints, PaymentRelay};
+       use crate::blinded_path::payment::{ForwardNode, ForwardTlvs, ReceiveTlvs, PaymentConstraints, PaymentContext, PaymentRelay};
        use crate::ln::PaymentSecret;
        use crate::ln::features::BlindedHopFeatures;
        use crate::ln::functional_test_utils::TEST_FINAL_CLTV;
@@ -361,6 +473,7 @@ mod tests {
                                max_cltv_expiry: 0,
                                htlc_minimum_msat: 1,
                        },
+                       payment_context: PaymentContext::unknown(),
                };
                let htlc_maximum_msat = 100_000;
                let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat, 12).unwrap();
@@ -379,6 +492,7 @@ mod tests {
                                max_cltv_expiry: 0,
                                htlc_minimum_msat: 1,
                        },
+                       payment_context: PaymentContext::unknown(),
                };
                let blinded_payinfo = super::compute_payinfo(&[], &recv_tlvs, 4242, TEST_FINAL_CLTV as u16).unwrap();
                assert_eq!(blinded_payinfo.fee_base_msat, 0);
@@ -432,6 +546,7 @@ mod tests {
                                max_cltv_expiry: 0,
                                htlc_minimum_msat: 3,
                        },
+                       payment_context: PaymentContext::unknown(),
                };
                let htlc_maximum_msat = 100_000;
                let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat, TEST_FINAL_CLTV as u16).unwrap();
@@ -482,6 +597,7 @@ mod tests {
                                max_cltv_expiry: 0,
                                htlc_minimum_msat: 1,
                        },
+                       payment_context: PaymentContext::unknown(),
                };
                let htlc_minimum_msat = 3798;
                assert!(super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_minimum_msat - 1, TEST_FINAL_CLTV as u16).is_err());
@@ -536,6 +652,7 @@ mod tests {
                                max_cltv_expiry: 0,
                                htlc_minimum_msat: 1,
                        },
+                       payment_context: PaymentContext::unknown(),
                };
 
                let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, 10_000, TEST_FINAL_CLTV as u16).unwrap();
index 2bf6d6130e1123eef41a77fd3e5ffd1e5e3b72be..2e37127e038548ebe86888e185f52af1b9264551 100644 (file)
@@ -124,6 +124,17 @@ pub enum ConfirmationTarget {
        ///
        /// [`ChannelManager::close_channel_with_feerate_and_script`]: crate::ln::channelmanager::ChannelManager::close_channel_with_feerate_and_script
        ChannelCloseMinimum,
+       /// The feerate [`OutputSweeper`] will use on transactions spending
+       /// [`SpendableOutputDescriptor`]s after a channel closure.
+       ///
+       /// Generally spending these outputs is safe as long as they eventually confirm, so a value
+       /// (slightly above) the mempool minimum should suffice. However, as this value will influence
+       /// how long funds will be unavailable after channel closure, [`FeeEstimator`] implementors
+       /// might want to choose a higher feerate to regain control over funds faster.
+       ///
+       /// [`OutputSweeper`]: crate::util::sweep::OutputSweeper
+       /// [`SpendableOutputDescriptor`]: crate::sign::SpendableOutputDescriptor
+       OutputSpendingFee,
 }
 
 /// A trait which should be implemented to provide feerate information on a number of time
index e22ccca986aacf622778c7c4c3c5320b3815e341..1fb30a9aeb5f9375b4e1ecdac2950aea62a9ab37 100644 (file)
@@ -20,6 +20,7 @@ use crate::chain::channelmonitor::{ChannelMonitor, ChannelMonitorUpdate, Monitor
 use crate::ln::ChannelId;
 use crate::sign::ecdsa::WriteableEcdsaChannelSigner;
 use crate::chain::transaction::{OutPoint, TransactionData};
+use crate::impl_writeable_tlv_based;
 
 #[allow(unused_imports)]
 use crate::prelude::*;
@@ -56,6 +57,11 @@ impl BestBlock {
        }
 }
 
+impl_writeable_tlv_based!(BestBlock, {
+       (0, block_hash, required),
+       (2, height, required),
+});
+
 
 /// The `Listen` trait is used to notify when blocks have been connected or disconnected from the
 /// chain.
index f6e7f7164874f0de97fb56a8dc3c4294d2206c94..e72bc0228fd18fe43319744c550e295c5efa5fe7 100644 (file)
@@ -18,6 +18,7 @@ pub mod bump_transaction;
 
 pub use bump_transaction::BumpTransactionEvent;
 
+use crate::blinded_path::payment::{Bolt12OfferContext, Bolt12RefundContext, PaymentContext, PaymentContextRef};
 use crate::sign::SpendableOutputDescriptor;
 use crate::ln::channelmanager::{InterceptId, PaymentId, RecipientOnionFields};
 use crate::ln::channel::FUNDING_CONF_DEADLINE_BLOCKS;
@@ -49,11 +50,12 @@ use crate::prelude::*;
 /// spontaneous payment or a "conventional" lightning payment that's paying an invoice.
 #[derive(Clone, Debug, PartialEq, Eq)]
 pub enum PaymentPurpose {
-       /// Information for receiving a payment that we generated an invoice for.
-       InvoicePayment {
+       /// A payment for a BOLT 11 invoice.
+       Bolt11InvoicePayment {
                /// The preimage to the payment_hash, if the payment hash (and secret) were fetched via
-               /// [`ChannelManager::create_inbound_payment`]. If provided, this can be handed directly to
-               /// [`ChannelManager::claim_funds`].
+               /// [`ChannelManager::create_inbound_payment`]. When handling [`Event::PaymentClaimable`],
+               /// this can be passed directly to [`ChannelManager::claim_funds`] to claim the payment. No
+               /// action is needed when seen in [`Event::PaymentClaimed`].
                ///
                /// [`ChannelManager::create_inbound_payment`]: crate::ln::channelmanager::ChannelManager::create_inbound_payment
                /// [`ChannelManager::claim_funds`]: crate::ln::channelmanager::ChannelManager::claim_funds
@@ -70,6 +72,48 @@ pub enum PaymentPurpose {
                /// [`ChannelManager::create_inbound_payment_for_hash`]: crate::ln::channelmanager::ChannelManager::create_inbound_payment_for_hash
                payment_secret: PaymentSecret,
        },
+       /// A payment for a BOLT 12 [`Offer`].
+       ///
+       /// [`Offer`]: crate::offers::offer::Offer
+       Bolt12OfferPayment {
+               /// The preimage to the payment hash. When handling [`Event::PaymentClaimable`], this can be
+               /// passed directly to [`ChannelManager::claim_funds`], if provided. No action is needed
+               /// when seen in [`Event::PaymentClaimed`].
+               ///
+               /// [`ChannelManager::claim_funds`]: crate::ln::channelmanager::ChannelManager::claim_funds
+               payment_preimage: Option<PaymentPreimage>,
+               /// The secret used to authenticate the sender to the recipient, preventing a number of
+               /// de-anonymization attacks while routing a payment.
+               ///
+               /// See [`PaymentPurpose::Bolt11InvoicePayment::payment_secret`] for further details.
+               payment_secret: PaymentSecret,
+               /// The context of the payment such as information about the corresponding [`Offer`] and
+               /// [`InvoiceRequest`].
+               ///
+               /// [`Offer`]: crate::offers::offer::Offer
+               /// [`InvoiceRequest`]: crate::offers::invoice_request::InvoiceRequest
+               payment_context: Bolt12OfferContext,
+       },
+       /// A payment for a BOLT 12 [`Refund`].
+       ///
+       /// [`Refund`]: crate::offers::refund::Refund
+       Bolt12RefundPayment {
+               /// The preimage to the payment hash. When handling [`Event::PaymentClaimable`], this can be
+               /// passed directly to [`ChannelManager::claim_funds`], if provided. No action is needed
+               /// when seen in [`Event::PaymentClaimed`].
+               ///
+               /// [`ChannelManager::claim_funds`]: crate::ln::channelmanager::ChannelManager::claim_funds
+               payment_preimage: Option<PaymentPreimage>,
+               /// The secret used to authenticate the sender to the recipient, preventing a number of
+               /// de-anonymization attacks while routing a payment.
+               ///
+               /// See [`PaymentPurpose::Bolt11InvoicePayment::payment_secret`] for further details.
+               payment_secret: PaymentSecret,
+               /// The context of the payment such as information about the corresponding [`Refund`].
+               ///
+               /// [`Refund`]: crate::offers::refund::Refund
+               payment_context: Bolt12RefundContext,
+       },
        /// Because this is a spontaneous payment, the payer generated their own preimage rather than us
        /// (the payee) providing a preimage.
        SpontaneousPayment(PaymentPreimage),
@@ -79,17 +123,67 @@ impl PaymentPurpose {
        /// Returns the preimage for this payment, if it is known.
        pub fn preimage(&self) -> Option<PaymentPreimage> {
                match self {
-                       PaymentPurpose::InvoicePayment { payment_preimage, .. } => *payment_preimage,
+                       PaymentPurpose::Bolt11InvoicePayment { payment_preimage, .. } => *payment_preimage,
+                       PaymentPurpose::Bolt12OfferPayment { payment_preimage, .. } => *payment_preimage,
+                       PaymentPurpose::Bolt12RefundPayment { payment_preimage, .. } => *payment_preimage,
                        PaymentPurpose::SpontaneousPayment(preimage) => Some(*preimage),
                }
        }
+
+       pub(crate) fn is_keysend(&self) -> bool {
+               match self {
+                       PaymentPurpose::Bolt11InvoicePayment { .. } => false,
+                       PaymentPurpose::Bolt12OfferPayment { .. } => false,
+                       PaymentPurpose::Bolt12RefundPayment { .. } => false,
+                       PaymentPurpose::SpontaneousPayment(..) => true,
+               }
+       }
+
+       pub(crate) fn from_parts(
+               payment_preimage: Option<PaymentPreimage>, payment_secret: PaymentSecret,
+               payment_context: Option<PaymentContext>,
+       ) -> Self {
+               match payment_context {
+                       Some(PaymentContext::Unknown(_)) | None => {
+                               PaymentPurpose::Bolt11InvoicePayment {
+                                       payment_preimage,
+                                       payment_secret,
+                               }
+                       },
+                       Some(PaymentContext::Bolt12Offer(context)) => {
+                               PaymentPurpose::Bolt12OfferPayment {
+                                       payment_preimage,
+                                       payment_secret,
+                                       payment_context: context,
+                               }
+                       },
+                       Some(PaymentContext::Bolt12Refund(context)) => {
+                               PaymentPurpose::Bolt12RefundPayment {
+                                       payment_preimage,
+                                       payment_secret,
+                                       payment_context: context,
+                               }
+                       },
+               }
+       }
 }
 
 impl_writeable_tlv_based_enum!(PaymentPurpose,
-       (0, InvoicePayment) => {
+       (0, Bolt11InvoicePayment) => {
                (0, payment_preimage, option),
                (2, payment_secret, required),
-       };
+       },
+       (4, Bolt12OfferPayment) => {
+               (0, payment_preimage, option),
+               (2, payment_secret, required),
+               (4, payment_context, required),
+       },
+       (6, Bolt12RefundPayment) => {
+               (0, payment_preimage, option),
+               (2, payment_secret, required),
+               (4, payment_context, required),
+       },
+       ;
        (2, SpontaneousPayment)
 );
 
@@ -792,9 +886,15 @@ pub enum Event {
        },
        /// Used to indicate that an output which you should know how to spend was confirmed on chain
        /// and is now spendable.
-       /// Such an output will *not* ever be spent by rust-lightning, and are not at risk of your
+       ///
+       /// Such an output will *never* be spent directly by LDK, and are not at risk of your
        /// counterparty spending them due to some kind of timeout. Thus, you need to store them
        /// somewhere and spend them when you create on-chain transactions.
+       ///
+       /// You may hand them to the [`OutputSweeper`] utility which will store and (re-)generate spending
+       /// transactions for you.
+       ///
+       /// [`OutputSweeper`]: crate::util::sweep::OutputSweeper
        SpendableOutputs {
                /// The outputs which you should store as spendable by you.
                outputs: Vec<SpendableOutputDescriptor>,
@@ -912,8 +1012,8 @@ pub enum Event {
                /// The features that this channel will operate with.
                channel_type: ChannelTypeFeatures,
        },
-       /// Used to indicate that a previously opened channel with the given `channel_id` is in the
-       /// process of closure.
+       /// Used to indicate that a channel that got past the initial handshake with the given `channel_id` is in the
+       /// process of closure. This includes previously opened channels, and channels that time out from not being funded.
        ///
        /// Note that this event is only triggered for accepted channels: if the
        /// [`UserConfig::manually_accept_inbound_channels`] config flag is set to true and the channel is
@@ -1058,10 +1158,27 @@ impl Writeable for Event {
                                1u8.write(writer)?;
                                let mut payment_secret = None;
                                let payment_preimage;
+                               let mut payment_context = None;
                                match &purpose {
-                                       PaymentPurpose::InvoicePayment { payment_preimage: preimage, payment_secret: secret } => {
+                                       PaymentPurpose::Bolt11InvoicePayment {
+                                               payment_preimage: preimage, payment_secret: secret
+                                       } => {
+                                               payment_secret = Some(secret);
+                                               payment_preimage = *preimage;
+                                       },
+                                       PaymentPurpose::Bolt12OfferPayment {
+                                               payment_preimage: preimage, payment_secret: secret, payment_context: context
+                                       } => {
+                                               payment_secret = Some(secret);
+                                               payment_preimage = *preimage;
+                                               payment_context = Some(PaymentContextRef::Bolt12Offer(context));
+                                       },
+                                       PaymentPurpose::Bolt12RefundPayment {
+                                               payment_preimage: preimage, payment_secret: secret, payment_context: context
+                                       } => {
                                                payment_secret = Some(secret);
                                                payment_preimage = *preimage;
+                                               payment_context = Some(PaymentContextRef::Bolt12Refund(context));
                                        },
                                        PaymentPurpose::SpontaneousPayment(preimage) => {
                                                payment_preimage = Some(*preimage);
@@ -1081,6 +1198,7 @@ impl Writeable for Event {
                                        (8, payment_preimage, option),
                                        (9, onion_fields, option),
                                        (10, skimmed_fee_opt, option),
+                                       (11, payment_context, option),
                                });
                        },
                        &Event::PaymentSent { ref payment_id, ref payment_preimage, ref payment_hash, ref fee_paid_msat } => {
@@ -1311,6 +1429,7 @@ impl MaybeReadable for Event {
                                        let mut claim_deadline = None;
                                        let mut via_user_channel_id = None;
                                        let mut onion_fields = None;
+                                       let mut payment_context = None;
                                        read_tlv_fields!(reader, {
                                                (0, payment_hash, required),
                                                (1, receiver_node_id, option),
@@ -1323,12 +1442,10 @@ impl MaybeReadable for Event {
                                                (8, payment_preimage, option),
                                                (9, onion_fields, option),
                                                (10, counterparty_skimmed_fee_msat_opt, option),
+                                               (11, payment_context, option),
                                        });
                                        let purpose = match payment_secret {
-                                               Some(secret) => PaymentPurpose::InvoicePayment {
-                                                       payment_preimage,
-                                                       payment_secret: secret
-                                               },
+                                               Some(secret) => PaymentPurpose::from_parts(payment_preimage, secret, payment_context),
                                                None if payment_preimage.is_some() => PaymentPurpose::SpontaneousPayment(payment_preimage.unwrap()),
                                                None => return Err(msgs::DecodeError::InvalidValue),
                                        };
index eb31be9eecf334ead7ece1d1d6ecd7b587534f61..a0438b2447a004a1fc2261c036c673d363968a37 100644 (file)
@@ -9,7 +9,7 @@
 
 use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
 use crate::blinded_path::BlindedPath;
-use crate::blinded_path::payment::{ForwardNode, ForwardTlvs, PaymentConstraints, PaymentRelay, ReceiveTlvs};
+use crate::blinded_path::payment::{ForwardNode, ForwardTlvs, PaymentConstraints, PaymentContext, PaymentRelay, ReceiveTlvs};
 use crate::events::{Event, HTLCDestination, MessageSendEvent, MessageSendEventsProvider, PaymentFailureReason};
 use crate::ln::PaymentSecret;
 use crate::ln::channelmanager;
@@ -63,6 +63,7 @@ fn blinded_payment_path(
                        htlc_minimum_msat:
                                intro_node_min_htlc_opt.unwrap_or_else(|| channel_upds.last().unwrap().htlc_minimum_msat),
                },
+               payment_context: PaymentContext::unknown(),
        };
        let mut secp_ctx = Secp256k1::new();
        BlindedPath::new_for_payment(
@@ -108,6 +109,7 @@ fn do_one_hop_blinded_path(success: bool) {
                        max_cltv_expiry: u32::max_value(),
                        htlc_minimum_msat: chan_upd.htlc_minimum_msat,
                },
+               payment_context: PaymentContext::unknown(),
        };
        let mut secp_ctx = Secp256k1::new();
        let blinded_path = BlindedPath::one_hop_for_payment(
@@ -151,6 +153,7 @@ fn mpp_to_one_hop_blinded_path() {
                        max_cltv_expiry: u32::max_value(),
                        htlc_minimum_msat: chan_upd_1_3.htlc_minimum_msat,
                },
+               payment_context: PaymentContext::unknown(),
        };
        let blinded_path = BlindedPath::one_hop_for_payment(
                nodes[3].node.get_our_node_id(), payee_tlvs, TEST_FINAL_CLTV as u16,
@@ -1281,6 +1284,7 @@ fn custom_tlvs_to_blinded_path() {
                        max_cltv_expiry: u32::max_value(),
                        htlc_minimum_msat: chan_upd.htlc_minimum_msat,
                },
+               payment_context: PaymentContext::unknown(),
        };
        let mut secp_ctx = Secp256k1::new();
        let blinded_path = BlindedPath::one_hop_for_payment(
index 2e95f5c63ff7e173bb5a107f0719fb89ca3e3301..11ab1fbb85158a40f6e796ab1b26856df16c8458 100644 (file)
@@ -173,11 +173,11 @@ fn do_test_simple_monitor_temporary_update_fail(disconnect: bool) {
                        assert_eq!(receiver_node_id.unwrap(), nodes[1].node.get_our_node_id());
                        assert_eq!(via_channel_id, Some(channel_id));
                        match &purpose {
-                               PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
+                               PaymentPurpose::Bolt11InvoicePayment { payment_preimage, payment_secret, .. } => {
                                        assert!(payment_preimage.is_none());
                                        assert_eq!(payment_secret_1, *payment_secret);
                                },
-                               _ => panic!("expected PaymentPurpose::InvoicePayment")
+                               _ => panic!("expected PaymentPurpose::Bolt11InvoicePayment")
                        }
                },
                _ => panic!("Unexpected event"),
@@ -554,11 +554,11 @@ fn do_test_monitor_temporary_update_fail(disconnect_count: usize) {
                        assert_eq!(receiver_node_id.unwrap(), nodes[1].node.get_our_node_id());
                        assert_eq!(via_channel_id, Some(channel_id));
                        match &purpose {
-                               PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
+                               PaymentPurpose::Bolt11InvoicePayment { payment_preimage, payment_secret, .. } => {
                                        assert!(payment_preimage.is_none());
                                        assert_eq!(payment_secret_2, *payment_secret);
                                },
-                               _ => panic!("expected PaymentPurpose::InvoicePayment")
+                               _ => panic!("expected PaymentPurpose::Bolt11InvoicePayment")
                        }
                },
                _ => panic!("Unexpected event"),
@@ -672,11 +672,11 @@ fn test_monitor_update_fail_cs() {
                        assert_eq!(receiver_node_id.unwrap(), nodes[1].node.get_our_node_id());
                        assert_eq!(via_channel_id, Some(channel_id));
                        match &purpose {
-                               PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
+                               PaymentPurpose::Bolt11InvoicePayment { payment_preimage, payment_secret, .. } => {
                                        assert!(payment_preimage.is_none());
                                        assert_eq!(our_payment_secret, *payment_secret);
                                },
-                               _ => panic!("expected PaymentPurpose::InvoicePayment")
+                               _ => panic!("expected PaymentPurpose::Bolt11InvoicePayment")
                        }
                },
                _ => panic!("Unexpected event"),
@@ -1683,11 +1683,11 @@ fn test_monitor_update_fail_claim() {
                        assert_eq!(via_channel_id, Some(channel_id));
                        assert_eq!(via_user_channel_id, Some(42));
                        match &purpose {
-                               PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
+                               PaymentPurpose::Bolt11InvoicePayment { payment_preimage, payment_secret, .. } => {
                                        assert!(payment_preimage.is_none());
                                        assert_eq!(payment_secret_2, *payment_secret);
                                },
-                               _ => panic!("expected PaymentPurpose::InvoicePayment")
+                               _ => panic!("expected PaymentPurpose::Bolt11InvoicePayment")
                        }
                },
                _ => panic!("Unexpected event"),
@@ -1699,11 +1699,11 @@ fn test_monitor_update_fail_claim() {
                        assert_eq!(receiver_node_id.unwrap(), nodes[0].node.get_our_node_id());
                        assert_eq!(via_channel_id, Some(channel_id));
                        match &purpose {
-                               PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
+                               PaymentPurpose::Bolt11InvoicePayment { payment_preimage, payment_secret, .. } => {
                                        assert!(payment_preimage.is_none());
                                        assert_eq!(payment_secret_3, *payment_secret);
                                },
-                               _ => panic!("expected PaymentPurpose::InvoicePayment")
+                               _ => panic!("expected PaymentPurpose::Bolt11InvoicePayment")
                        }
                },
                _ => panic!("Unexpected event"),
index 2698abee2d7db271c0e8a0a2c11d7e77ca976930..c02659cd0b01cc83548b988c362929074c56808d 100644 (file)
@@ -1189,9 +1189,9 @@ impl_writeable_tlv_based!(PendingChannelMonitorUpdate, {
 pub(super) enum ChannelPhase<SP: Deref> where SP::Target: SignerProvider {
        UnfundedOutboundV1(OutboundV1Channel<SP>),
        UnfundedInboundV1(InboundV1Channel<SP>),
-       #[cfg(dual_funding)]
+       #[cfg(any(dual_funding, splicing))]
        UnfundedOutboundV2(OutboundV2Channel<SP>),
-       #[cfg(dual_funding)]
+       #[cfg(any(dual_funding, splicing))]
        UnfundedInboundV2(InboundV2Channel<SP>),
        Funded(Channel<SP>),
 }
@@ -1205,9 +1205,9 @@ impl<'a, SP: Deref> ChannelPhase<SP> where
                        ChannelPhase::Funded(chan) => &chan.context,
                        ChannelPhase::UnfundedOutboundV1(chan) => &chan.context,
                        ChannelPhase::UnfundedInboundV1(chan) => &chan.context,
-                       #[cfg(dual_funding)]
+                       #[cfg(any(dual_funding, splicing))]
                        ChannelPhase::UnfundedOutboundV2(chan) => &chan.context,
-                       #[cfg(dual_funding)]
+                       #[cfg(any(dual_funding, splicing))]
                        ChannelPhase::UnfundedInboundV2(chan) => &chan.context,
                }
        }
@@ -1217,9 +1217,9 @@ impl<'a, SP: Deref> ChannelPhase<SP> where
                        ChannelPhase::Funded(ref mut chan) => &mut chan.context,
                        ChannelPhase::UnfundedOutboundV1(ref mut chan) => &mut chan.context,
                        ChannelPhase::UnfundedInboundV1(ref mut chan) => &mut chan.context,
-                       #[cfg(dual_funding)]
+                       #[cfg(any(dual_funding, splicing))]
                        ChannelPhase::UnfundedOutboundV2(ref mut chan) => &mut chan.context,
-                       #[cfg(dual_funding)]
+                       #[cfg(any(dual_funding, splicing))]
                        ChannelPhase::UnfundedInboundV2(ref mut chan) => &mut chan.context,
                }
        }
@@ -2730,7 +2730,7 @@ impl<SP: Deref> ChannelContext<SP> where SP::Target: SignerProvider  {
                        feerate_per_kw = cmp::max(feerate_per_kw, feerate);
                }
                let feerate_plus_quarter = feerate_per_kw.checked_mul(1250).map(|v| v / 1000);
-               cmp::max(2530, feerate_plus_quarter.unwrap_or(u32::max_value()))
+               cmp::max(feerate_per_kw + 2530, feerate_plus_quarter.unwrap_or(u32::max_value()))
        }
 
        /// Get forwarding information for the counterparty.
@@ -3501,7 +3501,7 @@ pub(crate) fn get_legacy_default_holder_selected_channel_reserve_satoshis(channe
 ///
 /// This is used both for outbound and inbound channels and has lower bound
 /// of `dust_limit_satoshis`.
-#[cfg(dual_funding)]
+#[cfg(any(dual_funding, splicing))]
 fn get_v2_channel_reserve_satoshis(channel_value_satoshis: u64, dust_limit_satoshis: u64) -> u64 {
        // Fixed at 1% of channel value by spec.
        let (q, _) = channel_value_satoshis.overflowing_div(100);
@@ -3524,7 +3524,7 @@ pub(crate) fn commit_tx_fee_msat(feerate_per_kw: u32, num_htlcs: usize, channel_
 }
 
 /// Context for dual-funded channels.
-#[cfg(dual_funding)]
+#[cfg(any(dual_funding, splicing))]
 pub(super) struct DualFundingChannelContext {
        /// The amount in satoshis we will be contributing to the channel.
        pub our_funding_satoshis: u64,
@@ -3541,7 +3541,7 @@ pub(super) struct DualFundingChannelContext {
 // Counterparty designates channel data owned by the another channel participant entity.
 pub(super) struct Channel<SP: Deref> where SP::Target: SignerProvider {
        pub context: ChannelContext<SP>,
-       #[cfg(dual_funding)]
+       #[cfg(any(dual_funding, splicing))]
        pub dual_funding_channel_context: Option<DualFundingChannelContext>,
 }
 
@@ -7704,7 +7704,7 @@ impl<SP: Deref> OutboundV1Channel<SP> where SP::Target: SignerProvider {
 
                let mut channel = Channel {
                        context: self.context,
-                       #[cfg(dual_funding)]
+                       #[cfg(any(dual_funding, splicing))]
                        dual_funding_channel_context: None,
                };
 
@@ -7994,7 +7994,7 @@ impl<SP: Deref> InboundV1Channel<SP> where SP::Target: SignerProvider {
                // `ChannelMonitor`.
                let mut channel = Channel {
                        context: self.context,
-                       #[cfg(dual_funding)]
+                       #[cfg(any(dual_funding, splicing))]
                        dual_funding_channel_context: None,
                };
                let need_channel_ready = channel.check_get_channel_ready(0).is_some();
@@ -8005,15 +8005,15 @@ impl<SP: Deref> InboundV1Channel<SP> where SP::Target: SignerProvider {
 }
 
 // A not-yet-funded outbound (from holder) channel using V2 channel establishment.
-#[cfg(dual_funding)]
+#[cfg(any(dual_funding, splicing))]
 pub(super) struct OutboundV2Channel<SP: Deref> where SP::Target: SignerProvider {
        pub context: ChannelContext<SP>,
        pub unfunded_context: UnfundedChannelContext,
-       #[cfg(dual_funding)]
+       #[cfg(any(dual_funding, splicing))]
        pub dual_funding_context: DualFundingChannelContext,
 }
 
-#[cfg(dual_funding)]
+#[cfg(any(dual_funding, splicing))]
 impl<SP: Deref> OutboundV2Channel<SP> where SP::Target: SignerProvider {
        pub fn new<ES: Deref, F: Deref>(
                fee_estimator: &LowerBoundedFeeEstimator<F>, entropy_source: &ES, signer_provider: &SP,
@@ -8129,14 +8129,14 @@ impl<SP: Deref> OutboundV2Channel<SP> where SP::Target: SignerProvider {
 }
 
 // A not-yet-funded inbound (from counterparty) channel using V2 channel establishment.
-#[cfg(dual_funding)]
+#[cfg(any(dual_funding, splicing))]
 pub(super) struct InboundV2Channel<SP: Deref> where SP::Target: SignerProvider {
        pub context: ChannelContext<SP>,
        pub unfunded_context: UnfundedChannelContext,
        pub dual_funding_context: DualFundingChannelContext,
 }
 
-#[cfg(dual_funding)]
+#[cfg(any(dual_funding, splicing))]
 impl<SP: Deref> InboundV2Channel<SP> where SP::Target: SignerProvider {
        /// Creates a new dual-funded channel from a remote side's request for one.
        /// Assumes chain_hash has already been checked and corresponds with what we expect!
@@ -9303,7 +9303,7 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
 
                                blocked_monitor_updates: blocked_monitor_updates.unwrap(),
                        },
-                       #[cfg(dual_funding)]
+                       #[cfg(any(dual_funding, splicing))]
                        dual_funding_channel_context: None,
                })
        }
index fa8a0b2163d2ffcb828d34c07675f8ee823e9e99..130eab49384ce884c99cc637422c61729dd1617f 100644 (file)
@@ -31,8 +31,8 @@ use bitcoin::secp256k1::{SecretKey,PublicKey};
 use bitcoin::secp256k1::Secp256k1;
 use bitcoin::{secp256k1, Sequence};
 
-use crate::blinded_path::BlindedPath;
-use crate::blinded_path::payment::{PaymentConstraints, ReceiveTlvs};
+use crate::blinded_path::{BlindedPath, NodeIdLookUp};
+use crate::blinded_path::payment::{Bolt12OfferContext, Bolt12RefundContext, PaymentConstraints, PaymentContext, ReceiveTlvs};
 use crate::chain;
 use crate::chain::{Confirm, ChannelMonitorUpdateStatus, Watch, BestBlock};
 use crate::chain::chaininterface::{BroadcasterInterface, ConfirmationTarget, FeeEstimator, LowerBoundedFeeEstimator};
@@ -61,7 +61,6 @@ use crate::ln::wire::Encode;
 use crate::offers::invoice::{BlindedPayInfo, Bolt12Invoice, DEFAULT_RELATIVE_EXPIRY, DerivedSigningPubkey, ExplicitSigningPubkey, InvoiceBuilder, UnsignedBolt12Invoice};
 use crate::offers::invoice_error::InvoiceError;
 use crate::offers::invoice_request::{DerivedPayerId, InvoiceRequestBuilder};
-use crate::offers::merkle::SignError;
 use crate::offers::offer::{Offer, OfferBuilder};
 use crate::offers::parse::Bolt12SemanticError;
 use crate::offers::refund::{Refund, RefundBuilder};
@@ -156,6 +155,11 @@ pub enum PendingHTLCRouting {
                /// [`Event::PaymentClaimable::onion_fields`] as
                /// [`RecipientOnionFields::payment_metadata`].
                payment_metadata: Option<Vec<u8>>,
+               /// The context of the payment included by the recipient in a blinded path, or `None` if a
+               /// blinded path was not used.
+               ///
+               /// Used in part to determine the [`events::PaymentPurpose`].
+               payment_context: Option<PaymentContext>,
                /// CLTV expiry of the received HTLC.
                ///
                /// Used to track when we should expire pending HTLCs that go unclaimed.
@@ -353,6 +357,11 @@ enum OnionPayload {
                /// This is only here for backwards-compatibility in serialization, in the future it can be
                /// removed, breaking clients running 0.0.106 and earlier.
                _legacy_hop_data: Option<msgs::FinalOnionHopData>,
+               /// The context of the payment included by the recipient in a blinded path, or `None` if a
+               /// blinded path was not used.
+               ///
+               /// Used in part to determine the [`events::PaymentPurpose`].
+               payment_context: Option<PaymentContext>,
        },
        /// Contains the payer-provided preimage.
        Spontaneous(PaymentPreimage),
@@ -918,9 +927,9 @@ impl <SP: Deref> PeerState<SP> where SP::Target: SignerProvider {
                        match phase {
                                ChannelPhase::Funded(_) | ChannelPhase::UnfundedOutboundV1(_) => true,
                                ChannelPhase::UnfundedInboundV1(_) => false,
-                               #[cfg(dual_funding)]
+                               #[cfg(any(dual_funding, splicing))]
                                ChannelPhase::UnfundedOutboundV2(_) => true,
-                               #[cfg(dual_funding)]
+                               #[cfg(any(dual_funding, splicing))]
                                ChannelPhase::UnfundedInboundV2(_) => false,
                        }
                )
@@ -1455,12 +1464,12 @@ where
 /// // On the event processing thread
 /// channel_manager.process_pending_events(&|event| match event {
 ///     Event::PaymentClaimable { payment_hash, purpose, .. } => match purpose {
-///         PaymentPurpose::InvoicePayment { payment_preimage: Some(payment_preimage), .. } => {
+///         PaymentPurpose::Bolt11InvoicePayment { payment_preimage: Some(payment_preimage), .. } => {
 ///             assert_eq!(payment_hash, known_payment_hash);
 ///             println!("Claiming payment {}", payment_hash);
 ///             channel_manager.claim_funds(payment_preimage);
 ///         },
-///         PaymentPurpose::InvoicePayment { payment_preimage: None, .. } => {
+///         PaymentPurpose::Bolt11InvoicePayment { payment_preimage: None, .. } => {
 ///             println!("Unknown payment hash: {}", payment_hash);
 ///         },
 ///         PaymentPurpose::SpontaneousPayment(payment_preimage) => {
@@ -1468,6 +1477,8 @@ where
 ///             println!("Claiming spontaneous payment {}", payment_hash);
 ///             channel_manager.claim_funds(payment_preimage);
 ///         },
+///         // ...
+/// #         _ => {},
 ///     },
 ///     Event::PaymentClaimed { payment_hash, amount_msat, .. } => {
 ///         assert_eq!(payment_hash, known_payment_hash);
@@ -1555,11 +1566,11 @@ where
 /// // On the event processing thread
 /// channel_manager.process_pending_events(&|event| match event {
 ///     Event::PaymentClaimable { payment_hash, purpose, .. } => match purpose {
-///         PaymentPurpose::InvoicePayment { payment_preimage: Some(payment_preimage), .. } => {
+///         PaymentPurpose::Bolt12OfferPayment { payment_preimage: Some(payment_preimage), .. } => {
 ///             println!("Claiming payment {}", payment_hash);
 ///             channel_manager.claim_funds(payment_preimage);
 ///         },
-///         PaymentPurpose::InvoicePayment { payment_preimage: None, .. } => {
+///         PaymentPurpose::Bolt12OfferPayment { payment_preimage: None, .. } => {
 ///             println!("Unknown payment hash: {}", payment_hash);
 ///         },
 ///         // ...
@@ -1695,25 +1706,31 @@ where
 /// #
 /// # fn example<T: AChannelManager>(channel_manager: T, refund: &Refund) {
 /// # let channel_manager = channel_manager.get_cm();
-/// match channel_manager.request_refund_payment(refund) {
-///     Ok(()) => println!("Requesting payment for refund"),
-///     Err(e) => println!("Unable to request payment for refund: {:?}", e),
-/// }
+/// let known_payment_hash = match channel_manager.request_refund_payment(refund) {
+///     Ok(invoice) => {
+///         let payment_hash = invoice.payment_hash();
+///         println!("Requesting refund payment {}", payment_hash);
+///         payment_hash
+///     },
+///     Err(e) => panic!("Unable to request payment for refund: {:?}", e),
+/// };
 ///
 /// // On the event processing thread
 /// channel_manager.process_pending_events(&|event| match event {
 ///     Event::PaymentClaimable { payment_hash, purpose, .. } => match purpose {
-///            PaymentPurpose::InvoicePayment { payment_preimage: Some(payment_preimage), .. } => {
+///            PaymentPurpose::Bolt12RefundPayment { payment_preimage: Some(payment_preimage), .. } => {
+///             assert_eq!(payment_hash, known_payment_hash);
 ///             println!("Claiming payment {}", payment_hash);
 ///             channel_manager.claim_funds(payment_preimage);
 ///         },
-///            PaymentPurpose::InvoicePayment { payment_preimage: None, .. } => {
+///            PaymentPurpose::Bolt12RefundPayment { payment_preimage: None, .. } => {
 ///             println!("Unknown payment hash: {}", payment_hash);
 ///            },
 ///         // ...
 /// #         _ => {},
 ///     },
 ///     Event::PaymentClaimed { payment_hash, amount_msat, .. } => {
+///         assert_eq!(payment_hash, known_payment_hash);
 ///         println!("Claimed {} msats", amount_msat);
 ///     },
 ///     // ...
@@ -2774,11 +2791,11 @@ macro_rules! convert_chan_phase_err {
                        ChannelPhase::UnfundedInboundV1(channel) => {
                                convert_chan_phase_err!($self, $err, channel, $channel_id, UNFUNDED_CHANNEL)
                        },
-                       #[cfg(dual_funding)]
+                       #[cfg(any(dual_funding, splicing))]
                        ChannelPhase::UnfundedOutboundV2(channel) => {
                                convert_chan_phase_err!($self, $err, channel, $channel_id, UNFUNDED_CHANNEL)
                        },
-                       #[cfg(dual_funding)]
+                       #[cfg(any(dual_funding, splicing))]
                        ChannelPhase::UnfundedInboundV2(channel) => {
                                convert_chan_phase_err!($self, $err, channel, $channel_id, UNFUNDED_CHANNEL)
                        },
@@ -3653,8 +3670,8 @@ where
                                                // Unfunded channel has no update
                                                (None, chan_phase.context().get_counterparty_node_id())
                                        },
-                                       // TODO(dual_funding): Combine this match arm with above once #[cfg(dual_funding)] is removed.
-                                       #[cfg(dual_funding)]
+                                       // TODO(dual_funding): Combine this match arm with above once #[cfg(any(dual_funding, splicing))] is removed.
+                                       #[cfg(any(dual_funding, splicing))]
                                        ChannelPhase::UnfundedOutboundV2(_) | ChannelPhase::UnfundedInboundV2(_) => {
                                                self.finish_close_channel(chan_phase.context_mut().force_shutdown(false, closure_reason));
                                                // Unfunded channel has no update
@@ -5331,13 +5348,14 @@ where
                                                                let blinded_failure = routing.blinded_failure();
                                                                let (cltv_expiry, onion_payload, payment_data, phantom_shared_secret, mut onion_fields) = match routing {
                                                                        PendingHTLCRouting::Receive {
-                                                                               payment_data, payment_metadata, incoming_cltv_expiry, phantom_shared_secret,
-                                                                               custom_tlvs, requires_blinded_error: _
+                                                                               payment_data, payment_metadata, payment_context,
+                                                                               incoming_cltv_expiry, phantom_shared_secret, custom_tlvs,
+                                                                               requires_blinded_error: _
                                                                        } => {
                                                                                let _legacy_hop_data = Some(payment_data.clone());
                                                                                let onion_fields = RecipientOnionFields { payment_secret: Some(payment_data.payment_secret),
                                                                                                payment_metadata, custom_tlvs };
-                                                                               (incoming_cltv_expiry, OnionPayload::Invoice { _legacy_hop_data },
+                                                                               (incoming_cltv_expiry, OnionPayload::Invoice { _legacy_hop_data, payment_context },
                                                                                        Some(payment_data), phantom_shared_secret, onion_fields)
                                                                        },
                                                                        PendingHTLCRouting::ReceiveKeysend {
@@ -5415,10 +5433,7 @@ where
                                                                macro_rules! check_total_value {
                                                                        ($purpose: expr) => {{
                                                                                let mut payment_claimable_generated = false;
-                                                                               let is_keysend = match $purpose {
-                                                                                       events::PaymentPurpose::SpontaneousPayment(_) => true,
-                                                                                       events::PaymentPurpose::InvoicePayment { .. } => false,
-                                                                               };
+                                                                               let is_keysend = $purpose.is_keysend();
                                                                                let mut claimable_payments = self.claimable_payments.lock().unwrap();
                                                                                if claimable_payments.pending_claiming_payments.contains_key(&payment_hash) {
                                                                                        fail_htlc!(claimable_htlc, payment_hash);
@@ -5515,7 +5530,7 @@ where
                                                                match payment_secrets.entry(payment_hash) {
                                                                        hash_map::Entry::Vacant(_) => {
                                                                                match claimable_htlc.onion_payload {
-                                                                                       OnionPayload::Invoice { .. } => {
+                                                                                       OnionPayload::Invoice { ref payment_context, .. } => {
                                                                                                let payment_data = payment_data.unwrap();
                                                                                                let (payment_preimage, min_final_cltv_expiry_delta) = match inbound_payment::verify(payment_hash, &payment_data, self.highest_seen_timestamp.load(Ordering::Acquire) as u64, &self.inbound_payment_key, &self.logger) {
                                                                                                        Ok(result) => result,
@@ -5532,10 +5547,11 @@ where
                                                                                                                fail_htlc!(claimable_htlc, payment_hash);
                                                                                                        }
                                                                                                }
-                                                                                               let purpose = events::PaymentPurpose::InvoicePayment {
-                                                                                                       payment_preimage: payment_preimage.clone(),
-                                                                                                       payment_secret: payment_data.payment_secret,
-                                                                                               };
+                                                                                               let purpose = events::PaymentPurpose::from_parts(
+                                                                                                       payment_preimage.clone(),
+                                                                                                       payment_data.payment_secret,
+                                                                                                       payment_context.clone(),
+                                                                                               );
                                                                                                check_total_value!(purpose);
                                                                                        },
                                                                                        OnionPayload::Spontaneous(preimage) => {
@@ -5545,10 +5561,13 @@ where
                                                                                }
                                                                        },
                                                                        hash_map::Entry::Occupied(inbound_payment) => {
-                                                                               if let OnionPayload::Spontaneous(_) = claimable_htlc.onion_payload {
-                                                                                       log_trace!(self.logger, "Failing new keysend HTLC with payment_hash {} because we already have an inbound payment with the same payment hash", &payment_hash);
-                                                                                       fail_htlc!(claimable_htlc, payment_hash);
-                                                                               }
+                                                                               let payment_context = match claimable_htlc.onion_payload {
+                                                                                       OnionPayload::Spontaneous(_) => {
+                                                                                               log_trace!(self.logger, "Failing new keysend HTLC with payment_hash {} because we already have an inbound payment with the same payment hash", &payment_hash);
+                                                                                               fail_htlc!(claimable_htlc, payment_hash);
+                                                                                       },
+                                                                                       OnionPayload::Invoice { ref payment_context, .. } => payment_context,
+                                                                               };
                                                                                let payment_data = payment_data.unwrap();
                                                                                if inbound_payment.get().payment_secret != payment_data.payment_secret {
                                                                                        log_trace!(self.logger, "Failing new HTLC with payment_hash {} as it didn't match our expected payment secret.", &payment_hash);
@@ -5558,10 +5577,11 @@ where
                                                                                                &payment_hash, payment_data.total_msat, inbound_payment.get().min_value_msat.unwrap());
                                                                                        fail_htlc!(claimable_htlc, payment_hash);
                                                                                } else {
-                                                                                       let purpose = events::PaymentPurpose::InvoicePayment {
-                                                                                               payment_preimage: inbound_payment.get().payment_preimage,
-                                                                                               payment_secret: payment_data.payment_secret,
-                                                                                       };
+                                                                                       let purpose = events::PaymentPurpose::from_parts(
+                                                                                               inbound_payment.get().payment_preimage,
+                                                                                               payment_data.payment_secret,
+                                                                                               payment_context.clone(),
+                                                                                       );
                                                                                        let payment_claimable_generated = check_total_value!(purpose);
                                                                                        if payment_claimable_generated {
                                                                                                inbound_payment.remove_entry();
@@ -5882,12 +5902,12 @@ where
                                                                process_unfunded_channel_tick(chan_id, &mut chan.context, &mut chan.unfunded_context,
                                                                        pending_msg_events, counterparty_node_id)
                                                        },
-                                                       #[cfg(dual_funding)]
+                                                       #[cfg(any(dual_funding, splicing))]
                                                        ChannelPhase::UnfundedInboundV2(chan) => {
                                                                process_unfunded_channel_tick(chan_id, &mut chan.context, &mut chan.unfunded_context,
                                                                        pending_msg_events, counterparty_node_id)
                                                        },
-                                                       #[cfg(dual_funding)]
+                                                       #[cfg(any(dual_funding, splicing))]
                                                        ChannelPhase::UnfundedOutboundV2(chan) => {
                                                                process_unfunded_channel_tick(chan_id, &mut chan.context, &mut chan.unfunded_context,
                                                                        pending_msg_events, counterparty_node_id)
@@ -7059,8 +7079,8 @@ where
                                                num_unfunded_channels += 1;
                                        }
                                },
-                               // TODO(dual_funding): Combine this match arm with above once #[cfg(dual_funding)] is removed.
-                               #[cfg(dual_funding)]
+                               // TODO(dual_funding): Combine this match arm with above once #[cfg(any(dual_funding, splicing))] is removed.
+                               #[cfg(any(dual_funding, splicing))]
                                ChannelPhase::UnfundedInboundV2(chan) => {
                                        // Only inbound V2 channels that are not 0conf and that we do not contribute to will be
                                        // included in the unfunded count.
@@ -7073,8 +7093,8 @@ where
                                        // Outbound channels don't contribute to the unfunded count in the DoS context.
                                        continue;
                                },
-                               // TODO(dual_funding): Combine this match arm with above once #[cfg(dual_funding)] is removed.
-                               #[cfg(dual_funding)]
+                               // TODO(dual_funding): Combine this match arm with above once #[cfg(any(dual_funding, splicing))] is removed.
+                               #[cfg(any(dual_funding, splicing))]
                                ChannelPhase::UnfundedOutboundV2(_) => {
                                        // Outbound channels don't contribute to the unfunded count in the DoS context.
                                        continue;
@@ -7501,7 +7521,7 @@ where
                                                finish_shutdown = Some(chan.context_mut().force_shutdown(false, ClosureReason::CounterpartyCoopClosedUnfundedChannel));
                                        },
                                        // TODO(dual_funding): Combine this match arm with above.
-                                       #[cfg(dual_funding)]
+                                       #[cfg(any(dual_funding, splicing))]
                                        ChannelPhase::UnfundedInboundV2(_) | ChannelPhase::UnfundedOutboundV2(_) => {
                                                let context = phase.context_mut();
                                                log_error!(self.logger, "Immediately closing unfunded channel {} as peer asked to cooperatively shut it down (which is unnecessary)", &msg.channel_id);
@@ -8775,7 +8795,7 @@ where
        ///
        /// The resulting invoice uses a [`PaymentHash`] recognized by the [`ChannelManager`] and a
        /// [`BlindedPath`] containing the [`PaymentSecret`] needed to reconstruct the corresponding
-       /// [`PaymentPreimage`].
+       /// [`PaymentPreimage`]. It is returned purely for informational purposes.
        ///
        /// # Limitations
        ///
@@ -8792,7 +8812,9 @@ where
        ///   the invoice.
        ///
        /// [`Bolt12Invoice`]: crate::offers::invoice::Bolt12Invoice
-       pub fn request_refund_payment(&self, refund: &Refund) -> Result<(), Bolt12SemanticError> {
+       pub fn request_refund_payment(
+               &self, refund: &Refund
+       ) -> Result<Bolt12Invoice, Bolt12SemanticError> {
                let expanded_key = &self.inbound_payment_key;
                let entropy = &*self.entropy_source;
                let secp_ctx = &self.secp_ctx;
@@ -8808,7 +8830,10 @@ where
 
                match self.create_inbound_payment(Some(amount_msats), relative_expiry, None) {
                        Ok((payment_hash, payment_secret)) => {
-                               let payment_paths = self.create_blinded_payment_paths(amount_msats, payment_secret)
+                               let payment_context = PaymentContext::Bolt12Refund(Bolt12RefundContext {});
+                               let payment_paths = self.create_blinded_payment_paths(
+                                       amount_msats, payment_secret, payment_context
+                               )
                                        .map_err(|_| Bolt12SemanticError::MissingPaths)?;
 
                                #[cfg(feature = "std")]
@@ -8831,7 +8856,7 @@ where
                                let mut pending_offers_messages = self.pending_offers_messages.lock().unwrap();
                                if refund.paths().is_empty() {
                                        let message = new_pending_onion_message(
-                                               OffersMessage::Invoice(invoice),
+                                               OffersMessage::Invoice(invoice.clone()),
                                                Destination::Node(refund.payer_id()),
                                                Some(reply_path),
                                        );
@@ -8847,7 +8872,7 @@ where
                                        }
                                }
 
-                               Ok(())
+                               Ok(invoice)
                        },
                        Err(()) => Err(Bolt12SemanticError::InvalidAmount),
                }
@@ -8859,10 +8884,9 @@ where
        /// This differs from [`create_inbound_payment_for_hash`] only in that it generates the
        /// [`PaymentHash`] and [`PaymentPreimage`] for you.
        ///
-       /// The [`PaymentPreimage`] will ultimately be returned to you in the [`PaymentClaimable`], which
-       /// will have the [`PaymentClaimable::purpose`] be [`PaymentPurpose::InvoicePayment`] with
-       /// its [`PaymentPurpose::InvoicePayment::payment_preimage`] field filled in. That should then be
-       /// passed directly to [`claim_funds`].
+       /// The [`PaymentPreimage`] will ultimately be returned to you in the [`PaymentClaimable`] event, which
+       /// will have the [`PaymentClaimable::purpose`] return `Some` for [`PaymentPurpose::preimage`]. That
+       /// should then be passed directly to [`claim_funds`].
        ///
        /// See [`create_inbound_payment_for_hash`] for detailed documentation on behavior and requirements.
        ///
@@ -8882,8 +8906,7 @@ where
        /// [`claim_funds`]: Self::claim_funds
        /// [`PaymentClaimable`]: events::Event::PaymentClaimable
        /// [`PaymentClaimable::purpose`]: events::Event::PaymentClaimable::purpose
-       /// [`PaymentPurpose::InvoicePayment`]: events::PaymentPurpose::InvoicePayment
-       /// [`PaymentPurpose::InvoicePayment::payment_preimage`]: events::PaymentPurpose::InvoicePayment::payment_preimage
+       /// [`PaymentPurpose::preimage`]: events::PaymentPurpose::preimage
        /// [`create_inbound_payment_for_hash`]: Self::create_inbound_payment_for_hash
        pub fn create_inbound_payment(&self, min_value_msat: Option<u64>, invoice_expiry_delta_secs: u32,
                min_final_cltv_expiry_delta: Option<u16>) -> Result<(PaymentHash, PaymentSecret), ()> {
@@ -8974,7 +8997,7 @@ where
        /// Creates multi-hop blinded payment paths for the given `amount_msats` by delegating to
        /// [`Router::create_blinded_payment_paths`].
        fn create_blinded_payment_paths(
-               &self, amount_msats: u64, payment_secret: PaymentSecret
+               &self, amount_msats: u64, payment_secret: PaymentSecret, payment_context: PaymentContext
        ) -> Result<Vec<(BlindedPayInfo, BlindedPath)>, ()> {
                let secp_ctx = &self.secp_ctx;
 
@@ -8988,6 +9011,7 @@ where
                                max_cltv_expiry,
                                htlc_minimum_msat: 1,
                        },
+                       payment_context,
                };
                self.router.create_blinded_payment_paths(
                        payee_node_id, first_hops, payee_tlvs, amount_msats, secp_ctx
@@ -9450,7 +9474,7 @@ where
                                                // Retain unfunded channels.
                                                ChannelPhase::UnfundedOutboundV1(_) | ChannelPhase::UnfundedInboundV1(_) => true,
                                                // TODO(dual_funding): Combine this match arm with above.
-                                               #[cfg(dual_funding)]
+                                               #[cfg(any(dual_funding, splicing))]
                                                ChannelPhase::UnfundedOutboundV2(_) | ChannelPhase::UnfundedInboundV2(_) => true,
                                                ChannelPhase::Funded(channel) => {
                                                        let res = f(channel);
@@ -9756,18 +9780,21 @@ where
                         msg.channel_id.clone())), *counterparty_node_id);
        }
 
+       #[cfg(splicing)]
        fn handle_splice(&self, counterparty_node_id: &PublicKey, msg: &msgs::Splice) {
                let _: Result<(), _> = handle_error!(self, Err(MsgHandleErrInternal::send_err_msg_no_close(
                        "Splicing not supported".to_owned(),
                         msg.channel_id.clone())), *counterparty_node_id);
        }
 
+       #[cfg(splicing)]
        fn handle_splice_ack(&self, counterparty_node_id: &PublicKey, msg: &msgs::SpliceAck) {
                let _: Result<(), _> = handle_error!(self, Err(MsgHandleErrInternal::send_err_msg_no_close(
                        "Splicing not supported (splice_ack)".to_owned(),
                         msg.channel_id.clone())), *counterparty_node_id);
        }
 
+       #[cfg(splicing)]
        fn handle_splice_locked(&self, counterparty_node_id: &PublicKey, msg: &msgs::SpliceLocked) {
                let _: Result<(), _> = handle_error!(self, Err(MsgHandleErrInternal::send_err_msg_no_close(
                        "Splicing not supported (splice_locked)".to_owned(),
@@ -9925,11 +9952,11 @@ where
                                                ChannelPhase::UnfundedInboundV1(chan) => {
                                                        &mut chan.context
                                                },
-                                               #[cfg(dual_funding)]
+                                               #[cfg(any(dual_funding, splicing))]
                                                ChannelPhase::UnfundedOutboundV2(chan) => {
                                                        &mut chan.context
                                                },
-                                               #[cfg(dual_funding)]
+                                               #[cfg(any(dual_funding, splicing))]
                                                ChannelPhase::UnfundedInboundV2(chan) => {
                                                        &mut chan.context
                                                },
@@ -10090,8 +10117,8 @@ where
                                                        });
                                                }
 
-                                               // TODO(dual_funding): Combine this match arm with above once #[cfg(dual_funding)] is removed.
-                                               #[cfg(dual_funding)]
+                                               // TODO(dual_funding): Combine this match arm with above once #[cfg(any(dual_funding, splicing))] is removed.
+                                               #[cfg(any(dual_funding, splicing))]
                                                ChannelPhase::UnfundedOutboundV2(chan) => {
                                                        pending_msg_events.push(events::MessageSendEvent::SendOpenChannelV2 {
                                                                node_id: chan.context.get_counterparty_node_id(),
@@ -10106,8 +10133,8 @@ where
                                                        debug_assert!(false);
                                                }
 
-                                               // TODO(dual_funding): Combine this match arm with above once #[cfg(dual_funding)] is removed.
-                                               #[cfg(dual_funding)]
+                                               // TODO(dual_funding): Combine this match arm with above once #[cfg(any(dual_funding, splicing))] is removed.
+                                               #[cfg(any(dual_funding, splicing))]
                                                ChannelPhase::UnfundedInboundV2(channel) => {
                                                        // Since unfunded inbound channel maps are cleared upon disconnecting a peer,
                                                        // they are not persisted and won't be recovered after a crash.
@@ -10210,7 +10237,7 @@ where
                                                        return;
                                                }
                                        },
-                                       #[cfg(dual_funding)]
+                                       #[cfg(any(dual_funding, splicing))]
                                        Some(ChannelPhase::UnfundedOutboundV2(ref mut chan)) => {
                                                if let Ok(msg) = chan.maybe_handle_error_without_close(self.chain_hash, &self.fee_estimator) {
                                                        peer_state.pending_msg_events.push(events::MessageSendEvent::SendOpenChannelV2 {
@@ -10221,7 +10248,7 @@ where
                                                }
                                        },
                                        None | Some(ChannelPhase::UnfundedInboundV1(_) | ChannelPhase::Funded(_)) => (),
-                                       #[cfg(dual_funding)]
+                                       #[cfg(any(dual_funding, splicing))]
                                        Some(ChannelPhase::UnfundedInboundV2(_)) => (),
                                }
                        }
@@ -10341,8 +10368,12 @@ where
                                        },
                                };
 
+                               let payment_context = PaymentContext::Bolt12Offer(Bolt12OfferContext {
+                                       offer_id: invoice_request.offer_id,
+                                       invoice_request: invoice_request.fields(),
+                               });
                                let payment_paths = match self.create_blinded_payment_paths(
-                                       amount_msats, payment_secret
+                                       amount_msats, payment_secret, payment_context
                                ) {
                                        Ok(payment_paths) => payment_paths,
                                        Err(()) => {
@@ -10356,7 +10387,7 @@ where
                                        self.highest_seen_timestamp.load(Ordering::Acquire) as u64
                                );
 
-                               if invoice_request.keys.is_some() {
+                               let response = if invoice_request.keys.is_some() {
                                        #[cfg(feature = "std")]
                                        let builder = invoice_request.respond_using_derived_keys(
                                                payment_paths, payment_hash
@@ -10365,12 +10396,10 @@ where
                                        let builder = invoice_request.respond_using_derived_keys_no_std(
                                                payment_paths, payment_hash, created_at
                                        );
-                                       let builder: Result<InvoiceBuilder<DerivedSigningPubkey>, _> =
-                                               builder.map(|b| b.into());
-                                       match builder.and_then(|b| b.allow_mpp().build_and_sign(secp_ctx)) {
-                                               Ok(invoice) => Some(OffersMessage::Invoice(invoice)),
-                                               Err(error) => Some(OffersMessage::InvoiceError(error.into())),
-                                       }
+                                       builder
+                                               .map(InvoiceBuilder::<DerivedSigningPubkey>::from)
+                                               .and_then(|builder| builder.allow_mpp().build_and_sign(secp_ctx))
+                                               .map_err(InvoiceError::from)
                                } else {
                                        #[cfg(feature = "std")]
                                        let builder = invoice_request.respond_with(payment_paths, payment_hash);
@@ -10378,47 +10407,46 @@ where
                                        let builder = invoice_request.respond_with_no_std(
                                                payment_paths, payment_hash, created_at
                                        );
-                                       let builder: Result<InvoiceBuilder<ExplicitSigningPubkey>, _> =
-                                               builder.map(|b| b.into());
-                                       let response = builder.and_then(|builder| builder.allow_mpp().build())
-                                               .map_err(|e| OffersMessage::InvoiceError(e.into()))
+                                       builder
+                                               .map(InvoiceBuilder::<ExplicitSigningPubkey>::from)
+                                               .and_then(|builder| builder.allow_mpp().build())
+                                               .map_err(InvoiceError::from)
                                                .and_then(|invoice| {
                                                        #[cfg(c_bindings)]
                                                        let mut invoice = invoice;
-                                                       match invoice.sign(|invoice: &UnsignedBolt12Invoice|
-                                                               self.node_signer.sign_bolt12_invoice(invoice)
-                                                       ) {
-                                                               Ok(invoice) => Ok(OffersMessage::Invoice(invoice)),
-                                                               Err(SignError::Signing) => Err(OffersMessage::InvoiceError(
-                                                                               InvoiceError::from_string("Failed signing invoice".to_string())
-                                                               )),
-                                                               Err(SignError::Verification(_)) => Err(OffersMessage::InvoiceError(
-                                                                               InvoiceError::from_string("Failed invoice signature verification".to_string())
-                                                               )),
-                                                       }
-                                               });
-                                       match response {
-                                               Ok(invoice) => Some(invoice),
-                                               Err(error) => Some(error),
-                                       }
+                                                       invoice
+                                                               .sign(|invoice: &UnsignedBolt12Invoice|
+                                                                       self.node_signer.sign_bolt12_invoice(invoice)
+                                                               )
+                                                               .map_err(InvoiceError::from)
+                                               })
+                               };
+
+                               match response {
+                                       Ok(invoice) => Some(OffersMessage::Invoice(invoice)),
+                                       Err(error) => Some(OffersMessage::InvoiceError(error.into())),
                                }
                        },
                        OffersMessage::Invoice(invoice) => {
-                               match invoice.verify(expanded_key, secp_ctx) {
-                                       Err(()) => {
-                                               Some(OffersMessage::InvoiceError(InvoiceError::from_string("Unrecognized invoice".to_owned())))
-                                       },
-                                       Ok(_) if invoice.invoice_features().requires_unknown_bits_from(&self.bolt12_invoice_features()) => {
-                                               Some(OffersMessage::InvoiceError(Bolt12SemanticError::UnknownRequiredFeatures.into()))
-                                       },
-                                       Ok(payment_id) => {
-                                               if let Err(e) = self.send_payment_for_bolt12_invoice(&invoice, payment_id) {
-                                                       log_trace!(self.logger, "Failed paying invoice: {:?}", e);
-                                                       Some(OffersMessage::InvoiceError(InvoiceError::from_string(format!("{:?}", e))))
+                               let response = invoice
+                                       .verify(expanded_key, secp_ctx)
+                                       .map_err(|()| InvoiceError::from_string("Unrecognized invoice".to_owned()))
+                                       .and_then(|payment_id| {
+                                               let features = self.bolt12_invoice_features();
+                                               if invoice.invoice_features().requires_unknown_bits_from(&features) {
+                                                       Err(InvoiceError::from(Bolt12SemanticError::UnknownRequiredFeatures))
                                                } else {
-                                                       None
+                                                       self.send_payment_for_bolt12_invoice(&invoice, payment_id)
+                                                               .map_err(|e| {
+                                                                       log_trace!(self.logger, "Failed paying invoice: {:?}", e);
+                                                                       InvoiceError::from_string(format!("{:?}", e))
+                                                               })
                                                }
-                                       },
+                                       });
+
+                               match response {
+                                       Ok(()) => None,
+                                       Err(e) => Some(OffersMessage::InvoiceError(e)),
                                }
                        },
                        OffersMessage::InvoiceError(invoice_error) => {
@@ -10433,6 +10461,23 @@ where
        }
 }
 
+impl<M: Deref, T: Deref, ES: Deref, NS: Deref, SP: Deref, F: Deref, R: Deref, L: Deref>
+NodeIdLookUp for ChannelManager<M, T, ES, NS, SP, F, R, L>
+where
+       M::Target: chain::Watch<<SP::Target as SignerProvider>::EcdsaSigner>,
+       T::Target: BroadcasterInterface,
+       ES::Target: EntropySource,
+       NS::Target: NodeSigner,
+       SP::Target: SignerProvider,
+       F::Target: FeeEstimator,
+       R::Target: Router,
+       L::Target: Logger,
+{
+       fn next_node_id(&self, short_channel_id: u64) -> Option<PublicKey> {
+               self.short_to_chan_info.read().unwrap().get(&short_channel_id).map(|(pubkey, _)| *pubkey)
+       }
+}
+
 /// Fetches the set of [`NodeFeatures`] flags that are provided by or required by
 /// [`ChannelManager`].
 pub(crate) fn provided_node_features(config: &UserConfig) -> NodeFeatures {
@@ -10655,6 +10700,7 @@ impl_writeable_tlv_based_enum!(PendingHTLCRouting,
                (3, payment_metadata, option),
                (5, custom_tlvs, optional_vec),
                (7, requires_blinded_error, (default_value, false)),
+               (9, payment_context, option),
        },
        (2, ReceiveKeysend) => {
                (0, payment_preimage, required),
@@ -10769,9 +10815,11 @@ impl_writeable_tlv_based!(HTLCPreviousHopData, {
 
 impl Writeable for ClaimableHTLC {
        fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
-               let (payment_data, keysend_preimage) = match &self.onion_payload {
-                       OnionPayload::Invoice { _legacy_hop_data } => (_legacy_hop_data.as_ref(), None),
-                       OnionPayload::Spontaneous(preimage) => (None, Some(preimage)),
+               let (payment_data, keysend_preimage, payment_context) = match &self.onion_payload {
+                       OnionPayload::Invoice { _legacy_hop_data, payment_context } => {
+                               (_legacy_hop_data.as_ref(), None, payment_context.as_ref())
+                       },
+                       OnionPayload::Spontaneous(preimage) => (None, Some(preimage), None),
                };
                write_tlv_fields!(writer, {
                        (0, self.prev_hop, required),
@@ -10783,6 +10831,7 @@ impl Writeable for ClaimableHTLC {
                        (6, self.cltv_expiry, required),
                        (8, keysend_preimage, option),
                        (10, self.counterparty_skimmed_fee_msat, option),
+                       (11, payment_context, option),
                });
                Ok(())
        }
@@ -10800,6 +10849,7 @@ impl Readable for ClaimableHTLC {
                        (6, cltv_expiry, required),
                        (8, keysend_preimage, option),
                        (10, counterparty_skimmed_fee_msat, option),
+                       (11, payment_context, option),
                });
                let payment_data: Option<msgs::FinalOnionHopData> = payment_data_opt;
                let value = value_ser.0.unwrap();
@@ -10820,7 +10870,7 @@ impl Readable for ClaimableHTLC {
                                        }
                                        total_msat = Some(payment_data.as_ref().unwrap().total_msat);
                                }
-                               OnionPayload::Invoice { _legacy_hop_data: payment_data }
+                               OnionPayload::Invoice { _legacy_hop_data: payment_data, payment_context }
                        },
                };
                Ok(Self {
@@ -11015,9 +11065,10 @@ where
                        best_block.block_hash.write(writer)?;
                }
 
+               let per_peer_state = self.per_peer_state.write().unwrap();
+
                let mut serializable_peer_count: u64 = 0;
                {
-                       let per_peer_state = self.per_peer_state.read().unwrap();
                        let mut number_of_funded_channels = 0;
                        for (_, peer_state_mutex) in per_peer_state.iter() {
                                let mut peer_state_lock = peer_state_mutex.lock().unwrap();
@@ -11064,8 +11115,6 @@ where
                        decode_update_add_htlcs_opt = Some(decode_update_add_htlcs);
                }
 
-               let per_peer_state = self.per_peer_state.write().unwrap();
-
                let pending_inbound_payments = self.pending_inbound_payments.lock().unwrap();
                let claimable_payments = self.claimable_payments.lock().unwrap();
                let pending_outbound_payments = self.pending_outbound_payments.pending_outbound_payments.lock().unwrap();
@@ -12057,9 +12106,9 @@ where
                                        return Err(DecodeError::InvalidValue);
                                }
                                let purpose = match &htlcs[0].onion_payload {
-                                       OnionPayload::Invoice { _legacy_hop_data } => {
+                                       OnionPayload::Invoice { _legacy_hop_data, payment_context: _ } => {
                                                if let Some(hop_data) = _legacy_hop_data {
-                                                       events::PaymentPurpose::InvoicePayment {
+                                                       events::PaymentPurpose::Bolt11InvoicePayment {
                                                                payment_preimage: match pending_inbound_payments.get(&payment_hash) {
                                                                        Some(inbound_payment) => inbound_payment.payment_preimage,
                                                                        None => match inbound_payment::verify(payment_hash, &hop_data, 0, &expanded_inbound_key, &args.logger) {
index 3a506b57fe2a05b572b7f4c34b2ecc1f617b4f38..cf52d946e34e0832e640b9d175dc1e98fff34cf4 100644 (file)
@@ -415,6 +415,7 @@ type TestOnionMessenger<'chan_man, 'node_cfg, 'chan_mon_cfg> = OnionMessenger<
        DedicatedEntropy,
        &'node_cfg test_utils::TestKeysInterface,
        &'chan_mon_cfg test_utils::TestLogger,
+       &'chan_man TestChannelManager<'node_cfg, 'chan_mon_cfg>,
        &'node_cfg test_utils::TestMessageRouter<'chan_mon_cfg>,
        &'chan_man TestChannelManager<'node_cfg, 'chan_mon_cfg>,
        IgnoringMessageHandler,
@@ -2128,7 +2129,15 @@ pub fn check_payment_claimable(
                        assert_eq!(expected_recv_value, *amount_msat);
                        assert_eq!(expected_receiver_node_id, receiver_node_id.unwrap());
                        match purpose {
-                               PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
+                               PaymentPurpose::Bolt11InvoicePayment { payment_preimage, payment_secret, .. } => {
+                                       assert_eq!(&expected_payment_preimage, payment_preimage);
+                                       assert_eq!(expected_payment_secret, *payment_secret);
+                               },
+                               PaymentPurpose::Bolt12OfferPayment { payment_preimage, payment_secret, .. } => {
+                                       assert_eq!(&expected_payment_preimage, payment_preimage);
+                                       assert_eq!(expected_payment_secret, *payment_secret);
+                               },
+                               PaymentPurpose::Bolt12RefundPayment { payment_preimage, payment_secret, .. } => {
                                        assert_eq!(&expected_payment_preimage, payment_preimage);
                                        assert_eq!(expected_payment_secret, *payment_secret);
                                },
@@ -2605,7 +2614,17 @@ pub fn do_pass_along_path<'a, 'b, 'c>(args: PassAlongPathArgs) -> Option<Event>
                                                assert!(onion_fields.is_some());
                                                assert_eq!(onion_fields.as_ref().unwrap().custom_tlvs, custom_tlvs);
                                                match &purpose {
-                                                       PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
+                                                       PaymentPurpose::Bolt11InvoicePayment { payment_preimage, payment_secret, .. } => {
+                                                               assert_eq!(expected_preimage, *payment_preimage);
+                                                               assert_eq!(our_payment_secret.unwrap(), *payment_secret);
+                                                               assert_eq!(Some(*payment_secret), onion_fields.as_ref().unwrap().payment_secret);
+                                                       },
+                                                       PaymentPurpose::Bolt12OfferPayment { payment_preimage, payment_secret, .. } => {
+                                                               assert_eq!(expected_preimage, *payment_preimage);
+                                                               assert_eq!(our_payment_secret.unwrap(), *payment_secret);
+                                                               assert_eq!(Some(*payment_secret), onion_fields.as_ref().unwrap().payment_secret);
+                                                       },
+                                                       PaymentPurpose::Bolt12RefundPayment { payment_preimage, payment_secret, .. } => {
                                                                assert_eq!(expected_preimage, *payment_preimage);
                                                                assert_eq!(our_payment_secret.unwrap(), *payment_secret);
                                                                assert_eq!(Some(*payment_secret), onion_fields.as_ref().unwrap().payment_secret);
@@ -2762,14 +2781,12 @@ pub fn pass_claimed_payment_along_route<'a, 'b, 'c, 'd>(args: ClaimAlongRouteArg
        let mut fwd_amt_msat = 0;
        match claim_event[0] {
                Event::PaymentClaimed {
-                       purpose: PaymentPurpose::SpontaneousPayment(preimage),
+                       purpose: PaymentPurpose::SpontaneousPayment(preimage)
+                               | PaymentPurpose::Bolt11InvoicePayment { payment_preimage: Some(preimage), .. }
+                               | PaymentPurpose::Bolt12OfferPayment { payment_preimage: Some(preimage), .. }
+                               | PaymentPurpose::Bolt12RefundPayment { payment_preimage: Some(preimage), .. },
                        amount_msat,
                        ref htlcs,
-                       .. }
-               | Event::PaymentClaimed {
-                       purpose: PaymentPurpose::InvoicePayment { payment_preimage: Some(preimage), ..},
-                       ref htlcs,
-                       amount_msat,
                        ..
                } => {
                        assert_eq!(preimage, our_payment_preimage);
@@ -2779,7 +2796,9 @@ pub fn pass_claimed_payment_along_route<'a, 'b, 'c, 'd>(args: ClaimAlongRouteArg
                        fwd_amt_msat = amount_msat;
                },
                Event::PaymentClaimed {
-                       purpose: PaymentPurpose::InvoicePayment { .. },
+                       purpose: PaymentPurpose::Bolt11InvoicePayment { .. }
+                               | PaymentPurpose::Bolt12OfferPayment { .. }
+                               | PaymentPurpose::Bolt12RefundPayment { .. },
                        payment_hash,
                        amount_msat,
                        ref htlcs,
@@ -3199,8 +3218,8 @@ pub fn create_network<'a, 'b: 'a, 'c: 'b>(node_count: usize, cfgs: &'b Vec<NodeC
        for i in 0..node_count {
                let dedicated_entropy = DedicatedEntropy(RandomBytes::new([i as u8; 32]));
                let onion_messenger = OnionMessenger::new(
-                       dedicated_entropy, cfgs[i].keys_manager, cfgs[i].logger, &cfgs[i].message_router,
-                       &chan_mgrs[i], IgnoringMessageHandler {},
+                       dedicated_entropy, cfgs[i].keys_manager, cfgs[i].logger, &chan_mgrs[i],
+                       &cfgs[i].message_router, &chan_mgrs[i], IgnoringMessageHandler {},
                );
                let gossip_sync = P2PGossipSync::new(cfgs[i].network_graph.as_ref(), None, cfgs[i].logger);
                let wallet_source = Arc::new(test_utils::TestWalletSource::new(SecretKey::from_slice(&[i as u8 + 1; 32]).unwrap()));
index 5ea3e6372c0d1c497c0b8fcde11961cfdb4206b0..465d6288d9d3e764662edb2276d7bb5bef477cb3 100644 (file)
@@ -17,7 +17,7 @@ use crate::chain::chaininterface::LowerBoundedFeeEstimator;
 use crate::chain::channelmonitor;
 use crate::chain::channelmonitor::{CLOSED_CHANNEL_UPDATE_ID, CLTV_CLAIM_BUFFER, LATENCY_GRACE_PERIOD_BLOCKS, ANTI_REORG_DELAY};
 use crate::chain::transaction::OutPoint;
-use crate::sign::{ecdsa::EcdsaChannelSigner, EntropySource, SignerProvider};
+use crate::sign::{ecdsa::EcdsaChannelSigner, EntropySource, OutputSpender, SignerProvider};
 use crate::events::{Event, MessageSendEvent, MessageSendEventsProvider, PathFailure, PaymentPurpose, ClosureReason, HTLCDestination, PaymentFailureReason};
 use crate::ln::{ChannelId, PaymentPreimage, PaymentSecret, PaymentHash};
 use crate::ln::channel::{commitment_tx_base_weight, COMMITMENT_TX_WEIGHT_PER_HTLC, CONCURRENT_INBOUND_HTLC_FEE_BUFFER, FEE_SPIKE_BUFFER_FEE_INCREASE_MULTIPLE, MIN_AFFORDABLE_HTLC_COUNT, get_holder_selected_channel_reserve_satoshis, OutboundV1Channel, InboundV1Channel, COINBASE_MATURITY, ChannelPhase};
@@ -2039,11 +2039,11 @@ fn test_channel_reserve_holding_cell_htlcs() {
                        assert_eq!(nodes[2].node.get_our_node_id(), receiver_node_id.unwrap());
                        assert_eq!(via_channel_id, Some(chan_2.2));
                        match &purpose {
-                               PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
+                               PaymentPurpose::Bolt11InvoicePayment { payment_preimage, payment_secret, .. } => {
                                        assert!(payment_preimage.is_none());
                                        assert_eq!(our_payment_secret_21, *payment_secret);
                                },
-                               _ => panic!("expected PaymentPurpose::InvoicePayment")
+                               _ => panic!("expected PaymentPurpose::Bolt11InvoicePayment")
                        }
                },
                _ => panic!("Unexpected event"),
@@ -2055,11 +2055,11 @@ fn test_channel_reserve_holding_cell_htlcs() {
                        assert_eq!(nodes[2].node.get_our_node_id(), receiver_node_id.unwrap());
                        assert_eq!(via_channel_id, Some(chan_2.2));
                        match &purpose {
-                               PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
+                               PaymentPurpose::Bolt11InvoicePayment { payment_preimage, payment_secret, .. } => {
                                        assert!(payment_preimage.is_none());
                                        assert_eq!(our_payment_secret_22, *payment_secret);
                                },
-                               _ => panic!("expected PaymentPurpose::InvoicePayment")
+                               _ => panic!("expected PaymentPurpose::Bolt11InvoicePayment")
                        }
                },
                _ => panic!("Unexpected event"),
@@ -3954,11 +3954,11 @@ fn do_test_drop_messages_peer_disconnect(messages_delivered: u8, simulate_broken
                        assert_eq!(receiver_node_id.unwrap(), nodes[1].node.get_our_node_id());
                        assert_eq!(via_channel_id, Some(channel_id));
                        match &purpose {
-                               PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
+                               PaymentPurpose::Bolt11InvoicePayment { payment_preimage, payment_secret, .. } => {
                                        assert!(payment_preimage.is_none());
                                        assert_eq!(payment_secret_1, *payment_secret);
                                },
-                               _ => panic!("expected PaymentPurpose::InvoicePayment")
+                               _ => panic!("expected PaymentPurpose::Bolt11InvoicePayment")
                        }
                },
                _ => panic!("Unexpected event"),
@@ -4319,11 +4319,11 @@ fn test_drop_messages_peer_disconnect_dual_htlc() {
                Event::PaymentClaimable { ref payment_hash, ref purpose, .. } => {
                        assert_eq!(payment_hash_2, *payment_hash);
                        match &purpose {
-                               PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
+                               PaymentPurpose::Bolt11InvoicePayment { payment_preimage, payment_secret, .. } => {
                                        assert!(payment_preimage.is_none());
                                        assert_eq!(payment_secret_2, *payment_secret);
                                },
-                               _ => panic!("expected PaymentPurpose::InvoicePayment")
+                               _ => panic!("expected PaymentPurpose::Bolt11InvoicePayment")
                        }
                },
                _ => panic!("Unexpected event"),
@@ -8388,10 +8388,10 @@ fn test_preimage_storage() {
        match events[0] {
                Event::PaymentClaimable { ref purpose, .. } => {
                        match &purpose {
-                               PaymentPurpose::InvoicePayment { payment_preimage, .. } => {
+                               PaymentPurpose::Bolt11InvoicePayment { payment_preimage, .. } => {
                                        claim_payment(&nodes[0], &[&nodes[1]], payment_preimage.unwrap());
                                },
-                               _ => panic!("expected PaymentPurpose::InvoicePayment")
+                               _ => panic!("expected PaymentPurpose::Bolt11InvoicePayment")
                        }
                },
                _ => panic!("Unexpected event"),
@@ -9950,7 +9950,10 @@ fn do_test_max_dust_htlc_exposure(dust_outbound_balance: bool, exposure_breach_e
        let dust_outbound_htlc_on_holder_tx_msat: u64 = (dust_buffer_feerate * htlc_timeout_tx_weight(&channel_type_features) / 1000 + open_channel.common_fields.dust_limit_satoshis - 1) * 1000;
        let dust_outbound_htlc_on_holder_tx: u64 = max_dust_htlc_exposure_msat / dust_outbound_htlc_on_holder_tx_msat;
 
-       let dust_inbound_htlc_on_holder_tx_msat: u64 = (dust_buffer_feerate * htlc_success_tx_weight(&channel_type_features) / 1000 + open_channel.common_fields.dust_limit_satoshis - 1) * 1000;
+       // Substract 3 sats for multiplier and 2 sats for fixed limit to make sure we are 50% below the dust limit.
+       // This is to make sure we fully use the dust limit. If we don't, we could end up with `dust_ibd_htlc_on_holder_tx` being 1
+       // while `max_dust_htlc_exposure_msat` is not equal to `dust_outbound_htlc_on_holder_tx_msat`.
+       let dust_inbound_htlc_on_holder_tx_msat: u64 = (dust_buffer_feerate * htlc_success_tx_weight(&channel_type_features) / 1000 + open_channel.common_fields.dust_limit_satoshis - if multiplier_dust_limit { 3 } else { 2 }) * 1000;
        let dust_inbound_htlc_on_holder_tx: u64 = max_dust_htlc_exposure_msat / dust_inbound_htlc_on_holder_tx_msat;
 
        let dust_htlc_on_counterparty_tx: u64 = 4;
index 1c6c4c0e9940aa153e6227c14ba85dd8365e1a4b..30615b86d2f6560057276370770446ce2c99881b 100644 (file)
@@ -9,7 +9,7 @@
 
 //! Further functional tests which test blockchain reorganizations.
 
-use crate::sign::{ecdsa::EcdsaChannelSigner, SpendableOutputDescriptor};
+use crate::sign::{ecdsa::EcdsaChannelSigner, OutputSpender, SpendableOutputDescriptor};
 use crate::chain::channelmonitor::{ANTI_REORG_DELAY, LATENCY_GRACE_PERIOD_BLOCKS, Balance};
 use crate::chain::transaction::OutPoint;
 use crate::chain::chaininterface::{LowerBoundedFeeEstimator, compute_feerate_sat_per_1000_weight};
index 8040d8c420984f6dae30fbd986d21afc036e7acb..136ed4d317b35bfdc43618a0109e7c3ed1a729a1 100644 (file)
@@ -543,6 +543,8 @@ pub struct TxSignatures {
        pub tx_hash: Txid,
        /// The list of witnesses
        pub witnesses: Vec<Witness>,
+       /// Optional signature for the shared input -- the previous funding outpoint -- signed by both peers
+       pub funding_outpoint_sig: Option<Signature>,
 }
 
 /// A tx_init_rbf message which initiates a replacement of the transaction after it's been
@@ -1460,10 +1462,13 @@ pub trait ChannelMessageHandler : MessageSendEventsProvider {
 
        // Splicing
        /// Handle an incoming `splice` message from the given peer.
+       #[cfg(splicing)]
        fn handle_splice(&self, their_node_id: &PublicKey, msg: &Splice);
        /// Handle an incoming `splice_ack` message from the given peer.
+       #[cfg(splicing)]
        fn handle_splice_ack(&self, their_node_id: &PublicKey, msg: &SpliceAck);
        /// Handle an incoming `splice_locked` message from the given peer.
+       #[cfg(splicing)]
        fn handle_splice_locked(&self, their_node_id: &PublicKey, msg: &SpliceLocked);
 
        // Interactive channel construction
@@ -1672,7 +1677,7 @@ pub struct FinalOnionHopData {
 
 mod fuzzy_internal_msgs {
        use bitcoin::secp256k1::PublicKey;
-       use crate::blinded_path::payment::{PaymentConstraints, PaymentRelay};
+       use crate::blinded_path::payment::{PaymentConstraints, PaymentContext, PaymentRelay};
        use crate::ln::{PaymentPreimage, PaymentSecret};
        use crate::ln::features::BlindedHopFeatures;
        use super::{FinalOnionHopData, TrampolineOnionPacket};
@@ -1711,6 +1716,7 @@ mod fuzzy_internal_msgs {
                        cltv_expiry_height: u32,
                        payment_secret: PaymentSecret,
                        payment_constraints: PaymentConstraints,
+                       payment_context: PaymentContext,
                        intro_node_blinding_point: Option<PublicKey>,
                        keysend_preimage: Option<PaymentPreimage>,
                        custom_tlvs: Vec<(u64, Vec<u8>)>,
@@ -2115,7 +2121,9 @@ impl_writeable_msg!(TxSignatures, {
        channel_id,
        tx_hash,
        witnesses,
-}, {});
+}, {
+       (0, funding_outpoint_sig, option),
+});
 
 impl_writeable_msg!(TxInitRbf, {
        channel_id,
@@ -2713,7 +2721,7 @@ impl<NS: Deref> ReadableArgs<(Option<PublicKey>, &NS)> for InboundOnionPayload w
                                        })
                                },
                                ChaChaPolyReadAdapter { readable: BlindedPaymentTlvs::Receive(ReceiveTlvs {
-                                       payment_secret, payment_constraints
+                                       payment_secret, payment_constraints, payment_context
                                })} => {
                                        if total_msat.unwrap_or(0) > MAX_VALUE_MSAT { return Err(DecodeError::InvalidValue) }
                                        Ok(Self::BlindedReceive {
@@ -2722,6 +2730,7 @@ impl<NS: Deref> ReadableArgs<(Option<PublicKey>, &NS)> for InboundOnionPayload w
                                                cltv_expiry_height: cltv_value.ok_or(DecodeError::InvalidValue)?,
                                                payment_secret,
                                                payment_constraints,
+                                               payment_context,
                                                intro_node_blinding_point,
                                                keysend_preimage,
                                                custom_tlvs,
@@ -3952,6 +3961,10 @@ mod tests {
 
        #[test]
        fn encoding_tx_signatures() {
+               let secp_ctx = Secp256k1::new();
+               let (privkey_1, _) = get_keys_from!("0101010101010101010101010101010101010101010101010101010101010101", secp_ctx);
+               let sig_1 = get_sig_on!(privkey_1, secp_ctx, String::from("01010101010101010101010101010101"));
+
                let tx_signatures = msgs::TxSignatures {
                        channel_id: ChannelId::from_bytes([2; 32]),
                        tx_hash: Txid::from_str("c2d4449afa8d26140898dd54d3390b057ba2a5afcf03ba29d7dc0d8b9ffe966e").unwrap(),
@@ -3963,6 +3976,7 @@ mod tests {
                                        <Vec<u8>>::from_hex("3045022100ee00dbf4a862463e837d7c08509de814d620e4d9830fa84818713e0fa358f145022021c3c7060c4d53fe84fd165d60208451108a778c13b92ca4c6bad439236126cc01").unwrap(),
                                        <Vec<u8>>::from_hex("028fbbf0b16f5ba5bcb5dd37cd4047ce6f726a21c06682f9ec2f52b057de1dbdb5").unwrap()]),
                        ],
+                       funding_outpoint_sig: Some(sig_1),
                };
                let encoded_value = tx_signatures.encode();
                let mut target_value = <Vec<u8>>::from_hex("0202020202020202020202020202020202020202020202020202020202020202").unwrap(); // channel_id
@@ -3982,6 +3996,8 @@ mod tests {
                target_value.append(&mut <Vec<u8>>::from_hex("3045022100ee00dbf4a862463e837d7c08509de814d620e4d9830fa84818713e0fa358f145022021c3c7060c4d53fe84fd165d60208451108a778c13b92ca4c6bad439236126cc01").unwrap());
                target_value.append(&mut <Vec<u8>>::from_hex("21").unwrap()); // len of witness element data (VarInt)
                target_value.append(&mut <Vec<u8>>::from_hex("028fbbf0b16f5ba5bcb5dd37cd4047ce6f726a21c06682f9ec2f52b057de1dbdb5").unwrap());
+               target_value.append(&mut <Vec<u8>>::from_hex("0040").unwrap()); // type and len (64)
+               target_value.append(&mut <Vec<u8>>::from_hex("d977cb9b53d93a6ff64bb5f1e158b4094b66e798fb12911168a3ccdf80a83096340a6a95da0ae8d9f776528eecdbb747eb6b545495a4319ed5378e35b21e073a").unwrap());
                assert_eq!(encoded_value, target_value);
        }
 
index e75bd2c70e1e9c3d6b68945fa0ea424a694d4f7d..75a2e290f39e824514fdaa3b6333eb20b319663e 100644 (file)
 
 use bitcoin::network::constants::Network;
 use core::time::Duration;
-use crate::blinded_path::BlindedPath;
+use crate::blinded_path::{BlindedPath, IntroductionNode};
+use crate::blinded_path::payment::{Bolt12OfferContext, Bolt12RefundContext, PaymentContext};
 use crate::events::{Event, MessageSendEventsProvider, PaymentPurpose};
 use crate::ln::channelmanager::{PaymentId, RecentPaymentDetails, Retry, self};
+use crate::ln::features::InvoiceRequestFeatures;
 use crate::ln::functional_test_utils::*;
 use crate::ln::msgs::{ChannelMessageHandler, Init, NodeAnnouncement, OnionMessage, OnionMessageHandler, RoutingMessageHandler, SocketAddress, UnsignedGossipMessage, UnsignedNodeAnnouncement};
 use crate::offers::invoice::Bolt12Invoice;
 use crate::offers::invoice_error::InvoiceError;
-use crate::offers::invoice_request::InvoiceRequest;
+use crate::offers::invoice_request::{InvoiceRequest, InvoiceRequestFields};
 use crate::offers::parse::Bolt12SemanticError;
 use crate::onion_message::messenger::PeeledOnion;
 use crate::onion_message::offers::OffersMessage;
@@ -151,16 +153,28 @@ fn route_bolt12_payment<'a, 'b, 'c>(
        do_pass_along_path(args);
 }
 
-fn claim_bolt12_payment<'a, 'b, 'c>(node: &Node<'a, 'b, 'c>, path: &[&Node<'a, 'b, 'c>]) {
+fn claim_bolt12_payment<'a, 'b, 'c>(
+       node: &Node<'a, 'b, 'c>, path: &[&Node<'a, 'b, 'c>], expected_payment_context: PaymentContext
+) {
        let recipient = &path[path.len() - 1];
-       match get_event!(recipient, Event::PaymentClaimable) {
-               Event::PaymentClaimable {
-                       purpose: PaymentPurpose::InvoicePayment {
-                               payment_preimage: Some(payment_preimage), ..
-                       }, ..
-               } => claim_payment(node, path, payment_preimage),
-               _ => panic!(),
+       let payment_purpose = match get_event!(recipient, Event::PaymentClaimable) {
+               Event::PaymentClaimable { purpose, .. } => purpose,
+               _ => panic!("No Event::PaymentClaimable"),
+       };
+       let payment_preimage = match payment_purpose.preimage() {
+               Some(preimage) => preimage,
+               None => panic!("No preimage in Event::PaymentClaimable"),
        };
+       match payment_purpose {
+               PaymentPurpose::Bolt12OfferPayment { payment_context, .. } => {
+                       assert_eq!(PaymentContext::Bolt12Offer(payment_context), expected_payment_context);
+               },
+               PaymentPurpose::Bolt12RefundPayment { payment_context, .. } => {
+                       assert_eq!(PaymentContext::Bolt12Refund(payment_context), expected_payment_context);
+               },
+               _ => panic!("Unexpected payment purpose: {:?}", payment_purpose),
+       }
+       claim_payment(node, path, payment_preimage);
 }
 
 fn extract_invoice_request<'a, 'b, 'c>(
@@ -260,8 +274,8 @@ fn prefers_non_tor_nodes_in_blinded_paths() {
        assert_ne!(offer.signing_pubkey(), bob_id);
        assert!(!offer.paths().is_empty());
        for path in offer.paths() {
-               assert_ne!(path.introduction_node_id, bob_id);
-               assert_ne!(path.introduction_node_id, charlie_id);
+               assert_ne!(path.introduction_node, IntroductionNode::NodeId(bob_id));
+               assert_ne!(path.introduction_node, IntroductionNode::NodeId(charlie_id));
        }
 
        // Use a one-hop blinded path when Bob is announced and all his peers are Tor-only.
@@ -275,7 +289,7 @@ fn prefers_non_tor_nodes_in_blinded_paths() {
        assert_ne!(offer.signing_pubkey(), bob_id);
        assert!(!offer.paths().is_empty());
        for path in offer.paths() {
-               assert_eq!(path.introduction_node_id, bob_id);
+               assert_eq!(path.introduction_node, IntroductionNode::NodeId(bob_id));
        }
 }
 
@@ -325,7 +339,7 @@ fn prefers_more_connected_nodes_in_blinded_paths() {
        assert_ne!(offer.signing_pubkey(), bob_id);
        assert!(!offer.paths().is_empty());
        for path in offer.paths() {
-               assert_eq!(path.introduction_node_id, nodes[4].node.get_our_node_id());
+               assert_eq!(path.introduction_node, IntroductionNode::NodeId(nodes[4].node.get_our_node_id()));
        }
 }
 
@@ -368,13 +382,14 @@ fn creates_and_pays_for_offer_using_two_hop_blinded_path() {
        disconnect_peers(david, &[bob, &nodes[4], &nodes[5]]);
 
        let offer = alice.node
-               .create_offer_builder("coffee".to_string()).unwrap()
+               .create_offer_builder("coffee".to_string())
+               .unwrap()
                .amount_msats(10_000_000)
                .build().unwrap();
        assert_ne!(offer.signing_pubkey(), alice_id);
        assert!(!offer.paths().is_empty());
        for path in offer.paths() {
-               assert_eq!(path.introduction_node_id, bob_id);
+               assert_eq!(path.introduction_node, IntroductionNode::NodeId(bob_id));
        }
 
        let payment_id = PaymentId([1; 32]);
@@ -393,9 +408,19 @@ fn creates_and_pays_for_offer_using_two_hop_blinded_path() {
        alice.onion_messenger.handle_onion_message(&bob_id, &onion_message);
 
        let (invoice_request, reply_path) = extract_invoice_request(alice, &onion_message);
+       let payment_context = PaymentContext::Bolt12Offer(Bolt12OfferContext {
+               offer_id: offer.id(),
+               invoice_request: InvoiceRequestFields {
+                       payer_id: invoice_request.payer_id(),
+                       amount_msats: None,
+                       features: InvoiceRequestFeatures::empty(),
+                       quantity: None,
+                       payer_note_truncated: None,
+               },
+       });
        assert_eq!(invoice_request.amount_msats(), None);
        assert_ne!(invoice_request.payer_id(), david_id);
-       assert_eq!(reply_path.unwrap().introduction_node_id, charlie_id);
+       assert_eq!(reply_path.unwrap().introduction_node, IntroductionNode::NodeId(charlie_id));
 
        let onion_message = alice.onion_messenger.next_onion_message_for_peer(charlie_id).unwrap();
        charlie.onion_messenger.handle_onion_message(&alice_id, &onion_message);
@@ -408,13 +433,13 @@ fn creates_and_pays_for_offer_using_two_hop_blinded_path() {
        assert_ne!(invoice.signing_pubkey(), alice_id);
        assert!(!invoice.payment_paths().is_empty());
        for (_, path) in invoice.payment_paths() {
-               assert_eq!(path.introduction_node_id, bob_id);
+               assert_eq!(path.introduction_node, IntroductionNode::NodeId(bob_id));
        }
 
        route_bolt12_payment(david, &[charlie, bob, alice], &invoice);
        expect_recent_payment!(david, RecentPaymentDetails::Pending, payment_id);
 
-       claim_bolt12_payment(david, &[charlie, bob, alice]);
+       claim_bolt12_payment(david, &[charlie, bob, alice], payment_context);
        expect_recent_payment!(david, RecentPaymentDetails::Fulfilled, payment_id);
 }
 
@@ -469,11 +494,12 @@ fn creates_and_pays_for_refund_using_two_hop_blinded_path() {
        assert_ne!(refund.payer_id(), david_id);
        assert!(!refund.paths().is_empty());
        for path in refund.paths() {
-               assert_eq!(path.introduction_node_id, charlie_id);
+               assert_eq!(path.introduction_node, IntroductionNode::NodeId(charlie_id));
        }
        expect_recent_payment!(david, RecentPaymentDetails::AwaitingInvoice, payment_id);
 
-       alice.node.request_refund_payment(&refund).unwrap();
+       let payment_context = PaymentContext::Bolt12Refund(Bolt12RefundContext {});
+       let expected_invoice = alice.node.request_refund_payment(&refund).unwrap();
 
        connect_peers(alice, charlie);
 
@@ -484,17 +510,19 @@ fn creates_and_pays_for_refund_using_two_hop_blinded_path() {
        david.onion_messenger.handle_onion_message(&charlie_id, &onion_message);
 
        let invoice = extract_invoice(david, &onion_message);
+       assert_eq!(invoice, expected_invoice);
+
        assert_eq!(invoice.amount_msats(), 10_000_000);
        assert_ne!(invoice.signing_pubkey(), alice_id);
        assert!(!invoice.payment_paths().is_empty());
        for (_, path) in invoice.payment_paths() {
-               assert_eq!(path.introduction_node_id, bob_id);
+               assert_eq!(path.introduction_node, IntroductionNode::NodeId(bob_id));
        }
 
        route_bolt12_payment(david, &[charlie, bob, alice], &invoice);
        expect_recent_payment!(david, RecentPaymentDetails::Pending, payment_id);
 
-       claim_bolt12_payment(david, &[charlie, bob, alice]);
+       claim_bolt12_payment(david, &[charlie, bob, alice], payment_context);
        expect_recent_payment!(david, RecentPaymentDetails::Fulfilled, payment_id);
 }
 
@@ -522,7 +550,7 @@ fn creates_and_pays_for_offer_using_one_hop_blinded_path() {
        assert_ne!(offer.signing_pubkey(), alice_id);
        assert!(!offer.paths().is_empty());
        for path in offer.paths() {
-               assert_eq!(path.introduction_node_id, alice_id);
+               assert_eq!(path.introduction_node, IntroductionNode::NodeId(alice_id));
        }
 
        let payment_id = PaymentId([1; 32]);
@@ -533,9 +561,19 @@ fn creates_and_pays_for_offer_using_one_hop_blinded_path() {
        alice.onion_messenger.handle_onion_message(&bob_id, &onion_message);
 
        let (invoice_request, reply_path) = extract_invoice_request(alice, &onion_message);
+       let payment_context = PaymentContext::Bolt12Offer(Bolt12OfferContext {
+               offer_id: offer.id(),
+               invoice_request: InvoiceRequestFields {
+                       payer_id: invoice_request.payer_id(),
+                       amount_msats: None,
+                       features: InvoiceRequestFeatures::empty(),
+                       quantity: None,
+                       payer_note_truncated: None,
+               },
+       });
        assert_eq!(invoice_request.amount_msats(), None);
        assert_ne!(invoice_request.payer_id(), bob_id);
-       assert_eq!(reply_path.unwrap().introduction_node_id, bob_id);
+       assert_eq!(reply_path.unwrap().introduction_node, IntroductionNode::NodeId(bob_id));
 
        let onion_message = alice.onion_messenger.next_onion_message_for_peer(bob_id).unwrap();
        bob.onion_messenger.handle_onion_message(&alice_id, &onion_message);
@@ -545,13 +583,13 @@ fn creates_and_pays_for_offer_using_one_hop_blinded_path() {
        assert_ne!(invoice.signing_pubkey(), alice_id);
        assert!(!invoice.payment_paths().is_empty());
        for (_, path) in invoice.payment_paths() {
-               assert_eq!(path.introduction_node_id, alice_id);
+               assert_eq!(path.introduction_node, IntroductionNode::NodeId(alice_id));
        }
 
        route_bolt12_payment(bob, &[alice], &invoice);
        expect_recent_payment!(bob, RecentPaymentDetails::Pending, payment_id);
 
-       claim_bolt12_payment(bob, &[alice]);
+       claim_bolt12_payment(bob, &[alice], payment_context);
        expect_recent_payment!(bob, RecentPaymentDetails::Fulfilled, payment_id);
 }
 
@@ -585,27 +623,30 @@ fn creates_and_pays_for_refund_using_one_hop_blinded_path() {
        assert_ne!(refund.payer_id(), bob_id);
        assert!(!refund.paths().is_empty());
        for path in refund.paths() {
-               assert_eq!(path.introduction_node_id, bob_id);
+               assert_eq!(path.introduction_node, IntroductionNode::NodeId(bob_id));
        }
        expect_recent_payment!(bob, RecentPaymentDetails::AwaitingInvoice, payment_id);
 
-       alice.node.request_refund_payment(&refund).unwrap();
+       let payment_context = PaymentContext::Bolt12Refund(Bolt12RefundContext {});
+       let expected_invoice = alice.node.request_refund_payment(&refund).unwrap();
 
        let onion_message = alice.onion_messenger.next_onion_message_for_peer(bob_id).unwrap();
        bob.onion_messenger.handle_onion_message(&alice_id, &onion_message);
 
        let invoice = extract_invoice(bob, &onion_message);
+       assert_eq!(invoice, expected_invoice);
+
        assert_eq!(invoice.amount_msats(), 10_000_000);
        assert_ne!(invoice.signing_pubkey(), alice_id);
        assert!(!invoice.payment_paths().is_empty());
        for (_, path) in invoice.payment_paths() {
-               assert_eq!(path.introduction_node_id, alice_id);
+               assert_eq!(path.introduction_node, IntroductionNode::NodeId(alice_id));
        }
 
        route_bolt12_payment(bob, &[alice], &invoice);
        expect_recent_payment!(bob, RecentPaymentDetails::Pending, payment_id);
 
-       claim_bolt12_payment(bob, &[alice]);
+       claim_bolt12_payment(bob, &[alice], payment_context);
        expect_recent_payment!(bob, RecentPaymentDetails::Fulfilled, payment_id);
 }
 
@@ -641,6 +682,18 @@ fn pays_for_offer_without_blinded_paths() {
        let onion_message = bob.onion_messenger.next_onion_message_for_peer(alice_id).unwrap();
        alice.onion_messenger.handle_onion_message(&bob_id, &onion_message);
 
+       let (invoice_request, _) = extract_invoice_request(alice, &onion_message);
+       let payment_context = PaymentContext::Bolt12Offer(Bolt12OfferContext {
+               offer_id: offer.id(),
+               invoice_request: InvoiceRequestFields {
+                       payer_id: invoice_request.payer_id(),
+                       amount_msats: None,
+                       features: InvoiceRequestFeatures::empty(),
+                       quantity: None,
+                       payer_note_truncated: None,
+               },
+       });
+
        let onion_message = alice.onion_messenger.next_onion_message_for_peer(bob_id).unwrap();
        bob.onion_messenger.handle_onion_message(&alice_id, &onion_message);
 
@@ -648,7 +701,7 @@ fn pays_for_offer_without_blinded_paths() {
        route_bolt12_payment(bob, &[alice], &invoice);
        expect_recent_payment!(bob, RecentPaymentDetails::Pending, payment_id);
 
-       claim_bolt12_payment(bob, &[alice]);
+       claim_bolt12_payment(bob, &[alice], payment_context);
        expect_recent_payment!(bob, RecentPaymentDetails::Fulfilled, payment_id);
 }
 
@@ -681,16 +734,19 @@ fn pays_for_refund_without_blinded_paths() {
        assert!(refund.paths().is_empty());
        expect_recent_payment!(bob, RecentPaymentDetails::AwaitingInvoice, payment_id);
 
-       alice.node.request_refund_payment(&refund).unwrap();
+       let payment_context = PaymentContext::Bolt12Refund(Bolt12RefundContext {});
+       let expected_invoice = alice.node.request_refund_payment(&refund).unwrap();
 
        let onion_message = alice.onion_messenger.next_onion_message_for_peer(bob_id).unwrap();
        bob.onion_messenger.handle_onion_message(&alice_id, &onion_message);
 
        let invoice = extract_invoice(bob, &onion_message);
+       assert_eq!(invoice, expected_invoice);
+
        route_bolt12_payment(bob, &[alice], &invoice);
        expect_recent_payment!(bob, RecentPaymentDetails::Pending, payment_id);
 
-       claim_bolt12_payment(bob, &[alice]);
+       claim_bolt12_payment(bob, &[alice], payment_context);
        expect_recent_payment!(bob, RecentPaymentDetails::Fulfilled, payment_id);
 }
 
@@ -1063,12 +1119,13 @@ fn fails_paying_invoice_more_than_once() {
        david.onion_messenger.handle_onion_message(&charlie_id, &onion_message);
 
        // David pays the first invoice
+       let payment_context = PaymentContext::Bolt12Refund(Bolt12RefundContext {});
        let invoice1 = extract_invoice(david, &onion_message);
 
        route_bolt12_payment(david, &[charlie, bob, alice], &invoice1);
        expect_recent_payment!(david, RecentPaymentDetails::Pending, payment_id);
 
-       claim_bolt12_payment(david, &[charlie, bob, alice]);
+       claim_bolt12_payment(david, &[charlie, bob, alice], payment_context);
        expect_recent_payment!(david, RecentPaymentDetails::Fulfilled, payment_id);
 
        disconnect_peers(alice, &[charlie]);
index aa8ee0ce9be6bfe3a46301c4cb5d9c57610f49c5..db8c4cd033708fd36413aaf4b8c30ec05e4075e8 100644 (file)
@@ -131,17 +131,18 @@ pub(super) fn create_recv_pending_htlc_info(
 ) -> Result<PendingHTLCInfo, InboundHTLCErr> {
        let (
                payment_data, keysend_preimage, custom_tlvs, onion_amt_msat, onion_cltv_expiry,
-               payment_metadata, requires_blinded_error
+               payment_metadata, payment_context, requires_blinded_error
        ) = match hop_data {
                msgs::InboundOnionPayload::Receive {
                        payment_data, keysend_preimage, custom_tlvs, sender_intended_htlc_amt_msat,
                        cltv_expiry_height, payment_metadata, ..
                } =>
                        (payment_data, keysend_preimage, custom_tlvs, sender_intended_htlc_amt_msat,
-                        cltv_expiry_height, payment_metadata, false),
+                        cltv_expiry_height, payment_metadata, None, false),
                msgs::InboundOnionPayload::BlindedReceive {
                        sender_intended_htlc_amt_msat, total_msat, cltv_expiry_height, payment_secret,
-                       intro_node_blinding_point, payment_constraints, keysend_preimage, custom_tlvs
+                       intro_node_blinding_point, payment_constraints, payment_context, keysend_preimage,
+                       custom_tlvs
                } => {
                        check_blinded_payment_constraints(
                                sender_intended_htlc_amt_msat, cltv_expiry, &payment_constraints
@@ -155,7 +156,7 @@ pub(super) fn create_recv_pending_htlc_info(
                                })?;
                        let payment_data = msgs::FinalOnionHopData { payment_secret, total_msat };
                        (Some(payment_data), keysend_preimage, custom_tlvs,
-                        sender_intended_htlc_amt_msat, cltv_expiry_height, None,
+                        sender_intended_htlc_amt_msat, cltv_expiry_height, None, Some(payment_context),
                         intro_node_blinding_point.is_none())
                }
                msgs::InboundOnionPayload::Forward { .. } => {
@@ -241,6 +242,7 @@ pub(super) fn create_recv_pending_htlc_info(
                PendingHTLCRouting::Receive {
                        payment_data: data,
                        payment_metadata,
+                       payment_context,
                        incoming_cltv_expiry: onion_cltv_expiry,
                        phantom_shared_secret,
                        custom_tlvs,
index a75120797cafacb1dd86f07baf7afd2842844af2..d1fa58372dd93dbe7e86c2af853d6b55f150d508 100644 (file)
@@ -2142,9 +2142,11 @@ fn do_accept_underpaying_htlcs_config(num_mpp_parts: usize) {
                        assert_eq!(skimmed_fee_msat * num_mpp_parts as u64, counterparty_skimmed_fee_msat);
                        assert_eq!(nodes[2].node.get_our_node_id(), receiver_node_id.unwrap());
                        match purpose {
-                               crate::events::PaymentPurpose::InvoicePayment { payment_preimage: ev_payment_preimage,
-                                       payment_secret: ev_payment_secret, .. } =>
-                               {
+                               crate::events::PaymentPurpose::Bolt11InvoicePayment {
+                                       payment_preimage: ev_payment_preimage,
+                                       payment_secret: ev_payment_secret,
+                                       ..
+                               } => {
                                        assert_eq!(payment_preimage, ev_payment_preimage.unwrap());
                                        assert_eq!(payment_secret, *ev_payment_secret);
                                },
index 9c27a23467ce7a8e7134225f352fc961a3998ace..9ce230861456927f53bf4a2b65dd33e800e17e5d 100644 (file)
@@ -248,12 +248,15 @@ impl ChannelMessageHandler for ErroringMessageHandler {
        fn handle_stfu(&self, their_node_id: &PublicKey, msg: &msgs::Stfu) {
                ErroringMessageHandler::push_error(&self, their_node_id, msg.channel_id);
        }
+       #[cfg(splicing)]
        fn handle_splice(&self, their_node_id: &PublicKey, msg: &msgs::Splice) {
                ErroringMessageHandler::push_error(&self, their_node_id, msg.channel_id);
        }
+       #[cfg(splicing)]
        fn handle_splice_ack(&self, their_node_id: &PublicKey, msg: &msgs::SpliceAck) {
                ErroringMessageHandler::push_error(&self, their_node_id, msg.channel_id);
        }
+       #[cfg(splicing)]
        fn handle_splice_locked(&self, their_node_id: &PublicKey, msg: &msgs::SpliceLocked) {
                ErroringMessageHandler::push_error(&self, their_node_id, msg.channel_id);
        }
@@ -1475,7 +1478,6 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                                let networks = self.message_handler.chan_handler.get_chain_hashes();
                                                                let resp = msgs::Init { features, networks, remote_network_address: filter_addresses(peer.their_socket_address.clone()) };
                                                                self.enqueue_message(peer, &resp);
-                                                               peer.awaiting_pong_timer_tick_intervals = 0;
                                                        },
                                                        NextNoiseStep::ActThree => {
                                                                let their_node_id = try_potential_handleerror!(peer,
@@ -1488,7 +1490,6 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                                let networks = self.message_handler.chan_handler.get_chain_hashes();
                                                                let resp = msgs::Init { features, networks, remote_network_address: filter_addresses(peer.their_socket_address.clone()) };
                                                                self.enqueue_message(peer, &resp);
-                                                               peer.awaiting_pong_timer_tick_intervals = 0;
                                                        },
                                                        NextNoiseStep::NoiseComplete => {
                                                                if peer.pending_read_is_header {
@@ -1681,6 +1682,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                return Err(PeerHandleError { }.into());
                        }
 
+                       peer_lock.awaiting_pong_timer_tick_intervals = 0;
                        peer_lock.their_features = Some(msg.features);
                        return Ok(None);
                } else if peer_lock.their_features.is_none() {
@@ -1785,13 +1787,16 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                self.message_handler.chan_handler.handle_stfu(&their_node_id, &msg);
                        }
 
+                       #[cfg(splicing)]
                        // Splicing messages:
                        wire::Message::Splice(msg) => {
                                self.message_handler.chan_handler.handle_splice(&their_node_id, &msg);
                        }
+                       #[cfg(splicing)]
                        wire::Message::SpliceAck(msg) => {
                                self.message_handler.chan_handler.handle_splice_ack(&their_node_id, &msg);
                        }
+                       #[cfg(splicing)]
                        wire::Message::SpliceLocked(msg) => {
                                self.message_handler.chan_handler.handle_splice_locked(&their_node_id, &msg);
                        }
@@ -2674,7 +2679,7 @@ mod tests {
        use crate::ln::ChannelId;
        use crate::ln::features::{InitFeatures, NodeFeatures};
        use crate::ln::peer_channel_encryptor::PeerChannelEncryptor;
-       use crate::ln::peer_handler::{CustomMessageHandler, PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler, filter_addresses};
+       use crate::ln::peer_handler::{CustomMessageHandler, PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler, filter_addresses, ErroringMessageHandler, MAX_BUFFER_DRAIN_TICK_INTERVALS_PER_PEER};
        use crate::ln::{msgs, wire};
        use crate::ln::msgs::{LightningError, SocketAddress};
        use crate::util::test_utils;
@@ -3216,6 +3221,105 @@ mod tests {
                assert!(peers[0].read_event(&mut fd_a, &b_data).is_err());
        }
 
+       #[test]
+       fn test_inbound_conn_handshake_complete_awaiting_pong() {
+               // Test that we do not disconnect an outbound peer after the noise handshake completes due
+               // to a pong timeout for a ping that was never sent if a timer tick fires after we send act
+               // two of the noise handshake along with our init message but before we receive their init
+               // message.
+               let logger = test_utils::TestLogger::new();
+               let node_signer_a = test_utils::TestNodeSigner::new(SecretKey::from_slice(&[42; 32]).unwrap());
+               let node_signer_b = test_utils::TestNodeSigner::new(SecretKey::from_slice(&[43; 32]).unwrap());
+               let peer_a = PeerManager::new(MessageHandler {
+                       chan_handler: ErroringMessageHandler::new(),
+                       route_handler: IgnoringMessageHandler {},
+                       onion_message_handler: IgnoringMessageHandler {},
+                       custom_message_handler: IgnoringMessageHandler {},
+               }, 0, &[0; 32], &logger, &node_signer_a);
+               let peer_b = PeerManager::new(MessageHandler {
+                       chan_handler: ErroringMessageHandler::new(),
+                       route_handler: IgnoringMessageHandler {},
+                       onion_message_handler: IgnoringMessageHandler {},
+                       custom_message_handler: IgnoringMessageHandler {},
+               }, 0, &[1; 32], &logger, &node_signer_b);
+
+               let a_id = node_signer_a.get_node_id(Recipient::Node).unwrap();
+               let mut fd_a = FileDescriptor {
+                       fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
+               let mut fd_b = FileDescriptor {
+                       fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
+
+               // Exchange messages with both peers until they both complete the init handshake.
+               let act_one = peer_b.new_outbound_connection(a_id, fd_b.clone(), None).unwrap();
+               peer_a.new_inbound_connection(fd_a.clone(), None).unwrap();
+
+               assert_eq!(peer_a.read_event(&mut fd_a, &act_one).unwrap(), false);
+               peer_a.process_events();
+
+               let act_two = fd_a.outbound_data.lock().unwrap().split_off(0);
+               assert_eq!(peer_b.read_event(&mut fd_b, &act_two).unwrap(), false);
+               peer_b.process_events();
+
+               // Calling this here triggers the race on inbound connections.
+               peer_b.timer_tick_occurred();
+
+               let act_three_with_init_b = fd_b.outbound_data.lock().unwrap().split_off(0);
+               assert!(!peer_a.peers.read().unwrap().get(&fd_a).unwrap().lock().unwrap().handshake_complete());
+               assert_eq!(peer_a.read_event(&mut fd_a, &act_three_with_init_b).unwrap(), false);
+               peer_a.process_events();
+               assert!(peer_a.peers.read().unwrap().get(&fd_a).unwrap().lock().unwrap().handshake_complete());
+
+               let init_a = fd_a.outbound_data.lock().unwrap().split_off(0);
+               assert!(!init_a.is_empty());
+
+               assert!(!peer_b.peers.read().unwrap().get(&fd_b).unwrap().lock().unwrap().handshake_complete());
+               assert_eq!(peer_b.read_event(&mut fd_b, &init_a).unwrap(), false);
+               peer_b.process_events();
+               assert!(peer_b.peers.read().unwrap().get(&fd_b).unwrap().lock().unwrap().handshake_complete());
+
+               // Make sure we're still connected.
+               assert_eq!(peer_b.peers.read().unwrap().len(), 1);
+
+               // B should send a ping on the first timer tick after `handshake_complete`.
+               assert!(fd_b.outbound_data.lock().unwrap().split_off(0).is_empty());
+               peer_b.timer_tick_occurred();
+               peer_b.process_events();
+               assert!(!fd_b.outbound_data.lock().unwrap().split_off(0).is_empty());
+
+               let mut send_warning = || {
+                       {
+                               let peers = peer_a.peers.read().unwrap();
+                               let mut peer_b = peers.get(&fd_a).unwrap().lock().unwrap();
+                               peer_a.enqueue_message(&mut peer_b, &msgs::WarningMessage {
+                                       channel_id: ChannelId([0; 32]),
+                                       data: "no disconnect plz".to_string(),
+                               });
+                       }
+                       peer_a.process_events();
+                       let msg = fd_a.outbound_data.lock().unwrap().split_off(0);
+                       assert!(!msg.is_empty());
+                       assert_eq!(peer_b.read_event(&mut fd_b, &msg).unwrap(), false);
+                       peer_b.process_events();
+               };
+
+               // Fire more ticks until we reach the pong timeout. We send any message except pong to
+               // pretend the connection is still alive.
+               send_warning();
+               for _ in 0..MAX_BUFFER_DRAIN_TICK_INTERVALS_PER_PEER {
+                       peer_b.timer_tick_occurred();
+                       send_warning();
+               }
+               assert_eq!(peer_b.peers.read().unwrap().len(), 1);
+
+               // One more tick should enforce the pong timeout.
+               peer_b.timer_tick_occurred();
+               assert_eq!(peer_b.peers.read().unwrap().len(), 0);
+       }
+
        #[test]
        fn test_filter_addresses(){
                // Tests the filter_addresses function.
index 62c82b01f59d7a36edd81cefb1cd45d407788d85..c15365629015cb3b54600efab1486532c3a888fc 100644 (file)
@@ -15,6 +15,7 @@ use crate::chain::transaction::OutPoint;
 use crate::chain::Confirm;
 use crate::events::{Event, MessageSendEventsProvider, ClosureReason, HTLCDestination, MessageSendEvent};
 use crate::ln::msgs::{ChannelMessageHandler, Init};
+use crate::sign::OutputSpender;
 use crate::util::test_utils;
 use crate::util::ser::Writeable;
 use crate::util::string::UntrustedString;
index dc4eecde6306748e972657423049eef17b879929..55e31399ae10074d2e72f1b04a4c1c71ed1808df 100644 (file)
@@ -60,8 +60,11 @@ pub(crate) enum Message<T> where T: core::fmt::Debug + Type + TestEq {
        FundingCreated(msgs::FundingCreated),
        FundingSigned(msgs::FundingSigned),
        Stfu(msgs::Stfu),
+       #[cfg(splicing)]
        Splice(msgs::Splice),
+       #[cfg(splicing)]
        SpliceAck(msgs::SpliceAck),
+       #[cfg(splicing)]
        SpliceLocked(msgs::SpliceLocked),
        TxAddInput(msgs::TxAddInput),
        TxAddOutput(msgs::TxAddOutput),
@@ -115,8 +118,11 @@ impl<T> Writeable for Message<T> where T: core::fmt::Debug + Type + TestEq {
                        &Message::FundingCreated(ref msg) => msg.write(writer),
                        &Message::FundingSigned(ref msg) => msg.write(writer),
                        &Message::Stfu(ref msg) => msg.write(writer),
+                       #[cfg(splicing)]
                        &Message::Splice(ref msg) => msg.write(writer),
+                       #[cfg(splicing)]
                        &Message::SpliceAck(ref msg) => msg.write(writer),
+                       #[cfg(splicing)]
                        &Message::SpliceLocked(ref msg) => msg.write(writer),
                        &Message::TxAddInput(ref msg) => msg.write(writer),
                        &Message::TxAddOutput(ref msg) => msg.write(writer),
@@ -170,8 +176,11 @@ impl<T> Type for Message<T> where T: core::fmt::Debug + Type + TestEq {
                        &Message::FundingCreated(ref msg) => msg.type_id(),
                        &Message::FundingSigned(ref msg) => msg.type_id(),
                        &Message::Stfu(ref msg) => msg.type_id(),
+                       #[cfg(splicing)]
                        &Message::Splice(ref msg) => msg.type_id(),
+                       #[cfg(splicing)]
                        &Message::SpliceAck(ref msg) => msg.type_id(),
+                       #[cfg(splicing)]
                        &Message::SpliceLocked(ref msg) => msg.type_id(),
                        &Message::TxAddInput(ref msg) => msg.type_id(),
                        &Message::TxAddOutput(ref msg) => msg.type_id(),
@@ -270,15 +279,18 @@ fn do_read<R: io::Read, T, H: core::ops::Deref>(buffer: &mut R, message_type: u1
                msgs::FundingSigned::TYPE => {
                        Ok(Message::FundingSigned(Readable::read(buffer)?))
                },
+               #[cfg(splicing)]
                msgs::Splice::TYPE => {
                        Ok(Message::Splice(Readable::read(buffer)?))
                },
                msgs::Stfu::TYPE => {
                        Ok(Message::Stfu(Readable::read(buffer)?))
                },
+               #[cfg(splicing)]
                msgs::SpliceAck::TYPE => {
                        Ok(Message::SpliceAck(Readable::read(buffer)?))
                },
+               #[cfg(splicing)]
                msgs::SpliceLocked::TYPE => {
                        Ok(Message::SpliceLocked(Readable::read(buffer)?))
                },
index f2fb387942dc303e2c895e4e4917777a07aa6492..ee5e6deb4086bb69992015a053efdffab89c8816 100644 (file)
 
 use bitcoin::blockdata::constants::ChainHash;
 use bitcoin::hash_types::{WPubkeyHash, WScriptHash};
-use bitcoin::hashes::Hash;
 use bitcoin::network::constants::Network;
 use bitcoin::secp256k1::{KeyPair, PublicKey, Secp256k1, self};
 use bitcoin::secp256k1::schnorr::Signature;
 use bitcoin::address::{Address, Payload, WitnessProgram, WitnessVersion};
 use bitcoin::key::TweakedPublicKey;
 use core::time::Duration;
+use core::hash::{Hash, Hasher};
 use crate::io;
 use crate::blinded_path::BlindedPath;
 use crate::ln::PaymentHash;
@@ -390,6 +390,7 @@ macro_rules! invoice_builder_methods { (
        /// Successive calls to this method will add another address. Caller is responsible for not
        /// adding duplicate addresses and only calling if capable of receiving to P2WSH addresses.
        pub fn fallback_v0_p2wsh($($self_mut)* $self: $self_type, script_hash: &WScriptHash) -> $return_type {
+               use bitcoin::hashes::Hash;
                let address = FallbackAddress {
                        version: WitnessVersion::V0.to_num(),
                        program: Vec::from(script_hash.to_byte_array()),
@@ -403,6 +404,7 @@ macro_rules! invoice_builder_methods { (
        /// Successive calls to this method will add another address. Caller is responsible for not
        /// adding duplicate addresses and only calling if capable of receiving to P2WPKH addresses.
        pub fn fallback_v0_p2wpkh($($self_mut)* $self: $self_type, pubkey_hash: &WPubkeyHash) -> $return_type {
+               use bitcoin::hashes::Hash;
                let address = FallbackAddress {
                        version: WitnessVersion::V0.to_num(),
                        program: Vec::from(pubkey_hash.to_byte_array()),
@@ -544,7 +546,7 @@ impl UnsignedBolt12Invoice {
                let mut bytes = Vec::new();
                unsigned_tlv_stream.write(&mut bytes).unwrap();
 
-               let tagged_hash = TaggedHash::new(SIGNATURE_TAG, &bytes);
+               let tagged_hash = TaggedHash::from_valid_tlv_stream_bytes(SIGNATURE_TAG, &bytes);
 
                Self { bytes, contents, tagged_hash }
        }
@@ -614,7 +616,6 @@ impl AsRef<TaggedHash> for UnsignedBolt12Invoice {
 /// [`Refund`]: crate::offers::refund::Refund
 /// [`InvoiceRequest`]: crate::offers::invoice_request::InvoiceRequest
 #[derive(Clone, Debug)]
-#[cfg_attr(test, derive(PartialEq))]
 pub struct Bolt12Invoice {
        bytes: Vec<u8>,
        contents: InvoiceContents,
@@ -886,6 +887,20 @@ impl Bolt12Invoice {
        }
 }
 
+impl PartialEq for Bolt12Invoice {
+       fn eq(&self, other: &Self) -> bool {
+               self.bytes.eq(&other.bytes)
+       }
+}
+
+impl Eq for Bolt12Invoice {}
+
+impl Hash for Bolt12Invoice {
+       fn hash<H: Hasher>(&self, state: &mut H) {
+               self.bytes.hash(state);
+       }
+}
+
 impl InvoiceContents {
        /// Whether the original offer or refund has expired.
        #[cfg(feature = "std")]
@@ -1210,7 +1225,7 @@ impl TryFrom<Vec<u8>> for UnsignedBolt12Invoice {
                        (payer_tlv_stream, offer_tlv_stream, invoice_request_tlv_stream, invoice_tlv_stream)
                )?;
 
-               let tagged_hash = TaggedHash::new(SIGNATURE_TAG, &bytes);
+               let tagged_hash = TaggedHash::from_valid_tlv_stream_bytes(SIGNATURE_TAG, &bytes);
 
                Ok(UnsignedBolt12Invoice { bytes, contents, tagged_hash })
        }
@@ -1355,7 +1370,7 @@ impl TryFrom<ParsedMessage<FullInvoiceTlvStream>> for Bolt12Invoice {
                        None => return Err(Bolt12ParseError::InvalidSemantics(Bolt12SemanticError::MissingSignature)),
                        Some(signature) => signature,
                };
-               let tagged_hash = TaggedHash::new(SIGNATURE_TAG, &bytes);
+               let tagged_hash = TaggedHash::from_valid_tlv_stream_bytes(SIGNATURE_TAG, &bytes);
                let pubkey = contents.fields().signing_pubkey;
                merkle::verify_signature(&signature, &tagged_hash, pubkey)?;
 
@@ -1455,7 +1470,7 @@ mod tests {
 
        use core::time::Duration;
 
-       use crate::blinded_path::{BlindedHop, BlindedPath};
+       use crate::blinded_path::{BlindedHop, BlindedPath, IntroductionNode};
        use crate::sign::KeyMaterial;
        use crate::ln::features::{Bolt12InvoiceFeatures, InvoiceRequestFeatures, OfferFeatures};
        use crate::ln::inbound_payment::ExpandedKey;
@@ -1586,7 +1601,7 @@ mod tests {
                assert_eq!(invoice.invoice_features(), &Bolt12InvoiceFeatures::empty());
                assert_eq!(invoice.signing_pubkey(), recipient_pubkey());
 
-               let message = TaggedHash::new(SIGNATURE_TAG, &invoice.bytes);
+               let message = TaggedHash::from_valid_tlv_stream_bytes(SIGNATURE_TAG, &invoice.bytes);
                assert!(merkle::verify_signature(&invoice.signature, &message, recipient_pubkey()).is_ok());
 
                let digest = Message::from_slice(&invoice.signable_hash()).unwrap();
@@ -1683,7 +1698,7 @@ mod tests {
                assert_eq!(invoice.invoice_features(), &Bolt12InvoiceFeatures::empty());
                assert_eq!(invoice.signing_pubkey(), recipient_pubkey());
 
-               let message = TaggedHash::new(SIGNATURE_TAG, &invoice.bytes);
+               let message = TaggedHash::from_valid_tlv_stream_bytes(SIGNATURE_TAG, &invoice.bytes);
                assert!(merkle::verify_signature(&invoice.signature, &message, recipient_pubkey()).is_ok());
 
                assert_eq!(
@@ -1804,7 +1819,7 @@ mod tests {
                let secp_ctx = Secp256k1::new();
 
                let blinded_path = BlindedPath {
-                       introduction_node_id: pubkey(40),
+                       introduction_node: IntroductionNode::NodeId(pubkey(40)),
                        blinding_point: pubkey(41),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: pubkey(42), encrypted_payload: vec![0; 43] },
index 1b634525416acef1a249858a7d78cc44b3c03b06..5ae5f457a84822172d970035286a562be0d67973 100644 (file)
@@ -11,6 +11,7 @@
 
 use crate::io;
 use crate::ln::msgs::DecodeError;
+use crate::offers::merkle::SignError;
 use crate::offers::parse::Bolt12SemanticError;
 use crate::util::ser::{HighZeroBytesDroppedBigSize, Readable, WithoutLength, Writeable, Writer};
 use crate::util::string::UntrustedString;
@@ -113,6 +114,19 @@ impl From<Bolt12SemanticError> for InvoiceError {
        }
 }
 
+impl From<SignError> for InvoiceError {
+       fn from(error: SignError) -> Self {
+               let message = match error {
+                       SignError::Signing => "Failed signing invoice",
+                       SignError::Verification(_) => "Failed invoice signature verification",
+               };
+               InvoiceError {
+                       erroneous_field: None,
+                       message: UntrustedString(message.to_string()),
+               }
+       }
+}
+
 #[cfg(test)]
 mod tests {
        use super::{ErroneousField, InvoiceError};
index 5e3ed40ac08db2b2a8a662c357bb4eb475fb45fc..9157613fcd977a132e9f7b5751f196af0592269a 100644 (file)
@@ -72,12 +72,12 @@ use crate::ln::inbound_payment::{ExpandedKey, IV_LEN, Nonce};
 use crate::ln::msgs::DecodeError;
 use crate::offers::invoice::BlindedPayInfo;
 use crate::offers::merkle::{SignError, SignFn, SignatureTlvStream, SignatureTlvStreamRef, TaggedHash, self};
-use crate::offers::offer::{Offer, OfferContents, OfferTlvStream, OfferTlvStreamRef};
+use crate::offers::offer::{Offer, OfferContents, OfferId, OfferTlvStream, OfferTlvStreamRef};
 use crate::offers::parse::{Bolt12ParseError, ParsedMessage, Bolt12SemanticError};
 use crate::offers::payer::{PayerContents, PayerTlvStream, PayerTlvStreamRef};
 use crate::offers::signer::{Metadata, MetadataMaterial};
-use crate::util::ser::{HighZeroBytesDroppedBigSize, SeekReadable, WithoutLength, Writeable, Writer};
-use crate::util::string::PrintableString;
+use crate::util::ser::{HighZeroBytesDroppedBigSize, Readable, SeekReadable, WithoutLength, Writeable, Writer};
+use crate::util::string::{PrintableString, UntrustedString};
 
 #[cfg(not(c_bindings))]
 use {
@@ -529,7 +529,7 @@ impl UnsignedInvoiceRequest {
                let mut bytes = Vec::new();
                unsigned_tlv_stream.write(&mut bytes).unwrap();
 
-               let tagged_hash = TaggedHash::new(SIGNATURE_TAG, &bytes);
+               let tagged_hash = TaggedHash::from_valid_tlv_stream_bytes(SIGNATURE_TAG, &bytes);
 
                Self { bytes, contents, tagged_hash }
        }
@@ -607,6 +607,9 @@ pub struct InvoiceRequest {
 /// ways to respond depending on whether the signing keys were derived.
 #[derive(Clone, Debug)]
 pub struct VerifiedInvoiceRequest {
+       /// The identifier of the [`Offer`] for which the [`InvoiceRequest`] was made.
+       pub offer_id: OfferId,
+
        /// The verified request.
        inner: InvoiceRequest,
 
@@ -764,8 +767,9 @@ macro_rules! invoice_request_verify_method { ($self: ident, $self_type: ty) => {
                #[cfg(c_bindings)]
                secp_ctx: &Secp256k1<secp256k1::All>,
        ) -> Result<VerifiedInvoiceRequest, ()> {
-               let keys = $self.contents.inner.offer.verify(&$self.bytes, key, secp_ctx)?;
+               let (offer_id, keys) = $self.contents.inner.offer.verify(&$self.bytes, key, secp_ctx)?;
                Ok(VerifiedInvoiceRequest {
+                       offer_id,
                        #[cfg(not(c_bindings))]
                        inner: $self,
                        #[cfg(c_bindings)]
@@ -868,6 +872,24 @@ impl VerifiedInvoiceRequest {
        invoice_request_respond_with_derived_signing_pubkey_methods!(self, self.inner, InvoiceBuilder<DerivedSigningPubkey>);
        #[cfg(c_bindings)]
        invoice_request_respond_with_derived_signing_pubkey_methods!(self, self.inner, InvoiceWithDerivedSigningPubkeyBuilder);
+
+       pub(crate) fn fields(&self) -> InvoiceRequestFields {
+               let InvoiceRequestContents {
+                       payer_id,
+                       inner: InvoiceRequestContentsWithoutPayerId {
+                               payer: _, offer: _, chain: _, amount_msats, features, quantity, payer_note
+                       },
+               } = &self.inner.contents;
+
+               InvoiceRequestFields {
+                       payer_id: *payer_id,
+                       amount_msats: *amount_msats,
+                       features: features.clone(),
+                       quantity: *quantity,
+                       payer_note_truncated: payer_note.clone()
+                               .map(|mut s| { s.truncate(PAYER_NOTE_LIMIT); UntrustedString(s) }),
+               }
+       }
 }
 
 impl InvoiceRequestContents {
@@ -1022,7 +1044,7 @@ impl TryFrom<Vec<u8>> for UnsignedInvoiceRequest {
                        (payer_tlv_stream, offer_tlv_stream, invoice_request_tlv_stream)
                )?;
 
-               let tagged_hash = TaggedHash::new(SIGNATURE_TAG, &bytes);
+               let tagged_hash = TaggedHash::from_valid_tlv_stream_bytes(SIGNATURE_TAG, &bytes);
 
                Ok(UnsignedInvoiceRequest { bytes, contents, tagged_hash })
        }
@@ -1046,7 +1068,7 @@ impl TryFrom<Vec<u8>> for InvoiceRequest {
                        None => return Err(Bolt12ParseError::InvalidSemantics(Bolt12SemanticError::MissingSignature)),
                        Some(signature) => signature,
                };
-               let message = TaggedHash::new(SIGNATURE_TAG, &bytes);
+               let message = TaggedHash::from_valid_tlv_stream_bytes(SIGNATURE_TAG, &bytes);
                merkle::verify_signature(&signature, &message, contents.payer_id)?;
 
                Ok(InvoiceRequest { bytes, contents, signature })
@@ -1096,9 +1118,68 @@ impl TryFrom<PartialInvoiceRequestTlvStream> for InvoiceRequestContents {
        }
 }
 
+/// Fields sent in an [`InvoiceRequest`] message to include in [`PaymentContext::Bolt12Offer`].
+///
+/// [`PaymentContext::Bolt12Offer`]: crate::blinded_path::payment::PaymentContext::Bolt12Offer
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct InvoiceRequestFields {
+       /// A possibly transient pubkey used to sign the invoice request.
+       pub payer_id: PublicKey,
+
+       /// The amount to pay in msats (i.e., the minimum lightning-payable unit for [`chain`]), which
+       /// must be greater than or equal to [`Offer::amount`], converted if necessary.
+       ///
+       /// [`chain`]: InvoiceRequest::chain
+       pub amount_msats: Option<u64>,
+
+       /// Features pertaining to requesting an invoice.
+       pub features: InvoiceRequestFeatures,
+
+       /// The quantity of the offer's item conforming to [`Offer::is_valid_quantity`].
+       pub quantity: Option<u64>,
+
+       /// A payer-provided note which will be seen by the recipient and reflected back in the invoice
+       /// response. Truncated to [`PAYER_NOTE_LIMIT`] characters.
+       pub payer_note_truncated: Option<UntrustedString>,
+}
+
+/// The maximum number of characters included in [`InvoiceRequestFields::payer_note_truncated`].
+pub const PAYER_NOTE_LIMIT: usize = 512;
+
+impl Writeable for InvoiceRequestFields {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
+               write_tlv_fields!(writer, {
+                       (0, self.payer_id, required),
+                       (2, self.amount_msats.map(|v| HighZeroBytesDroppedBigSize(v)), option),
+                       (4, WithoutLength(&self.features), required),
+                       (6, self.quantity.map(|v| HighZeroBytesDroppedBigSize(v)), option),
+                       (8, self.payer_note_truncated.as_ref().map(|s| WithoutLength(&s.0)), option),
+               });
+               Ok(())
+       }
+}
+
+impl Readable for InvoiceRequestFields {
+       fn read<R: io::Read>(reader: &mut R) -> Result<Self, DecodeError> {
+               _init_and_read_len_prefixed_tlv_fields!(reader, {
+                       (0, payer_id, required),
+                       (2, amount_msats, (option, encoding: (u64, HighZeroBytesDroppedBigSize))),
+                       (4, features, (option, encoding: (InvoiceRequestFeatures, WithoutLength))),
+                       (6, quantity, (option, encoding: (u64, HighZeroBytesDroppedBigSize))),
+                       (8, payer_note_truncated, (option, encoding: (String, WithoutLength))),
+               });
+               let features = features.unwrap_or(InvoiceRequestFeatures::empty());
+
+               Ok(InvoiceRequestFields {
+                       payer_id: payer_id.0.unwrap(), amount_msats, features, quantity,
+                       payer_note_truncated: payer_note_truncated.map(|s| UntrustedString(s)),
+               })
+       }
+}
+
 #[cfg(test)]
 mod tests {
-       use super::{InvoiceRequest, InvoiceRequestTlvStreamRef, SIGNATURE_TAG, UnsignedInvoiceRequest};
+       use super::{InvoiceRequest, InvoiceRequestFields, InvoiceRequestTlvStreamRef, PAYER_NOTE_LIMIT, SIGNATURE_TAG, UnsignedInvoiceRequest};
 
        use bitcoin::blockdata::constants::ChainHash;
        use bitcoin::network::constants::Network;
@@ -1125,8 +1206,8 @@ mod tests {
        use crate::offers::parse::{Bolt12ParseError, Bolt12SemanticError};
        use crate::offers::payer::PayerTlvStreamRef;
        use crate::offers::test_utils::*;
-       use crate::util::ser::{BigSize, Writeable};
-       use crate::util::string::PrintableString;
+       use crate::util::ser::{BigSize, Readable, Writeable};
+       use crate::util::string::{PrintableString, UntrustedString};
 
        #[test]
        fn builds_invoice_request_with_defaults() {
@@ -1192,7 +1273,7 @@ mod tests {
                assert_eq!(invoice_request.payer_id(), payer_pubkey());
                assert_eq!(invoice_request.payer_note(), None);
 
-               let message = TaggedHash::new(SIGNATURE_TAG, &invoice_request.bytes);
+               let message = TaggedHash::from_valid_tlv_stream_bytes(SIGNATURE_TAG, &invoice_request.bytes);
                assert!(merkle::verify_signature(&invoice_request.signature, &message, payer_pubkey()).is_ok());
 
                assert_eq!(
@@ -1297,7 +1378,7 @@ mod tests {
                let mut bytes = Vec::new();
                tlv_stream.write(&mut bytes).unwrap();
 
-               let message = TaggedHash::new(INVOICE_SIGNATURE_TAG, &bytes);
+               let message = TaggedHash::from_valid_tlv_stream_bytes(INVOICE_SIGNATURE_TAG, &bytes);
                let signature = merkle::sign_message(recipient_sign, &message, recipient_pubkey()).unwrap();
                signature_tlv_stream.signature = Some(&signature);
 
@@ -1320,7 +1401,7 @@ mod tests {
                let mut bytes = Vec::new();
                tlv_stream.write(&mut bytes).unwrap();
 
-               let message = TaggedHash::new(INVOICE_SIGNATURE_TAG, &bytes);
+               let message = TaggedHash::from_valid_tlv_stream_bytes(INVOICE_SIGNATURE_TAG, &bytes);
                let signature = merkle::sign_message(recipient_sign, &message, recipient_pubkey()).unwrap();
                signature_tlv_stream.signature = Some(&signature);
 
@@ -1369,7 +1450,7 @@ mod tests {
                let mut bytes = Vec::new();
                tlv_stream.write(&mut bytes).unwrap();
 
-               let message = TaggedHash::new(INVOICE_SIGNATURE_TAG, &bytes);
+               let message = TaggedHash::from_valid_tlv_stream_bytes(INVOICE_SIGNATURE_TAG, &bytes);
                let signature = merkle::sign_message(recipient_sign, &message, recipient_pubkey()).unwrap();
                signature_tlv_stream.signature = Some(&signature);
 
@@ -1392,7 +1473,7 @@ mod tests {
                let mut bytes = Vec::new();
                tlv_stream.write(&mut bytes).unwrap();
 
-               let message = TaggedHash::new(INVOICE_SIGNATURE_TAG, &bytes);
+               let message = TaggedHash::from_valid_tlv_stream_bytes(INVOICE_SIGNATURE_TAG, &bytes);
                let signature = merkle::sign_message(recipient_sign, &message, recipient_pubkey()).unwrap();
                signature_tlv_stream.signature = Some(&signature);
 
@@ -2162,4 +2243,55 @@ mod tests {
                        Err(e) => assert_eq!(e, Bolt12ParseError::Decode(DecodeError::InvalidValue)),
                }
        }
+
+       #[test]
+       fn copies_verified_invoice_request_fields() {
+               let desc = "foo".to_string();
+               let node_id = recipient_pubkey();
+               let expanded_key = ExpandedKey::new(&KeyMaterial([42; 32]));
+               let entropy = FixedEntropy {};
+               let secp_ctx = Secp256k1::new();
+
+               #[cfg(c_bindings)]
+               use crate::offers::offer::OfferWithDerivedMetadataBuilder as OfferBuilder;
+               let offer = OfferBuilder
+                       ::deriving_signing_pubkey(desc, node_id, &expanded_key, &entropy, &secp_ctx)
+                       .chain(Network::Testnet)
+                       .amount_msats(1000)
+                       .supported_quantity(Quantity::Unbounded)
+                       .build().unwrap();
+               assert_eq!(offer.signing_pubkey(), node_id);
+
+               let invoice_request = offer.request_invoice(vec![1; 32], payer_pubkey()).unwrap()
+                       .chain(Network::Testnet).unwrap()
+                       .amount_msats(1001).unwrap()
+                       .quantity(1).unwrap()
+                       .payer_note("0".repeat(PAYER_NOTE_LIMIT * 2))
+                       .build().unwrap()
+                       .sign(payer_sign).unwrap();
+               match invoice_request.verify(&expanded_key, &secp_ctx) {
+                       Ok(invoice_request) => {
+                               let fields = invoice_request.fields();
+                               assert_eq!(invoice_request.offer_id, offer.id());
+                               assert_eq!(
+                                       fields,
+                                       InvoiceRequestFields {
+                                               payer_id: payer_pubkey(),
+                                               amount_msats: Some(1001),
+                                               features: InvoiceRequestFeatures::empty(),
+                                               quantity: Some(1),
+                                               payer_note_truncated: Some(UntrustedString("0".repeat(PAYER_NOTE_LIMIT))),
+                                       }
+                               );
+
+                               let mut buffer = Vec::new();
+                               fields.write(&mut buffer).unwrap();
+
+                               let deserialized_fields: InvoiceRequestFields =
+                                       Readable::read(&mut buffer.as_slice()).unwrap();
+                               assert_eq!(deserialized_fields, fields);
+                       },
+                       Err(_) => panic!("unexpected error"),
+               }
+       }
 }
index da3fab589966d8d3d5988437a72aec56bbc1b8d8..a3979866926170e8b5dc467ef10c59aae9c91c54 100644 (file)
@@ -38,10 +38,20 @@ pub struct TaggedHash {
 }
 
 impl TaggedHash {
+       /// Creates a tagged hash with the given parameters.
+       ///
+       /// Panics if `bytes` is not a well-formed TLV stream containing at least one TLV record.
+       pub(super) fn from_valid_tlv_stream_bytes(tag: &'static str, bytes: &[u8]) -> Self {
+               let tlv_stream = TlvStream::new(bytes);
+               Self::from_tlv_stream(tag, tlv_stream)
+       }
+
        /// Creates a tagged hash with the given parameters.
        ///
        /// Panics if `tlv_stream` is not a well-formed TLV stream containing at least one TLV record.
-       pub(super) fn new(tag: &'static str, tlv_stream: &[u8]) -> Self {
+       pub(super) fn from_tlv_stream<'a, I: core::iter::Iterator<Item = TlvRecord<'a>>>(
+               tag: &'static str, tlv_stream: I
+       ) -> Self {
                let tag_hash = sha256::Hash::hash(tag.as_bytes());
                let merkle_root = root_hash(tlv_stream);
                let digest = Message::from_slice(tagged_hash(tag_hash, merkle_root).as_byte_array()).unwrap();
@@ -66,6 +76,10 @@ impl TaggedHash {
        pub fn merkle_root(&self) -> sha256::Hash {
                self.merkle_root
        }
+
+       pub(super) fn to_bytes(&self) -> [u8; 32] {
+               *self.digest.as_ref()
+       }
 }
 
 impl AsRef<TaggedHash> for TaggedHash {
@@ -137,9 +151,10 @@ pub(super) fn verify_signature(
 
 /// Computes a merkle root hash for the given data, which must be a well-formed TLV stream
 /// containing at least one TLV record.
-fn root_hash(data: &[u8]) -> sha256::Hash {
+fn root_hash<'a, I: core::iter::Iterator<Item = TlvRecord<'a>>>(tlv_stream: I) -> sha256::Hash {
+       let mut tlv_stream = tlv_stream.peekable();
        let nonce_tag = tagged_hash_engine(sha256::Hash::from_engine({
-               let first_tlv_record = TlvStream::new(&data[..]).next().unwrap();
+               let first_tlv_record = tlv_stream.peek().unwrap();
                let mut engine = sha256::Hash::engine();
                engine.input("LnNonce".as_bytes());
                engine.input(first_tlv_record.record_bytes);
@@ -149,8 +164,7 @@ fn root_hash(data: &[u8]) -> sha256::Hash {
        let branch_tag = tagged_hash_engine(sha256::Hash::hash("LnBranch".as_bytes()));
 
        let mut leaves = Vec::new();
-       let tlv_stream = TlvStream::new(&data[..]);
-       for record in tlv_stream.skip_signatures() {
+       for record in TlvStream::skip_signatures(tlv_stream) {
                leaves.push(tagged_hash_from_engine(leaf_tag.clone(), &record.record_bytes));
                leaves.push(tagged_hash_from_engine(nonce_tag.clone(), &record.type_bytes));
        }
@@ -227,8 +241,10 @@ impl<'a> TlvStream<'a> {
                        .take_while(move |record| take_range.contains(&record.r#type))
        }
 
-       fn skip_signatures(self) -> core::iter::Filter<TlvStream<'a>, fn(&TlvRecord) -> bool> {
-               self.filter(|record| !SIGNATURE_TYPES.contains(&record.r#type))
+       fn skip_signatures(
+               tlv_stream: impl core::iter::Iterator<Item = TlvRecord<'a>>
+       ) -> impl core::iter::Iterator<Item = TlvRecord<'a>> {
+               tlv_stream.filter(|record| !SIGNATURE_TYPES.contains(&record.r#type))
        }
 }
 
@@ -276,7 +292,7 @@ impl<'a> Writeable for WithoutSignatures<'a> {
        #[inline]
        fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
                let tlv_stream = TlvStream::new(self.0);
-               for record in tlv_stream.skip_signatures() {
+               for record in TlvStream::skip_signatures(tlv_stream) {
                        writer.write_all(record.record_bytes)?;
                }
                Ok(())
@@ -304,15 +320,15 @@ mod tests {
                macro_rules! tlv2 { () => { "02080000010000020003" } }
                macro_rules! tlv3 { () => { "03310266e4598d1d3c415f572a8488830b60f7e744ed9235eb0b1ba93283b315c0351800000000000000010000000000000002" } }
                assert_eq!(
-                       super::root_hash(&<Vec<u8>>::from_hex(tlv1!()).unwrap()),
+                       super::root_hash(TlvStream::new(&<Vec<u8>>::from_hex(tlv1!()).unwrap())),
                        sha256::Hash::from_slice(&<Vec<u8>>::from_hex("b013756c8fee86503a0b4abdab4cddeb1af5d344ca6fc2fa8b6c08938caa6f93").unwrap()).unwrap(),
                );
                assert_eq!(
-                       super::root_hash(&<Vec<u8>>::from_hex(concat!(tlv1!(), tlv2!())).unwrap()),
+                       super::root_hash(TlvStream::new(&<Vec<u8>>::from_hex(concat!(tlv1!(), tlv2!())).unwrap())),
                        sha256::Hash::from_slice(&<Vec<u8>>::from_hex("c3774abbf4815aa54ccaa026bff6581f01f3be5fe814c620a252534f434bc0d1").unwrap()).unwrap(),
                );
                assert_eq!(
-                       super::root_hash(&<Vec<u8>>::from_hex(concat!(tlv1!(), tlv2!(), tlv3!())).unwrap()),
+                       super::root_hash(TlvStream::new(&<Vec<u8>>::from_hex(concat!(tlv1!(), tlv2!(), tlv3!())).unwrap())),
                        sha256::Hash::from_slice(&<Vec<u8>>::from_hex("ab2e79b1283b0b31e0b035258de23782df6b89a38cfa7237bde69aed1a658c5d").unwrap()).unwrap(),
                );
        }
@@ -344,7 +360,7 @@ mod tests {
                        "lnr1qqyqqqqqqqqqqqqqqcp4256ypqqkgzshgysy6ct5dpjk6ct5d93kzmpq23ex2ct5d9ek293pqthvwfzadd7jejes8q9lhc4rvjxd022zv5l44g6qah82ru5rdpnpjkppqvjx204vgdzgsqpvcp4mldl3plscny0rt707gvpdh6ndydfacz43euzqhrurageg3n7kafgsek6gz3e9w52parv8gs2hlxzk95tzeswywffxlkeyhml0hh46kndmwf4m6xma3tkq2lu04qz3slje2rfthc89vss",
                );
                assert_eq!(
-                       super::root_hash(&invoice_request.bytes[..]),
+                       super::root_hash(TlvStream::new(&invoice_request.bytes[..])),
                        sha256::Hash::from_slice(&<Vec<u8>>::from_hex("608407c18ad9a94d9ea2bcdbe170b6c20c462a7833a197621c916f78cf18e624").unwrap()).unwrap(),
                );
                assert_eq!(
index f7b75138b5108bbecad7cae5faae4ede0d9749e5..3dedc6cd35d25166263603fb58fb6173faeb8935 100644 (file)
@@ -90,11 +90,11 @@ use crate::blinded_path::BlindedPath;
 use crate::ln::channelmanager::PaymentId;
 use crate::ln::features::OfferFeatures;
 use crate::ln::inbound_payment::{ExpandedKey, IV_LEN, Nonce};
-use crate::ln::msgs::MAX_VALUE_MSAT;
-use crate::offers::merkle::TlvStream;
+use crate::ln::msgs::{DecodeError, MAX_VALUE_MSAT};
+use crate::offers::merkle::{TaggedHash, TlvStream};
 use crate::offers::parse::{Bech32Encode, Bolt12ParseError, Bolt12SemanticError, ParsedMessage};
 use crate::offers::signer::{Metadata, MetadataMaterial, self};
-use crate::util::ser::{HighZeroBytesDroppedBigSize, WithoutLength, Writeable, Writer};
+use crate::util::ser::{HighZeroBytesDroppedBigSize, Readable, WithoutLength, Writeable, Writer};
 use crate::util::string::PrintableString;
 
 #[cfg(not(c_bindings))]
@@ -114,6 +114,37 @@ use std::time::SystemTime;
 
 pub(super) const IV_BYTES: &[u8; IV_LEN] = b"LDK Offer ~~~~~~";
 
+/// An identifier for an [`Offer`] built using [`DerivedMetadata`].
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
+pub struct OfferId(pub [u8; 32]);
+
+impl OfferId {
+       const ID_TAG: &'static str = "LDK Offer ID";
+
+       fn from_valid_offer_tlv_stream(bytes: &[u8]) -> Self {
+               let tagged_hash = TaggedHash::from_valid_tlv_stream_bytes(Self::ID_TAG, bytes);
+               Self(tagged_hash.to_bytes())
+       }
+
+       fn from_valid_invreq_tlv_stream(bytes: &[u8]) -> Self {
+               let tlv_stream = TlvStream::new(bytes).range(OFFER_TYPES);
+               let tagged_hash = TaggedHash::from_tlv_stream(Self::ID_TAG, tlv_stream);
+               Self(tagged_hash.to_bytes())
+       }
+}
+
+impl Writeable for OfferId {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               self.0.write(w)
+       }
+}
+
+impl Readable for OfferId {
+       fn read<R: io::Read>(r: &mut R) -> Result<Self, DecodeError> {
+               Ok(OfferId(Readable::read(r)?))
+       }
+}
+
 /// Builds an [`Offer`] for the "offer to be paid" flow.
 ///
 /// See [module-level documentation] for usage.
@@ -370,12 +401,15 @@ macro_rules! offer_builder_methods { (
                let mut bytes = Vec::new();
                $self.offer.write(&mut bytes).unwrap();
 
+               let id = OfferId::from_valid_offer_tlv_stream(&bytes);
+
                Offer {
                        bytes,
                        #[cfg(not(c_bindings))]
                        contents: $self.offer,
                        #[cfg(c_bindings)]
-                       contents: $self.offer.clone()
+                       contents: $self.offer.clone(),
+                       id,
                }
        }
 } }
@@ -488,6 +522,7 @@ pub struct Offer {
        // fields.
        pub(super) bytes: Vec<u8>,
        pub(super) contents: OfferContents,
+       id: OfferId,
 }
 
 /// The contents of an [`Offer`], which may be shared with an [`InvoiceRequest`] or a
@@ -577,6 +612,11 @@ macro_rules! offer_accessors { ($self: ident, $contents: expr) => {
 impl Offer {
        offer_accessors!(self, self.contents);
 
+       /// Returns the id of the offer.
+       pub fn id(&self) -> OfferId {
+               self.id
+       }
+
        pub(super) fn implied_chain(&self) -> ChainHash {
                self.contents.implied_chain()
        }
@@ -853,7 +893,7 @@ impl OfferContents {
        /// Verifies that the offer metadata was produced from the offer in the TLV stream.
        pub(super) fn verify<T: secp256k1::Signing>(
                &self, bytes: &[u8], key: &ExpandedKey, secp_ctx: &Secp256k1<T>
-       ) -> Result<Option<KeyPair>, ()> {
+       ) -> Result<(OfferId, Option<KeyPair>), ()> {
                match self.metadata() {
                        Some(metadata) => {
                                let tlv_stream = TlvStream::new(bytes).range(OFFER_TYPES).filter(|record| {
@@ -865,9 +905,13 @@ impl OfferContents {
                                                _ => true,
                                        }
                                });
-                               signer::verify_recipient_metadata(
+                               let keys = signer::verify_recipient_metadata(
                                        metadata, key, IV_BYTES, self.signing_pubkey(), tlv_stream, secp_ctx
-                               )
+                               )?;
+
+                               let offer_id = OfferId::from_valid_invreq_tlv_stream(bytes);
+
+                               Ok((offer_id, keys))
                        },
                        None => Err(()),
                }
@@ -1002,7 +1046,9 @@ impl TryFrom<Vec<u8>> for Offer {
                let offer = ParsedMessage::<OfferTlvStream>::try_from(bytes)?;
                let ParsedMessage { bytes, tlv_stream } = offer;
                let contents = OfferContents::try_from(tlv_stream)?;
-               Ok(Offer { bytes, contents })
+               let id = OfferId::from_valid_offer_tlv_stream(&bytes);
+
+               Ok(Offer { bytes, contents, id })
        }
 }
 
@@ -1078,7 +1124,7 @@ mod tests {
        use bitcoin::secp256k1::Secp256k1;
        use core::num::NonZeroU64;
        use core::time::Duration;
-       use crate::blinded_path::{BlindedHop, BlindedPath};
+       use crate::blinded_path::{BlindedHop, BlindedPath, IntroductionNode};
        use crate::sign::KeyMaterial;
        use crate::ln::features::OfferFeatures;
        use crate::ln::inbound_payment::ExpandedKey;
@@ -1210,7 +1256,10 @@ mod tests {
                let invoice_request = offer.request_invoice(vec![1; 32], payer_pubkey()).unwrap()
                        .build().unwrap()
                        .sign(payer_sign).unwrap();
-               assert!(invoice_request.verify(&expanded_key, &secp_ctx).is_ok());
+               match invoice_request.verify(&expanded_key, &secp_ctx) {
+                       Ok(invoice_request) => assert_eq!(invoice_request.offer_id, offer.id()),
+                       Err(_) => panic!("unexpected error"),
+               }
 
                // Fails verification with altered offer field
                let mut tlv_stream = offer.as_tlv_stream();
@@ -1249,7 +1298,7 @@ mod tests {
                let secp_ctx = Secp256k1::new();
 
                let blinded_path = BlindedPath {
-                       introduction_node_id: pubkey(40),
+                       introduction_node: IntroductionNode::NodeId(pubkey(40)),
                        blinding_point: pubkey(41),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: pubkey(42), encrypted_payload: vec![0; 43] },
@@ -1269,7 +1318,10 @@ mod tests {
                let invoice_request = offer.request_invoice(vec![1; 32], payer_pubkey()).unwrap()
                        .build().unwrap()
                        .sign(payer_sign).unwrap();
-               assert!(invoice_request.verify(&expanded_key, &secp_ctx).is_ok());
+               match invoice_request.verify(&expanded_key, &secp_ctx) {
+                       Ok(invoice_request) => assert_eq!(invoice_request.offer_id, offer.id()),
+                       Err(_) => panic!("unexpected error"),
+               }
 
                // Fails verification with altered offer field
                let mut tlv_stream = offer.as_tlv_stream();
@@ -1395,7 +1447,7 @@ mod tests {
        fn builds_offer_with_paths() {
                let paths = vec![
                        BlindedPath {
-                               introduction_node_id: pubkey(40),
+                               introduction_node: IntroductionNode::NodeId(pubkey(40)),
                                blinding_point: pubkey(41),
                                blinded_hops: vec![
                                        BlindedHop { blinded_node_id: pubkey(43), encrypted_payload: vec![0; 43] },
@@ -1403,7 +1455,7 @@ mod tests {
                                ],
                        },
                        BlindedPath {
-                               introduction_node_id: pubkey(40),
+                               introduction_node: IntroductionNode::NodeId(pubkey(40)),
                                blinding_point: pubkey(41),
                                blinded_hops: vec![
                                        BlindedHop { blinded_node_id: pubkey(45), encrypted_payload: vec![0; 45] },
@@ -1585,7 +1637,7 @@ mod tests {
        fn parses_offer_with_paths() {
                let offer = OfferBuilder::new("foo".into(), pubkey(42))
                        .path(BlindedPath {
-                               introduction_node_id: pubkey(40),
+                               introduction_node: IntroductionNode::NodeId(pubkey(40)),
                                blinding_point: pubkey(41),
                                blinded_hops: vec![
                                        BlindedHop { blinded_node_id: pubkey(43), encrypted_payload: vec![0; 43] },
@@ -1593,7 +1645,7 @@ mod tests {
                                ],
                        })
                        .path(BlindedPath {
-                               introduction_node_id: pubkey(40),
+                               introduction_node: IntroductionNode::NodeId(pubkey(40)),
                                blinding_point: pubkey(41),
                                blinded_hops: vec![
                                        BlindedHop { blinded_node_id: pubkey(45), encrypted_payload: vec![0; 45] },
index 16014cd3c0b7ac6080a31b1c198b7bb669798219..03253fb6400bfe2c1261fcb2b00f13bc6eff1779 100644 (file)
@@ -907,7 +907,7 @@ mod tests {
 
        use core::time::Duration;
 
-       use crate::blinded_path::{BlindedHop, BlindedPath};
+       use crate::blinded_path::{BlindedHop, BlindedPath, IntroductionNode};
        use crate::sign::KeyMaterial;
        use crate::ln::channelmanager::PaymentId;
        use crate::ln::features::{InvoiceRequestFeatures, OfferFeatures};
@@ -1062,7 +1062,7 @@ mod tests {
                let payment_id = PaymentId([1; 32]);
 
                let blinded_path = BlindedPath {
-                       introduction_node_id: pubkey(40),
+                       introduction_node: IntroductionNode::NodeId(pubkey(40)),
                        blinding_point: pubkey(41),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: pubkey(43), encrypted_payload: vec![0; 43] },
@@ -1151,7 +1151,7 @@ mod tests {
        fn builds_refund_with_paths() {
                let paths = vec![
                        BlindedPath {
-                               introduction_node_id: pubkey(40),
+                               introduction_node: IntroductionNode::NodeId(pubkey(40)),
                                blinding_point: pubkey(41),
                                blinded_hops: vec![
                                        BlindedHop { blinded_node_id: pubkey(43), encrypted_payload: vec![0; 43] },
@@ -1159,7 +1159,7 @@ mod tests {
                                ],
                        },
                        BlindedPath {
-                               introduction_node_id: pubkey(40),
+                               introduction_node: IntroductionNode::NodeId(pubkey(40)),
                                blinding_point: pubkey(41),
                                blinded_hops: vec![
                                        BlindedHop { blinded_node_id: pubkey(45), encrypted_payload: vec![0; 45] },
@@ -1368,7 +1368,7 @@ mod tests {
                let past_expiry = Duration::from_secs(0);
                let paths = vec![
                        BlindedPath {
-                               introduction_node_id: pubkey(40),
+                               introduction_node: IntroductionNode::NodeId(pubkey(40)),
                                blinding_point: pubkey(41),
                                blinded_hops: vec![
                                        BlindedHop { blinded_node_id: pubkey(43), encrypted_payload: vec![0; 43] },
@@ -1376,7 +1376,7 @@ mod tests {
                                ],
                        },
                        BlindedPath {
-                               introduction_node_id: pubkey(40),
+                               introduction_node: IntroductionNode::NodeId(pubkey(40)),
                                blinding_point: pubkey(41),
                                blinded_hops: vec![
                                        BlindedHop { blinded_node_id: pubkey(45), encrypted_payload: vec![0; 45] },
index b4329803016fadb395689bf696e30d78c2984e97..149ba15c3a2392d4caa98c5d63db02a1c988ab8b 100644 (file)
@@ -13,7 +13,7 @@ use bitcoin::secp256k1::{KeyPair, PublicKey, Secp256k1, SecretKey};
 use bitcoin::secp256k1::schnorr::Signature;
 
 use core::time::Duration;
-use crate::blinded_path::{BlindedHop, BlindedPath};
+use crate::blinded_path::{BlindedHop, BlindedPath, IntroductionNode};
 use crate::sign::EntropySource;
 use crate::ln::PaymentHash;
 use crate::ln::features::BlindedHopFeatures;
@@ -69,7 +69,7 @@ pub(super) fn privkey(byte: u8) -> SecretKey {
 pub(crate) fn payment_paths() -> Vec<(BlindedPayInfo, BlindedPath)> {
        let paths = vec![
                BlindedPath {
-                       introduction_node_id: pubkey(40),
+                       introduction_node: IntroductionNode::NodeId(pubkey(40)),
                        blinding_point: pubkey(41),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: pubkey(43), encrypted_payload: vec![0; 43] },
@@ -77,7 +77,7 @@ pub(crate) fn payment_paths() -> Vec<(BlindedPayInfo, BlindedPath)> {
                        ],
                },
                BlindedPath {
-                       introduction_node_id: pubkey(40),
+                       introduction_node: IntroductionNode::NodeId(pubkey(40)),
                        blinding_point: pubkey(41),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: pubkey(45), encrypted_payload: vec![0; 45] },
index 16f0babe3612e365accffd4c6a7ee110e65880ba..acf34a3a8c8a112067bd22cd905b328166398321 100644 (file)
@@ -9,7 +9,7 @@
 
 //! Onion message testing and test utilities live here.
 
-use crate::blinded_path::BlindedPath;
+use crate::blinded_path::{BlindedPath, EmptyNodeIdLookUp};
 use crate::events::{Event, EventsProvider};
 use crate::ln::features::{ChannelFeatures, InitFeatures};
 use crate::ln::msgs::{self, DecodeError, OnionMessageHandler};
@@ -42,6 +42,7 @@ struct MessengerNode {
                Arc<test_utils::TestKeysInterface>,
                Arc<test_utils::TestNodeSigner>,
                Arc<test_utils::TestLogger>,
+               Arc<EmptyNodeIdLookUp>,
                Arc<DefaultMessageRouter<
                        Arc<NetworkGraph<Arc<test_utils::TestLogger>>>,
                        Arc<test_utils::TestLogger>,
@@ -175,6 +176,7 @@ fn create_nodes_using_secrets(secrets: Vec<SecretKey>) -> Vec<MessengerNode> {
                let entropy_source = Arc::new(test_utils::TestKeysInterface::new(&seed, Network::Testnet));
                let node_signer = Arc::new(test_utils::TestNodeSigner::new(secret_key));
 
+               let node_id_lookup = Arc::new(EmptyNodeIdLookUp {});
                let message_router = Arc::new(
                        DefaultMessageRouter::new(network_graph.clone(), entropy_source.clone())
                );
@@ -185,7 +187,7 @@ fn create_nodes_using_secrets(secrets: Vec<SecretKey>) -> Vec<MessengerNode> {
                        node_id: node_signer.get_node_id(Recipient::Node).unwrap(),
                        entropy_source: entropy_source.clone(),
                        messenger: OnionMessenger::new(
-                               entropy_source, node_signer, logger.clone(), message_router,
+                               entropy_source, node_signer, logger.clone(), node_id_lookup, message_router,
                                offers_message_handler, custom_message_handler.clone()
                        ),
                        custom_message_handler,
index e213bcbb0e1dc50d9fda8983f164e82611db26d9..1d7a730fa3625126097fd6d25de150e5ba46c742 100644 (file)
@@ -15,15 +15,15 @@ use bitcoin::hashes::hmac::{Hmac, HmacEngine};
 use bitcoin::hashes::sha256::Hash as Sha256;
 use bitcoin::secp256k1::{self, PublicKey, Scalar, Secp256k1, SecretKey};
 
-use crate::blinded_path::BlindedPath;
-use crate::blinded_path::message::{advance_path_by_one, ForwardTlvs, ReceiveTlvs};
+use crate::blinded_path::{BlindedPath, IntroductionNode, NodeIdLookUp};
+use crate::blinded_path::message::{advance_path_by_one, ForwardTlvs, NextHop, ReceiveTlvs};
 use crate::blinded_path::utils;
 use crate::events::{Event, EventHandler, EventsProvider};
 use crate::sign::{EntropySource, NodeSigner, Recipient};
 use crate::ln::features::{InitFeatures, NodeFeatures};
 use crate::ln::msgs::{self, OnionMessage, OnionMessageHandler, SocketAddress};
 use crate::ln::onion_utils;
-use crate::routing::gossip::{NetworkGraph, NodeId};
+use crate::routing::gossip::{NetworkGraph, NodeId, ReadOnlyNetworkGraph};
 use super::packet::OnionMessageContents;
 use super::packet::ParsedOnionMessageContents;
 use super::offers::OffersMessageHandler;
@@ -70,7 +70,7 @@ pub(super) const MAX_TIMER_TICKS: usize = 2;
 /// # use bitcoin::hashes::_export::_core::time::Duration;
 /// # use bitcoin::hashes::hex::FromHex;
 /// # use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey, self};
-/// # use lightning::blinded_path::BlindedPath;
+/// # use lightning::blinded_path::{BlindedPath, EmptyNodeIdLookUp};
 /// # use lightning::sign::{EntropySource, KeysManager};
 /// # use lightning::ln::peer_handler::IgnoringMessageHandler;
 /// # use lightning::onion_message::messenger::{Destination, MessageRouter, OnionMessagePath, OnionMessenger};
@@ -111,14 +111,15 @@ pub(super) const MAX_TIMER_TICKS: usize = 2;
 /// # let hop_node_id1 = PublicKey::from_secret_key(&secp_ctx, &node_secret);
 /// # let (hop_node_id3, hop_node_id4) = (hop_node_id1, hop_node_id1);
 /// # let destination_node_id = hop_node_id1;
+/// # let node_id_lookup = EmptyNodeIdLookUp {};
 /// # let message_router = Arc::new(FakeMessageRouter {});
 /// # let custom_message_handler = IgnoringMessageHandler {};
 /// # let offers_message_handler = IgnoringMessageHandler {};
 /// // Create the onion messenger. This must use the same `keys_manager` as is passed to your
 /// // ChannelManager.
 /// let onion_messenger = OnionMessenger::new(
-///     &keys_manager, &keys_manager, logger, message_router, &offers_message_handler,
-///     &custom_message_handler
+///     &keys_manager, &keys_manager, logger, &node_id_lookup, message_router,
+///     &offers_message_handler, &custom_message_handler
 /// );
 
 /// # #[derive(Debug)]
@@ -155,11 +156,12 @@ pub(super) const MAX_TIMER_TICKS: usize = 2;
 ///
 /// [`InvoiceRequest`]: crate::offers::invoice_request::InvoiceRequest
 /// [`Bolt12Invoice`]: crate::offers::invoice::Bolt12Invoice
-pub struct OnionMessenger<ES: Deref, NS: Deref, L: Deref, MR: Deref, OMH: Deref, CMH: Deref>
+pub struct OnionMessenger<ES: Deref, NS: Deref, L: Deref, NL: Deref, MR: Deref, OMH: Deref, CMH: Deref>
 where
        ES::Target: EntropySource,
        NS::Target: NodeSigner,
        L::Target: Logger,
+       NL::Target: NodeIdLookUp,
        MR::Target: MessageRouter,
        OMH::Target: OffersMessageHandler,
        CMH::Target: CustomOnionMessageHandler,
@@ -169,6 +171,7 @@ where
        logger: L,
        message_recipients: Mutex<HashMap<PublicKey, OnionMessageRecipient>>,
        secp_ctx: Secp256k1<secp256k1::All>,
+       node_id_lookup: NL,
        message_router: MR,
        offers_handler: OMH,
        custom_handler: CMH,
@@ -318,15 +321,21 @@ where
        ES::Target: EntropySource,
 {
        fn find_path(
-               &self, sender: PublicKey, peers: Vec<PublicKey>, destination: Destination
+               &self, sender: PublicKey, peers: Vec<PublicKey>, mut destination: Destination
        ) -> Result<OnionMessagePath, ()> {
-               let first_node = destination.first_node();
+               let network_graph = self.network_graph.deref().read_only();
+               destination.resolve(&network_graph);
+
+               let first_node = match destination.first_node() {
+                       Some(first_node) => first_node,
+                       None => return Err(()),
+               };
+
                if peers.contains(&first_node) || sender == first_node {
                        Ok(OnionMessagePath {
                                intermediate_nodes: vec![], destination, first_node_addresses: None
                        })
                } else {
-                       let network_graph = self.network_graph.deref().read_only();
                        let node_announcement = network_graph
                                .node(&NodeId::from_pubkey(&first_node))
                                .and_then(|node_info| node_info.announcement_info.as_ref())
@@ -416,11 +425,11 @@ pub struct OnionMessagePath {
 
 impl OnionMessagePath {
        /// Returns the first node in the path.
-       pub fn first_node(&self) -> PublicKey {
+       pub fn first_node(&self) -> Option<PublicKey> {
                self.intermediate_nodes
                        .first()
                        .copied()
-                       .unwrap_or_else(|| self.destination.first_node())
+                       .or_else(|| self.destination.first_node())
        }
 }
 
@@ -434,6 +443,22 @@ pub enum Destination {
 }
 
 impl Destination {
+       /// Attempts to resolve the [`IntroductionNode::DirectedShortChannelId`] of a
+       /// [`Destination::BlindedPath`] to a [`IntroductionNode::NodeId`], if applicable, using the
+       /// provided [`ReadOnlyNetworkGraph`].
+       pub fn resolve(&mut self, network_graph: &ReadOnlyNetworkGraph) {
+               if let Destination::BlindedPath(path) = self {
+                       if let IntroductionNode::DirectedShortChannelId(..) = path.introduction_node {
+                               if let Some(pubkey) = path
+                                       .public_introduction_node_id(network_graph)
+                                       .and_then(|node_id| node_id.as_pubkey().ok())
+                               {
+                                       path.introduction_node = IntroductionNode::NodeId(pubkey);
+                               }
+                       }
+               }
+       }
+
        pub(super) fn num_hops(&self) -> usize {
                match self {
                        Destination::Node(_) => 1,
@@ -441,10 +466,15 @@ impl Destination {
                }
        }
 
-       fn first_node(&self) -> PublicKey {
+       fn first_node(&self) -> Option<PublicKey> {
                match self {
-                       Destination::Node(node_id) => *node_id,
-                       Destination::BlindedPath(BlindedPath { introduction_node_id: node_id, .. }) => *node_id,
+                       Destination::Node(node_id) => Some(*node_id),
+                       Destination::BlindedPath(BlindedPath { introduction_node, .. }) => {
+                               match introduction_node {
+                                       IntroductionNode::NodeId(pubkey) => Some(*pubkey),
+                                       IntroductionNode::DirectedShortChannelId(..) => None,
+                               }
+                       },
                }
        }
 }
@@ -487,6 +517,10 @@ pub enum SendError {
        ///
        /// [`NodeSigner`]: crate::sign::NodeSigner
        GetNodeIdFailed,
+       /// The provided [`Destination`] has a blinded path with an unresolved introduction node. An
+       /// attempt to resolve it in the [`MessageRouter`] when finding an [`OnionMessagePath`] likely
+       /// failed.
+       UnresolvedIntroductionNode,
        /// We attempted to send to a blinded path where we are the introduction node, and failed to
        /// advance the blinded path to make the second hop the new introduction node. Either
        /// [`NodeSigner::ecdh`] failed, we failed to tweak the current blinding point to get the
@@ -538,23 +572,56 @@ pub trait CustomOnionMessageHandler {
 #[derive(Debug)]
 pub enum PeeledOnion<T: OnionMessageContents> {
        /// Forwarded onion, with the next node id and a new onion
-       Forward(PublicKey, OnionMessage),
+       Forward(NextHop, OnionMessage),
        /// Received onion message, with decrypted contents, path_id, and reply path
        Receive(ParsedOnionMessageContents<T>, Option<[u8; 32]>, Option<BlindedPath>)
 }
 
+
+/// Creates an [`OnionMessage`] with the given `contents` for sending to the destination of
+/// `path`, first calling [`Destination::resolve`] on `path.destination` with the given
+/// [`ReadOnlyNetworkGraph`].
+///
+/// Returns the node id of the peer to send the message to, the message itself, and any addresses
+/// needed to connect to the first node.
+pub fn create_onion_message_resolving_destination<
+       ES: Deref, NS: Deref, NL: Deref, T: OnionMessageContents
+>(
+       entropy_source: &ES, node_signer: &NS, node_id_lookup: &NL,
+       network_graph: &ReadOnlyNetworkGraph, secp_ctx: &Secp256k1<secp256k1::All>,
+       mut path: OnionMessagePath, contents: T, reply_path: Option<BlindedPath>,
+) -> Result<(PublicKey, OnionMessage, Option<Vec<SocketAddress>>), SendError>
+where
+       ES::Target: EntropySource,
+       NS::Target: NodeSigner,
+       NL::Target: NodeIdLookUp,
+{
+       path.destination.resolve(network_graph);
+       create_onion_message(
+               entropy_source, node_signer, node_id_lookup, secp_ctx, path, contents, reply_path,
+       )
+}
+
 /// Creates an [`OnionMessage`] with the given `contents` for sending to the destination of
 /// `path`.
 ///
 /// Returns the node id of the peer to send the message to, the message itself, and any addresses
-/// need to connect to the first node.
-pub fn create_onion_message<ES: Deref, NS: Deref, T: OnionMessageContents>(
-       entropy_source: &ES, node_signer: &NS, secp_ctx: &Secp256k1<secp256k1::All>,
-       path: OnionMessagePath, contents: T, reply_path: Option<BlindedPath>,
+/// needed to connect to the first node.
+///
+/// Returns [`SendError::UnresolvedIntroductionNode`] if:
+/// - `destination` contains a blinded path with an [`IntroductionNode::DirectedShortChannelId`],
+/// - unless it can be resolved by [`NodeIdLookUp::next_node_id`].
+/// Use [`create_onion_message_resolving_destination`] instead to resolve the introduction node
+/// first with a [`ReadOnlyNetworkGraph`].
+pub fn create_onion_message<ES: Deref, NS: Deref, NL: Deref, T: OnionMessageContents>(
+       entropy_source: &ES, node_signer: &NS, node_id_lookup: &NL,
+       secp_ctx: &Secp256k1<secp256k1::All>, path: OnionMessagePath, contents: T,
+       reply_path: Option<BlindedPath>,
 ) -> Result<(PublicKey, OnionMessage, Option<Vec<SocketAddress>>), SendError>
 where
        ES::Target: EntropySource,
        NS::Target: NodeSigner,
+       NL::Target: NodeIdLookUp,
 {
        let OnionMessagePath { intermediate_nodes, mut destination, first_node_addresses } = path;
        if let Destination::BlindedPath(BlindedPath { ref blinded_hops, .. }) = destination {
@@ -571,8 +638,17 @@ where
                if let Destination::BlindedPath(ref mut blinded_path) = destination {
                        let our_node_id = node_signer.get_node_id(Recipient::Node)
                                .map_err(|()| SendError::GetNodeIdFailed)?;
-                       if blinded_path.introduction_node_id == our_node_id {
-                               advance_path_by_one(blinded_path, node_signer, &secp_ctx)
+                       let introduction_node_id = match blinded_path.introduction_node {
+                               IntroductionNode::NodeId(pubkey) => pubkey,
+                               IntroductionNode::DirectedShortChannelId(direction, scid) => {
+                                       match node_id_lookup.next_node_id(scid) {
+                                               Some(next_node_id) => *direction.select_pubkey(&our_node_id, &next_node_id),
+                                               None => return Err(SendError::UnresolvedIntroductionNode),
+                                       }
+                               },
+                       };
+                       if introduction_node_id == our_node_id {
+                               advance_path_by_one(blinded_path, node_signer, node_id_lookup, &secp_ctx)
                                        .map_err(|()| SendError::BlindedPathAdvanceFailed)?;
                        }
                }
@@ -583,15 +659,21 @@ where
        let (first_node_id, blinding_point) = if let Some(first_node_id) = intermediate_nodes.first() {
                (*first_node_id, PublicKey::from_secret_key(&secp_ctx, &blinding_secret))
        } else {
-               match destination {
-                       Destination::Node(pk) => (pk, PublicKey::from_secret_key(&secp_ctx, &blinding_secret)),
-                       Destination::BlindedPath(BlindedPath { introduction_node_id, blinding_point, .. }) =>
-                               (introduction_node_id, blinding_point),
+               match &destination {
+                       Destination::Node(pk) => (*pk, PublicKey::from_secret_key(&secp_ctx, &blinding_secret)),
+                       Destination::BlindedPath(BlindedPath { introduction_node, blinding_point, .. }) => {
+                               match introduction_node {
+                                       IntroductionNode::NodeId(pubkey) => (*pubkey, *blinding_point),
+                                       IntroductionNode::DirectedShortChannelId(..) => {
+                                               return Err(SendError::UnresolvedIntroductionNode);
+                                       },
+                               }
+                       }
                }
        };
        let (packet_payloads, packet_keys) = packet_payloads_and_keys(
-               &secp_ctx, &intermediate_nodes, destination, contents, reply_path, &blinding_secret)
-               .map_err(|e| SendError::Secp256k1(e))?;
+               &secp_ctx, &intermediate_nodes, destination, contents, reply_path, &blinding_secret
+       )?;
 
        let prng_seed = entropy_source.get_secure_random_bytes();
        let onion_routing_packet = construct_onion_message_packet(
@@ -647,9 +729,9 @@ where
                        Ok(PeeledOnion::Receive(message, path_id, reply_path))
                },
                Ok((Payload::Forward(ForwardControlTlvs::Unblinded(ForwardTlvs {
-                       next_node_id, next_blinding_override
+                       next_hop, next_blinding_override
                })), Some((next_hop_hmac, new_packet_bytes)))) => {
-                       // TODO: we need to check whether `next_node_id` is our node, in which case this is a dummy
+                       // TODO: we need to check whether `next_hop` is our node, in which case this is a dummy
                        // blinded hop and this onion message is destined for us. In this situation, we should keep
                        // unwrapping the onion layers to get to the final payload. Since we don't have the option
                        // of creating blinded paths with dummy hops currently, we should be ok to not handle this
@@ -685,7 +767,7 @@ where
                                onion_routing_packet: outgoing_packet,
                        };
 
-                       Ok(PeeledOnion::Forward(next_node_id, onion_message))
+                       Ok(PeeledOnion::Forward(next_hop, onion_message))
                },
                Err(e) => {
                        log_trace!(logger, "Errored decoding onion message packet: {:?}", e);
@@ -698,12 +780,13 @@ where
        }
 }
 
-impl<ES: Deref, NS: Deref, L: Deref, MR: Deref, OMH: Deref, CMH: Deref>
-OnionMessenger<ES, NS, L, MR, OMH, CMH>
+impl<ES: Deref, NS: Deref, L: Deref, NL: Deref, MR: Deref, OMH: Deref, CMH: Deref>
+OnionMessenger<ES, NS, L, NL, MR, OMH, CMH>
 where
        ES::Target: EntropySource,
        NS::Target: NodeSigner,
        L::Target: Logger,
+       NL::Target: NodeIdLookUp,
        MR::Target: MessageRouter,
        OMH::Target: OffersMessageHandler,
        CMH::Target: CustomOnionMessageHandler,
@@ -711,8 +794,8 @@ where
        /// Constructs a new `OnionMessenger` to send, forward, and delegate received onion messages to
        /// their respective handlers.
        pub fn new(
-               entropy_source: ES, node_signer: NS, logger: L, message_router: MR, offers_handler: OMH,
-               custom_handler: CMH
+               entropy_source: ES, node_signer: NS, logger: L, node_id_lookup: NL, message_router: MR,
+               offers_handler: OMH, custom_handler: CMH
        ) -> Self {
                let mut secp_ctx = Secp256k1::new();
                secp_ctx.seeded_randomize(&entropy_source.get_secure_random_bytes());
@@ -722,6 +805,7 @@ where
                        message_recipients: Mutex::new(new_hash_map()),
                        secp_ctx,
                        logger,
+                       node_id_lookup,
                        message_router,
                        offers_handler,
                        custom_handler,
@@ -804,7 +888,8 @@ where
                log_trace!(self.logger, "Constructing onion message {}: {:?}", log_suffix, contents);
 
                let (first_node_id, onion_message, addresses) = create_onion_message(
-                       &self.entropy_source, &self.node_signer, &self.secp_ctx, path, contents, reply_path
+                       &self.entropy_source, &self.node_signer, &self.node_id_lookup, &self.secp_ctx, path,
+                       contents, reply_path,
                )?;
 
                let mut message_recipients = self.message_recipients.lock().unwrap();
@@ -900,12 +985,13 @@ fn outbound_buffer_full(peer_node_id: &PublicKey, buffer: &HashMap<PublicKey, On
        false
 }
 
-impl<ES: Deref, NS: Deref, L: Deref, MR: Deref, OMH: Deref, CMH: Deref> EventsProvider
-for OnionMessenger<ES, NS, L, MR, OMH, CMH>
+impl<ES: Deref, NS: Deref, L: Deref, NL: Deref, MR: Deref, OMH: Deref, CMH: Deref> EventsProvider
+for OnionMessenger<ES, NS, L, NL, MR, OMH, CMH>
 where
        ES::Target: EntropySource,
        NS::Target: NodeSigner,
        L::Target: Logger,
+       NL::Target: NodeIdLookUp,
        MR::Target: MessageRouter,
        OMH::Target: OffersMessageHandler,
        CMH::Target: CustomOnionMessageHandler,
@@ -921,12 +1007,13 @@ where
        }
 }
 
-impl<ES: Deref, NS: Deref, L: Deref, MR: Deref, OMH: Deref, CMH: Deref> OnionMessageHandler
-for OnionMessenger<ES, NS, L, MR, OMH, CMH>
+impl<ES: Deref, NS: Deref, L: Deref, NL: Deref, MR: Deref, OMH: Deref, CMH: Deref> OnionMessageHandler
+for OnionMessenger<ES, NS, L, NL, MR, OMH, CMH>
 where
        ES::Target: EntropySource,
        NS::Target: NodeSigner,
        L::Target: Logger,
+       NL::Target: NodeIdLookUp,
        MR::Target: MessageRouter,
        OMH::Target: OffersMessageHandler,
        CMH::Target: CustomOnionMessageHandler,
@@ -961,7 +1048,18 @@ where
                                        },
                                }
                        },
-                       Ok(PeeledOnion::Forward(next_node_id, onion_message)) => {
+                       Ok(PeeledOnion::Forward(next_hop, onion_message)) => {
+                               let next_node_id = match next_hop {
+                                       NextHop::NodeId(pubkey) => pubkey,
+                                       NextHop::ShortChannelId(scid) => match self.node_id_lookup.next_node_id(scid) {
+                                               Some(pubkey) => pubkey,
+                                               None => {
+                                                       log_trace!(self.logger, "Dropping forwarded onion messager: unable to resolve next hop using SCID {}", scid);
+                                                       return
+                                               },
+                                       },
+                               };
+
                                let mut message_recipients = self.message_recipients.lock().unwrap();
                                if outbound_buffer_full(&next_node_id, &message_recipients) {
                                        log_trace!(
@@ -1097,6 +1195,7 @@ pub type SimpleArcOnionMessenger<M, T, F, L> = OnionMessenger<
        Arc<KeysManager>,
        Arc<KeysManager>,
        Arc<L>,
+       Arc<SimpleArcChannelManager<M, T, F, L>>,
        Arc<DefaultMessageRouter<Arc<NetworkGraph<Arc<L>>>, Arc<L>, Arc<KeysManager>>>,
        Arc<SimpleArcChannelManager<M, T, F, L>>,
        IgnoringMessageHandler
@@ -1116,8 +1215,9 @@ pub type SimpleRefOnionMessenger<
        &'a KeysManager,
        &'a KeysManager,
        &'b L,
-       &'i DefaultMessageRouter<&'g NetworkGraph<&'b L>, &'b L, &'a KeysManager>,
-       &'j SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L>,
+       &'i SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L>,
+       &'j DefaultMessageRouter<&'g NetworkGraph<&'b L>, &'b L, &'a KeysManager>,
+       &'i SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L>,
        IgnoringMessageHandler
 >;
 
@@ -1126,14 +1226,23 @@ pub type SimpleRefOnionMessenger<
 fn packet_payloads_and_keys<T: OnionMessageContents, S: secp256k1::Signing + secp256k1::Verification>(
        secp_ctx: &Secp256k1<S>, unblinded_path: &[PublicKey], destination: Destination, message: T,
        mut reply_path: Option<BlindedPath>, session_priv: &SecretKey
-) -> Result<(Vec<(Payload<T>, [u8; 32])>, Vec<onion_utils::OnionKeys>), secp256k1::Error> {
+) -> Result<(Vec<(Payload<T>, [u8; 32])>, Vec<onion_utils::OnionKeys>), SendError> {
        let num_hops = unblinded_path.len() + destination.num_hops();
        let mut payloads = Vec::with_capacity(num_hops);
        let mut onion_packet_keys = Vec::with_capacity(num_hops);
 
-       let (mut intro_node_id_blinding_pt, num_blinded_hops) = if let Destination::BlindedPath(BlindedPath {
-               introduction_node_id, blinding_point, blinded_hops }) = &destination {
-               (Some((*introduction_node_id, *blinding_point)), blinded_hops.len()) } else { (None, 0) };
+       let (mut intro_node_id_blinding_pt, num_blinded_hops) = match &destination {
+               Destination::Node(_) => (None, 0),
+               Destination::BlindedPath(BlindedPath { introduction_node, blinding_point, blinded_hops }) => {
+                       let introduction_node_id = match introduction_node {
+                               IntroductionNode::NodeId(pubkey) => pubkey,
+                               IntroductionNode::DirectedShortChannelId(..) => {
+                                       return Err(SendError::UnresolvedIntroductionNode);
+                               },
+                       };
+                       (Some((*introduction_node_id, *blinding_point)), blinded_hops.len())
+               },
+       };
        let num_unblinded_hops = num_hops - num_blinded_hops;
 
        let mut unblinded_path_idx = 0;
@@ -1146,7 +1255,7 @@ fn packet_payloads_and_keys<T: OnionMessageContents, S: secp256k1::Signing + sec
                                if let Some(ss) = prev_control_tlvs_ss.take() {
                                        payloads.push((Payload::Forward(ForwardControlTlvs::Unblinded(
                                                ForwardTlvs {
-                                                       next_node_id: unblinded_pk_opt.unwrap(),
+                                                       next_hop: NextHop::NodeId(unblinded_pk_opt.unwrap()),
                                                        next_blinding_override: None,
                                                }
                                        )), ss));
@@ -1156,7 +1265,7 @@ fn packet_payloads_and_keys<T: OnionMessageContents, S: secp256k1::Signing + sec
                        } else if let Some((intro_node_id, blinding_pt)) = intro_node_id_blinding_pt.take() {
                                if let Some(control_tlvs_ss) = prev_control_tlvs_ss.take() {
                                        payloads.push((Payload::Forward(ForwardControlTlvs::Unblinded(ForwardTlvs {
-                                               next_node_id: intro_node_id,
+                                               next_hop: NextHop::NodeId(intro_node_id),
                                                next_blinding_override: Some(blinding_pt),
                                        })), control_tlvs_ss));
                                }
@@ -1181,7 +1290,7 @@ fn packet_payloads_and_keys<T: OnionMessageContents, S: secp256k1::Signing + sec
                                mu,
                        });
                }
-       )?;
+       ).map_err(|e| SendError::Secp256k1(e))?;
 
        if let Some(control_tlvs) = final_control_tlvs {
                payloads.push((Payload::Receive {
index d9349fdadbfaba6c1f0d08c265c48bf14da49bb7..510f0ea025a0d615b0f54292d865602ac7e103a6 100644 (file)
@@ -13,7 +13,7 @@ use bitcoin::secp256k1::PublicKey;
 use bitcoin::secp256k1::ecdh::SharedSecret;
 
 use crate::blinded_path::BlindedPath;
-use crate::blinded_path::message::{ForwardTlvs, ReceiveTlvs};
+use crate::blinded_path::message::{ForwardTlvs, NextHop, ReceiveTlvs};
 use crate::blinded_path::utils::Padding;
 use crate::ln::msgs::DecodeError;
 use crate::ln::onion_utils;
@@ -284,20 +284,26 @@ impl Readable for ControlTlvs {
        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
                _init_and_read_tlv_stream!(r, {
                        (1, _padding, option),
-                       (2, _short_channel_id, option),
+                       (2, short_channel_id, option),
                        (4, next_node_id, option),
                        (6, path_id, option),
                        (8, next_blinding_override, option),
                });
                let _padding: Option<Padding> = _padding;
-               let _short_channel_id: Option<u64> = _short_channel_id;
 
-               let valid_fwd_fmt  = next_node_id.is_some() && path_id.is_none();
-               let valid_recv_fmt = next_node_id.is_none() && next_blinding_override.is_none();
+               let next_hop = match (short_channel_id, next_node_id) {
+                       (Some(_), Some(_)) => return Err(DecodeError::InvalidValue),
+                       (Some(scid), None) => Some(NextHop::ShortChannelId(scid)),
+                       (None, Some(pubkey)) => Some(NextHop::NodeId(pubkey)),
+                       (None, None) => None,
+               };
+
+               let valid_fwd_fmt = next_hop.is_some() && path_id.is_none();
+               let valid_recv_fmt = next_hop.is_none() && next_blinding_override.is_none();
 
                let payload_fmt = if valid_fwd_fmt {
                        ControlTlvs::Forward(ForwardTlvs {
-                               next_node_id: next_node_id.unwrap(),
+                               next_hop: next_hop.unwrap(),
                                next_blinding_override,
                        })
                } else if valid_recv_fmt {
index e8276712ee89504e002b619aa70544a147467840..59ec3f6186236d07b13ec3ae7a74b3b74dce3b58 100644 (file)
@@ -11,7 +11,7 @@
 
 use bitcoin::secp256k1::{PublicKey, Secp256k1, self};
 
-use crate::blinded_path::{BlindedHop, BlindedPath};
+use crate::blinded_path::{BlindedHop, BlindedPath, Direction, IntroductionNode};
 use crate::blinded_path::payment::{ForwardNode, ForwardTlvs, PaymentConstraints, PaymentRelay, ReceiveTlvs};
 use crate::ln::PaymentHash;
 use crate::ln::channelmanager::{ChannelDetails, PaymentId, MIN_FINAL_CLTV_EXPIRY_DELTA};
@@ -1144,11 +1144,11 @@ pub struct FirstHopCandidate<'a> {
        ///
        /// [`find_route`] validates this prior to constructing a [`CandidateRouteHop`].
        ///
-       /// This is not exported to bindings users as lifetimes are not expressable in most languages.
+       /// This is not exported to bindings users as lifetimes are not expressible in most languages.
        pub details: &'a ChannelDetails,
        /// The node id of the payer, which is also the source side of this candidate route hop.
        ///
-       /// This is not exported to bindings users as lifetimes are not expressable in most languages.
+       /// This is not exported to bindings users as lifetimes are not expressible in most languages.
        pub payer_node_id: &'a NodeId,
 }
 
@@ -1158,7 +1158,7 @@ pub struct PublicHopCandidate<'a> {
        /// Information about the channel, including potentially its capacity and
        /// direction-specific information.
        ///
-       /// This is not exported to bindings users as lifetimes are not expressable in most languages.
+       /// This is not exported to bindings users as lifetimes are not expressible in most languages.
        pub info: DirectedChannelInfo<'a>,
        /// The short channel ID of the channel, i.e. the identifier by which we refer to this
        /// channel.
@@ -1170,21 +1170,26 @@ pub struct PublicHopCandidate<'a> {
 pub struct PrivateHopCandidate<'a> {
        /// Information about the private hop communicated via BOLT 11.
        ///
-       /// This is not exported to bindings users as lifetimes are not expressable in most languages.
+       /// This is not exported to bindings users as lifetimes are not expressible in most languages.
        pub hint: &'a RouteHintHop,
        /// Node id of the next hop in BOLT 11 route hint.
        ///
-       /// This is not exported to bindings users as lifetimes are not expressable in most languages.
+       /// This is not exported to bindings users as lifetimes are not expressible in most languages.
        pub target_node_id: &'a NodeId
 }
 
 /// A [`CandidateRouteHop::Blinded`] entry.
 #[derive(Clone, Debug)]
 pub struct BlindedPathCandidate<'a> {
+       /// The node id of the introduction node, resolved from either the [`NetworkGraph`] or first
+       /// hops.
+       ///
+       /// This is not exported to bindings users as lifetimes are not expressible in most languages.
+       pub source_node_id: &'a NodeId,
        /// Information about the blinded path including the fee, HTLC amount limits, and
        /// cryptographic material required to build an HTLC through the given path.
        ///
-       /// This is not exported to bindings users as lifetimes are not expressable in most languages.
+       /// This is not exported to bindings users as lifetimes are not expressible in most languages.
        pub hint: &'a (BlindedPayInfo, BlindedPath),
        /// Index of the hint in the original list of blinded hints.
        ///
@@ -1196,12 +1201,17 @@ pub struct BlindedPathCandidate<'a> {
 /// A [`CandidateRouteHop::OneHopBlinded`] entry.
 #[derive(Clone, Debug)]
 pub struct OneHopBlindedPathCandidate<'a> {
+       /// The node id of the introduction node, resolved from either the [`NetworkGraph`] or first
+       /// hops.
+       ///
+       /// This is not exported to bindings users as lifetimes are not expressible in most languages.
+       pub source_node_id: &'a NodeId,
        /// Information about the blinded path including the fee, HTLC amount limits, and
        /// cryptographic material required to build an HTLC terminating with the given path.
        ///
        /// Note that the [`BlindedPayInfo`] is ignored here.
        ///
-       /// This is not exported to bindings users as lifetimes are not expressable in most languages.
+       /// This is not exported to bindings users as lifetimes are not expressible in most languages.
        pub hint: &'a (BlindedPayInfo, BlindedPath),
        /// Index of the hint in the original list of blinded hints.
        ///
@@ -1409,8 +1419,8 @@ impl<'a> CandidateRouteHop<'a> {
                        CandidateRouteHop::FirstHop(hop) => *hop.payer_node_id,
                        CandidateRouteHop::PublicHop(hop) => *hop.info.source(),
                        CandidateRouteHop::PrivateHop(hop) => hop.hint.src_node_id.into(),
-                       CandidateRouteHop::Blinded(hop) => hop.hint.1.introduction_node_id.into(),
-                       CandidateRouteHop::OneHopBlinded(hop) => hop.hint.1.introduction_node_id.into(),
+                       CandidateRouteHop::Blinded(hop) => *hop.source_node_id,
+                       CandidateRouteHop::OneHopBlinded(hop) => *hop.source_node_id,
                }
        }
        /// Returns the target node id of this hop, if known.
@@ -1725,8 +1735,20 @@ impl<'a> fmt::Display for LoggedCandidateHop<'a> {
        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
                match self.0 {
                        CandidateRouteHop::Blinded(BlindedPathCandidate { hint, .. }) | CandidateRouteHop::OneHopBlinded(OneHopBlindedPathCandidate { hint, .. }) => {
-                               "blinded route hint with introduction node id ".fmt(f)?;
-                               hint.1.introduction_node_id.fmt(f)?;
+                               "blinded route hint with introduction node ".fmt(f)?;
+                               match &hint.1.introduction_node {
+                                       IntroductionNode::NodeId(pubkey) => write!(f, "id {}", pubkey)?,
+                                       IntroductionNode::DirectedShortChannelId(direction, scid) => {
+                                               match direction {
+                                                       Direction::NodeOne => {
+                                                               write!(f, "one on channel with SCID {}", scid)?;
+                                                       },
+                                                       Direction::NodeTwo => {
+                                                               write!(f, "two on channel with SCID {}", scid)?;
+                                                       },
+                                               }
+                                       }
+                               }
                                " and blinding point ".fmt(f)?;
                                hint.1.blinding_point.fmt(f)
                        },
@@ -1852,6 +1874,9 @@ where L::Target: Logger {
                return Err(LightningError{err: "Cannot send a payment of 0 msat".to_owned(), action: ErrorAction::IgnoreError});
        }
 
+       let introduction_node_id_cache = payment_params.payee.blinded_route_hints().iter()
+               .map(|(_, path)| path.public_introduction_node_id(network_graph))
+               .collect::<Vec<_>>();
        match &payment_params.payee {
                Payee::Clear { route_hints, node_id, .. } => {
                        for route in route_hints.iter() {
@@ -1863,17 +1888,19 @@ where L::Target: Logger {
                        }
                },
                Payee::Blinded { route_hints, .. } => {
-                       if route_hints.iter().all(|(_, path)| &path.introduction_node_id == our_node_pubkey) {
+                       if introduction_node_id_cache.iter().all(|introduction_node_id| *introduction_node_id == Some(&our_node_id)) {
                                return Err(LightningError{err: "Cannot generate a route to blinded paths if we are the introduction node to all of them".to_owned(), action: ErrorAction::IgnoreError});
                        }
-                       for (_, blinded_path) in route_hints.iter() {
+                       for ((_, blinded_path), introduction_node_id) in route_hints.iter().zip(introduction_node_id_cache.iter()) {
                                if blinded_path.blinded_hops.len() == 0 {
                                        return Err(LightningError{err: "0-hop blinded path provided".to_owned(), action: ErrorAction::IgnoreError});
-                               } else if &blinded_path.introduction_node_id == our_node_pubkey {
+                               } else if *introduction_node_id == Some(&our_node_id) {
                                        log_info!(logger, "Got blinded path with ourselves as the introduction node, ignoring");
                                } else if blinded_path.blinded_hops.len() == 1 &&
-                                       route_hints.iter().any( |(_, p)| p.blinded_hops.len() == 1
-                                               && p.introduction_node_id != blinded_path.introduction_node_id)
+                                       route_hints
+                                               .iter().zip(introduction_node_id_cache.iter())
+                                               .filter(|((_, p), _)| p.blinded_hops.len() == 1)
+                                               .any(|(_, p_introduction_node_id)| p_introduction_node_id != introduction_node_id)
                                {
                                        return Err(LightningError{err: format!("1-hop blinded paths must all have matching introduction node ids"), action: ErrorAction::IgnoreError});
                                }
@@ -2515,26 +2542,53 @@ where L::Target: Logger {
                // earlier than general path finding, they will be somewhat prioritized, although currently
                // it matters only if the fees are exactly the same.
                for (hint_idx, hint) in payment_params.payee.blinded_route_hints().iter().enumerate() {
-                       let intro_node_id = NodeId::from_pubkey(&hint.1.introduction_node_id);
-                       let have_intro_node_in_graph =
-                               // Only add the hops in this route to our candidate set if either
-                               // we have a direct channel to the first hop or the first hop is
-                               // in the regular network graph.
-                               first_hop_targets.get(&intro_node_id).is_some() ||
-                               network_nodes.get(&intro_node_id).is_some();
-                       if !have_intro_node_in_graph || our_node_id == intro_node_id { continue }
+                       // Only add the hops in this route to our candidate set if either
+                       // we have a direct channel to the first hop or the first hop is
+                       // in the regular network graph.
+                       let source_node_id = match introduction_node_id_cache[hint_idx] {
+                               Some(node_id) => node_id,
+                               None => match &hint.1.introduction_node {
+                                       IntroductionNode::NodeId(pubkey) => {
+                                               let node_id = NodeId::from_pubkey(&pubkey);
+                                               match first_hop_targets.get_key_value(&node_id).map(|(key, _)| key) {
+                                                       Some(node_id) => node_id,
+                                                       None => continue,
+                                               }
+                                       },
+                                       IntroductionNode::DirectedShortChannelId(direction, scid) => {
+                                               let first_hop = first_hop_targets.iter().find(|(_, channels)|
+                                                       channels
+                                                               .iter()
+                                                               .any(|details| Some(*scid) == details.get_outbound_payment_scid())
+                                               );
+                                               match first_hop {
+                                                       Some((counterparty_node_id, _)) => {
+                                                               direction.select_node_id(&our_node_id, counterparty_node_id)
+                                                       },
+                                                       None => continue,
+                                               }
+                                       },
+                               },
+                       };
+                       if our_node_id == *source_node_id { continue }
                        let candidate = if hint.1.blinded_hops.len() == 1 {
-                               CandidateRouteHop::OneHopBlinded(OneHopBlindedPathCandidate { hint, hint_idx })
-                       } else { CandidateRouteHop::Blinded(BlindedPathCandidate { hint, hint_idx }) };
+                               CandidateRouteHop::OneHopBlinded(
+                                       OneHopBlindedPathCandidate { source_node_id, hint, hint_idx }
+                               )
+                       } else {
+                               CandidateRouteHop::Blinded(BlindedPathCandidate { source_node_id, hint, hint_idx })
+                       };
                        let mut path_contribution_msat = path_value_msat;
                        if let Some(hop_used_msat) = add_entry!(&candidate,
                                0, path_contribution_msat, 0, 0_u64, 0, 0)
                        {
                                path_contribution_msat = hop_used_msat;
                        } else { continue }
-                       if let Some(first_channels) = first_hop_targets.get_mut(&NodeId::from_pubkey(&hint.1.introduction_node_id)) {
-                               sort_first_hop_channels(first_channels, &used_liquidities, recommended_value_msat,
-                                       our_node_pubkey);
+                       if let Some(first_channels) = first_hop_targets.get(source_node_id) {
+                               let mut first_channels = first_channels.clone();
+                               sort_first_hop_channels(
+                                       &mut first_channels, &used_liquidities, recommended_value_msat, our_node_pubkey
+                               );
                                for details in first_channels {
                                        let first_hop_candidate = CandidateRouteHop::FirstHop(FirstHopCandidate {
                                                details, payer_node_id: &our_node_id,
@@ -2630,9 +2684,11 @@ where L::Target: Logger {
                                                .saturating_add(1);
 
                                        // Searching for a direct channel between last checked hop and first_hop_targets
-                                       if let Some(first_channels) = first_hop_targets.get_mut(target) {
-                                               sort_first_hop_channels(first_channels, &used_liquidities,
-                                                       recommended_value_msat, our_node_pubkey);
+                                       if let Some(first_channels) = first_hop_targets.get(target) {
+                                               let mut first_channels = first_channels.clone();
+                                               sort_first_hop_channels(
+                                                       &mut first_channels, &used_liquidities, recommended_value_msat, our_node_pubkey
+                                               );
                                                for details in first_channels {
                                                        let first_hop_candidate = CandidateRouteHop::FirstHop(FirstHopCandidate {
                                                                details, payer_node_id: &our_node_id,
@@ -2677,9 +2733,11 @@ where L::Target: Logger {
                                                // Note that we *must* check if the last hop was added as `add_entry`
                                                // always assumes that the third argument is a node to which we have a
                                                // path.
-                                               if let Some(first_channels) = first_hop_targets.get_mut(&NodeId::from_pubkey(&hop.src_node_id)) {
-                                                       sort_first_hop_channels(first_channels, &used_liquidities,
-                                                               recommended_value_msat, our_node_pubkey);
+                                               if let Some(first_channels) = first_hop_targets.get(&NodeId::from_pubkey(&hop.src_node_id)) {
+                                                       let mut first_channels = first_channels.clone();
+                                                       sort_first_hop_channels(
+                                                               &mut first_channels, &used_liquidities, recommended_value_msat, our_node_pubkey
+                                                       );
                                                        for details in first_channels {
                                                                let first_hop_candidate = CandidateRouteHop::FirstHop(FirstHopCandidate {
                                                                        details, payer_node_id: &our_node_id,
@@ -3223,7 +3281,7 @@ fn build_route_from_hops_internal<L: Deref>(
 
 #[cfg(test)]
 mod tests {
-       use crate::blinded_path::{BlindedHop, BlindedPath};
+       use crate::blinded_path::{BlindedHop, BlindedPath, IntroductionNode};
        use crate::routing::gossip::{NetworkGraph, P2PGossipSync, NodeId, EffectiveCapacity};
        use crate::routing::utxo::UtxoResult;
        use crate::routing::router::{get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features,
@@ -5090,7 +5148,7 @@ mod tests {
                // MPP to a 1-hop blinded path for nodes[2]
                let bolt12_features = channelmanager::provided_bolt12_invoice_features(&config);
                let blinded_path = BlindedPath {
-                       introduction_node_id: nodes[2],
+                       introduction_node: IntroductionNode::NodeId(nodes[2]),
                        blinding_point: ln_test_utils::pubkey(42),
                        blinded_hops: vec![BlindedHop { blinded_node_id: ln_test_utils::pubkey(42 as u8), encrypted_payload: Vec::new() }],
                };
@@ -5108,18 +5166,18 @@ mod tests {
 
                // MPP to 3 2-hop blinded paths
                let mut blinded_path_node_0 = blinded_path.clone();
-               blinded_path_node_0.introduction_node_id = nodes[0];
+               blinded_path_node_0.introduction_node = IntroductionNode::NodeId(nodes[0]);
                blinded_path_node_0.blinded_hops.push(blinded_path.blinded_hops[0].clone());
                let mut node_0_payinfo = blinded_payinfo.clone();
                node_0_payinfo.htlc_maximum_msat = 50_000;
 
                let mut blinded_path_node_7 = blinded_path_node_0.clone();
-               blinded_path_node_7.introduction_node_id = nodes[7];
+               blinded_path_node_7.introduction_node = IntroductionNode::NodeId(nodes[7]);
                let mut node_7_payinfo = blinded_payinfo.clone();
                node_7_payinfo.htlc_maximum_msat = 60_000;
 
                let mut blinded_path_node_1 = blinded_path_node_0.clone();
-               blinded_path_node_1.introduction_node_id = nodes[1];
+               blinded_path_node_1.introduction_node = IntroductionNode::NodeId(nodes[1]);
                let mut node_1_payinfo = blinded_payinfo.clone();
                node_1_payinfo.htlc_maximum_msat = 180_000;
 
@@ -5301,10 +5359,15 @@ mod tests {
                                if let Some(bt) = &path.blinded_tail {
                                        assert_eq!(path.hops.len() + if bt.hops.len() == 1 { 0 } else { 1 }, 2);
                                        if bt.hops.len() > 1 {
-                                               assert_eq!(path.hops.last().unwrap().pubkey,
+                                               let network_graph = network_graph.read_only();
+                                               assert_eq!(
+                                                       NodeId::from_pubkey(&path.hops.last().unwrap().pubkey),
                                                        payment_params.payee.blinded_route_hints().iter()
                                                                .find(|(p, _)| p.htlc_maximum_msat == path.final_value_msat())
-                                                               .map(|(_, p)| p.introduction_node_id).unwrap());
+                                                               .and_then(|(_, p)| p.public_introduction_node_id(&network_graph))
+                                                               .copied()
+                                                               .unwrap()
+                                               );
                                        } else {
                                                assert_eq!(path.hops.last().unwrap().pubkey, nodes[2]);
                                        }
@@ -7200,7 +7263,7 @@ mod tests {
 
                // Make sure this works for blinded route hints.
                let blinded_path = BlindedPath {
-                       introduction_node_id: intermed_node_id,
+                       introduction_node: IntroductionNode::NodeId(intermed_node_id),
                        blinding_point: ln_test_utils::pubkey(42),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: ln_test_utils::pubkey(42), encrypted_payload: vec![] },
@@ -7234,7 +7297,7 @@ mod tests {
        #[test]
        fn blinded_route_ser() {
                let blinded_path_1 = BlindedPath {
-                       introduction_node_id: ln_test_utils::pubkey(42),
+                       introduction_node: IntroductionNode::NodeId(ln_test_utils::pubkey(42)),
                        blinding_point: ln_test_utils::pubkey(43),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: ln_test_utils::pubkey(44), encrypted_payload: Vec::new() },
@@ -7242,7 +7305,7 @@ mod tests {
                        ],
                };
                let blinded_path_2 = BlindedPath {
-                       introduction_node_id: ln_test_utils::pubkey(46),
+                       introduction_node: IntroductionNode::NodeId(ln_test_utils::pubkey(46)),
                        blinding_point: ln_test_utils::pubkey(47),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: ln_test_utils::pubkey(48), encrypted_payload: Vec::new() },
@@ -7301,7 +7364,7 @@ mod tests {
                // account for the blinded tail's final amount_msat.
                let mut inflight_htlcs = InFlightHtlcs::new();
                let blinded_path = BlindedPath {
-                       introduction_node_id: ln_test_utils::pubkey(43),
+                       introduction_node: IntroductionNode::NodeId(ln_test_utils::pubkey(43)),
                        blinding_point: ln_test_utils::pubkey(48),
                        blinded_hops: vec![BlindedHop { blinded_node_id: ln_test_utils::pubkey(49), encrypted_payload: Vec::new() }],
                };
@@ -7316,7 +7379,7 @@ mod tests {
                                maybe_announced_channel: false,
                        },
                        RouteHop {
-                               pubkey: blinded_path.introduction_node_id,
+                               pubkey: ln_test_utils::pubkey(43),
                                node_features: NodeFeatures::empty(),
                                short_channel_id: 43,
                                channel_features: ChannelFeatures::empty(),
@@ -7340,7 +7403,7 @@ mod tests {
        fn blinded_path_cltv_shadow_offset() {
                // Make sure we add a shadow offset when sending to blinded paths.
                let blinded_path = BlindedPath {
-                       introduction_node_id: ln_test_utils::pubkey(43),
+                       introduction_node: IntroductionNode::NodeId(ln_test_utils::pubkey(43)),
                        blinding_point: ln_test_utils::pubkey(44),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: ln_test_utils::pubkey(45), encrypted_payload: Vec::new() },
@@ -7358,7 +7421,7 @@ mod tests {
                                maybe_announced_channel: false,
                        },
                        RouteHop {
-                               pubkey: blinded_path.introduction_node_id,
+                               pubkey: ln_test_utils::pubkey(43),
                                node_features: NodeFeatures::empty(),
                                short_channel_id: 43,
                                channel_features: ChannelFeatures::empty(),
@@ -7400,7 +7463,7 @@ mod tests {
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
                let mut blinded_path = BlindedPath {
-                       introduction_node_id: nodes[2],
+                       introduction_node: IntroductionNode::NodeId(nodes[2]),
                        blinding_point: ln_test_utils::pubkey(42),
                        blinded_hops: Vec::with_capacity(num_blinded_hops),
                };
@@ -7432,7 +7495,10 @@ mod tests {
                assert_eq!(tail.final_value_msat, 1001);
 
                let final_hop = route.paths[0].hops.last().unwrap();
-               assert_eq!(final_hop.pubkey, blinded_path.introduction_node_id);
+               assert_eq!(
+                       NodeId::from_pubkey(&final_hop.pubkey),
+                       *blinded_path.public_introduction_node_id(&network_graph).unwrap()
+               );
                if tail.hops.len() > 1 {
                        assert_eq!(final_hop.fee_msat,
                                blinded_payinfo.fee_base_msat as u64 + blinded_payinfo.fee_proportional_millionths as u64 * tail.final_value_msat / 1000000);
@@ -7455,7 +7521,7 @@ mod tests {
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
                let mut invalid_blinded_path = BlindedPath {
-                       introduction_node_id: nodes[2],
+                       introduction_node: IntroductionNode::NodeId(nodes[2]),
                        blinding_point: ln_test_utils::pubkey(42),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: ln_test_utils::pubkey(43), encrypted_payload: vec![0; 43] },
@@ -7471,7 +7537,7 @@ mod tests {
                };
 
                let mut invalid_blinded_path_2 = invalid_blinded_path.clone();
-               invalid_blinded_path_2.introduction_node_id = ln_test_utils::pubkey(45);
+               invalid_blinded_path_2.introduction_node = IntroductionNode::NodeId(ln_test_utils::pubkey(45));
                let payment_params = PaymentParameters::blinded(vec![
                        (blinded_payinfo.clone(), invalid_blinded_path.clone()),
                        (blinded_payinfo.clone(), invalid_blinded_path_2)]);
@@ -7485,7 +7551,7 @@ mod tests {
                        _ => panic!("Expected error")
                }
 
-               invalid_blinded_path.introduction_node_id = our_id;
+               invalid_blinded_path.introduction_node = IntroductionNode::NodeId(our_id);
                let payment_params = PaymentParameters::blinded(vec![(blinded_payinfo.clone(), invalid_blinded_path.clone())]);
                let route_params = RouteParameters::from_payment_params_and_value(payment_params, 1001);
                match get_route(&our_id, &route_params, &network_graph, None, Arc::clone(&logger), &scorer,
@@ -7497,7 +7563,7 @@ mod tests {
                        _ => panic!("Expected error")
                }
 
-               invalid_blinded_path.introduction_node_id = ln_test_utils::pubkey(46);
+               invalid_blinded_path.introduction_node = IntroductionNode::NodeId(ln_test_utils::pubkey(46));
                invalid_blinded_path.blinded_hops.clear();
                let payment_params = PaymentParameters::blinded(vec![(blinded_payinfo, invalid_blinded_path)]);
                let route_params = RouteParameters::from_payment_params_and_value(payment_params, 1001);
@@ -7526,7 +7592,7 @@ mod tests {
 
                let bolt12_features = channelmanager::provided_bolt12_invoice_features(&config);
                let blinded_path_1 = BlindedPath {
-                       introduction_node_id: nodes[2],
+                       introduction_node: IntroductionNode::NodeId(nodes[2]),
                        blinding_point: ln_test_utils::pubkey(42),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: ln_test_utils::pubkey(42 as u8), encrypted_payload: Vec::new() },
@@ -7623,7 +7689,7 @@ mod tests {
                        get_channel_details(Some(1), nodes[1], InitFeatures::from_le_bytes(vec![0b11]), 10_000_000)];
 
                let blinded_path = BlindedPath {
-                       introduction_node_id: nodes[1],
+                       introduction_node: IntroductionNode::NodeId(nodes[1]),
                        blinding_point: ln_test_utils::pubkey(42),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: ln_test_utils::pubkey(42 as u8), encrypted_payload: Vec::new() },
@@ -7692,7 +7758,7 @@ mod tests {
                                18446744073709551615)];
 
                let blinded_path = BlindedPath {
-                       introduction_node_id: nodes[1],
+                       introduction_node: IntroductionNode::NodeId(nodes[1]),
                        blinding_point: ln_test_utils::pubkey(42),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: ln_test_utils::pubkey(42 as u8), encrypted_payload: Vec::new() },
@@ -7748,7 +7814,7 @@ mod tests {
                let amt_msat = 21_7020_5185_1423_0019;
 
                let blinded_path = BlindedPath {
-                       introduction_node_id: our_id,
+                       introduction_node: IntroductionNode::NodeId(our_id),
                        blinding_point: ln_test_utils::pubkey(42),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: ln_test_utils::pubkey(42 as u8), encrypted_payload: Vec::new() },
@@ -7767,7 +7833,7 @@ mod tests {
                        (blinded_payinfo.clone(), blinded_path.clone()),
                        (blinded_payinfo.clone(), blinded_path.clone()),
                ];
-               blinded_hints[1].1.introduction_node_id = nodes[6];
+               blinded_hints[1].1.introduction_node = IntroductionNode::NodeId(nodes[6]);
 
                let bolt12_features = channelmanager::provided_bolt12_invoice_features(&config);
                let payment_params = PaymentParameters::blinded(blinded_hints.clone())
@@ -7800,7 +7866,7 @@ mod tests {
                let amt_msat = 21_7020_5185_1423_0019;
 
                let blinded_path = BlindedPath {
-                       introduction_node_id: our_id,
+                       introduction_node: IntroductionNode::NodeId(our_id),
                        blinding_point: ln_test_utils::pubkey(42),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: ln_test_utils::pubkey(42 as u8), encrypted_payload: Vec::new() },
@@ -7824,7 +7890,7 @@ mod tests {
                blinded_hints[1].0.htlc_minimum_msat = 21_7020_5185_1423_0019;
                blinded_hints[1].0.htlc_maximum_msat = 1844_6744_0737_0955_1615;
 
-               blinded_hints[2].1.introduction_node_id = nodes[6];
+               blinded_hints[2].1.introduction_node = IntroductionNode::NodeId(nodes[6]);
 
                let bolt12_features = channelmanager::provided_bolt12_invoice_features(&config);
                let payment_params = PaymentParameters::blinded(blinded_hints.clone())
@@ -7871,7 +7937,7 @@ mod tests {
                let htlc_min = 2_5165_8240;
                let payment_params = if blinded_payee {
                        let blinded_path = BlindedPath {
-                               introduction_node_id: nodes[0],
+                               introduction_node: IntroductionNode::NodeId(nodes[0]),
                                blinding_point: ln_test_utils::pubkey(42),
                                blinded_hops: vec![
                                        BlindedHop { blinded_node_id: ln_test_utils::pubkey(42 as u8), encrypted_payload: Vec::new() },
@@ -7951,7 +8017,7 @@ mod tests {
                let htlc_mins = [1_4392, 19_7401, 1027, 6_5535];
                let payment_params = if blinded_payee {
                        let blinded_path = BlindedPath {
-                               introduction_node_id: nodes[0],
+                               introduction_node: IntroductionNode::NodeId(nodes[0]),
                                blinding_point: ln_test_utils::pubkey(42),
                                blinded_hops: vec![
                                        BlindedHop { blinded_node_id: ln_test_utils::pubkey(42 as u8), encrypted_payload: Vec::new() },
@@ -8052,7 +8118,7 @@ mod tests {
                                cltv_expiry_delta: 10,
                                features: BlindedHopFeatures::empty(),
                        }, BlindedPath {
-                               introduction_node_id: nodes[0],
+                               introduction_node: IntroductionNode::NodeId(nodes[0]),
                                blinding_point: ln_test_utils::pubkey(42),
                                blinded_hops: vec![
                                        BlindedHop { blinded_node_id: ln_test_utils::pubkey(42 as u8), encrypted_payload: Vec::new() },
@@ -8102,7 +8168,7 @@ mod tests {
                let htlc_mins = [49_0000, 1125_0000];
                let payment_params = {
                        let blinded_path = BlindedPath {
-                               introduction_node_id: nodes[0],
+                               introduction_node: IntroductionNode::NodeId(nodes[0]),
                                blinding_point: ln_test_utils::pubkey(42),
                                blinded_hops: vec![
                                        BlindedHop { blinded_node_id: ln_test_utils::pubkey(42 as u8), encrypted_payload: Vec::new() },
index 4850479b8992905cc20e6d1e219dab6251081b28..4cb9144d3394ed94896b861e5ac847994776f122 100644 (file)
@@ -2152,7 +2152,7 @@ impl Readable for ChannelLiquidity {
 #[cfg(test)]
 mod tests {
        use super::{ChannelLiquidity, HistoricalBucketRangeTracker, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters, ProbabilisticScorer};
-       use crate::blinded_path::{BlindedHop, BlindedPath};
+       use crate::blinded_path::{BlindedHop, BlindedPath, IntroductionNode};
        use crate::util::config::UserConfig;
 
        use crate::ln::channelmanager;
@@ -3567,7 +3567,7 @@ mod tests {
                let mut path = payment_path_for_amount(768);
                let recipient_hop = path.hops.pop().unwrap();
                let blinded_path = BlindedPath {
-                       introduction_node_id: path.hops.last().as_ref().unwrap().pubkey,
+                       introduction_node: IntroductionNode::NodeId(path.hops.last().as_ref().unwrap().pubkey),
                        blinding_point: test_utils::pubkey(42),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: test_utils::pubkey(44), encrypted_payload: Vec::new() }
index 23266c13bd473cc78007f0b7eff8e3b277d1419a..9b5efee4b9695cc10a3df1cef99e4d4328dc05b1 100644 (file)
@@ -805,6 +805,28 @@ pub trait NodeSigner {
        fn sign_gossip_message(&self, msg: UnsignedGossipMessage) -> Result<Signature, ()>;
 }
 
+/// A trait that describes a wallet capable of creating a spending [`Transaction`] from a set of
+/// [`SpendableOutputDescriptor`]s.
+pub trait OutputSpender {
+       /// Creates a [`Transaction`] which spends the given descriptors to the given outputs, plus an
+       /// output to the given change destination (if sufficient change value remains). The
+       /// transaction will have a feerate, at least, of the given value.
+       ///
+       /// The `locktime` argument is used to set the transaction's locktime. If `None`, the
+       /// transaction will have a locktime of 0. It it recommended to set this to the current block
+       /// height to avoid fee sniping, unless you have some specific reason to use a different
+       /// locktime.
+       ///
+       /// Returns `Err(())` if the output value is greater than the input value minus required fee,
+       /// if a descriptor was duplicated, or if an output descriptor `script_pubkey`
+       /// does not match the one we can spend.
+       fn spend_spendable_outputs<C: Signing>(
+               &self, descriptors: &[&SpendableOutputDescriptor], outputs: Vec<TxOut>,
+               change_destination_script: ScriptBuf, feerate_sat_per_1000_weight: u32,
+               locktime: Option<LockTime>, secp_ctx: &Secp256k1<C>,
+       ) -> Result<Transaction, ()>;
+}
+
 // Primarily needed in doctests because of https://github.com/rust-lang/rust/issues/67295
 /// A dynamic [`SignerProvider`] temporarily needed for doc tests.
 #[cfg(taproot)]
@@ -882,6 +904,17 @@ pub trait SignerProvider {
        fn get_shutdown_scriptpubkey(&self) -> Result<ShutdownScript, ()>;
 }
 
+/// A helper trait that describes an on-chain wallet capable of returning a (change) destination
+/// script.
+pub trait ChangeDestinationSource {
+       /// Returns a script pubkey which can be used as a change destination for
+       /// [`OutputSpender::spend_spendable_outputs`].
+       ///
+       /// This method should return a different value each time it is called, to avoid linking
+       /// on-chain funds controlled to the same user.
+       fn get_change_destination_script(&self) -> Result<ScriptBuf, ()>;
+}
+
 /// A simple implementation of [`WriteableEcdsaChannelSigner`] that just keeps the private keys in memory.
 ///
 /// This implementation performs no policy checks and is insufficient by itself as
@@ -1991,50 +2024,6 @@ impl KeysManager {
 
                Ok(psbt)
        }
-
-       /// Creates a [`Transaction`] which spends the given descriptors to the given outputs, plus an
-       /// output to the given change destination (if sufficient change value remains). The
-       /// transaction will have a feerate, at least, of the given value.
-       ///
-       /// The `locktime` argument is used to set the transaction's locktime. If `None`, the
-       /// transaction will have a locktime of 0. It it recommended to set this to the current block
-       /// height to avoid fee sniping, unless you have some specific reason to use a different
-       /// locktime.
-       ///
-       /// Returns `Err(())` if the output value is greater than the input value minus required fee,
-       /// if a descriptor was duplicated, or if an output descriptor `script_pubkey`
-       /// does not match the one we can spend.
-       ///
-       /// We do not enforce that outputs meet the dust limit or that any output scripts are standard.
-       ///
-       /// May panic if the [`SpendableOutputDescriptor`]s were not generated by channels which used
-       /// this [`KeysManager`] or one of the [`InMemorySigner`] created by this [`KeysManager`].
-       pub fn spend_spendable_outputs<C: Signing>(
-               &self, descriptors: &[&SpendableOutputDescriptor], outputs: Vec<TxOut>,
-               change_destination_script: ScriptBuf, feerate_sat_per_1000_weight: u32,
-               locktime: Option<LockTime>, secp_ctx: &Secp256k1<C>,
-       ) -> Result<Transaction, ()> {
-               let (mut psbt, expected_max_weight) =
-                       SpendableOutputDescriptor::create_spendable_outputs_psbt(
-                               descriptors,
-                               outputs,
-                               change_destination_script,
-                               feerate_sat_per_1000_weight,
-                               locktime,
-                       )?;
-               psbt = self.sign_spendable_outputs_psbt(descriptors, psbt, secp_ctx)?;
-
-               let spend_tx = psbt.extract_tx();
-
-               debug_assert!(expected_max_weight >= spend_tx.weight().to_wu());
-               // Note that witnesses with a signature vary somewhat in size, so allow
-               // `expected_max_weight` to overshoot by up to 3 bytes per input.
-               debug_assert!(
-                       expected_max_weight <= spend_tx.weight().to_wu() + descriptors.len() as u64 * 3
-               );
-
-               Ok(spend_tx)
-       }
 }
 
 impl EntropySource for KeysManager {
@@ -2106,6 +2095,44 @@ impl NodeSigner for KeysManager {
        }
 }
 
+impl OutputSpender for KeysManager {
+       /// Creates a [`Transaction`] which spends the given descriptors to the given outputs, plus an
+       /// output to the given change destination (if sufficient change value remains).
+       ///
+       /// See [`OutputSpender::spend_spendable_outputs`] documentation for more information.
+       ///
+       /// We do not enforce that outputs meet the dust limit or that any output scripts are standard.
+       ///
+       /// May panic if the [`SpendableOutputDescriptor`]s were not generated by channels which used
+       /// this [`KeysManager`] or one of the [`InMemorySigner`] created by this [`KeysManager`].
+       fn spend_spendable_outputs<C: Signing>(
+               &self, descriptors: &[&SpendableOutputDescriptor], outputs: Vec<TxOut>,
+               change_destination_script: ScriptBuf, feerate_sat_per_1000_weight: u32,
+               locktime: Option<LockTime>, secp_ctx: &Secp256k1<C>,
+       ) -> Result<Transaction, ()> {
+               let (mut psbt, expected_max_weight) =
+                       SpendableOutputDescriptor::create_spendable_outputs_psbt(
+                               descriptors,
+                               outputs,
+                               change_destination_script,
+                               feerate_sat_per_1000_weight,
+                               locktime,
+                       )?;
+               psbt = self.sign_spendable_outputs_psbt(descriptors, psbt, secp_ctx)?;
+
+               let spend_tx = psbt.extract_tx();
+
+               debug_assert!(expected_max_weight >= spend_tx.weight().to_wu());
+               // Note that witnesses with a signature vary somewhat in size, so allow
+               // `expected_max_weight` to overshoot by up to 3 bytes per input.
+               debug_assert!(
+                       expected_max_weight <= spend_tx.weight().to_wu() + descriptors.len() as u64 * 3
+               );
+
+               Ok(spend_tx)
+       }
+}
+
 impl SignerProvider for KeysManager {
        type EcdsaSigner = InMemorySigner;
        #[cfg(taproot)]
@@ -2238,6 +2265,25 @@ impl NodeSigner for PhantomKeysManager {
        }
 }
 
+impl OutputSpender for PhantomKeysManager {
+       /// See [`OutputSpender::spend_spendable_outputs`] and [`KeysManager::spend_spendable_outputs`]
+       /// for documentation on this method.
+       fn spend_spendable_outputs<C: Signing>(
+               &self, descriptors: &[&SpendableOutputDescriptor], outputs: Vec<TxOut>,
+               change_destination_script: ScriptBuf, feerate_sat_per_1000_weight: u32,
+               locktime: Option<LockTime>, secp_ctx: &Secp256k1<C>,
+       ) -> Result<Transaction, ()> {
+               self.inner.spend_spendable_outputs(
+                       descriptors,
+                       outputs,
+                       change_destination_script,
+                       feerate_sat_per_1000_weight,
+                       locktime,
+                       secp_ctx,
+               )
+       }
+}
+
 impl SignerProvider for PhantomKeysManager {
        type EcdsaSigner = InMemorySigner;
        #[cfg(taproot)]
@@ -2299,22 +2345,6 @@ impl PhantomKeysManager {
                }
        }
 
-       /// See [`KeysManager::spend_spendable_outputs`] for documentation on this method.
-       pub fn spend_spendable_outputs<C: Signing>(
-               &self, descriptors: &[&SpendableOutputDescriptor], outputs: Vec<TxOut>,
-               change_destination_script: ScriptBuf, feerate_sat_per_1000_weight: u32,
-               locktime: Option<LockTime>, secp_ctx: &Secp256k1<C>,
-       ) -> Result<Transaction, ()> {
-               self.inner.spend_spendable_outputs(
-                       descriptors,
-                       outputs,
-                       change_destination_script,
-                       feerate_sat_per_1000_weight,
-                       locktime,
-                       secp_ctx,
-               )
-       }
-
        /// See [`KeysManager::derive_channel_keys`] for documentation on this method.
        pub fn derive_channel_keys(
                &self, channel_value_satoshis: u64, params: &[u8; 32],
index 4f694bd2b6ec723e5dcaed614db36bdad0ea2aeb..97788ffe68acbd5f90d67bd501710c3b5942ce51 100644 (file)
@@ -56,6 +56,11 @@ impl<K: Clone + Hash + Ord, V> IndexedMap<K, V> {
                self.map.get_mut(key)
        }
 
+       /// Fetches the key-value pair corresponding to the supplied key, if one exists.
+       pub fn get_key_value(&self, key: &K) -> Option<(&K, &V)> {
+               self.map.get_key_value(key)
+       }
+
        #[inline]
        /// Returns true if an element with the given `key` exists in the map.
        pub fn contains_key(&self, key: &K) -> bool {
index 31bdf1ca53c2562836e49cb5721e5a0950f2c070..c1ab8c75c2ee70efb51e7bb214ee76a8d06bef03 100644 (file)
@@ -22,6 +22,7 @@ pub mod invoice;
 pub mod persist;
 pub mod scid_utils;
 pub mod string;
+pub mod sweep;
 pub mod wakers;
 #[cfg(fuzzing)]
 pub mod base32;
index 3f918935f16700008cae3b2854babeae60bbd740..249a089cd4883170be76bf8574855379ab48c643 100644 (file)
@@ -75,6 +75,20 @@ pub const SCORER_PERSISTENCE_SECONDARY_NAMESPACE: &str = "";
 /// The key under which the [`WriteableScore`] will be persisted.
 pub const SCORER_PERSISTENCE_KEY: &str = "scorer";
 
+/// The primary namespace under which [`OutputSweeper`] state will be persisted.
+///
+/// [`OutputSweeper`]: crate::util::sweep::OutputSweeper
+pub const OUTPUT_SWEEPER_PERSISTENCE_PRIMARY_NAMESPACE: &str = "";
+/// The secondary namespace under which [`OutputSweeper`] state will be persisted.
+///
+/// [`OutputSweeper`]: crate::util::sweep::OutputSweeper
+pub const OUTPUT_SWEEPER_PERSISTENCE_SECONDARY_NAMESPACE: &str = "";
+/// The secondary namespace under which [`OutputSweeper`] state will be persisted.
+/// The key under which [`OutputSweeper`] state will be persisted.
+///
+/// [`OutputSweeper`]: crate::util::sweep::OutputSweeper
+pub const OUTPUT_SWEEPER_PERSISTENCE_KEY: &str = "output_sweeper";
+
 /// A sentinel value to be prepended to monitors persisted by the [`MonitorUpdatingPersister`].
 ///
 /// This serves to prevent someone from accidentally loading such monitors (which may need
index df030d0b01eb4452ccf8e7b391e8c738bb9af453..740b7c12561ce8466d1a2e30f56faacc797961ac 100644 (file)
@@ -1148,7 +1148,7 @@ mod tests {
 
        use crate::io::{self, Cursor};
        use crate::ln::msgs::DecodeError;
-       use crate::util::ser::{Writeable, HighZeroBytesDroppedBigSize, VecWriter};
+       use crate::util::ser::{MaybeReadable, Readable, Writeable, HighZeroBytesDroppedBigSize, VecWriter};
        use bitcoin::hashes::hex::FromHex;
        use bitcoin::secp256k1::PublicKey;
 
@@ -1258,6 +1258,131 @@ mod tests {
                } else { panic!(); }
        }
 
+       /// A "V1" enum with only one variant
+       enum InnerEnumV1 {
+               StructVariantA {
+                       field: u32,
+               },
+       }
+
+       impl_writeable_tlv_based_enum_upgradable!(InnerEnumV1,
+               (0, StructVariantA) => {
+                       (0, field, required),
+               },
+       );
+
+       struct OuterStructOptionalEnumV1 {
+               inner_enum: Option<InnerEnumV1>,
+               other_field: u32,
+       }
+
+       impl_writeable_tlv_based!(OuterStructOptionalEnumV1, {
+               (0, inner_enum, upgradable_option),
+               (2, other_field, required),
+       });
+
+       /// An upgraded version of [`InnerEnumV1`] that added a second variant
+       enum InnerEnumV2 {
+               StructVariantA {
+                       field: u32,
+               },
+               StructVariantB {
+                       field2: u64,
+               }
+       }
+
+       impl_writeable_tlv_based_enum_upgradable!(InnerEnumV2,
+               (0, StructVariantA) => {
+                       (0, field, required),
+               },
+               (1, StructVariantB) => {
+                       (0, field2, required),
+               },
+       );
+
+       struct OuterStructOptionalEnumV2 {
+               inner_enum: Option<InnerEnumV2>,
+               other_field: u32,
+       }
+
+       impl_writeable_tlv_based!(OuterStructOptionalEnumV2, {
+               (0, inner_enum, upgradable_option),
+               (2, other_field, required),
+       });
+
+       #[test]
+       fn upgradable_enum_option() {
+               // Test downgrading from `OuterStructOptionalEnumV2` to `OuterStructOptionalEnumV1` and
+               // ensure we still read the `other_field` just fine.
+               let serialized_bytes = OuterStructOptionalEnumV2 {
+                       inner_enum: Some(InnerEnumV2::StructVariantB { field2: 64 }),
+                       other_field: 0x1bad1dea,
+               }.encode();
+               let mut s = Cursor::new(serialized_bytes);
+
+               let outer_struct: OuterStructOptionalEnumV1 = Readable::read(&mut s).unwrap();
+               assert!(outer_struct.inner_enum.is_none());
+               assert_eq!(outer_struct.other_field, 0x1bad1dea);
+       }
+
+       /// A struct that is read with an [`InnerEnumV1`] but is written with an [`InnerEnumV2`].
+       struct OuterStructRequiredEnum {
+               #[allow(unused)]
+               inner_enum: InnerEnumV1,
+       }
+
+       impl MaybeReadable for OuterStructRequiredEnum {
+               fn read<R: io::Read>(reader: &mut R) -> Result<Option<Self>, DecodeError> {
+                       let mut inner_enum = crate::util::ser::UpgradableRequired(None);
+                       read_tlv_fields!(reader, {
+                               (0, inner_enum, upgradable_required),
+                       });
+                       Ok(Some(Self {
+                               inner_enum: inner_enum.0.unwrap(),
+                       }))
+               }
+       }
+
+       impl Writeable for OuterStructRequiredEnum {
+               fn write<W: crate::util::ser::Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
+                       write_tlv_fields!(writer, {
+                               (0, InnerEnumV2::StructVariantB { field2: 0xdeadbeef }, required),
+                       });
+                       Ok(())
+               }
+       }
+
+       struct OuterOuterStruct {
+               outer_struct: Option<OuterStructRequiredEnum>,
+               other_field: u32,
+       }
+
+       impl_writeable_tlv_based!(OuterOuterStruct, {
+               (0, outer_struct, upgradable_option),
+               (2, other_field, required),
+       });
+
+
+       #[test]
+       fn upgradable_enum_required() {
+               // Test downgrading from an `OuterOuterStruct` (i.e. test downgrading an
+               // `upgradable_required` `InnerEnumV2` to an `InnerEnumV1`).
+               //
+               // Note that `OuterStructRequiredEnum` has a split write/read implementation that writes an
+               // `InnerEnumV2::StructVariantB` irrespective of the value of `inner_enum`.
+
+               let dummy_inner_enum = InnerEnumV1::StructVariantA { field: 42 };
+               let serialized_bytes = OuterOuterStruct {
+                       outer_struct: Some(OuterStructRequiredEnum { inner_enum: dummy_inner_enum }),
+                       other_field: 0x1bad1dea,
+               }.encode();
+               let mut s = Cursor::new(serialized_bytes);
+
+               let outer_outer_struct: OuterOuterStruct = Readable::read(&mut s).unwrap();
+               assert!(outer_outer_struct.outer_struct.is_none());
+               assert_eq!(outer_outer_struct.other_field, 0x1bad1dea);
+       }
+
        // BOLT TLV test cases
        fn tlv_reader_n1(s: &[u8]) -> Result<(Option<HighZeroBytesDroppedBigSize<u64>>, Option<u64>, Option<(PublicKey, u64, u64)>, Option<u16>), DecodeError> {
                let mut s = Cursor::new(s);
diff --git a/lightning/src/util/sweep.rs b/lightning/src/util/sweep.rs
new file mode 100644 (file)
index 0000000..59d3a08
--- /dev/null
@@ -0,0 +1,875 @@
+// This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
+// or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
+// You may not use this file except in accordance with one or both of these
+// licenses.
+
+//! This module contains an [`OutputSweeper`] utility that keeps track of
+//! [`SpendableOutputDescriptor`]s, i.e., persists them in a given [`KVStore`] and regularly retries
+//! sweeping them.
+
+use crate::chain::chaininterface::{BroadcasterInterface, ConfirmationTarget, FeeEstimator};
+use crate::chain::channelmonitor::ANTI_REORG_DELAY;
+use crate::chain::{self, BestBlock, Confirm, Filter, Listen, WatchedOutput};
+use crate::io;
+use crate::ln::msgs::DecodeError;
+use crate::ln::ChannelId;
+use crate::prelude::Vec;
+use crate::sign::{ChangeDestinationSource, OutputSpender, SpendableOutputDescriptor};
+use crate::sync::Mutex;
+use crate::util::logger::Logger;
+use crate::util::persist::{
+       KVStore, OUTPUT_SWEEPER_PERSISTENCE_KEY, OUTPUT_SWEEPER_PERSISTENCE_PRIMARY_NAMESPACE,
+       OUTPUT_SWEEPER_PERSISTENCE_SECONDARY_NAMESPACE,
+};
+use crate::util::ser::{Readable, ReadableArgs, Writeable};
+use crate::{impl_writeable_tlv_based, log_debug, log_error};
+
+use bitcoin::blockdata::block::Header;
+use bitcoin::blockdata::locktime::absolute::LockTime;
+use bitcoin::secp256k1::Secp256k1;
+use bitcoin::{BlockHash, Transaction, Txid};
+
+use core::ops::Deref;
+
+/// The state of a spendable output currently tracked by an [`OutputSweeper`].
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct TrackedSpendableOutput {
+       /// The tracked output descriptor.
+       pub descriptor: SpendableOutputDescriptor,
+       /// The channel this output belongs to.
+       ///
+       /// Will be `None` if no `channel_id` was given to [`OutputSweeper::track_spendable_outputs`]
+       pub channel_id: Option<ChannelId>,
+       /// The current status of the output spend.
+       pub status: OutputSpendStatus,
+}
+
+impl TrackedSpendableOutput {
+       fn to_watched_output(&self, cur_hash: BlockHash) -> WatchedOutput {
+               let block_hash = self.status.first_broadcast_hash().or(Some(cur_hash));
+               match &self.descriptor {
+                       SpendableOutputDescriptor::StaticOutput { outpoint, output, channel_keys_id: _ } => {
+                               WatchedOutput {
+                                       block_hash,
+                                       outpoint: *outpoint,
+                                       script_pubkey: output.script_pubkey.clone(),
+                               }
+                       },
+                       SpendableOutputDescriptor::DelayedPaymentOutput(output) => WatchedOutput {
+                               block_hash,
+                               outpoint: output.outpoint,
+                               script_pubkey: output.output.script_pubkey.clone(),
+                       },
+                       SpendableOutputDescriptor::StaticPaymentOutput(output) => WatchedOutput {
+                               block_hash,
+                               outpoint: output.outpoint,
+                               script_pubkey: output.output.script_pubkey.clone(),
+                       },
+               }
+       }
+
+       /// Returns whether the output is spent in the given transaction.
+       pub fn is_spent_in(&self, tx: &Transaction) -> bool {
+               let prev_outpoint = match &self.descriptor {
+                       SpendableOutputDescriptor::StaticOutput { outpoint, .. } => *outpoint,
+                       SpendableOutputDescriptor::DelayedPaymentOutput(output) => output.outpoint,
+                       SpendableOutputDescriptor::StaticPaymentOutput(output) => output.outpoint,
+               }
+               .into_bitcoin_outpoint();
+
+               tx.input.iter().any(|input| input.previous_output == prev_outpoint)
+       }
+}
+
+impl_writeable_tlv_based!(TrackedSpendableOutput, {
+       (0, descriptor, required),
+       (2, channel_id, option),
+       (4, status, required),
+});
+
+/// The current status of the output spend.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum OutputSpendStatus {
+       /// The output is tracked but an initial spending transaction hasn't been generated and
+       /// broadcasted yet.
+       PendingInitialBroadcast {
+               /// The height at which we will first generate and broadcast a spending transaction.
+               delayed_until_height: Option<u32>,
+       },
+       /// A transaction spending the output has been broadcasted but is pending its first confirmation on-chain.
+       PendingFirstConfirmation {
+               /// The hash of the chain tip when we first broadcast a transaction spending this output.
+               first_broadcast_hash: BlockHash,
+               /// The best height when we last broadcast a transaction spending this output.
+               latest_broadcast_height: u32,
+               /// The transaction spending this output we last broadcasted.
+               latest_spending_tx: Transaction,
+       },
+       /// A transaction spending the output has been confirmed on-chain but will be tracked until it
+       /// reaches [`ANTI_REORG_DELAY`] confirmations.
+       PendingThresholdConfirmations {
+               /// The hash of the chain tip when we first broadcast a transaction spending this output.
+               first_broadcast_hash: BlockHash,
+               /// The best height when we last broadcast a transaction spending this output.
+               latest_broadcast_height: u32,
+               /// The transaction spending this output we saw confirmed on-chain.
+               latest_spending_tx: Transaction,
+               /// The height at which the spending transaction was confirmed.
+               confirmation_height: u32,
+               /// The hash of the block in which the spending transaction was confirmed.
+               confirmation_hash: BlockHash,
+       },
+}
+
+impl OutputSpendStatus {
+       fn broadcast(&mut self, cur_hash: BlockHash, cur_height: u32, latest_spending_tx: Transaction) {
+               match self {
+                       Self::PendingInitialBroadcast { delayed_until_height } => {
+                               if let Some(delayed_until_height) = delayed_until_height {
+                                       debug_assert!(
+                                               cur_height >= *delayed_until_height,
+                                               "We should never broadcast before the required height is reached."
+                                       );
+                               }
+                               *self = Self::PendingFirstConfirmation {
+                                       first_broadcast_hash: cur_hash,
+                                       latest_broadcast_height: cur_height,
+                                       latest_spending_tx,
+                               };
+                       },
+                       Self::PendingFirstConfirmation { first_broadcast_hash, .. } => {
+                               *self = Self::PendingFirstConfirmation {
+                                       first_broadcast_hash: *first_broadcast_hash,
+                                       latest_broadcast_height: cur_height,
+                                       latest_spending_tx,
+                               };
+                       },
+                       Self::PendingThresholdConfirmations { .. } => {
+                               debug_assert!(false, "We should never rebroadcast confirmed transactions.");
+                       },
+               }
+       }
+
+       fn confirmed(
+               &mut self, confirmation_hash: BlockHash, confirmation_height: u32,
+               latest_spending_tx: Transaction,
+       ) {
+               match self {
+                       Self::PendingInitialBroadcast { .. } => {
+                               // Generally we can't see any of our transactions confirmed if they haven't been
+                               // broadcasted yet, so this should never be reachable via `transactions_confirmed`.
+                               debug_assert!(false, "We should never confirm when we haven't broadcasted. This a bug and should never happen, please report.");
+                               *self = Self::PendingThresholdConfirmations {
+                                       first_broadcast_hash: confirmation_hash,
+                                       latest_broadcast_height: confirmation_height,
+                                       latest_spending_tx,
+                                       confirmation_height,
+                                       confirmation_hash,
+                               };
+                       },
+                       Self::PendingFirstConfirmation {
+                               first_broadcast_hash,
+                               latest_broadcast_height,
+                               ..
+                       } => {
+                               debug_assert!(confirmation_height >= *latest_broadcast_height);
+                               *self = Self::PendingThresholdConfirmations {
+                                       first_broadcast_hash: *first_broadcast_hash,
+                                       latest_broadcast_height: *latest_broadcast_height,
+                                       latest_spending_tx,
+                                       confirmation_height,
+                                       confirmation_hash,
+                               };
+                       },
+                       Self::PendingThresholdConfirmations {
+                               first_broadcast_hash,
+                               latest_broadcast_height,
+                               ..
+                       } => {
+                               *self = Self::PendingThresholdConfirmations {
+                                       first_broadcast_hash: *first_broadcast_hash,
+                                       latest_broadcast_height: *latest_broadcast_height,
+                                       latest_spending_tx,
+                                       confirmation_height,
+                                       confirmation_hash,
+                               };
+                       },
+               }
+       }
+
+       fn unconfirmed(&mut self) {
+               match self {
+                       Self::PendingInitialBroadcast { .. } => {
+                               debug_assert!(
+                                       false,
+                                       "We should only mark a spend as unconfirmed if it used to be confirmed."
+                               );
+                       },
+                       Self::PendingFirstConfirmation { .. } => {
+                               debug_assert!(
+                                       false,
+                                       "We should only mark a spend as unconfirmed if it used to be confirmed."
+                               );
+                       },
+                       Self::PendingThresholdConfirmations {
+                               first_broadcast_hash,
+                               latest_broadcast_height,
+                               latest_spending_tx,
+                               ..
+                       } => {
+                               *self = Self::PendingFirstConfirmation {
+                                       first_broadcast_hash: *first_broadcast_hash,
+                                       latest_broadcast_height: *latest_broadcast_height,
+                                       latest_spending_tx: latest_spending_tx.clone(),
+                               };
+                       },
+               }
+       }
+
+       fn is_delayed(&self, cur_height: u32) -> bool {
+               match self {
+                       Self::PendingInitialBroadcast { delayed_until_height } => {
+                               delayed_until_height.map_or(false, |req_height| cur_height < req_height)
+                       },
+                       Self::PendingFirstConfirmation { .. } => false,
+                       Self::PendingThresholdConfirmations { .. } => false,
+               }
+       }
+
+       fn first_broadcast_hash(&self) -> Option<BlockHash> {
+               match self {
+                       Self::PendingInitialBroadcast { .. } => None,
+                       Self::PendingFirstConfirmation { first_broadcast_hash, .. } => {
+                               Some(*first_broadcast_hash)
+                       },
+                       Self::PendingThresholdConfirmations { first_broadcast_hash, .. } => {
+                               Some(*first_broadcast_hash)
+                       },
+               }
+       }
+
+       fn latest_broadcast_height(&self) -> Option<u32> {
+               match self {
+                       Self::PendingInitialBroadcast { .. } => None,
+                       Self::PendingFirstConfirmation { latest_broadcast_height, .. } => {
+                               Some(*latest_broadcast_height)
+                       },
+                       Self::PendingThresholdConfirmations { latest_broadcast_height, .. } => {
+                               Some(*latest_broadcast_height)
+                       },
+               }
+       }
+
+       fn confirmation_height(&self) -> Option<u32> {
+               match self {
+                       Self::PendingInitialBroadcast { .. } => None,
+                       Self::PendingFirstConfirmation { .. } => None,
+                       Self::PendingThresholdConfirmations { confirmation_height, .. } => {
+                               Some(*confirmation_height)
+                       },
+               }
+       }
+
+       fn confirmation_hash(&self) -> Option<BlockHash> {
+               match self {
+                       Self::PendingInitialBroadcast { .. } => None,
+                       Self::PendingFirstConfirmation { .. } => None,
+                       Self::PendingThresholdConfirmations { confirmation_hash, .. } => {
+                               Some(*confirmation_hash)
+                       },
+               }
+       }
+
+       fn latest_spending_tx(&self) -> Option<&Transaction> {
+               match self {
+                       Self::PendingInitialBroadcast { .. } => None,
+                       Self::PendingFirstConfirmation { latest_spending_tx, .. } => Some(latest_spending_tx),
+                       Self::PendingThresholdConfirmations { latest_spending_tx, .. } => {
+                               Some(latest_spending_tx)
+                       },
+               }
+       }
+
+       fn is_confirmed(&self) -> bool {
+               match self {
+                       Self::PendingInitialBroadcast { .. } => false,
+                       Self::PendingFirstConfirmation { .. } => false,
+                       Self::PendingThresholdConfirmations { .. } => true,
+               }
+       }
+}
+
+impl_writeable_tlv_based_enum!(OutputSpendStatus,
+       (0, PendingInitialBroadcast) => {
+               (0, delayed_until_height, option),
+       },
+       (2, PendingFirstConfirmation) => {
+               (0, first_broadcast_hash, required),
+               (2, latest_broadcast_height, required),
+               (4, latest_spending_tx, required),
+       },
+       (4, PendingThresholdConfirmations) => {
+               (0, first_broadcast_hash, required),
+               (2, latest_broadcast_height, required),
+               (4, latest_spending_tx, required),
+               (6, confirmation_height, required),
+               (8, confirmation_hash, required),
+       };
+);
+
+/// A utility that keeps track of [`SpendableOutputDescriptor`]s, persists them in a given
+/// [`KVStore`] and regularly retries sweeping them based on a callback given to the constructor
+/// methods.
+///
+/// Users should call [`Self::track_spendable_outputs`] for any [`SpendableOutputDescriptor`]s received via [`Event::SpendableOutputs`].
+///
+/// This needs to be notified of chain state changes either via its [`Listen`] or [`Confirm`]
+/// implementation and hence has to be connected with the utilized chain data sources.
+///
+/// If chain data is provided via the [`Confirm`] interface or via filtered blocks, users are
+/// required to give their chain data sources (i.e., [`Filter`] implementation) to the respective
+/// constructor.
+///
+/// [`Event::SpendableOutputs`]: crate::events::Event::SpendableOutputs
+pub struct OutputSweeper<B: Deref, D: Deref, E: Deref, F: Deref, K: Deref, L: Deref, O: Deref>
+where
+       B::Target: BroadcasterInterface,
+       D::Target: ChangeDestinationSource,
+       E::Target: FeeEstimator,
+       F::Target: Filter + Sync + Send,
+       K::Target: KVStore,
+       L::Target: Logger,
+       O::Target: OutputSpender,
+{
+       sweeper_state: Mutex<SweeperState>,
+       broadcaster: B,
+       fee_estimator: E,
+       chain_data_source: Option<F>,
+       output_spender: O,
+       change_destination_source: D,
+       kv_store: K,
+       logger: L,
+}
+
+impl<B: Deref, D: Deref, E: Deref, F: Deref, K: Deref, L: Deref, O: Deref>
+       OutputSweeper<B, D, E, F, K, L, O>
+where
+       B::Target: BroadcasterInterface,
+       D::Target: ChangeDestinationSource,
+       E::Target: FeeEstimator,
+       F::Target: Filter + Sync + Send,
+       K::Target: KVStore,
+       L::Target: Logger,
+       O::Target: OutputSpender,
+{
+       /// Constructs a new [`OutputSweeper`].
+       ///
+       /// If chain data is provided via the [`Confirm`] interface or via filtered blocks, users also
+       /// need to register their [`Filter`] implementation via the given `chain_data_source`.
+       pub fn new(
+               best_block: BestBlock, broadcaster: B, fee_estimator: E, chain_data_source: Option<F>,
+               output_spender: O, change_destination_source: D, kv_store: K, logger: L,
+       ) -> Self {
+               let outputs = Vec::new();
+               let sweeper_state = Mutex::new(SweeperState { outputs, best_block });
+               Self {
+                       sweeper_state,
+                       broadcaster,
+                       fee_estimator,
+                       chain_data_source,
+                       output_spender,
+                       change_destination_source,
+                       kv_store,
+                       logger,
+               }
+       }
+
+       /// Tells the sweeper to track the given outputs descriptors.
+       ///
+       /// Usually, this should be called based on the values emitted by the
+       /// [`Event::SpendableOutputs`].
+       ///
+       /// The given `exclude_static_ouputs` flag controls whether the sweeper will filter out
+       /// [`SpendableOutputDescriptor::StaticOutput`]s, which may be handled directly by the on-chain
+       /// wallet implementation.
+       ///
+       /// If `delay_until_height` is set, we will delay the spending until the respective block
+       /// height is reached. This can be used to batch spends, e.g., to reduce on-chain fees.
+       ///
+       /// [`Event::SpendableOutputs`]: crate::events::Event::SpendableOutputs
+       pub fn track_spendable_outputs(
+               &self, output_descriptors: Vec<SpendableOutputDescriptor>, channel_id: Option<ChannelId>,
+               exclude_static_ouputs: bool, delay_until_height: Option<u32>,
+       ) {
+               let mut relevant_descriptors = output_descriptors
+                       .into_iter()
+                       .filter(|desc| {
+                               !(exclude_static_ouputs
+                                       && matches!(desc, SpendableOutputDescriptor::StaticOutput { .. }))
+                       })
+                       .peekable();
+
+               if relevant_descriptors.peek().is_none() {
+                       return;
+               }
+
+               let mut spending_tx_opt;
+               {
+                       let mut state_lock = self.sweeper_state.lock().unwrap();
+                       for descriptor in relevant_descriptors {
+                               let output_info = TrackedSpendableOutput {
+                                       descriptor,
+                                       channel_id,
+                                       status: OutputSpendStatus::PendingInitialBroadcast {
+                                               delayed_until_height: delay_until_height,
+                                       },
+                               };
+
+                               if state_lock
+                                       .outputs
+                                       .iter()
+                                       .find(|o| o.descriptor == output_info.descriptor)
+                                       .is_some()
+                               {
+                                       continue;
+                               }
+
+                               state_lock.outputs.push(output_info);
+                       }
+                       spending_tx_opt = self.regenerate_spend_if_necessary(&mut *state_lock);
+                       self.persist_state(&*state_lock).unwrap_or_else(|e| {
+                               log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
+                               // Skip broadcasting if the persist failed.
+                               spending_tx_opt = None;
+                       });
+               }
+
+               if let Some(spending_tx) = spending_tx_opt {
+                       self.broadcaster.broadcast_transactions(&[&spending_tx]);
+               }
+       }
+
+       /// Returns a list of the currently tracked spendable outputs.
+       pub fn tracked_spendable_outputs(&self) -> Vec<TrackedSpendableOutput> {
+               self.sweeper_state.lock().unwrap().outputs.clone()
+       }
+
+       /// Gets the latest best block which was connected either via the [`Listen`] or
+       /// [`Confirm`] interfaces.
+       pub fn current_best_block(&self) -> BestBlock {
+               self.sweeper_state.lock().unwrap().best_block
+       }
+
+       fn regenerate_spend_if_necessary(
+               &self, sweeper_state: &mut SweeperState,
+       ) -> Option<Transaction> {
+               let cur_height = sweeper_state.best_block.height;
+               let cur_hash = sweeper_state.best_block.block_hash;
+               let filter_fn = |o: &TrackedSpendableOutput| {
+                       if o.status.is_confirmed() {
+                               // Don't rebroadcast confirmed txs.
+                               return false;
+                       }
+
+                       if o.status.is_delayed(cur_height) {
+                               // Don't generate and broadcast if still delayed
+                               return false;
+                       }
+
+                       if o.status.latest_broadcast_height() >= Some(cur_height) {
+                               // Only broadcast once per block height.
+                               return false;
+                       }
+
+                       true
+               };
+
+               let respend_descriptors: Vec<&SpendableOutputDescriptor> =
+                       sweeper_state.outputs.iter().filter(|o| filter_fn(*o)).map(|o| &o.descriptor).collect();
+
+               if respend_descriptors.is_empty() {
+                       // Nothing to do.
+                       return None;
+               }
+
+               let spending_tx = match self.spend_outputs(&*sweeper_state, respend_descriptors) {
+                       Ok(spending_tx) => {
+                               log_debug!(
+                                       self.logger,
+                                       "Generating and broadcasting sweeping transaction {}",
+                                       spending_tx.txid()
+                               );
+                               spending_tx
+                       },
+                       Err(e) => {
+                               log_error!(self.logger, "Error spending outputs: {:?}", e);
+                               return None;
+                       },
+               };
+
+               // As we didn't modify the state so far, the same filter_fn yields the same elements as
+               // above.
+               let respend_outputs = sweeper_state.outputs.iter_mut().filter(|o| filter_fn(&**o));
+               for output_info in respend_outputs {
+                       if let Some(filter) = self.chain_data_source.as_ref() {
+                               let watched_output = output_info.to_watched_output(cur_hash);
+                               filter.register_output(watched_output);
+                       }
+
+                       output_info.status.broadcast(cur_hash, cur_height, spending_tx.clone());
+               }
+
+               Some(spending_tx)
+       }
+
+       fn prune_confirmed_outputs(&self, sweeper_state: &mut SweeperState) {
+               let cur_height = sweeper_state.best_block.height;
+
+               // Prune all outputs that have sufficient depth by now.
+               sweeper_state.outputs.retain(|o| {
+                       if let Some(confirmation_height) = o.status.confirmation_height() {
+                               if cur_height >= confirmation_height + ANTI_REORG_DELAY - 1 {
+                                       log_debug!(self.logger,
+                                               "Pruning swept output as sufficiently confirmed via spend in transaction {:?}. Pruned descriptor: {:?}",
+                                               o.status.latest_spending_tx().map(|t| t.txid()), o.descriptor
+                                       );
+                                       return false;
+                               }
+                       }
+                       true
+               });
+       }
+
+       fn persist_state(&self, sweeper_state: &SweeperState) -> Result<(), io::Error> {
+               self.kv_store
+                       .write(
+                               OUTPUT_SWEEPER_PERSISTENCE_PRIMARY_NAMESPACE,
+                               OUTPUT_SWEEPER_PERSISTENCE_SECONDARY_NAMESPACE,
+                               OUTPUT_SWEEPER_PERSISTENCE_KEY,
+                               &sweeper_state.encode(),
+                       )
+                       .map_err(|e| {
+                               log_error!(
+                                       self.logger,
+                                       "Write for key {}/{}/{} failed due to: {}",
+                                       OUTPUT_SWEEPER_PERSISTENCE_PRIMARY_NAMESPACE,
+                                       OUTPUT_SWEEPER_PERSISTENCE_SECONDARY_NAMESPACE,
+                                       OUTPUT_SWEEPER_PERSISTENCE_KEY,
+                                       e
+                               );
+                               e
+                       })
+       }
+
+       fn spend_outputs(
+               &self, sweeper_state: &SweeperState, descriptors: Vec<&SpendableOutputDescriptor>,
+       ) -> Result<Transaction, ()> {
+               let tx_feerate =
+                       self.fee_estimator.get_est_sat_per_1000_weight(ConfirmationTarget::OutputSpendingFee);
+               let change_destination_script =
+                       self.change_destination_source.get_change_destination_script()?;
+               let cur_height = sweeper_state.best_block.height;
+               let locktime = Some(LockTime::from_height(cur_height).unwrap_or(LockTime::ZERO));
+               self.output_spender.spend_spendable_outputs(
+                       &descriptors,
+                       Vec::new(),
+                       change_destination_script,
+                       tx_feerate,
+                       locktime,
+                       &Secp256k1::new(),
+               )
+       }
+
+       fn transactions_confirmed_internal(
+               &self, sweeper_state: &mut SweeperState, header: &Header,
+               txdata: &chain::transaction::TransactionData, height: u32,
+       ) {
+               let confirmation_hash = header.block_hash();
+               for (_, tx) in txdata {
+                       for output_info in sweeper_state.outputs.iter_mut() {
+                               if output_info.is_spent_in(*tx) {
+                                       output_info.status.confirmed(confirmation_hash, height, (*tx).clone())
+                               }
+                       }
+               }
+       }
+
+       fn best_block_updated_internal(
+               &self, sweeper_state: &mut SweeperState, header: &Header, height: u32,
+       ) -> Option<Transaction> {
+               sweeper_state.best_block = BestBlock::new(header.block_hash(), height);
+               self.prune_confirmed_outputs(sweeper_state);
+               let spending_tx_opt = self.regenerate_spend_if_necessary(sweeper_state);
+               spending_tx_opt
+       }
+}
+
+impl<B: Deref, D: Deref, E: Deref, F: Deref, K: Deref, L: Deref, O: Deref> Listen
+       for OutputSweeper<B, D, E, F, K, L, O>
+where
+       B::Target: BroadcasterInterface,
+       D::Target: ChangeDestinationSource,
+       E::Target: FeeEstimator,
+       F::Target: Filter + Sync + Send,
+       K::Target: KVStore,
+       L::Target: Logger,
+       O::Target: OutputSpender,
+{
+       fn filtered_block_connected(
+               &self, header: &Header, txdata: &chain::transaction::TransactionData, height: u32,
+       ) {
+               let mut spending_tx_opt;
+               {
+                       let mut state_lock = self.sweeper_state.lock().unwrap();
+                       assert_eq!(state_lock.best_block.block_hash, header.prev_blockhash,
+                               "Blocks must be connected in chain-order - the connected header must build on the last connected header");
+                       assert_eq!(state_lock.best_block.height, height - 1,
+                               "Blocks must be connected in chain-order - the connected block height must be one greater than the previous height");
+
+                       self.transactions_confirmed_internal(&mut *state_lock, header, txdata, height);
+                       spending_tx_opt = self.best_block_updated_internal(&mut *state_lock, header, height);
+
+                       self.persist_state(&*state_lock).unwrap_or_else(|e| {
+                               log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
+                               // Skip broadcasting if the persist failed.
+                               spending_tx_opt = None;
+                       });
+               }
+
+               if let Some(spending_tx) = spending_tx_opt {
+                       self.broadcaster.broadcast_transactions(&[&spending_tx]);
+               }
+       }
+
+       fn block_disconnected(&self, header: &Header, height: u32) {
+               let mut state_lock = self.sweeper_state.lock().unwrap();
+
+               let new_height = height - 1;
+               let block_hash = header.block_hash();
+
+               assert_eq!(state_lock.best_block.block_hash, block_hash,
+               "Blocks must be disconnected in chain-order - the disconnected header must be the last connected header");
+               assert_eq!(state_lock.best_block.height, height,
+                       "Blocks must be disconnected in chain-order - the disconnected block must have the correct height");
+               state_lock.best_block = BestBlock::new(header.prev_blockhash, new_height);
+
+               for output_info in state_lock.outputs.iter_mut() {
+                       if output_info.status.confirmation_hash() == Some(block_hash) {
+                               debug_assert_eq!(output_info.status.confirmation_height(), Some(height));
+                               output_info.status.unconfirmed();
+                       }
+               }
+
+               self.persist_state(&*state_lock).unwrap_or_else(|e| {
+                       log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
+               });
+       }
+}
+
+impl<B: Deref, D: Deref, E: Deref, F: Deref, K: Deref, L: Deref, O: Deref> Confirm
+       for OutputSweeper<B, D, E, F, K, L, O>
+where
+       B::Target: BroadcasterInterface,
+       D::Target: ChangeDestinationSource,
+       E::Target: FeeEstimator,
+       F::Target: Filter + Sync + Send,
+       K::Target: KVStore,
+       L::Target: Logger,
+       O::Target: OutputSpender,
+{
+       fn transactions_confirmed(
+               &self, header: &Header, txdata: &chain::transaction::TransactionData, height: u32,
+       ) {
+               let mut state_lock = self.sweeper_state.lock().unwrap();
+               self.transactions_confirmed_internal(&mut *state_lock, header, txdata, height);
+               self.persist_state(&*state_lock).unwrap_or_else(|e| {
+                       log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
+               });
+       }
+
+       fn transaction_unconfirmed(&self, txid: &Txid) {
+               let mut state_lock = self.sweeper_state.lock().unwrap();
+
+               // Get what height was unconfirmed.
+               let unconf_height = state_lock
+                       .outputs
+                       .iter()
+                       .find(|o| o.status.latest_spending_tx().map(|tx| tx.txid()) == Some(*txid))
+                       .and_then(|o| o.status.confirmation_height());
+
+               if let Some(unconf_height) = unconf_height {
+                       // Unconfirm all >= this height.
+                       state_lock
+                               .outputs
+                               .iter_mut()
+                               .filter(|o| o.status.confirmation_height() >= Some(unconf_height))
+                               .for_each(|o| o.status.unconfirmed());
+
+                       self.persist_state(&*state_lock).unwrap_or_else(|e| {
+                               log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
+                       });
+               }
+       }
+
+       fn best_block_updated(&self, header: &Header, height: u32) {
+               let mut spending_tx_opt;
+               {
+                       let mut state_lock = self.sweeper_state.lock().unwrap();
+                       spending_tx_opt = self.best_block_updated_internal(&mut *state_lock, header, height);
+                       self.persist_state(&*state_lock).unwrap_or_else(|e| {
+                               log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
+                               // Skip broadcasting if the persist failed.
+                               spending_tx_opt = None;
+                       });
+               }
+
+               if let Some(spending_tx) = spending_tx_opt {
+                       self.broadcaster.broadcast_transactions(&[&spending_tx]);
+               }
+       }
+
+       fn get_relevant_txids(&self) -> Vec<(Txid, u32, Option<BlockHash>)> {
+               let state_lock = self.sweeper_state.lock().unwrap();
+               state_lock
+                       .outputs
+                       .iter()
+                       .filter_map(|o| match o.status {
+                               OutputSpendStatus::PendingThresholdConfirmations {
+                                       ref latest_spending_tx,
+                                       confirmation_height,
+                                       confirmation_hash,
+                                       ..
+                               } => Some((latest_spending_tx.txid(), confirmation_height, Some(confirmation_hash))),
+                               _ => None,
+                       })
+                       .collect::<Vec<_>>()
+       }
+}
+
+#[derive(Debug, Clone)]
+struct SweeperState {
+       outputs: Vec<TrackedSpendableOutput>,
+       best_block: BestBlock,
+}
+
+impl_writeable_tlv_based!(SweeperState, {
+       (0, outputs, required_vec),
+       (2, best_block, required),
+});
+
+/// A `enum` signalling to the [`OutputSweeper`] that it should delay spending an output until a
+/// future block height is reached.
+#[derive(Debug, Clone)]
+pub enum SpendingDelay {
+       /// A relative delay indicating we shouldn't spend the output before `cur_height + num_blocks`
+       /// is reached.
+       Relative {
+               /// The number of blocks until we'll generate and broadcast the spending transaction.
+               num_blocks: u32,
+       },
+       /// An absolute delay indicating we shouldn't spend the output before `height` is reached.
+       Absolute {
+               /// The height at which we'll generate and broadcast the spending transaction.
+               height: u32,
+       },
+}
+
+impl<B: Deref, D: Deref, E: Deref, F: Deref, K: Deref, L: Deref, O: Deref>
+       ReadableArgs<(B, E, Option<F>, O, D, K, L)> for OutputSweeper<B, D, E, F, K, L, O>
+where
+       B::Target: BroadcasterInterface,
+       D::Target: ChangeDestinationSource,
+       E::Target: FeeEstimator,
+       F::Target: Filter + Sync + Send,
+       K::Target: KVStore,
+       L::Target: Logger,
+       O::Target: OutputSpender,
+{
+       #[inline]
+       fn read<R: io::Read>(
+               reader: &mut R, args: (B, E, Option<F>, O, D, K, L),
+       ) -> Result<Self, DecodeError> {
+               let (
+                       broadcaster,
+                       fee_estimator,
+                       chain_data_source,
+                       output_spender,
+                       change_destination_source,
+                       kv_store,
+                       logger,
+               ) = args;
+               let state = SweeperState::read(reader)?;
+               let best_block = state.best_block;
+
+               if let Some(filter) = chain_data_source.as_ref() {
+                       for output_info in &state.outputs {
+                               let watched_output = output_info.to_watched_output(best_block.block_hash);
+                               filter.register_output(watched_output);
+                       }
+               }
+
+               let sweeper_state = Mutex::new(state);
+               Ok(Self {
+                       sweeper_state,
+                       broadcaster,
+                       fee_estimator,
+                       chain_data_source,
+                       output_spender,
+                       change_destination_source,
+                       kv_store,
+                       logger,
+               })
+       }
+}
+
+impl<B: Deref, D: Deref, E: Deref, F: Deref, K: Deref, L: Deref, O: Deref>
+       ReadableArgs<(B, E, Option<F>, O, D, K, L)> for (BestBlock, OutputSweeper<B, D, E, F, K, L, O>)
+where
+       B::Target: BroadcasterInterface,
+       D::Target: ChangeDestinationSource,
+       E::Target: FeeEstimator,
+       F::Target: Filter + Sync + Send,
+       K::Target: KVStore,
+       L::Target: Logger,
+       O::Target: OutputSpender,
+{
+       #[inline]
+       fn read<R: io::Read>(
+               reader: &mut R, args: (B, E, Option<F>, O, D, K, L),
+       ) -> Result<Self, DecodeError> {
+               let (
+                       broadcaster,
+                       fee_estimator,
+                       chain_data_source,
+                       output_spender,
+                       change_destination_source,
+                       kv_store,
+                       logger,
+               ) = args;
+               let state = SweeperState::read(reader)?;
+               let best_block = state.best_block;
+
+               if let Some(filter) = chain_data_source.as_ref() {
+                       for output_info in &state.outputs {
+                               let watched_output = output_info.to_watched_output(best_block.block_hash);
+                               filter.register_output(watched_output);
+                       }
+               }
+
+               let sweeper_state = Mutex::new(state);
+               Ok((
+                       best_block,
+                       OutputSweeper {
+                               sweeper_state,
+                               broadcaster,
+                               fee_estimator,
+                               chain_data_source,
+                               output_spender,
+                               change_destination_source,
+                               kv_store,
+                               logger,
+                       },
+               ))
+       }
+}
index 36018b23f79690666792dde6d7d157304c3d269c..95bc2a7c661982ea9062497f1647af434111dcd6 100644 (file)
@@ -784,12 +784,15 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler {
        fn handle_stfu(&self, _their_node_id: &PublicKey, msg: &msgs::Stfu) {
                self.received_msg(wire::Message::Stfu(msg.clone()));
        }
+       #[cfg(splicing)]
        fn handle_splice(&self, _their_node_id: &PublicKey, msg: &msgs::Splice) {
                self.received_msg(wire::Message::Splice(msg.clone()));
        }
+       #[cfg(splicing)]
        fn handle_splice_ack(&self, _their_node_id: &PublicKey, msg: &msgs::SpliceAck) {
                self.received_msg(wire::Message::SpliceAck(msg.clone()));
        }
+       #[cfg(splicing)]
        fn handle_splice_locked(&self, _their_node_id: &PublicKey, msg: &msgs::SpliceLocked) {
                self.received_msg(wire::Message::SpliceLocked(msg.clone()));
        }