Merge pull request #1481 from TheBlueMatt/2022-05-new-chain-tests
authorvalentinewallace <valentinewallace@users.noreply.github.com>
Fri, 27 May 2022 17:38:42 +0000 (10:38 -0700)
committerGitHub <noreply@github.com>
Fri, 27 May 2022 17:38:42 +0000 (10:38 -0700)
Test coverage for `transaction_unconfirmed`

52 files changed:
.editorconfig
.github/workflows/build.yml
Cargo.toml
fuzz/Cargo.toml
fuzz/src/bin/gen_target.sh
fuzz/src/bin/process_network_graph_target.rs [new file with mode: 0644]
fuzz/src/chanmon_consistency.rs
fuzz/src/full_stack.rs
fuzz/src/lib.rs
fuzz/src/peer_crypt.rs
fuzz/src/process_network_graph.rs [new file with mode: 0644]
fuzz/targets.h
lightning-background-processor/src/lib.rs
lightning-invoice/src/lib.rs
lightning-invoice/src/payment.rs
lightning-invoice/src/time_utils.rs [new symlink]
lightning-persister/src/lib.rs
lightning-rapid-gossip-sync/Cargo.toml [new file with mode: 0644]
lightning-rapid-gossip-sync/README.md [new file with mode: 0644]
lightning-rapid-gossip-sync/res/.gitignore [new file with mode: 0644]
lightning-rapid-gossip-sync/src/error.rs [new file with mode: 0644]
lightning-rapid-gossip-sync/src/lib.rs [new file with mode: 0644]
lightning-rapid-gossip-sync/src/processing.rs [new file with mode: 0644]
lightning/src/chain/chainmonitor.rs
lightning/src/chain/channelmonitor.rs
lightning/src/chain/mod.rs
lightning/src/lib.rs
lightning/src/ln/chan_utils.rs
lightning/src/ln/chanmon_update_fail_tests.rs
lightning/src/ln/channel.rs
lightning/src/ln/channelmanager.rs
lightning/src/ln/features.rs
lightning/src/ln/functional_test_utils.rs
lightning/src/ln/functional_tests.rs
lightning/src/ln/monitor_tests.rs
lightning/src/ln/msgs.rs
lightning/src/ln/payment_tests.rs
lightning/src/ln/peer_channel_encryptor.rs
lightning/src/ln/peer_handler.rs
lightning/src/ln/priv_short_conf_tests.rs
lightning/src/ln/reorg_tests.rs
lightning/src/ln/script.rs
lightning/src/ln/shutdown_tests.rs
lightning/src/ln/wire.rs
lightning/src/routing/network_graph.rs
lightning/src/routing/router.rs
lightning/src/routing/scoring.rs
lightning/src/util/events.rs
lightning/src/util/mod.rs
lightning/src/util/ser.rs
lightning/src/util/test_utils.rs
lightning/src/util/time.rs [new file with mode: 0644]

index e5657670c10706dfbc274144d4e9289ba265ef62..dab24fe3dd37f2aedbbe45c70f826eecda517bcb 100644 (file)
@@ -3,3 +3,4 @@
 [*]
 indent_style = tab
 insert_final_newline = true
+trim_trailing_whitespace = true
index cff1713440035a0b76d89a9e6c88904c97c0b87e..59a11b4e4a11efc7788b51fe04ad664ca15da7f9 100644 (file)
@@ -141,6 +141,7 @@ jobs:
         run: |
           cargo test --verbose --color always  -p lightning
           cargo test --verbose --color always  -p lightning-invoice
+          cargo test --verbose --color always  -p lightning-rapid-gossip-sync
           cargo build --verbose  --color always -p lightning-persister
           cargo build --verbose  --color always -p lightning-background-processor
       - name: Test C Bindings Modifications on Rust ${{ matrix.toolchain }}
@@ -221,11 +222,24 @@ jobs:
       - name: Fetch routing graph snapshot
         if: steps.cache-graph.outputs.cache-hit != 'true'
         run: |
-          wget -O lightning/net_graph-2021-05-31.bin https://bitcoin.ninja/ldk-net_graph-v0.0.15-2021-05-31.bin
-          if [ "$(sha256sum lightning/net_graph-2021-05-31.bin | awk '{ print $1 }')" != "05a5361278f68ee2afd086cc04a1f927a63924be451f3221d380533acfacc303" ]; then
+          curl --verbose -L -o lightning/net_graph-2021-05-31.bin https://bitcoin.ninja/ldk-net_graph-v0.0.15-2021-05-31.bin
+          echo "Sha sum: $(sha256sum lightning/net_graph-2021-05-31.bin | awk '{ print $1 }')"
+          if [ "$(sha256sum lightning/net_graph-2021-05-31.bin | awk '{ print $1 }')" != "${EXPECTED_ROUTING_GRAPH_SNAPSHOT_SHASUM}" ]; then
             echo "Bad hash"
             exit 1
           fi
+        env:
+          EXPECTED_ROUTING_GRAPH_SNAPSHOT_SHASUM: 05a5361278f68ee2afd086cc04a1f927a63924be451f3221d380533acfacc303
+      - name: Fetch rapid graph sync reference input
+        run: |
+          curl --verbose -L -o lightning-rapid-gossip-sync/res/full_graph.lngossip https://bitcoin.ninja/ldk-compressed_graph-bc08df7542-2022-05-05.bin
+          echo "Sha sum: $(sha256sum lightning-rapid-gossip-sync/res/full_graph.lngossip | awk '{ print $1 }')"
+          if [ "$(sha256sum lightning-rapid-gossip-sync/res/full_graph.lngossip | awk '{ print $1 }')" != "${EXPECTED_RAPID_GOSSIP_SHASUM}" ]; then
+            echo "Bad hash"
+            exit 1
+          fi
+        env:
+          EXPECTED_RAPID_GOSSIP_SHASUM: 9637b91cea9d64320cf48fc0787c70fe69fc062f90d3512e207044110cadfd7b
       - name: Test with Network Graph on Rust ${{ matrix.toolchain }}
         run: |
           cd lightning
index 6e03fc1ac4cfc4bab1feed65368651bc5135cb9e..f263dc8eccb16414c5e440b70c02bd6e9e4ab2df 100644 (file)
@@ -7,6 +7,7 @@ members = [
     "lightning-net-tokio",
     "lightning-persister",
     "lightning-background-processor",
+    "lightning-rapid-gossip-sync"
 ]
 
 exclude = [
index 88e577617b54faeab07ba0e61a213d29f726b1ba..66dabcfe4be3242eac2240a287476745c8f03e2a 100644 (file)
@@ -19,6 +19,7 @@ stdin_fuzz = []
 [dependencies]
 afl = { version = "0.4", optional = true }
 lightning = { path = "../lightning", features = ["regex"] }
+lightning-rapid-gossip-sync = { path = "../lightning-rapid-gossip-sync" }
 bitcoin = { version = "0.28.1", features = ["secp-lowmemory"] }
 hex = "0.3"
 honggfuzz = { version = "0.5", optional = true }
index eb07df6342f86dc6d99904953f376f414de5289f..72fefe51609103c7ed7c8cbf74ed2b33f6b9552f 100755 (executable)
@@ -10,6 +10,7 @@ GEN_TEST chanmon_deser
 GEN_TEST chanmon_consistency
 GEN_TEST full_stack
 GEN_TEST peer_crypt
+GEN_TEST process_network_graph
 GEN_TEST router
 GEN_TEST zbase32
 
diff --git a/fuzz/src/bin/process_network_graph_target.rs b/fuzz/src/bin/process_network_graph_target.rs
new file mode 100644 (file)
index 0000000..380efdf
--- /dev/null
@@ -0,0 +1,113 @@
+// This file is Copyright its original authors, visible in version control
+// history.
+//
+// This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
+// or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
+// You may not use this file except in accordance with one or both of these
+// licenses.
+
+// This file is auto-generated by gen_target.sh based on target_template.txt
+// To modify it, modify target_template.txt and run gen_target.sh instead.
+
+#![cfg_attr(feature = "libfuzzer_fuzz", no_main)]
+
+#[cfg(not(fuzzing))]
+compile_error!("Fuzz targets need cfg=fuzzing");
+
+extern crate lightning_fuzz;
+use lightning_fuzz::process_network_graph::*;
+
+#[cfg(feature = "afl")]
+#[macro_use] extern crate afl;
+#[cfg(feature = "afl")]
+fn main() {
+       fuzz!(|data| {
+               process_network_graph_run(data.as_ptr(), data.len());
+       });
+}
+
+#[cfg(feature = "honggfuzz")]
+#[macro_use] extern crate honggfuzz;
+#[cfg(feature = "honggfuzz")]
+fn main() {
+       loop {
+               fuzz!(|data| {
+                       process_network_graph_run(data.as_ptr(), data.len());
+               });
+       }
+}
+
+#[cfg(feature = "libfuzzer_fuzz")]
+#[macro_use] extern crate libfuzzer_sys;
+#[cfg(feature = "libfuzzer_fuzz")]
+fuzz_target!(|data: &[u8]| {
+       process_network_graph_run(data.as_ptr(), data.len());
+});
+
+#[cfg(feature = "stdin_fuzz")]
+fn main() {
+       use std::io::Read;
+
+       let mut data = Vec::with_capacity(8192);
+       std::io::stdin().read_to_end(&mut data).unwrap();
+       process_network_graph_run(data.as_ptr(), data.len());
+}
+
+#[test]
+fn run_test_cases() {
+       use std::fs;
+       use std::io::Read;
+       use lightning_fuzz::utils::test_logger::StringBuffer;
+
+       use std::sync::{atomic, Arc};
+       {
+               let data: Vec<u8> = vec![0];
+               process_network_graph_run(data.as_ptr(), data.len());
+       }
+       let mut threads = Vec::new();
+       let threads_running = Arc::new(atomic::AtomicUsize::new(0));
+       if let Ok(tests) = fs::read_dir("test_cases/process_network_graph") {
+               for test in tests {
+                       let mut data: Vec<u8> = Vec::new();
+                       let path = test.unwrap().path();
+                       fs::File::open(&path).unwrap().read_to_end(&mut data).unwrap();
+                       threads_running.fetch_add(1, atomic::Ordering::AcqRel);
+
+                       let thread_count_ref = Arc::clone(&threads_running);
+                       let main_thread_ref = std::thread::current();
+                       threads.push((path.file_name().unwrap().to_str().unwrap().to_string(),
+                               std::thread::spawn(move || {
+                                       let string_logger = StringBuffer::new();
+
+                                       let panic_logger = string_logger.clone();
+                                       let res = if ::std::panic::catch_unwind(move || {
+                                               process_network_graph_test(&data, panic_logger);
+                                       }).is_err() {
+                                               Some(string_logger.into_string())
+                                       } else { None };
+                                       thread_count_ref.fetch_sub(1, atomic::Ordering::AcqRel);
+                                       main_thread_ref.unpark();
+                                       res
+                               })
+                       ));
+                       while threads_running.load(atomic::Ordering::Acquire) > 32 {
+                               std::thread::park();
+                       }
+               }
+       }
+       let mut failed_outputs = Vec::new();
+       for (test, thread) in threads.drain(..) {
+               if let Some(output) = thread.join().unwrap() {
+                       println!("\nOutput of {}:\n{}\n", test, output);
+                       failed_outputs.push(test);
+               }
+       }
+       if !failed_outputs.is_empty() {
+               println!("Test cases which failed: ");
+               for case in failed_outputs {
+                       println!("{}", case);
+               }
+               panic!();
+       }
+}
index 8472f8fa6278a4e158e810a600895545cbbd27ae..1b662c0e721c05b49b55bd65a35af52026988481 100644 (file)
@@ -148,7 +148,7 @@ impl chain::Watch<EnforcingSigner> for TestChainMonitor {
                self.chain_monitor.update_channel(funding_txo, update)
        }
 
-       fn release_pending_monitor_events(&self) -> Vec<MonitorEvent> {
+       fn release_pending_monitor_events(&self) -> Vec<(OutPoint, Vec<MonitorEvent>)> {
                return self.chain_monitor.release_pending_monitor_events();
        }
 }
@@ -442,7 +442,7 @@ pub fn do_test<Out: Output>(data: &[u8], underlying_out: Out) {
                                                value: *channel_value_satoshis, script_pubkey: output_script.clone(),
                                        }]};
                                        funding_output = OutPoint { txid: tx.txid(), index: 0 };
-                                       $source.funding_transaction_generated(&temporary_channel_id, tx.clone()).unwrap();
+                                       $source.funding_transaction_generated(&temporary_channel_id, &$dest.get_our_node_id(), tx.clone()).unwrap();
                                        channel_txn.push(tx);
                                } else { panic!("Wrong event type"); }
                        }
index 330124f8a8942a23585517148096a80bc984dd4b..e3c8290b16ea820e4a26bf585a5684384f1f46c3 100644 (file)
@@ -408,7 +408,7 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
        let mut should_forward = false;
        let mut payments_received: Vec<PaymentHash> = Vec::new();
        let mut payments_sent = 0;
-       let mut pending_funding_generation: Vec<([u8; 32], u64, Script)> = Vec::new();
+       let mut pending_funding_generation: Vec<([u8; 32], PublicKey, u64, Script)> = Vec::new();
        let mut pending_funding_signatures = HashMap::new();
 
        loop {
@@ -516,7 +516,7 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
                                let channel_id = get_slice!(1)[0] as usize;
                                if channel_id >= channels.len() { return; }
                                channels.sort_by(|a, b| { a.channel_id.cmp(&b.channel_id) });
-                               if channelmanager.close_channel(&channels[channel_id].channel_id).is_err() { return; }
+                               if channelmanager.close_channel(&channels[channel_id].channel_id, &channels[channel_id].counterparty.node_id).is_err() { return; }
                        },
                        7 => {
                                if should_forward {
@@ -556,7 +556,7 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
                        10 => {
                                'outer_loop: for funding_generation in pending_funding_generation.drain(..) {
                                        let mut tx = Transaction { version: 0, lock_time: 0, input: Vec::new(), output: vec![TxOut {
-                                                       value: funding_generation.1, script_pubkey: funding_generation.2,
+                                                       value: funding_generation.2, script_pubkey: funding_generation.3,
                                                }] };
                                        let funding_output = 'search_loop: loop {
                                                let funding_txid = tx.txid();
@@ -575,7 +575,7 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
                                                        continue 'outer_loop;
                                                }
                                        };
-                                       if let Err(e) = channelmanager.funding_transaction_generated(&funding_generation.0, tx.clone()) {
+                                       if let Err(e) = channelmanager.funding_transaction_generated(&funding_generation.0, &funding_generation.1, tx.clone()) {
                                                // It's possible the channel has been closed in the mean time, but any other
                                                // failure may be a bug.
                                                if let APIError::ChannelUnavailable { err } = e {
@@ -624,7 +624,7 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
                                let channel_id = get_slice!(1)[0] as usize;
                                if channel_id >= channels.len() { return; }
                                channels.sort_by(|a, b| { a.channel_id.cmp(&b.channel_id) });
-                               channelmanager.force_close_channel(&channels[channel_id].channel_id).unwrap();
+                               channelmanager.force_close_channel(&channels[channel_id].channel_id, &channels[channel_id].counterparty.node_id).unwrap();
                        },
                        // 15 is above
                        _ => return,
@@ -632,8 +632,8 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
                loss_detector.handler.process_events();
                for event in loss_detector.manager.get_and_clear_pending_events() {
                        match event {
-                               Event::FundingGenerationReady { temporary_channel_id, channel_value_satoshis, output_script, .. } => {
-                                       pending_funding_generation.push((temporary_channel_id, channel_value_satoshis, output_script));
+                               Event::FundingGenerationReady { temporary_channel_id, counterparty_node_id, channel_value_satoshis, output_script, .. } => {
+                                       pending_funding_generation.push((temporary_channel_id, counterparty_node_id, channel_value_satoshis, output_script));
                                },
                                Event::PaymentReceived { payment_hash, .. } => {
                                        //TODO: enhance by fetching random amounts from fuzz input?
index a0cc42b8189f9ae8837fb8ea64106c902185781d..5e158aee36ffe050e6f53176ddd1fc7887935871 100644 (file)
@@ -9,6 +9,7 @@
 
 extern crate bitcoin;
 extern crate lightning;
+extern crate lightning_rapid_gossip_sync;
 extern crate hex;
 
 pub mod utils;
@@ -17,6 +18,7 @@ pub mod chanmon_deser;
 pub mod chanmon_consistency;
 pub mod full_stack;
 pub mod peer_crypt;
+pub mod process_network_graph;
 pub mod router;
 pub mod zbase32;
 
index 9bef432982497af54dc5c0f9912e22558f79f30b..de0443ebd5b2bd3eea3e4b117442620eae57f46d 100644 (file)
@@ -9,7 +9,7 @@
 
 use lightning::ln::peer_channel_encryptor::PeerChannelEncryptor;
 
-use bitcoin::secp256k1::{PublicKey,SecretKey};
+use bitcoin::secp256k1::{Secp256k1, PublicKey, SecretKey};
 
 use utils::test_logger;
 
@@ -35,6 +35,8 @@ pub fn do_test(data: &[u8]) {
                }
        }
 
+       let secp_ctx = Secp256k1::signing_only();
+
        let our_network_key = match SecretKey::from_slice(get_slice!(32)) {
                Ok(key) => key,
                Err(_) => return,
@@ -50,16 +52,16 @@ pub fn do_test(data: &[u8]) {
                        Err(_) => return,
                };
                let mut crypter = PeerChannelEncryptor::new_outbound(their_pubkey, ephemeral_key);
-               crypter.get_act_one();
-               match crypter.process_act_two(get_slice!(50), &our_network_key) {
+               crypter.get_act_one(&secp_ctx);
+               match crypter.process_act_two(get_slice!(50), &our_network_key, &secp_ctx) {
                        Ok(_) => {},
                        Err(_) => return,
                }
                assert!(crypter.is_ready_for_encryption());
                crypter
        } else {
-               let mut crypter = PeerChannelEncryptor::new_inbound(&our_network_key);
-               match crypter.process_act_one_with_keys(get_slice!(50), &our_network_key, ephemeral_key) {
+               let mut crypter = PeerChannelEncryptor::new_inbound(&our_network_key, &secp_ctx);
+               match crypter.process_act_one_with_keys(get_slice!(50), &our_network_key, ephemeral_key, &secp_ctx) {
                        Ok(_) => {},
                        Err(_) => return,
                }
diff --git a/fuzz/src/process_network_graph.rs b/fuzz/src/process_network_graph.rs
new file mode 100644 (file)
index 0000000..3f30335
--- /dev/null
@@ -0,0 +1,20 @@
+// Import that needs to be added manually
+use utils::test_logger;
+
+/// Actual fuzz test, method signature and name are fixed
+fn do_test(data: &[u8]) {
+       let block_hash = bitcoin::BlockHash::default();
+       let network_graph = lightning::routing::network_graph::NetworkGraph::new(block_hash);
+       lightning_rapid_gossip_sync::processing::update_network_graph(&network_graph, data);
+}
+
+/// Method that needs to be added manually, {name}_test
+pub fn process_network_graph_test<Out: test_logger::Output>(data: &[u8], _out: Out) {
+       do_test(data);
+}
+
+/// Method that needs to be added manually, {name}_run
+#[no_mangle]
+pub extern "C" fn process_network_graph_run(data: *const u8, datalen: usize) {
+       do_test(unsafe { std::slice::from_raw_parts(data, datalen) });
+}
index 5d45e3d02388ea9a1659bf0cbb6e88e379d2f940..798fb66479519cf9b00a84ce3d5873787c449b58 100644 (file)
@@ -3,6 +3,7 @@ void chanmon_deser_run(const unsigned char* data, size_t data_len);
 void chanmon_consistency_run(const unsigned char* data, size_t data_len);
 void full_stack_run(const unsigned char* data, size_t data_len);
 void peer_crypt_run(const unsigned char* data, size_t data_len);
+void process_network_graph_run(const unsigned char* data, size_t data_len);
 void router_run(const unsigned char* data, size_t data_len);
 void zbase32_run(const unsigned char* data, size_t data_len);
 void msg_accept_channel_run(const unsigned char* data, size_t data_len);
index 107f65f9b74f141cdb087eded0befe59d4c9dfac..dab376178488654d70322155400c0bd263b740a4 100644 (file)
@@ -379,7 +379,7 @@ mod tests {
        use lightning::util::ser::Writeable;
        use lightning::util::test_utils;
        use lightning::util::persist::KVStorePersister;
-       use lightning_invoice::payment::{InvoicePayer, RetryAttempts};
+       use lightning_invoice::payment::{InvoicePayer, Retry};
        use lightning_invoice::utils::DefaultRouter;
        use lightning_persister::FilesystemPersister;
        use std::fs;
@@ -540,7 +540,7 @@ mod tests {
        macro_rules! handle_funding_generation_ready {
                ($event: expr, $channel_value: expr) => {{
                        match $event {
-                               &Event::FundingGenerationReady { temporary_channel_id, channel_value_satoshis, ref output_script, user_channel_id } => {
+                               &Event::FundingGenerationReady { temporary_channel_id, channel_value_satoshis, ref output_script, user_channel_id, .. } => {
                                        assert_eq!(channel_value_satoshis, $channel_value);
                                        assert_eq!(user_channel_id, 42);
 
@@ -556,7 +556,7 @@ mod tests {
 
        macro_rules! end_open_channel {
                ($node_a: expr, $node_b: expr, $temporary_channel_id: expr, $tx: expr) => {{
-                       $node_a.node.funding_transaction_generated(&$temporary_channel_id, $tx.clone()).unwrap();
+                       $node_a.node.funding_transaction_generated(&$temporary_channel_id, &$node_b.node.get_our_node_id(), $tx.clone()).unwrap();
                        $node_b.node.handle_funding_created(&$node_a.node.get_our_node_id(), &get_event_msg!($node_a, MessageSendEvent::SendFundingCreated, $node_b.node.get_our_node_id()));
                        $node_a.node.handle_funding_signed(&$node_b.node.get_our_node_id(), &get_event_msg!($node_b, MessageSendEvent::SendFundingSigned, $node_a.node.get_our_node_id()));
                }}
@@ -637,7 +637,7 @@ mod tests {
                }
 
                // Force-close the channel.
-               nodes[0].node.force_close_channel(&OutPoint { txid: tx.txid(), index: 0 }.to_channel_id()).unwrap();
+               nodes[0].node.force_close_channel(&OutPoint { txid: tx.txid(), index: 0 }.to_channel_id(), &nodes[1].node.get_our_node_id()).unwrap();
 
                // Check that the force-close updates are persisted.
                check_persisted_data!(nodes[0].node, filepath.clone());
@@ -776,7 +776,7 @@ mod tests {
                let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()));
 
                // Force close the channel and check that the SpendableOutputs event was handled.
-               nodes[0].node.force_close_channel(&nodes[0].node.list_channels()[0].channel_id).unwrap();
+               nodes[0].node.force_close_channel(&nodes[0].node.list_channels()[0].channel_id, &nodes[1].node.get_our_node_id()).unwrap();
                let commitment_tx = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap().pop().unwrap();
                confirm_transaction_depth(&mut nodes[0], &commitment_tx, BREAKDOWN_TIMEOUT as u32);
                let event = receiver
@@ -801,7 +801,7 @@ mod tests {
                let data_dir = nodes[0].persister.get_data_dir();
                let persister = Arc::new(Persister::new(data_dir));
                let router = DefaultRouter::new(Arc::clone(&nodes[0].network_graph), Arc::clone(&nodes[0].logger), random_seed_bytes);
-               let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, Arc::clone(&nodes[0].scorer), Arc::clone(&nodes[0].logger), |_: &_| {}, RetryAttempts(2)));
+               let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, Arc::clone(&nodes[0].scorer), Arc::clone(&nodes[0].logger), |_: &_| {}, Retry::Attempts(2)));
                let event_handler = Arc::clone(&invoice_payer);
                let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()));
                assert!(bg_processor.stop().is_ok());
index 616ea99f0fe253cf2275862b1a63e6f0db3232f2..008d3344b55c5f2becae9cfd553d1c3f3f84dc2a 100644 (file)
@@ -11,7 +11,7 @@
 #![cfg_attr(all(not(feature = "std"), not(test)), no_std)]
 
 //! This crate provides data structures to represent
-//! [lightning BOLT11](https://github.com/lightningnetwork/lightning-rfc/blob/master/11-payment-encoding.md)
+//! [lightning BOLT11](https://github.com/lightning/bolts/blob/master/11-payment-encoding.md)
 //! invoices and functions to create, encode and decode these. If you just want to use the standard
 //! en-/decoding functionality this should get you started:
 //!
@@ -25,6 +25,8 @@ compile_error!("at least one of the `std` or `no-std` features must be enabled")
 pub mod payment;
 pub mod utils;
 
+pub(crate) mod time_utils;
+
 extern crate bech32;
 extern crate bitcoin_hashes;
 #[macro_use] extern crate lightning;
@@ -793,18 +795,15 @@ impl SignedRawInvoice {
 /// variant. If no element was found `None` gets returned.
 ///
 /// The following example would extract the first B.
-/// ```
-/// use Enum::*
 ///
 /// enum Enum {
 ///    A(u8),
 ///    B(u16)
 /// }
 ///
-/// let elements = vec![A(1), A(2), B(3), A(4)]
+/// let elements = vec![Enum::A(1), Enum::A(2), Enum::B(3), Enum::A(4)];
 ///
-/// assert_eq!(find_extract!(elements.iter(), Enum::B(ref x), x), Some(3u16))
-/// ```
+/// assert_eq!(find_extract!(elements.iter(), Enum::B(x), x), Some(3u16));
 macro_rules! find_extract {
        ($iter:expr, $enm:pat, $enm_var:ident) => {
                find_all_extract!($iter, $enm, $enm_var).next()
@@ -815,20 +814,18 @@ macro_rules! find_extract {
 /// variant through an iterator.
 ///
 /// The following example would extract all A.
-/// ```
-/// use Enum::*
 ///
 /// enum Enum {
 ///    A(u8),
 ///    B(u16)
 /// }
 ///
-/// let elements = vec![A(1), A(2), B(3), A(4)]
+/// let elements = vec![Enum::A(1), Enum::A(2), Enum::B(3), Enum::A(4)];
 ///
 /// assert_eq!(
-///    find_all_extract!(elements.iter(), Enum::A(ref x), x).collect::<Vec<u8>>(),
-///    vec![1u8, 2u8, 4u8])
-/// ```
+///    find_all_extract!(elements.iter(), Enum::A(x), x).collect::<Vec<u8>>(),
+///    vec![1u8, 2u8, 4u8]
+/// );
 macro_rules! find_all_extract {
        ($iter:expr, $enm:pat, $enm_var:ident) => {
                $iter.filter_map(|tf| match *tf {
index 6b79a0123d9e872cc3dcebdea7905c5b6d8613b2..c095f74dd16e2e0dc9c3c41dacf7122c75e179a9 100644 (file)
 //! # use lightning::ln::{PaymentHash, PaymentPreimage, PaymentSecret};
 //! # use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure};
 //! # use lightning::ln::msgs::LightningError;
-//! # use lightning::routing::scoring::Score;
 //! # use lightning::routing::network_graph::NodeId;
 //! # use lightning::routing::router::{Route, RouteHop, RouteParameters};
+//! # use lightning::routing::scoring::{ChannelUsage, Score};
 //! # use lightning::util::events::{Event, EventHandler, EventsProvider};
 //! # use lightning::util::logger::{Logger, Record};
 //! # use lightning::util::ser::{Writeable, Writer};
 //! # use lightning_invoice::Invoice;
-//! # use lightning_invoice::payment::{InvoicePayer, Payer, RetryAttempts, Router};
+//! # use lightning_invoice::payment::{InvoicePayer, Payer, Retry, Router};
 //! # use secp256k1::PublicKey;
 //! # use std::cell::RefCell;
 //! # use std::ops::Deref;
@@ -90,7 +90,7 @@
 //! # }
 //! # impl Score for FakeScorer {
 //! #     fn channel_penalty_msat(
-//! #         &self, _short_channel_id: u64, _send_amt: u64, _chan_amt: u64, _source: &NodeId, _target: &NodeId
+//! #         &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId, _usage: ChannelUsage
 //! #     ) -> u64 { 0 }
 //! #     fn payment_path_failed(&mut self, _path: &[&RouteHop], _short_channel_id: u64) {}
 //! #     fn payment_path_successful(&mut self, _path: &[&RouteHop]) {}
 //! # let router = FakeRouter {};
 //! # let scorer = RefCell::new(FakeScorer {});
 //! # let logger = FakeLogger {};
-//! let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+//! let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2));
 //!
 //! let invoice = "...";
 //! if let Ok(invoice) = invoice.parse::<Invoice>() {
@@ -146,10 +146,13 @@ use lightning::routing::scoring::{LockableScore, Score};
 use lightning::routing::router::{PaymentParameters, Route, RouteParameters};
 use lightning::util::events::{Event, EventHandler};
 use lightning::util::logger::Logger;
+use time_utils::Time;
 use crate::sync::Mutex;
 
 use secp256k1::PublicKey;
 
+use core::fmt;
+use core::fmt::{Debug, Display, Formatter};
 use core::ops::Deref;
 use core::time::Duration;
 #[cfg(feature = "std")]
@@ -160,7 +163,17 @@ use std::time::SystemTime;
 /// See [module-level documentation] for details.
 ///
 /// [module-level documentation]: crate::payment
-pub struct InvoicePayer<P: Deref, R, S: Deref, L: Deref, E: EventHandler>
+pub type InvoicePayer<P, R, S, L, E> = InvoicePayerUsingTime::<P, R, S, L, E, ConfiguredTime>;
+
+#[cfg(not(feature = "no-std"))]
+type ConfiguredTime = std::time::Instant;
+#[cfg(feature = "no-std")]
+use time_utils;
+#[cfg(feature = "no-std")]
+type ConfiguredTime = time_utils::Eternity;
+
+/// (C-not exported) generally all users should use the [`InvoicePayer`] type alias.
+pub struct InvoicePayerUsingTime<P: Deref, R, S: Deref, L: Deref, E: EventHandler, T: Time>
 where
        P::Target: Payer,
        R: for <'a> Router<<<S as Deref>::Target as LockableScore<'a>>::Locked>,
@@ -173,8 +186,42 @@ where
        logger: L,
        event_handler: E,
        /// Caches the overall attempts at making a payment, which is updated prior to retrying.
-       payment_cache: Mutex<HashMap<PaymentHash, usize>>,
-       retry_attempts: RetryAttempts,
+       payment_cache: Mutex<HashMap<PaymentHash, PaymentAttempts<T>>>,
+       retry: Retry,
+}
+
+/// Storing minimal payment attempts information required for determining if a outbound payment can
+/// be retried.
+#[derive(Clone, Copy)]
+struct PaymentAttempts<T: Time> {
+       /// This count will be incremented only after the result of the attempt is known. When it's 0,
+       /// it means the result of the first attempt is now known yet.
+       count: usize,
+       /// This field is only used when retry is [`Retry::Timeout`] which is only build with feature std
+       first_attempted_at: T
+}
+
+impl<T: Time> PaymentAttempts<T> {
+       fn new() -> Self {
+               PaymentAttempts {
+                       count: 0,
+                       first_attempted_at: T::now()
+               }
+       }
+}
+
+impl<T: Time> Display for PaymentAttempts<T> {
+       fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
+               #[cfg(feature = "no-std")]
+               return write!( f, "attempts: {}", self.count);
+               #[cfg(not(feature = "no-std"))]
+               return write!(
+                       f,
+                       "attempts: {}, duration: {}s",
+                       self.count,
+                       T::now().duration_since(self.first_attempted_at).as_secs()
+               );
+       }
 }
 
 /// A trait defining behavior of an [`Invoice`] payer.
@@ -211,13 +258,33 @@ pub trait Router<S: Score> {
        ) -> Result<Route, LightningError>;
 }
 
-/// Number of attempts to retry payment path failures for an [`Invoice`].
+/// Strategies available to retry payment path failures for an [`Invoice`].
 ///
-/// Note that this is the number of *path* failures, not full payment retries. For multi-path
-/// payments, if this is less than the total number of paths, we will never even retry all of the
-/// payment's paths.
 #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
-pub struct RetryAttempts(pub usize);
+pub enum Retry {
+       /// Max number of attempts to retry payment.
+       ///
+       /// Note that this is the number of *path* failures, not full payment retries. For multi-path
+       /// payments, if this is less than the total number of paths, we will never even retry all of the
+       /// payment's paths.
+       Attempts(usize),
+       #[cfg(feature = "std")]
+       /// Time elapsed before abandoning retries for a payment.
+       Timeout(Duration),
+}
+
+impl Retry {
+       fn is_retryable_now<T: Time>(&self, attempts: &PaymentAttempts<T>) -> bool {
+               match (self, attempts) {
+                       (Retry::Attempts(max_retry_count), PaymentAttempts { count, .. }) => {
+                               max_retry_count >= &count
+                       },
+                       #[cfg(feature = "std")]
+                       (Retry::Timeout(max_duration), PaymentAttempts { first_attempted_at, .. } ) =>
+                               *max_duration >= T::now().duration_since(*first_attempted_at),
+               }
+       }
+}
 
 /// An error that may occur when making a payment.
 #[derive(Clone, Debug)]
@@ -230,7 +297,7 @@ pub enum PaymentError {
        Sending(PaymentSendFailure),
 }
 
-impl<P: Deref, R, S: Deref, L: Deref, E: EventHandler> InvoicePayer<P, R, S, L, E>
+impl<P: Deref, R, S: Deref, L: Deref, E: EventHandler, T: Time> InvoicePayerUsingTime<P, R, S, L, E, T>
 where
        P::Target: Payer,
        R: for <'a> Router<<<S as Deref>::Target as LockableScore<'a>>::Locked>,
@@ -240,9 +307,9 @@ where
        /// Creates an invoice payer that retries failed payment paths.
        ///
        /// Will forward any [`Event::PaymentPathFailed`] events to the decorated `event_handler` once
-       /// `retry_attempts` has been exceeded for a given [`Invoice`].
+       /// `retry` has been exceeded for a given [`Invoice`].
        pub fn new(
-               payer: P, router: R, scorer: S, logger: L, event_handler: E, retry_attempts: RetryAttempts
+               payer: P, router: R, scorer: S, logger: L, event_handler: E, retry: Retry
        ) -> Self {
                Self {
                        payer,
@@ -251,7 +318,7 @@ where
                        logger,
                        event_handler,
                        payment_cache: Mutex::new(HashMap::new()),
-                       retry_attempts,
+                       retry,
                }
        }
 
@@ -292,7 +359,7 @@ where
                let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
                match self.payment_cache.lock().unwrap().entry(payment_hash) {
                        hash_map::Entry::Occupied(_) => return Err(PaymentError::Invoice("payment pending")),
-                       hash_map::Entry::Vacant(entry) => entry.insert(0),
+                       hash_map::Entry::Vacant(entry) => entry.insert(PaymentAttempts::new()),
                };
 
                let payment_secret = Some(invoice.payment_secret().clone());
@@ -311,6 +378,7 @@ where
                let send_payment = |route: &Route| {
                        self.payer.send_payment(route, payment_hash, &payment_secret)
                };
+
                self.pay_internal(&route_params, payment_hash, send_payment)
                        .map_err(|e| { self.payment_cache.lock().unwrap().remove(&payment_hash); e })
        }
@@ -327,7 +395,7 @@ where
                let payment_hash = PaymentHash(Sha256::hash(&payment_preimage.0).into_inner());
                match self.payment_cache.lock().unwrap().entry(payment_hash) {
                        hash_map::Entry::Occupied(_) => return Err(PaymentError::Invoice("payment pending")),
-                       hash_map::Entry::Vacant(entry) => entry.insert(0),
+                       hash_map::Entry::Vacant(entry) => entry.insert(PaymentAttempts::new()),
                };
 
                let route_params = RouteParameters {
@@ -367,13 +435,13 @@ where
                                PaymentSendFailure::PathParameterError(_) => Err(e),
                                PaymentSendFailure::AllFailedRetrySafe(_) => {
                                        let mut payment_cache = self.payment_cache.lock().unwrap();
-                                       let retry_count = payment_cache.get_mut(&payment_hash).unwrap();
-                                       if *retry_count >= self.retry_attempts.0 {
-                                               Err(e)
-                                       } else {
-                                               *retry_count += 1;
+                                       let payment_attempts = payment_cache.get_mut(&payment_hash).unwrap();
+                                       payment_attempts.count += 1;
+                                       if self.retry.is_retryable_now(payment_attempts) {
                                                core::mem::drop(payment_cache);
                                                Ok(self.pay_internal(params, payment_hash, send_payment)?)
+                                       } else {
+                                               Err(e)
                                        }
                                },
                                PaymentSendFailure::PartialFailure { failed_paths_retry, payment_id, .. } => {
@@ -399,20 +467,22 @@ where
        fn retry_payment(
                &self, payment_id: PaymentId, payment_hash: PaymentHash, params: &RouteParameters
        ) -> Result<(), ()> {
-               let max_payment_attempts = self.retry_attempts.0 + 1;
-               let attempts = *self.payment_cache.lock().unwrap()
-                       .entry(payment_hash)
-                       .and_modify(|attempts| *attempts += 1)
-                       .or_insert(1);
-
-               if attempts >= max_payment_attempts {
-                       log_trace!(self.logger, "Payment {} exceeded maximum attempts; not retrying (attempts: {})", log_bytes!(payment_hash.0), attempts);
+               let attempts =
+                       *self.payment_cache.lock().unwrap().entry(payment_hash)
+                       .and_modify(|attempts| attempts.count += 1)
+                       .or_insert(PaymentAttempts {
+                               count: 1,
+                               first_attempted_at: T::now()
+                       });
+
+               if !self.retry.is_retryable_now(&attempts) {
+                       log_trace!(self.logger, "Payment {} exceeded maximum attempts; not retrying ({})", log_bytes!(payment_hash.0), attempts);
                        return Err(());
                }
 
                #[cfg(feature = "std")] {
                        if has_expired(params) {
-                               log_trace!(self.logger, "Invoice expired for payment {}; not retrying (attempts: {})", log_bytes!(payment_hash.0), attempts);
+                               log_trace!(self.logger, "Invoice expired for payment {}; not retrying ({:})", log_bytes!(payment_hash.0), attempts);
                                return Err(());
                        }
                }
@@ -424,7 +494,7 @@ where
                        &self.scorer.lock()
                );
                if route.is_err() {
-                       log_trace!(self.logger, "Failed to find a route for payment {}; not retrying (attempts: {})", log_bytes!(payment_hash.0), attempts);
+                       log_trace!(self.logger, "Failed to find a route for payment {}; not retrying ({:})", log_bytes!(payment_hash.0), attempts);
                        return Err(());
                }
 
@@ -468,7 +538,7 @@ fn has_expired(route_params: &RouteParameters) -> bool {
        } else { false }
 }
 
-impl<P: Deref, R, S: Deref, L: Deref, E: EventHandler> EventHandler for InvoicePayer<P, R, S, L, E>
+impl<P: Deref, R, S: Deref, L: Deref, E: EventHandler, T: Time> EventHandler for InvoicePayerUsingTime<P, R, S, L, E, T>
 where
        P::Target: Payer,
        R: for <'a> Router<<<S as Deref>::Target as LockableScore<'a>>::Locked>,
@@ -511,7 +581,7 @@ where
                                let mut payment_cache = self.payment_cache.lock().unwrap();
                                let attempts = payment_cache
                                        .remove(payment_hash)
-                                       .map_or(1, |attempts| attempts + 1);
+                                       .map_or(1, |attempts| attempts.count + 1);
                                log_trace!(self.logger, "Payment {} succeeded (attempts: {})", log_bytes!(payment_hash.0), attempts);
                        },
                        _ => {},
@@ -534,6 +604,7 @@ mod tests {
        use lightning::ln::msgs::{ChannelMessageHandler, ErrorAction, LightningError};
        use lightning::routing::network_graph::NodeId;
        use lightning::routing::router::{PaymentParameters, Route, RouteHop};
+       use lightning::routing::scoring::ChannelUsage;
        use lightning::util::test_utils::TestLogger;
        use lightning::util::errors::APIError;
        use lightning::util::events::{Event, EventsProvider, MessageSendEvent, MessageSendEventsProvider};
@@ -541,6 +612,7 @@ mod tests {
        use std::cell::RefCell;
        use std::collections::VecDeque;
        use std::time::{SystemTime, Duration};
+       use time_utils::tests::SinceEpoch;
        use DEFAULT_EXPIRY_TIME;
 
        fn invoice(payment_preimage: PaymentPreimage) -> Invoice {
@@ -624,7 +696,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(0));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -653,7 +725,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -698,7 +770,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2));
 
                assert!(invoice_payer.pay_invoice(&invoice).is_ok());
        }
@@ -720,7 +792,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(PaymentId([1; 32]));
                let event = Event::PaymentPathFailed {
@@ -749,7 +821,7 @@ mod tests {
        }
 
        #[test]
-       fn fails_paying_invoice_after_max_retries() {
+       fn fails_paying_invoice_after_max_retry_counts() {
                let event_handled = core::cell::RefCell::new(false);
                let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
 
@@ -765,7 +837,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -805,6 +877,52 @@ mod tests {
                assert_eq!(*payer.attempts.borrow(), 3);
        }
 
+       #[cfg(feature = "std")]
+       #[test]
+       fn fails_paying_invoice_after_max_retry_timeout() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let final_value_msat = invoice.amount_milli_satoshis().unwrap();
+
+               let payer = TestPayer::new()
+                       .expect_send(Amount::ForInvoice(final_value_msat))
+                       .expect_send(Amount::OnRetry(final_value_msat / 2));
+
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
+               let logger = TestLogger::new();
+               type InvoicePayerUsingSinceEpoch <P, R, S, L, E> = InvoicePayerUsingTime::<P, R, S, L, E, SinceEpoch>;
+
+               let invoice_payer =
+                       InvoicePayerUsingSinceEpoch::new(&payer, router, &scorer, &logger, event_handler, Retry::Timeout(Duration::from_secs(120)));
+
+               let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
+               assert_eq!(*payer.attempts.borrow(), 1);
+
+               let event = Event::PaymentPathFailed {
+                       payment_id,
+                       payment_hash: PaymentHash(invoice.payment_hash().clone().into_inner()),
+                       network_update: None,
+                       rejected_by_dest: false,
+                       all_paths_failed: true,
+                       path: TestRouter::path_for_value(final_value_msat),
+                       short_channel_id: None,
+                       retry: Some(TestRouter::retry_for_invoice(&invoice)),
+               };
+               invoice_payer.handle_event(&event);
+               assert_eq!(*event_handled.borrow(), false);
+               assert_eq!(*payer.attempts.borrow(), 2);
+
+               SinceEpoch::advance(Duration::from_secs(121));
+
+               invoice_payer.handle_event(&event);
+               assert_eq!(*event_handled.borrow(), true);
+               assert_eq!(*payer.attempts.borrow(), 2);
+       }
+
        #[test]
        fn fails_paying_invoice_with_missing_retry_params() {
                let event_handled = core::cell::RefCell::new(false);
@@ -819,7 +937,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -851,7 +969,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = expired_invoice(payment_preimage);
@@ -876,7 +994,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -917,7 +1035,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -951,7 +1069,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -987,7 +1105,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(0));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
 
@@ -1026,7 +1144,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, Retry::Attempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -1050,7 +1168,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, Retry::Attempts(0));
 
                match invoice_payer.pay_invoice(&invoice) {
                        Err(PaymentError::Sending(_)) => {},
@@ -1074,7 +1192,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(0));
 
                let payment_id =
                        Some(invoice_payer.pay_zero_value_invoice(&invoice, final_value_msat).unwrap());
@@ -1097,7 +1215,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -1128,7 +1246,7 @@ mod tests {
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_pubkey(
                                pubkey, payment_preimage, final_value_msat, final_cltv_expiry_delta
@@ -1183,7 +1301,7 @@ mod tests {
                }));
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                let event = Event::PaymentPathFailed {
@@ -1219,7 +1337,7 @@ mod tests {
                );
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = invoice_payer.pay_invoice(&invoice).unwrap();
                let event = Event::PaymentPathSuccessful {
@@ -1327,7 +1445,7 @@ mod tests {
 
        impl Score for TestScorer {
                fn channel_penalty_msat(
-                       &self, _short_channel_id: u64, _send_amt: u64, _chan_amt: u64, _source: &NodeId, _target: &NodeId
+                       &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId, _usage: ChannelUsage
                ) -> u64 { 0 }
 
                fn payment_path_failed(&mut self, actual_path: &[&RouteHop], actual_short_channel_id: u64) {
@@ -1554,7 +1672,7 @@ mod tests {
 
                let event_handler = |_: &_| { panic!(); };
                let scorer = RefCell::new(TestScorer::new());
-               let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, RetryAttempts(1));
+               let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, Retry::Attempts(1));
 
                assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager_and_duration_since_epoch(
                        &nodes[1].node, nodes[1].keys_manager, Currency::Bitcoin, Some(100_010_000), "Invoice".to_string(),
@@ -1600,7 +1718,7 @@ mod tests {
 
                let event_handler = |_: &_| { panic!(); };
                let scorer = RefCell::new(TestScorer::new());
-               let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, RetryAttempts(1));
+               let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, Retry::Attempts(1));
 
                assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager_and_duration_since_epoch(
                        &nodes[1].node, nodes[1].keys_manager, Currency::Bitcoin, Some(100_010_000), "Invoice".to_string(),
@@ -1682,7 +1800,7 @@ mod tests {
                        event_checker(event);
                };
                let scorer = RefCell::new(TestScorer::new());
-               let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, RetryAttempts(1));
+               let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, Retry::Attempts(1));
 
                assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager_and_duration_since_epoch(
                        &nodes[1].node, nodes[1].keys_manager, Currency::Bitcoin, Some(100_010_000), "Invoice".to_string(),
diff --git a/lightning-invoice/src/time_utils.rs b/lightning-invoice/src/time_utils.rs
new file mode 120000 (symlink)
index 0000000..5326cff
--- /dev/null
@@ -0,0 +1 @@
+../../lightning/src/util/time.rs
\ No newline at end of file
index c23baf8ad34bccb7ea7efcaa32639a43f4b23c42..28845150c44606f79aced14c8984e72c10317d37 100644 (file)
@@ -213,7 +213,7 @@ mod tests {
 
                // Force close because cooperative close doesn't result in any persisted
                // updates.
-               nodes[0].node.force_close_channel(&nodes[0].node.list_channels()[0].channel_id).unwrap();
+               nodes[0].node.force_close_channel(&nodes[0].node.list_channels()[0].channel_id, &nodes[1].node.get_our_node_id()).unwrap();
                check_closed_event!(nodes[0], 1, ClosureReason::HolderForceClosed);
                check_closed_broadcast!(nodes[0], true);
                check_added_monitors!(nodes[0], 1);
@@ -247,7 +247,7 @@ mod tests {
                let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
                let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
                let chan = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
-               nodes[1].node.force_close_channel(&chan.2).unwrap();
+               nodes[1].node.force_close_channel(&chan.2, &nodes[0].node.get_our_node_id()).unwrap();
                check_closed_event!(nodes[1], 1, ClosureReason::HolderForceClosed);
                let mut added_monitors = nodes[1].chain_monitor.added_monitors.lock().unwrap();
                let update_map = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap();
@@ -286,7 +286,7 @@ mod tests {
                let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
                let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
                let chan = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
-               nodes[1].node.force_close_channel(&chan.2).unwrap();
+               nodes[1].node.force_close_channel(&chan.2, &nodes[0].node.get_our_node_id()).unwrap();
                check_closed_event!(nodes[1], 1, ClosureReason::HolderForceClosed);
                let mut added_monitors = nodes[1].chain_monitor.added_monitors.lock().unwrap();
                let update_map = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap();
diff --git a/lightning-rapid-gossip-sync/Cargo.toml b/lightning-rapid-gossip-sync/Cargo.toml
new file mode 100644 (file)
index 0000000..58446a4
--- /dev/null
@@ -0,0 +1,20 @@
+[package]
+name = "lightning-rapid-gossip-sync"
+version = "0.0.106"
+authors = ["Arik Sosman <git@arik.io>"]
+license = "MIT OR Apache-2.0"
+repository = "https://github.com/lightningdevkit/rust-lightning"
+edition = "2018"
+description = """
+Utility to process gossip routing data from Rapid Gossip Sync Server.
+"""
+
+[features]
+_bench_unstable = []
+
+[dependencies]
+lightning = { version = "0.0.106", path = "../lightning" }
+bitcoin = { version = "0.28.1", default-features = false }
+
+[dev-dependencies]
+lightning = { version = "0.0.106", path = "../lightning", features = ["_test_utils"] }
diff --git a/lightning-rapid-gossip-sync/README.md b/lightning-rapid-gossip-sync/README.md
new file mode 100644 (file)
index 0000000..86d4981
--- /dev/null
@@ -0,0 +1,120 @@
+# lightning-rapid-gossip-sync
+
+This crate exposes functionality for rapid gossip graph syncing, aimed primarily at mobile clients.
+Its server counterpart is the
+[rapid-gossip-sync-server](https://github.com/lightningdevkit/rapid-gossip-sync-server) repository.
+
+## Mechanism
+
+The (presumed) server sends a compressed gossip response containing gossip data. The gossip data is
+formatted compactly, omitting signatures and opportunistically incremental where previous channel
+updates are known.
+
+Essentially, the serialization structure is as follows:
+
+1. Fixed prefix bytes `76, 68, 75, 1` (the first three bytes are ASCII for `LDK`)
+    - The purpose of this prefix is to identify the serialization format, should other rapid gossip
+      sync formats arise in the future
+    - The fourth byte is the protocol version in case our format gets updated
+2. Chain hash (32 bytes)
+3. Latest seen timestamp (`u32`)
+4. An unsigned int indicating the number of node IDs to follow
+5. An array of compressed node ID pubkeys (all pubkeys are presumed to be standard
+   compressed 33-byte-serializations)
+6. An unsigned int indicating the number of channel announcement messages to follow
+7. An array of significantly stripped down customized channel announcements
+8. An unsigned int indicating the number of channel update messages to follow
+9. A series of default values used for non-incremental channel updates
+    - The values are defined as follows:
+        1. `default_cltv_expiry_delta`
+        2. `default_htlc_minimum_msat`
+        3. `default_fee_base_msat`
+        4. `default_fee_proportional_millionths`
+        5. `default_htlc_maximum_msat` (`u64`, and if the default is no maximum, `u64::MAX`)
+    - The defaults are calculated by the server based on the frequency among non-incremental
+      updates within a given delta set
+10. An array of customized channel updates
+
+You will also notice that `NodeAnnouncement` messages are omitted altogether as the node IDs are
+implicitly extracted from the channel announcements and updates.
+
+The data is then applied to the current network graph, artificially dated to the timestamp of the
+latest seen message less one week, be it an announcement or an update, from the server's
+perspective. The network graph should not be pruned until the graph sync completes.
+
+### Custom Channel Announcement
+
+To achieve compactness and avoid data repetition, we're sending a significantly stripped down
+version of the channel announcement message, which contains only the following data:
+
+1. `channel_features`: `u16` + `n`, where `n` is the number of bytes indicated by the first `u16`
+2. `short_channel_id`: `CompactSize` (incremental `CompactSize` deltas starting from 0)
+3. `node_id_1_index`: `CompactSize` (index of node id within the previously sent sequence)
+4. `node_id_2_index`: `CompactSize` (index of node id within the previously sent sequence)
+
+### Custom Channel Update
+
+For the purpose of rapid syncing, we have deviated from the channel update format specified in
+BOLT 7 significantly. Our custom channel updates are structured as follows:
+
+1. `short_channel_id`: `CompactSize` (incremental `CompactSize` deltas starting at 0)
+2. `custom_channel_flags`: `u8`
+3. `update_data`
+
+Specifically, our custom channel flags break down like this:
+
+| 128                 | 64 | 32 | 16 | 8 | 4 | 2                | 1         |
+|---------------------|----|----|----|---|---|------------------|-----------|
+| Incremental update? |    |    |    |   |   | Disable channel? | Direction |
+
+If the most significant bit is set to `1`, indicating an incremental update, the intermediate bit
+flags assume the following meaning:
+
+| 64                              | 32                              | 16                          | 8                                         | 4                               |
+|---------------------------------|---------------------------------|-----------------------------|-------------------------------------------|---------------------------------|
+| `cltv_expiry_delta` has changed | `htlc_minimum_msat` has changed | `fee_base_msat` has changed | `fee_proportional_millionths` has changed | `htlc_maximum_msat` has changed |
+
+If the most significant bit is set to `0`, the meaning is almost identical, except instead of a
+change, the flags now represent a deviation from the defaults sent at the beginning of the update
+sequence.
+
+In both cases, `update_data` only contains the fields that are indicated by the channel flags to be
+non-default or to have mutated.
+
+## Delta Calculation
+
+The way a server is meant to calculate this rapid gossip sync data is by taking the latest time
+any change, be it either an announcement or an update, was seen. That timestamp is included in each
+rapid sync message, so all the client needs to do is cache one variable.
+
+If a particular channel update had never occurred before, the full update is sent. If a channel has
+had updates prior to the provided timestamp, the latest update prior to the timestamp is taken as a
+reference, and the delta is calculated against it.
+
+Depending on whether the rapid sync message is calculated on the fly or a snapshotted version is
+returned, intermediate changes between the latest update seen by the client and the latest update
+broadcast on the network may be taken into account when calculating the delta.
+
+## Performance
+
+Given the primary purpose of this utility is a faster graph sync, we thought it might be helpful to
+provide some examples of various delta sets. These examples were calculated as of May 19th  2022
+with a network graph comprised of 80,000 channel announcements and 160,000 directed channel updates.
+
+| Full sync                   |        |
+|-----------------------------|--------|
+| Message Length              | 4.7 MB |
+| Gzipped Message Length      | 2.0 MB |
+| Client-side Processing Time | 1.4 s  |
+
+| Week-old sync               |        |
+|-----------------------------|--------|
+| Message Length              | 2.7 MB |
+| Gzipped Message Length      | 862 kB |
+| Client-side Processing Time | 907 ms |
+
+| Day-old sync                |         |
+|-----------------------------|---------|
+| Message Length              | 191 kB  |
+| Gzipped Message Length      | 92.8 kB |
+| Client-side Processing Time | 196 ms  |
diff --git a/lightning-rapid-gossip-sync/res/.gitignore b/lightning-rapid-gossip-sync/res/.gitignore
new file mode 100644 (file)
index 0000000..d6b7ef3
--- /dev/null
@@ -0,0 +1,2 @@
+*
+!.gitignore
diff --git a/lightning-rapid-gossip-sync/src/error.rs b/lightning-rapid-gossip-sync/src/error.rs
new file mode 100644 (file)
index 0000000..fee8fea
--- /dev/null
@@ -0,0 +1,40 @@
+use core::fmt::Debug;
+use std::fmt::Formatter;
+use lightning::ln::msgs::{DecodeError, LightningError};
+
+/// All-encompassing standard error type that processing can return
+pub enum GraphSyncError {
+       /// Error trying to read the update data, typically due to an erroneous data length indication
+       /// that is greater than the actual amount of data provided
+       DecodeError(DecodeError),
+       /// Error applying the patch to the network graph, usually the result of updates that are too
+       /// old or missing prerequisite data to the application of updates out of order
+       LightningError(LightningError),
+}
+
+impl From<std::io::Error> for GraphSyncError {
+       fn from(error: std::io::Error) -> Self {
+               Self::DecodeError(DecodeError::Io(error.kind()))
+       }
+}
+
+impl From<DecodeError> for GraphSyncError {
+       fn from(error: DecodeError) -> Self {
+               Self::DecodeError(error)
+       }
+}
+
+impl From<LightningError> for GraphSyncError {
+       fn from(error: LightningError) -> Self {
+               Self::LightningError(error)
+       }
+}
+
+impl Debug for GraphSyncError {
+       fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+               match self {
+                       GraphSyncError::DecodeError(e) => f.write_fmt(format_args!("DecodeError: {:?}", e)),
+                       GraphSyncError::LightningError(e) => f.write_fmt(format_args!("LightningError: {:?}", e))
+               }
+       }
+}
diff --git a/lightning-rapid-gossip-sync/src/lib.rs b/lightning-rapid-gossip-sync/src/lib.rs
new file mode 100644 (file)
index 0000000..123f323
--- /dev/null
@@ -0,0 +1,244 @@
+#![deny(missing_docs)]
+#![deny(unsafe_code)]
+#![deny(broken_intra_doc_links)]
+#![deny(non_upper_case_globals)]
+#![deny(non_camel_case_types)]
+#![deny(non_snake_case)]
+#![deny(unused_mut)]
+#![deny(unused_variables)]
+#![deny(unused_imports)]
+//! This crate exposes functionality to rapidly sync gossip data, aimed primarily at mobile
+//! devices.
+//!
+//! The server sends a compressed response containing differential gossip data. The gossip data is
+//! formatted compactly, omitting signatures and opportunistically incremental where previous
+//! channel updates are known (a mechanism that is enabled when the timestamp of the last known
+//! channel update is communicated). A reference server implementation can be found
+//! [here](https://github.com/lightningdevkit/rapid-gossip-sync-server).
+//!
+//! An example server request could look as simple as the following. Note that the first ever rapid
+//! sync should use `0` for `last_sync_timestamp`:
+//!
+//! ```shell
+//! curl -o rapid_sync.lngossip https://rapidsync.lightningdevkit.org/snapshot/<last_sync_timestamp>
+//! ```
+//!
+//! Then, call the network processing function. In this example, we process the update by reading
+//! its contents from disk, which we do by calling the `sync_network_graph_with_file_path` method:
+//!
+//! ```
+//! use bitcoin::blockdata::constants::genesis_block;
+//! use bitcoin::Network;
+//! use lightning::routing::network_graph::NetworkGraph;
+//!
+//! let block_hash = genesis_block(Network::Bitcoin).header.block_hash();
+//! let network_graph = NetworkGraph::new(block_hash);
+//! let new_last_sync_timestamp_result = lightning_rapid_gossip_sync::sync_network_graph_with_file_path(&network_graph, "./rapid_sync.lngossip");
+//! ```
+//!
+//! The primary benefit this syncing mechanism provides is that given a trusted server, a
+//! low-powered client can offload the validation of gossip signatures. This enables a client to
+//! privately calculate routes for payments, and do so much faster and earlier than requiring a full
+//! peer-to-peer gossip sync to complete.
+//!
+//! The reason the rapid sync server requires trust is that it could provide bogus data, though at
+//! worst, all that would result in is a fake network topology, which wouldn't enable the server to
+//! steal or siphon off funds. It could, however, reduce the client's privacy by forcing all
+//! payments to be routed via channels the server controls.
+//!
+//! The way a server is meant to calculate this rapid gossip sync data is by using a `latest_seen`
+//! timestamp provided by the client. It's not included in either channel announcement or update,
+//! (not least due to announcements not including any timestamps at all, but only a block height)
+//! but rather, it's a timestamp of when the server saw a particular message.
+
+// Allow and import test features for benching
+#![cfg_attr(all(test, feature = "_bench_unstable"), feature(test))]
+#[cfg(all(test, feature = "_bench_unstable"))]
+extern crate test;
+
+use std::fs::File;
+
+use lightning::routing::network_graph;
+
+use crate::error::GraphSyncError;
+
+/// Error types that these functions can return
+pub mod error;
+
+/// Core functionality of this crate
+pub mod processing;
+
+/// Sync gossip data from a file
+/// Returns the last sync timestamp to be used the next time rapid sync data is queried.
+///
+/// `network_graph`: The network graph to apply the updates to
+///
+/// `sync_path`: Path to the file where the gossip update data is located
+///
+pub fn sync_network_graph_with_file_path(
+       network_graph: &network_graph::NetworkGraph,
+       sync_path: &str,
+) -> Result<u32, GraphSyncError> {
+       let mut file = File::open(sync_path)?;
+       processing::update_network_graph_from_byte_stream(&network_graph, &mut file)
+}
+
+#[cfg(test)]
+mod tests {
+       use std::fs;
+
+       use bitcoin::blockdata::constants::genesis_block;
+       use bitcoin::Network;
+
+       use lightning::ln::msgs::DecodeError;
+       use lightning::routing::network_graph::NetworkGraph;
+
+       use crate::sync_network_graph_with_file_path;
+
+       #[test]
+       fn test_sync_from_file() {
+               struct FileSyncTest {
+                       directory: String,
+               }
+
+               impl FileSyncTest {
+                       fn new(tmp_directory: &str, valid_response: &[u8]) -> FileSyncTest {
+                               let test = FileSyncTest { directory: tmp_directory.to_owned() };
+
+                               let graph_sync_test_directory = test.get_test_directory();
+                               fs::create_dir_all(graph_sync_test_directory).unwrap();
+
+                               let graph_sync_test_file = test.get_test_file_path();
+                               fs::write(&graph_sync_test_file, valid_response).unwrap();
+
+                               test
+                       }
+                       fn get_test_directory(&self) -> String {
+                               let graph_sync_test_directory = self.directory.clone() + "/graph-sync-tests";
+                               graph_sync_test_directory
+                       }
+                       fn get_test_file_path(&self) -> String {
+                               let graph_sync_test_directory = self.get_test_directory();
+                               let graph_sync_test_file = graph_sync_test_directory.to_owned() + "/test_data.lngossip";
+                               graph_sync_test_file
+                       }
+               }
+
+               impl Drop for FileSyncTest {
+                       fn drop(&mut self) {
+                               fs::remove_dir_all(self.directory.clone()).unwrap();
+                       }
+               }
+
+               // same as incremental_only_update_fails_without_prior_same_direction_updates
+               let valid_response = vec![
+                       76, 68, 75, 1, 111, 226, 140, 10, 182, 241, 179, 114, 193, 166, 162, 70, 174, 99, 247,
+                       79, 147, 30, 131, 101, 225, 90, 8, 156, 104, 214, 25, 0, 0, 0, 0, 0, 97, 227, 98, 218,
+                       0, 0, 0, 4, 2, 22, 7, 207, 206, 25, 164, 197, 231, 230, 231, 56, 102, 61, 250, 251,
+                       187, 172, 38, 46, 79, 247, 108, 44, 155, 48, 219, 238, 252, 53, 192, 6, 67, 2, 36, 125,
+                       157, 176, 223, 175, 234, 116, 94, 248, 201, 225, 97, 235, 50, 47, 115, 172, 63, 136,
+                       88, 216, 115, 11, 111, 217, 114, 84, 116, 124, 231, 107, 2, 158, 1, 242, 121, 152, 106,
+                       204, 131, 186, 35, 93, 70, 216, 10, 237, 224, 183, 89, 95, 65, 3, 83, 185, 58, 138,
+                       181, 64, 187, 103, 127, 68, 50, 2, 201, 19, 17, 138, 136, 149, 185, 226, 156, 137, 175,
+                       110, 32, 237, 0, 217, 90, 31, 100, 228, 149, 46, 219, 175, 168, 77, 4, 143, 38, 128,
+                       76, 97, 0, 0, 0, 2, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 0, 1, 0, 0, 255, 2, 68,
+                       226, 0, 6, 11, 0, 1, 2, 3, 0, 0, 0, 2, 0, 40, 0, 0, 0, 0, 0, 0, 3, 232, 0, 0, 3, 232,
+                       0, 0, 0, 1, 0, 0, 0, 0, 58, 85, 116, 216, 255, 8, 153, 192, 0, 2, 27, 0, 0, 25, 0, 0,
+                       0, 1, 0, 0, 0, 125, 255, 2, 68, 226, 0, 6, 11, 0, 1, 5, 0, 0, 0, 0, 29, 129, 25, 192,
+               ];
+
+               let tmp_directory = "./rapid-gossip-sync-tests-tmp";
+               let sync_test = FileSyncTest::new(tmp_directory, &valid_response);
+               let graph_sync_test_file = sync_test.get_test_file_path();
+
+               let block_hash = genesis_block(Network::Bitcoin).block_hash();
+               let network_graph = NetworkGraph::new(block_hash);
+
+               assert_eq!(network_graph.read_only().channels().len(), 0);
+
+               let sync_result = sync_network_graph_with_file_path(&network_graph, &graph_sync_test_file);
+
+               if sync_result.is_err() {
+                       panic!("Unexpected sync result: {:?}", sync_result)
+               }
+
+               assert_eq!(network_graph.read_only().channels().len(), 2);
+               let after = network_graph.to_string();
+               assert!(
+                       after.contains("021607cfce19a4c5e7e6e738663dfafbbbac262e4ff76c2c9b30dbeefc35c00643")
+               );
+               assert!(
+                       after.contains("02247d9db0dfafea745ef8c9e161eb322f73ac3f8858d8730b6fd97254747ce76b")
+               );
+               assert!(
+                       after.contains("029e01f279986acc83ba235d46d80aede0b7595f410353b93a8ab540bb677f4432")
+               );
+               assert!(
+                       after.contains("02c913118a8895b9e29c89af6e20ed00d95a1f64e4952edbafa84d048f26804c61")
+               );
+               assert!(after.contains("619737530008010752"));
+               assert!(after.contains("783241506229452801"));
+       }
+
+       #[test]
+       fn measure_native_read_from_file() {
+               let block_hash = genesis_block(Network::Bitcoin).block_hash();
+               let network_graph = NetworkGraph::new(block_hash);
+
+               assert_eq!(network_graph.read_only().channels().len(), 0);
+
+               let start = std::time::Instant::now();
+               let sync_result =
+                       sync_network_graph_with_file_path(&network_graph, "./res/full_graph.lngossip");
+               if let Err(crate::error::GraphSyncError::DecodeError(DecodeError::Io(io_error))) = &sync_result {
+                       let error_string = format!("Input file lightning-graph-sync/res/full_graph.lngossip is missing! Download it from https://bitcoin.ninja/ldk-compressed_graph-bc08df7542-2022-05-05.bin\n\n{:?}", io_error);
+                       #[cfg(not(require_route_graph_test))]
+                       {
+                               println!("{}", error_string);
+                               return;
+                       }
+                       #[cfg(require_route_graph_test)]
+                       panic!("{}", error_string);
+               }
+               let elapsed = start.elapsed();
+               println!("initialization duration: {:?}", elapsed);
+               if sync_result.is_err() {
+                       panic!("Unexpected sync result: {:?}", sync_result)
+               }
+       }
+}
+
+#[cfg(all(test, feature = "_bench_unstable"))]
+pub mod bench {
+       use test::Bencher;
+
+       use bitcoin::blockdata::constants::genesis_block;
+       use bitcoin::Network;
+
+       use lightning::ln::msgs::DecodeError;
+       use lightning::routing::network_graph::NetworkGraph;
+
+       use crate::sync_network_graph_with_file_path;
+
+       #[bench]
+       fn bench_reading_full_graph_from_file(b: &mut Bencher) {
+               let block_hash = genesis_block(Network::Bitcoin).block_hash();
+               b.iter(|| {
+                       let network_graph = NetworkGraph::new(block_hash);
+                       let sync_result = sync_network_graph_with_file_path(
+                               &network_graph,
+                               "./res/full_graph.lngossip",
+                       );
+                       if let Err(crate::error::GraphSyncError::DecodeError(DecodeError::Io(io_error))) = &sync_result {
+                               let error_string = format!("Input file lightning-graph-sync/res/full_graph.lngossip is missing! Download it from https://bitcoin.ninja/ldk-compressed_graph-bc08df7542-2022-05-05.bin\n\n{:?}", io_error);
+                               #[cfg(not(require_route_graph_test))]
+                               {
+                                       println!("{}", error_string);
+                                       return;
+                               }
+                               panic!("{}", error_string);
+                       }
+                       assert!(sync_result.is_ok())
+               });
+       }
+}
diff --git a/lightning-rapid-gossip-sync/src/processing.rs b/lightning-rapid-gossip-sync/src/processing.rs
new file mode 100644 (file)
index 0000000..ceb8b82
--- /dev/null
@@ -0,0 +1,499 @@
+use std::cmp::max;
+use std::io;
+use std::io::Read;
+
+use bitcoin::BlockHash;
+use bitcoin::secp256k1::PublicKey;
+
+use lightning::ln::msgs::{
+       DecodeError, ErrorAction, LightningError, OptionalField, UnsignedChannelUpdate,
+};
+use lightning::routing::network_graph;
+use lightning::util::ser::{BigSize, Readable};
+
+use crate::error::GraphSyncError;
+
+/// The purpose of this prefix is to identify the serialization format, should other rapid gossip
+/// sync formats arise in the future.
+///
+/// The fourth byte is the protocol version in case our format gets updated.
+const GOSSIP_PREFIX: [u8; 4] = [76, 68, 75, 1];
+
+/// Maximum vector allocation capacity for distinct node IDs. This constraint is necessary to
+/// avoid malicious updates being able to trigger excessive memory allocation.
+const MAX_INITIAL_NODE_ID_VECTOR_CAPACITY: u32 = 50_000;
+
+/// Update network graph from binary data.
+/// Returns the last sync timestamp to be used the next time rapid sync data is queried.
+///
+/// `network_graph`: network graph to be updated
+///
+/// `update_data`: `&[u8]` binary stream that comprises the update data
+pub fn update_network_graph(
+       network_graph: &network_graph::NetworkGraph,
+       update_data: &[u8],
+) -> Result<u32, GraphSyncError> {
+       let mut read_cursor = io::Cursor::new(update_data);
+       update_network_graph_from_byte_stream(&network_graph, &mut read_cursor)
+}
+
+pub(crate) fn update_network_graph_from_byte_stream<R: Read>(
+       network_graph: &network_graph::NetworkGraph,
+       mut read_cursor: &mut R,
+) -> Result<u32, GraphSyncError> {
+       let mut prefix = [0u8; 4];
+       read_cursor.read_exact(&mut prefix)?;
+
+       match prefix {
+               GOSSIP_PREFIX => {},
+               _ => {
+                       return Err(DecodeError::UnknownVersion.into());
+               }
+       };
+
+       let chain_hash: BlockHash = Readable::read(read_cursor)?;
+       let latest_seen_timestamp: u32 = Readable::read(read_cursor)?;
+       // backdate the applied timestamp by a week
+       let backdated_timestamp = latest_seen_timestamp.saturating_sub(24 * 3600 * 7);
+
+       let node_id_count: u32 = Readable::read(read_cursor)?;
+       let mut node_ids: Vec<PublicKey> = Vec::with_capacity(std::cmp::min(
+               node_id_count,
+               MAX_INITIAL_NODE_ID_VECTOR_CAPACITY,
+       ) as usize);
+       for _ in 0..node_id_count {
+               let current_node_id = Readable::read(read_cursor)?;
+               node_ids.push(current_node_id);
+       }
+
+       let mut previous_scid: u64 = 0;
+       let announcement_count: u32 = Readable::read(read_cursor)?;
+       for _ in 0..announcement_count {
+               let features = Readable::read(read_cursor)?;
+
+               // handle SCID
+               let scid_delta: BigSize = Readable::read(read_cursor)?;
+               let short_channel_id = previous_scid
+                       .checked_add(scid_delta.0)
+                       .ok_or(DecodeError::InvalidValue)?;
+               previous_scid = short_channel_id;
+
+               let node_id_1_index: BigSize = Readable::read(read_cursor)?;
+               let node_id_2_index: BigSize = Readable::read(read_cursor)?;
+               if max(node_id_1_index.0, node_id_2_index.0) >= node_id_count as u64 {
+                       return Err(DecodeError::InvalidValue.into());
+               };
+               let node_id_1 = node_ids[node_id_1_index.0 as usize];
+               let node_id_2 = node_ids[node_id_2_index.0 as usize];
+
+               let announcement_result = network_graph.add_channel_from_partial_announcement(
+                       short_channel_id,
+                       backdated_timestamp as u64,
+                       features,
+                       node_id_1,
+                       node_id_2,
+               );
+               if let Err(lightning_error) = announcement_result {
+                       if let ErrorAction::IgnoreDuplicateGossip = lightning_error.action {
+                               // everything is fine, just a duplicate channel announcement
+                       } else {
+                               return Err(lightning_error.into());
+                       }
+               }
+       }
+
+       previous_scid = 0; // updates start at a new scid
+
+       let update_count: u32 = Readable::read(read_cursor)?;
+       if update_count == 0 {
+               return Ok(latest_seen_timestamp);
+       }
+
+       // obtain default values for non-incremental updates
+       let default_cltv_expiry_delta: u16 = Readable::read(&mut read_cursor)?;
+       let default_htlc_minimum_msat: u64 = Readable::read(&mut read_cursor)?;
+       let default_fee_base_msat: u32 = Readable::read(&mut read_cursor)?;
+       let default_fee_proportional_millionths: u32 = Readable::read(&mut read_cursor)?;
+       let tentative_default_htlc_maximum_msat: u64 = Readable::read(&mut read_cursor)?;
+       let default_htlc_maximum_msat = if tentative_default_htlc_maximum_msat == u64::max_value() {
+               OptionalField::Absent
+       } else {
+               OptionalField::Present(tentative_default_htlc_maximum_msat)
+       };
+
+       for _ in 0..update_count {
+               let scid_delta: BigSize = Readable::read(read_cursor)?;
+               let short_channel_id = previous_scid
+                       .checked_add(scid_delta.0)
+                       .ok_or(DecodeError::InvalidValue)?;
+               previous_scid = short_channel_id;
+
+               let channel_flags: u8 = Readable::read(read_cursor)?;
+
+               // flags are always sent in full, and hence always need updating
+               let standard_channel_flags = channel_flags & 0b_0000_0011;
+
+               let mut synthetic_update = if channel_flags & 0b_1000_0000 == 0 {
+                       // full update, field flags will indicate deviations from the default
+                       UnsignedChannelUpdate {
+                               chain_hash,
+                               short_channel_id,
+                               timestamp: backdated_timestamp,
+                               flags: standard_channel_flags,
+                               cltv_expiry_delta: default_cltv_expiry_delta,
+                               htlc_minimum_msat: default_htlc_minimum_msat,
+                               htlc_maximum_msat: default_htlc_maximum_msat.clone(),
+                               fee_base_msat: default_fee_base_msat,
+                               fee_proportional_millionths: default_fee_proportional_millionths,
+                               excess_data: vec![],
+                       }
+               } else {
+                       // incremental update, field flags will indicate mutated values
+                       let read_only_network_graph = network_graph.read_only();
+                       let channel = read_only_network_graph
+                               .channels()
+                               .get(&short_channel_id)
+                               .ok_or(LightningError {
+                                       err: "Couldn't find channel for update".to_owned(),
+                                       action: ErrorAction::IgnoreError,
+                               })?;
+
+                       let directional_info = channel
+                               .get_directional_info(channel_flags)
+                               .ok_or(LightningError {
+                                       err: "Couldn't find previous directional data for update".to_owned(),
+                                       action: ErrorAction::IgnoreError,
+                               })?;
+
+                       let htlc_maximum_msat =
+                               if let Some(htlc_maximum_msat) = directional_info.htlc_maximum_msat {
+                                       OptionalField::Present(htlc_maximum_msat)
+                               } else {
+                                       OptionalField::Absent
+                               };
+
+                       UnsignedChannelUpdate {
+                               chain_hash,
+                               short_channel_id,
+                               timestamp: backdated_timestamp,
+                               flags: standard_channel_flags,
+                               cltv_expiry_delta: directional_info.cltv_expiry_delta,
+                               htlc_minimum_msat: directional_info.htlc_minimum_msat,
+                               htlc_maximum_msat,
+                               fee_base_msat: directional_info.fees.base_msat,
+                               fee_proportional_millionths: directional_info.fees.proportional_millionths,
+                               excess_data: vec![],
+                       }
+               };
+
+               if channel_flags & 0b_0100_0000 > 0 {
+                       let cltv_expiry_delta: u16 = Readable::read(read_cursor)?;
+                       synthetic_update.cltv_expiry_delta = cltv_expiry_delta;
+               }
+
+               if channel_flags & 0b_0010_0000 > 0 {
+                       let htlc_minimum_msat: u64 = Readable::read(read_cursor)?;
+                       synthetic_update.htlc_minimum_msat = htlc_minimum_msat;
+               }
+
+               if channel_flags & 0b_0001_0000 > 0 {
+                       let fee_base_msat: u32 = Readable::read(read_cursor)?;
+                       synthetic_update.fee_base_msat = fee_base_msat;
+               }
+
+               if channel_flags & 0b_0000_1000 > 0 {
+                       let fee_proportional_millionths: u32 = Readable::read(read_cursor)?;
+                       synthetic_update.fee_proportional_millionths = fee_proportional_millionths;
+               }
+
+               if channel_flags & 0b_0000_0100 > 0 {
+                       let tentative_htlc_maximum_msat: u64 = Readable::read(read_cursor)?;
+                       synthetic_update.htlc_maximum_msat = if tentative_htlc_maximum_msat == u64::max_value()
+                       {
+                               OptionalField::Absent
+                       } else {
+                               OptionalField::Present(tentative_htlc_maximum_msat)
+                       };
+               }
+
+               network_graph.update_channel_unsigned(&synthetic_update)?;
+       }
+
+       Ok(latest_seen_timestamp)
+}
+
+#[cfg(test)]
+mod tests {
+       use bitcoin::blockdata::constants::genesis_block;
+       use bitcoin::Network;
+
+       use lightning::ln::msgs::DecodeError;
+       use lightning::routing::network_graph::NetworkGraph;
+
+       use crate::error::GraphSyncError;
+       use crate::processing::update_network_graph;
+
+       #[test]
+       fn network_graph_fails_to_update_from_clipped_input() {
+               let block_hash = genesis_block(Network::Bitcoin).block_hash();
+               let network_graph = NetworkGraph::new(block_hash);
+
+               let example_input = vec![
+                       76, 68, 75, 1, 111, 226, 140, 10, 182, 241, 179, 114, 193, 166, 162, 70, 174, 99, 247,
+                       79, 147, 30, 131, 101, 225, 90, 8, 156, 104, 214, 25, 0, 0, 0, 0, 0, 97, 227, 98, 218,
+                       0, 0, 0, 4, 2, 22, 7, 207, 206, 25, 164, 197, 231, 230, 231, 56, 102, 61, 250, 251,
+                       187, 172, 38, 46, 79, 247, 108, 44, 155, 48, 219, 238, 252, 53, 192, 6, 67, 2, 36, 125,
+                       157, 176, 223, 175, 234, 116, 94, 248, 201, 225, 97, 235, 50, 47, 115, 172, 63, 136,
+                       88, 216, 115, 11, 111, 217, 114, 84, 116, 124, 231, 107, 2, 158, 1, 242, 121, 152, 106,
+                       204, 131, 186, 35, 93, 70, 216, 10, 237, 224, 183, 89, 95, 65, 3, 83, 185, 58, 138,
+                       181, 64, 187, 103, 127, 68, 50, 2, 201, 19, 17, 138, 136, 149, 185, 226, 156, 137, 175,
+                       110, 32, 237, 0, 217, 90, 31, 100, 228, 149, 46, 219, 175, 168, 77, 4, 143, 38, 128,
+                       76, 97, 0, 0, 0, 2, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 0, 1, 0, 0, 255, 2, 68,
+                       226, 0, 6, 11, 0, 1, 2, 3, 0, 0, 0, 2, 0, 40, 0, 0, 0, 0, 0, 0, 3, 232, 0, 0, 0, 100,
+                       0, 0, 2, 224, 0, 0, 0, 0, 29, 129, 25, 192, 255, 8, 153, 192, 0, 2, 27, 0, 0, 36, 0, 0,
+                       0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 58, 85, 116, 216, 255, 2, 68, 226, 0, 6, 11, 0, 1, 24, 0,
+                       0, 3, 232, 0, 0, 0,
+               ];
+               let update_result = update_network_graph(&network_graph, &example_input[..]);
+               assert!(update_result.is_err());
+               if let Err(GraphSyncError::DecodeError(DecodeError::ShortRead)) = update_result {
+                       // this is the expected error type
+               } else {
+                       panic!("Unexpected update result: {:?}", update_result)
+               }
+       }
+
+       #[test]
+       fn incremental_only_update_fails_without_prior_announcements() {
+               let incremental_update_input = vec![
+                       76, 68, 75, 1, 111, 226, 140, 10, 182, 241, 179, 114, 193, 166, 162, 70, 174, 99, 247,
+                       79, 147, 30, 131, 101, 225, 90, 8, 156, 104, 214, 25, 0, 0, 0, 0, 0, 97, 229, 183, 167,
+                       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                       0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 136, 0, 0, 0, 221, 255, 2,
+                       68, 226, 0, 6, 11, 0, 1, 128,
+               ];
+
+               let block_hash = genesis_block(Network::Bitcoin).block_hash();
+               let network_graph = NetworkGraph::new(block_hash);
+
+               assert_eq!(network_graph.read_only().channels().len(), 0);
+
+               let update_result = update_network_graph(&network_graph, &incremental_update_input[..]);
+               assert!(update_result.is_err());
+               if let Err(GraphSyncError::LightningError(lightning_error)) = update_result {
+                       assert_eq!(lightning_error.err, "Couldn't find channel for update");
+               } else {
+                       panic!("Unexpected update result: {:?}", update_result)
+               }
+       }
+
+       #[test]
+       fn incremental_only_update_fails_without_prior_updates() {
+               let announced_update_input = vec![
+                       76, 68, 75, 1, 111, 226, 140, 10, 182, 241, 179, 114, 193, 166, 162, 70, 174, 99, 247,
+                       79, 147, 30, 131, 101, 225, 90, 8, 156, 104, 214, 25, 0, 0, 0, 0, 0, 97, 229, 183, 167,
+                       0, 0, 0, 4, 2, 22, 7, 207, 206, 25, 164, 197, 231, 230, 231, 56, 102, 61, 250, 251,
+                       187, 172, 38, 46, 79, 247, 108, 44, 155, 48, 219, 238, 252, 53, 192, 6, 67, 2, 36, 125,
+                       157, 176, 223, 175, 234, 116, 94, 248, 201, 225, 97, 235, 50, 47, 115, 172, 63, 136,
+                       88, 216, 115, 11, 111, 217, 114, 84, 116, 124, 231, 107, 2, 158, 1, 242, 121, 152, 106,
+                       204, 131, 186, 35, 93, 70, 216, 10, 237, 224, 183, 89, 95, 65, 3, 83, 185, 58, 138,
+                       181, 64, 187, 103, 127, 68, 50, 2, 201, 19, 17, 138, 136, 149, 185, 226, 156, 137, 175,
+                       110, 32, 237, 0, 217, 90, 31, 100, 228, 149, 46, 219, 175, 168, 77, 4, 143, 38, 128,
+                       76, 97, 0, 0, 0, 2, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 0, 1, 0, 0, 255, 2, 68,
+                       226, 0, 6, 11, 0, 1, 2, 3, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 136, 0, 0, 0, 221, 255,
+                       2, 68, 226, 0, 6, 11, 0, 1, 128,
+               ];
+
+               let block_hash = genesis_block(Network::Bitcoin).block_hash();
+               let network_graph = NetworkGraph::new(block_hash);
+
+               assert_eq!(network_graph.read_only().channels().len(), 0);
+
+               let update_result = update_network_graph(&network_graph, &announced_update_input[..]);
+               assert!(update_result.is_err());
+               if let Err(GraphSyncError::LightningError(lightning_error)) = update_result {
+                       assert_eq!(
+                               lightning_error.err,
+                               "Couldn't find previous directional data for update"
+                       );
+               } else {
+                       panic!("Unexpected update result: {:?}", update_result)
+               }
+       }
+
+       #[test]
+       fn incremental_only_update_fails_without_prior_same_direction_updates() {
+               let initialization_input = vec![
+                       76, 68, 75, 1, 111, 226, 140, 10, 182, 241, 179, 114, 193, 166, 162, 70, 174, 99, 247,
+                       79, 147, 30, 131, 101, 225, 90, 8, 156, 104, 214, 25, 0, 0, 0, 0, 0, 97, 227, 98, 218,
+                       0, 0, 0, 4, 2, 22, 7, 207, 206, 25, 164, 197, 231, 230, 231, 56, 102, 61, 250, 251,
+                       187, 172, 38, 46, 79, 247, 108, 44, 155, 48, 219, 238, 252, 53, 192, 6, 67, 2, 36, 125,
+                       157, 176, 223, 175, 234, 116, 94, 248, 201, 225, 97, 235, 50, 47, 115, 172, 63, 136,
+                       88, 216, 115, 11, 111, 217, 114, 84, 116, 124, 231, 107, 2, 158, 1, 242, 121, 152, 106,
+                       204, 131, 186, 35, 93, 70, 216, 10, 237, 224, 183, 89, 95, 65, 3, 83, 185, 58, 138,
+                       181, 64, 187, 103, 127, 68, 50, 2, 201, 19, 17, 138, 136, 149, 185, 226, 156, 137, 175,
+                       110, 32, 237, 0, 217, 90, 31, 100, 228, 149, 46, 219, 175, 168, 77, 4, 143, 38, 128,
+                       76, 97, 0, 0, 0, 2, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 0, 1, 0, 0, 255, 2, 68,
+                       226, 0, 6, 11, 0, 1, 2, 3, 0, 0, 0, 2, 0, 40, 0, 0, 0, 0, 0, 0, 3, 232, 0, 0, 3, 232,
+                       0, 0, 0, 1, 0, 0, 0, 0, 58, 85, 116, 216, 255, 8, 153, 192, 0, 2, 27, 0, 0, 25, 0, 0,
+                       0, 1, 0, 0, 0, 125, 255, 2, 68, 226, 0, 6, 11, 0, 1, 5, 0, 0, 0, 0, 29, 129, 25, 192,
+               ];
+
+               let block_hash = genesis_block(Network::Bitcoin).block_hash();
+               let network_graph = NetworkGraph::new(block_hash);
+
+               assert_eq!(network_graph.read_only().channels().len(), 0);
+
+               let initialization_result = update_network_graph(&network_graph, &initialization_input[..]);
+               if initialization_result.is_err() {
+                       panic!(
+                               "Unexpected initialization result: {:?}",
+                               initialization_result
+                       )
+               }
+
+               assert_eq!(network_graph.read_only().channels().len(), 2);
+               let initialized = network_graph.to_string();
+               assert!(initialized
+                       .contains("021607cfce19a4c5e7e6e738663dfafbbbac262e4ff76c2c9b30dbeefc35c00643"));
+               assert!(initialized
+                       .contains("02247d9db0dfafea745ef8c9e161eb322f73ac3f8858d8730b6fd97254747ce76b"));
+               assert!(initialized
+                       .contains("029e01f279986acc83ba235d46d80aede0b7595f410353b93a8ab540bb677f4432"));
+               assert!(initialized
+                       .contains("02c913118a8895b9e29c89af6e20ed00d95a1f64e4952edbafa84d048f26804c61"));
+               assert!(initialized.contains("619737530008010752"));
+               assert!(initialized.contains("783241506229452801"));
+
+               let opposite_direction_incremental_update_input = vec![
+                       76, 68, 75, 1, 111, 226, 140, 10, 182, 241, 179, 114, 193, 166, 162, 70, 174, 99, 247,
+                       79, 147, 30, 131, 101, 225, 90, 8, 156, 104, 214, 25, 0, 0, 0, 0, 0, 97, 229, 183, 167,
+                       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                       0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 136, 0, 0, 0, 221, 255, 2,
+                       68, 226, 0, 6, 11, 0, 1, 128,
+               ];
+               let update_result = update_network_graph(
+                       &network_graph,
+                       &opposite_direction_incremental_update_input[..],
+               );
+               assert!(update_result.is_err());
+               if let Err(GraphSyncError::LightningError(lightning_error)) = update_result {
+                       assert_eq!(
+                               lightning_error.err,
+                               "Couldn't find previous directional data for update"
+                       );
+               } else {
+                       panic!("Unexpected update result: {:?}", update_result)
+               }
+       }
+
+       #[test]
+       fn incremental_update_succeeds_with_prior_announcements_and_full_updates() {
+               let initialization_input = vec![
+                       76, 68, 75, 1, 111, 226, 140, 10, 182, 241, 179, 114, 193, 166, 162, 70, 174, 99, 247,
+                       79, 147, 30, 131, 101, 225, 90, 8, 156, 104, 214, 25, 0, 0, 0, 0, 0, 97, 227, 98, 218,
+                       0, 0, 0, 4, 2, 22, 7, 207, 206, 25, 164, 197, 231, 230, 231, 56, 102, 61, 250, 251,
+                       187, 172, 38, 46, 79, 247, 108, 44, 155, 48, 219, 238, 252, 53, 192, 6, 67, 2, 36, 125,
+                       157, 176, 223, 175, 234, 116, 94, 248, 201, 225, 97, 235, 50, 47, 115, 172, 63, 136,
+                       88, 216, 115, 11, 111, 217, 114, 84, 116, 124, 231, 107, 2, 158, 1, 242, 121, 152, 106,
+                       204, 131, 186, 35, 93, 70, 216, 10, 237, 224, 183, 89, 95, 65, 3, 83, 185, 58, 138,
+                       181, 64, 187, 103, 127, 68, 50, 2, 201, 19, 17, 138, 136, 149, 185, 226, 156, 137, 175,
+                       110, 32, 237, 0, 217, 90, 31, 100, 228, 149, 46, 219, 175, 168, 77, 4, 143, 38, 128,
+                       76, 97, 0, 0, 0, 2, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 0, 1, 0, 0, 255, 2, 68,
+                       226, 0, 6, 11, 0, 1, 2, 3, 0, 0, 0, 4, 0, 40, 0, 0, 0, 0, 0, 0, 3, 232, 0, 0, 3, 232,
+                       0, 0, 0, 1, 0, 0, 0, 0, 58, 85, 116, 216, 255, 8, 153, 192, 0, 2, 27, 0, 0, 56, 0, 0,
+                       0, 0, 0, 0, 0, 1, 0, 0, 0, 100, 0, 0, 2, 224, 0, 25, 0, 0, 0, 1, 0, 0, 0, 125, 255, 2,
+                       68, 226, 0, 6, 11, 0, 1, 4, 0, 0, 0, 0, 29, 129, 25, 192, 0, 5, 0, 0, 0, 0, 29, 129,
+                       25, 192,
+               ];
+
+               let block_hash = genesis_block(Network::Bitcoin).block_hash();
+               let network_graph = NetworkGraph::new(block_hash);
+
+               assert_eq!(network_graph.read_only().channels().len(), 0);
+
+               let initialization_result = update_network_graph(&network_graph, &initialization_input[..]);
+               assert!(initialization_result.is_ok());
+
+               let single_direction_incremental_update_input = vec![
+                       76, 68, 75, 1, 111, 226, 140, 10, 182, 241, 179, 114, 193, 166, 162, 70, 174, 99, 247,
+                       79, 147, 30, 131, 101, 225, 90, 8, 156, 104, 214, 25, 0, 0, 0, 0, 0, 97, 229, 183, 167,
+                       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                       0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 136, 0, 0, 0, 221, 255, 2,
+                       68, 226, 0, 6, 11, 0, 1, 128,
+               ];
+               let update_result = update_network_graph(
+                       &network_graph,
+                       &single_direction_incremental_update_input[..],
+               );
+               if update_result.is_err() {
+                       panic!("Unexpected update result: {:?}", update_result)
+               }
+
+               assert_eq!(network_graph.read_only().channels().len(), 2);
+               let after = network_graph.to_string();
+               assert!(
+                       after.contains("021607cfce19a4c5e7e6e738663dfafbbbac262e4ff76c2c9b30dbeefc35c00643")
+               );
+               assert!(
+                       after.contains("02247d9db0dfafea745ef8c9e161eb322f73ac3f8858d8730b6fd97254747ce76b")
+               );
+               assert!(
+                       after.contains("029e01f279986acc83ba235d46d80aede0b7595f410353b93a8ab540bb677f4432")
+               );
+               assert!(
+                       after.contains("02c913118a8895b9e29c89af6e20ed00d95a1f64e4952edbafa84d048f26804c61")
+               );
+               assert!(after.contains("619737530008010752"));
+               assert!(after.contains("783241506229452801"));
+       }
+
+       #[test]
+       fn full_update_succeeds() {
+               let valid_input = vec![
+                       76, 68, 75, 1, 111, 226, 140, 10, 182, 241, 179, 114, 193, 166, 162, 70, 174, 99, 247,
+                       79, 147, 30, 131, 101, 225, 90, 8, 156, 104, 214, 25, 0, 0, 0, 0, 0, 97, 227, 98, 218,
+                       0, 0, 0, 4, 2, 22, 7, 207, 206, 25, 164, 197, 231, 230, 231, 56, 102, 61, 250, 251,
+                       187, 172, 38, 46, 79, 247, 108, 44, 155, 48, 219, 238, 252, 53, 192, 6, 67, 2, 36, 125,
+                       157, 176, 223, 175, 234, 116, 94, 248, 201, 225, 97, 235, 50, 47, 115, 172, 63, 136,
+                       88, 216, 115, 11, 111, 217, 114, 84, 116, 124, 231, 107, 2, 158, 1, 242, 121, 152, 106,
+                       204, 131, 186, 35, 93, 70, 216, 10, 237, 224, 183, 89, 95, 65, 3, 83, 185, 58, 138,
+                       181, 64, 187, 103, 127, 68, 50, 2, 201, 19, 17, 138, 136, 149, 185, 226, 156, 137, 175,
+                       110, 32, 237, 0, 217, 90, 31, 100, 228, 149, 46, 219, 175, 168, 77, 4, 143, 38, 128,
+                       76, 97, 0, 0, 0, 2, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 0, 1, 0, 0, 255, 2, 68,
+                       226, 0, 6, 11, 0, 1, 2, 3, 0, 0, 0, 4, 0, 40, 0, 0, 0, 0, 0, 0, 3, 232, 0, 0, 3, 232,
+                       0, 0, 0, 1, 0, 0, 0, 0, 29, 129, 25, 192, 255, 8, 153, 192, 0, 2, 27, 0, 0, 60, 0, 0,
+                       0, 0, 0, 0, 0, 1, 0, 0, 0, 100, 0, 0, 2, 224, 0, 0, 0, 0, 58, 85, 116, 216, 0, 29, 0,
+                       0, 0, 1, 0, 0, 0, 125, 0, 0, 0, 0, 58, 85, 116, 216, 255, 2, 68, 226, 0, 6, 11, 0, 1,
+                       0, 0, 1,
+               ];
+
+               let block_hash = genesis_block(Network::Bitcoin).block_hash();
+               let network_graph = NetworkGraph::new(block_hash);
+
+               assert_eq!(network_graph.read_only().channels().len(), 0);
+
+               let update_result = update_network_graph(&network_graph, &valid_input[..]);
+               if update_result.is_err() {
+                       panic!("Unexpected update result: {:?}", update_result)
+               }
+
+               assert_eq!(network_graph.read_only().channels().len(), 2);
+               let after = network_graph.to_string();
+               assert!(
+                       after.contains("021607cfce19a4c5e7e6e738663dfafbbbac262e4ff76c2c9b30dbeefc35c00643")
+               );
+               assert!(
+                       after.contains("02247d9db0dfafea745ef8c9e161eb322f73ac3f8858d8730b6fd97254747ce76b")
+               );
+               assert!(
+                       after.contains("029e01f279986acc83ba235d46d80aede0b7595f410353b93a8ab540bb677f4432")
+               );
+               assert!(
+                       after.contains("02c913118a8895b9e29c89af6e20ed00d95a1f64e4952edbafa84d048f26804c61")
+               );
+               assert!(after.contains("619737530008010752"));
+               assert!(after.contains("783241506229452801"));
+       }
+}
index aae260e735bdbae5c0a538af8024e7e2ef2a78cb..503e6bdee0669551d1853932447a5db08fc92c17 100644 (file)
@@ -235,7 +235,7 @@ pub struct ChainMonitor<ChannelSigner: Sign, C: Deref, T: Deref, F: Deref, L: De
        persister: P,
        /// "User-provided" (ie persistence-completion/-failed) [`MonitorEvent`]s. These came directly
        /// from the user and not from a [`ChannelMonitor`].
-       pending_monitor_events: Mutex<Vec<MonitorEvent>>,
+       pending_monitor_events: Mutex<Vec<(OutPoint, Vec<MonitorEvent>)>>,
        /// The best block height seen, used as a proxy for the passage of time.
        highest_chain_height: AtomicUsize,
 }
@@ -299,7 +299,7 @@ where C::Target: chain::Filter,
                                                        log_trace!(self.logger, "Finished syncing Channel Monitor for channel {}", log_funding_info!(monitor)),
                                                Err(ChannelMonitorUpdateErr::PermanentFailure) => {
                                                        monitor_state.channel_perm_failed.store(true, Ordering::Release);
-                                                       self.pending_monitor_events.lock().unwrap().push(MonitorEvent::UpdateFailed(*funding_outpoint));
+                                                       self.pending_monitor_events.lock().unwrap().push((*funding_outpoint, vec![MonitorEvent::UpdateFailed(*funding_outpoint)]));
                                                },
                                                Err(ChannelMonitorUpdateErr::TemporaryFailure) => {
                                                        log_debug!(self.logger, "Channel Monitor sync for channel {} in progress, holding events until completion!", log_funding_info!(monitor));
@@ -455,10 +455,10 @@ where C::Target: chain::Filter,
                                        // UpdateCompleted event.
                                        return Ok(());
                                }
-                               self.pending_monitor_events.lock().unwrap().push(MonitorEvent::UpdateCompleted {
+                               self.pending_monitor_events.lock().unwrap().push((funding_txo, vec![MonitorEvent::UpdateCompleted {
                                        funding_txo,
                                        monitor_update_id: monitor_data.monitor.get_latest_update_id(),
-                               });
+                               }]));
                        },
                        MonitorUpdateId { contents: UpdateOrigin::ChainSync(_) } => {
                                if !monitor_data.has_pending_chainsync_updates(&pending_monitor_updates) {
@@ -476,10 +476,10 @@ where C::Target: chain::Filter,
        /// channel_monitor_updated once with the highest ID.
        #[cfg(any(test, fuzzing))]
        pub fn force_channel_monitor_updated(&self, funding_txo: OutPoint, monitor_update_id: u64) {
-               self.pending_monitor_events.lock().unwrap().push(MonitorEvent::UpdateCompleted {
+               self.pending_monitor_events.lock().unwrap().push((funding_txo, vec![MonitorEvent::UpdateCompleted {
                        funding_txo,
                        monitor_update_id,
-               });
+               }]));
        }
 
        #[cfg(any(test, fuzzing, feature = "_test_utils"))]
@@ -666,7 +666,7 @@ where C::Target: chain::Filter,
                }
        }
 
-       fn release_pending_monitor_events(&self) -> Vec<MonitorEvent> {
+       fn release_pending_monitor_events(&self) -> Vec<(OutPoint, Vec<MonitorEvent>)> {
                let mut pending_monitor_events = self.pending_monitor_events.lock().unwrap().split_off(0);
                for monitor_state in self.monitors.read().unwrap().values() {
                        let is_pending_monitor_update = monitor_state.has_pending_chainsync_updates(&monitor_state.pending_monitor_updates.lock().unwrap());
@@ -692,7 +692,11 @@ where C::Target: chain::Filter,
                                        log_error!(self.logger, "   To avoid funds-loss, we are allowing monitor updates to be released.");
                                        log_error!(self.logger, "   This may cause duplicate payment events to be generated.");
                                }
-                               pending_monitor_events.append(&mut monitor_state.monitor.get_and_clear_pending_monitor_events());
+                               let monitor_events = monitor_state.monitor.get_and_clear_pending_monitor_events();
+                               if monitor_events.len() > 0 {
+                                       let monitor_outpoint = monitor_state.monitor.get_funding_txo().0;
+                                       pending_monitor_events.push((monitor_outpoint, monitor_events));
+                               }
                        }
                }
                pending_monitor_events
index 738fff3837ca045f5c8c9e3e93e588bb88974191..62dc48607f6a17ecb96ccf7ba268903e7d16c7a5 100644 (file)
@@ -166,11 +166,11 @@ pub struct HTLCUpdate {
        pub(crate) payment_hash: PaymentHash,
        pub(crate) payment_preimage: Option<PaymentPreimage>,
        pub(crate) source: HTLCSource,
-       pub(crate) onchain_value_satoshis: Option<u64>,
+       pub(crate) htlc_value_satoshis: Option<u64>,
 }
 impl_writeable_tlv_based!(HTLCUpdate, {
        (0, payment_hash, required),
-       (1, onchain_value_satoshis, option),
+       (1, htlc_value_satoshis, option),
        (2, source, required),
        (4, payment_preimage, option),
 });
@@ -357,10 +357,10 @@ enum OnchainEvent {
        HTLCUpdate {
                source: HTLCSource,
                payment_hash: PaymentHash,
-               onchain_value_satoshis: Option<u64>,
+               htlc_value_satoshis: Option<u64>,
                /// None in the second case, above, ie when there is no relevant output in the commitment
                /// transaction which appeared on chain.
-               input_idx: Option<u32>,
+               commitment_tx_output_idx: Option<u32>,
        },
        MaturingOutput {
                descriptor: SpendableOutputDescriptor,
@@ -381,7 +381,7 @@ enum OnchainEvent {
        ///  * a revoked-state HTLC transaction was broadcasted, which was claimed by the revocation
        ///    signature.
        HTLCSpendConfirmation {
-               input_idx: u32,
+               commitment_tx_output_idx: u32,
                /// If the claim was made by either party with a preimage, this is filled in
                preimage: Option<PaymentPreimage>,
                /// If the claim was made by us on an inbound HTLC against a local commitment transaction,
@@ -423,9 +423,9 @@ impl MaybeReadable for OnchainEventEntry {
 impl_writeable_tlv_based_enum_upgradable!(OnchainEvent,
        (0, HTLCUpdate) => {
                (0, source, required),
-               (1, onchain_value_satoshis, option),
+               (1, htlc_value_satoshis, option),
                (2, payment_hash, required),
-               (3, input_idx, option),
+               (3, commitment_tx_output_idx, option),
        },
        (1, MaturingOutput) => {
                (0, descriptor, required),
@@ -434,7 +434,7 @@ impl_writeable_tlv_based_enum_upgradable!(OnchainEvent,
                (0, on_local_output_csv, option),
        },
        (5, HTLCSpendConfirmation) => {
-               (0, input_idx, required),
+               (0, commitment_tx_output_idx, required),
                (2, preimage, option),
                (4, on_to_local_output_csv, option),
        },
@@ -452,7 +452,7 @@ pub(crate) enum ChannelMonitorUpdateStep {
                commitment_txid: Txid,
                htlc_outputs: Vec<(HTLCOutputInCommitment, Option<Box<HTLCSource>>)>,
                commitment_number: u64,
-               their_revocation_point: PublicKey,
+               their_per_commitment_point: PublicKey,
        },
        PaymentPreimage {
                payment_preimage: PaymentPreimage,
@@ -494,7 +494,7 @@ impl_writeable_tlv_based_enum_upgradable!(ChannelMonitorUpdateStep,
        (1, LatestCounterpartyCommitmentTXInfo) => {
                (0, commitment_txid, required),
                (2, commitment_number, required),
-               (4, their_revocation_point, required),
+               (4, their_per_commitment_point, required),
                (6, htlc_outputs, vec_type),
        },
        (2, PaymentPreimage) => {
@@ -568,13 +568,13 @@ pub enum Balance {
 /// An HTLC which has been irrevocably resolved on-chain, and has reached ANTI_REORG_DELAY.
 #[derive(PartialEq)]
 struct IrrevocablyResolvedHTLC {
-       input_idx: u32,
+       commitment_tx_output_idx: u32,
        /// Only set if the HTLC claim was ours using a payment preimage
        payment_preimage: Option<PaymentPreimage>,
 }
 
 impl_writeable_tlv_based!(IrrevocablyResolvedHTLC, {
-       (0, input_idx, required),
+       (0, commitment_tx_output_idx, required),
        (2, payment_preimage, option),
 });
 
@@ -619,8 +619,8 @@ pub(crate) struct ChannelMonitorImpl<Signer: Sign> {
        counterparty_commitment_params: CounterpartyCommitmentParameters,
        funding_redeemscript: Script,
        channel_value_satoshis: u64,
-       // first is the idx of the first of the two revocation points
-       their_cur_revocation_points: Option<(u64, PublicKey, Option<PublicKey>)>,
+       // first is the idx of the first of the two per-commitment points
+       their_cur_per_commitment_points: Option<(u64, PublicKey, Option<PublicKey>)>,
 
        on_holder_tx_csv: u16,
 
@@ -753,7 +753,7 @@ impl<Signer: Sign> PartialEq for ChannelMonitorImpl<Signer> {
                        self.counterparty_commitment_params != other.counterparty_commitment_params ||
                        self.funding_redeemscript != other.funding_redeemscript ||
                        self.channel_value_satoshis != other.channel_value_satoshis ||
-                       self.their_cur_revocation_points != other.their_cur_revocation_points ||
+                       self.their_cur_per_commitment_points != other.their_cur_per_commitment_points ||
                        self.on_holder_tx_csv != other.on_holder_tx_csv ||
                        self.commitment_secrets != other.commitment_secrets ||
                        self.counterparty_claimable_outpoints != other.counterparty_claimable_outpoints ||
@@ -828,7 +828,7 @@ impl<Signer: Sign> Writeable for ChannelMonitorImpl<Signer> {
                self.funding_redeemscript.write(writer)?;
                self.channel_value_satoshis.write(writer)?;
 
-               match self.their_cur_revocation_points {
+               match self.their_cur_per_commitment_points {
                        Some((idx, pubkey, second_option)) => {
                                writer.write_all(&byte_utils::be48_to_array(idx))?;
                                writer.write_all(&pubkey.serialize())?;
@@ -1020,7 +1020,7 @@ impl<Signer: Sign> ChannelMonitor<Signer> {
                                counterparty_commitment_params,
                                funding_redeemscript,
                                channel_value_satoshis,
-                               their_cur_revocation_points: None,
+                               their_cur_per_commitment_points: None,
 
                                on_holder_tx_csv: counterparty_channel_parameters.selected_contest_delay,
 
@@ -1070,11 +1070,11 @@ impl<Signer: Sign> ChannelMonitor<Signer> {
                txid: Txid,
                htlc_outputs: Vec<(HTLCOutputInCommitment, Option<Box<HTLCSource>>)>,
                commitment_number: u64,
-               their_revocation_point: PublicKey,
+               their_per_commitment_point: PublicKey,
                logger: &L,
        ) where L::Target: Logger {
                self.inner.lock().unwrap().provide_latest_counterparty_commitment_tx(
-                       txid, htlc_outputs, commitment_number, their_revocation_point, logger)
+                       txid, htlc_outputs, commitment_number, their_per_commitment_point, logger)
        }
 
        #[cfg(test)]
@@ -1391,10 +1391,10 @@ impl<Signer: Sign> ChannelMonitor<Signer> {
                macro_rules! walk_htlcs {
                        ($holder_commitment: expr, $htlc_iter: expr) => {
                                for htlc in $htlc_iter {
-                                       if let Some(htlc_input_idx) = htlc.transaction_output_index {
+                                       if let Some(htlc_commitment_tx_output_idx) = htlc.transaction_output_index {
                                                if let Some(conf_thresh) = us.onchain_events_awaiting_threshold_conf.iter().find_map(|event| {
                                                        if let OnchainEvent::MaturingOutput { descriptor: SpendableOutputDescriptor::DelayedPaymentOutput(descriptor) } = &event.event {
-                                                               if descriptor.outpoint.index as u32 == htlc_input_idx { Some(event.confirmation_threshold()) } else { None }
+                                                               if descriptor.outpoint.index as u32 == htlc_commitment_tx_output_idx { Some(event.confirmation_threshold()) } else { None }
                                                        } else { None }
                                                }) {
                                                        debug_assert!($holder_commitment);
@@ -1402,7 +1402,7 @@ impl<Signer: Sign> ChannelMonitor<Signer> {
                                                                claimable_amount_satoshis: htlc.amount_msat / 1000,
                                                                confirmation_height: conf_thresh,
                                                        });
-                                               } else if us.htlcs_resolved_on_chain.iter().any(|v| v.input_idx == htlc_input_idx) {
+                                               } else if us.htlcs_resolved_on_chain.iter().any(|v| v.commitment_tx_output_idx == htlc_commitment_tx_output_idx) {
                                                        // Funding transaction spends should be fully confirmed by the time any
                                                        // HTLC transactions are resolved, unless we're talking about a holder
                                                        // commitment tx, whose resolution is delayed until the CSV timeout is
@@ -1414,8 +1414,9 @@ impl<Signer: Sign> ChannelMonitor<Signer> {
                                                        // indicating we have spent this HTLC with a timeout, claiming it back
                                                        // and awaiting confirmations on it.
                                                        let htlc_update_pending = us.onchain_events_awaiting_threshold_conf.iter().find_map(|event| {
-                                                               if let OnchainEvent::HTLCUpdate { input_idx: Some(input_idx), .. } = event.event {
-                                                                       if input_idx == htlc_input_idx { Some(event.confirmation_threshold()) } else { None }
+                                                               if let OnchainEvent::HTLCUpdate { commitment_tx_output_idx: Some(commitment_tx_output_idx), .. } = event.event {
+                                                                       if commitment_tx_output_idx == htlc_commitment_tx_output_idx {
+                                                                               Some(event.confirmation_threshold()) } else { None }
                                                                } else { None }
                                                        });
                                                        if let Some(conf_thresh) = htlc_update_pending {
@@ -1436,8 +1437,8 @@ impl<Signer: Sign> ChannelMonitor<Signer> {
                                                        // preimage, we lost funds to our counterparty! We will then continue
                                                        // to show it as ContentiousClaimable until ANTI_REORG_DELAY.
                                                        let htlc_spend_pending = us.onchain_events_awaiting_threshold_conf.iter().find_map(|event| {
-                                                               if let OnchainEvent::HTLCSpendConfirmation { input_idx, preimage, .. } = event.event {
-                                                                       if input_idx == htlc_input_idx {
+                                                               if let OnchainEvent::HTLCSpendConfirmation { commitment_tx_output_idx, preimage, .. } = event.event {
+                                                                       if commitment_tx_output_idx == htlc_commitment_tx_output_idx {
                                                                                Some((event.confirmation_threshold(), preimage.is_some()))
                                                                        } else { None }
                                                                } else { None }
@@ -1546,7 +1547,7 @@ impl<Signer: Sign> ChannelMonitor<Signer> {
                macro_rules! walk_htlcs {
                        ($holder_commitment: expr, $htlc_iter: expr) => {
                                for (htlc, source) in $htlc_iter {
-                                       if us.htlcs_resolved_on_chain.iter().any(|v| Some(v.input_idx) == htlc.transaction_output_index) {
+                                       if us.htlcs_resolved_on_chain.iter().any(|v| Some(v.commitment_tx_output_idx) == htlc.transaction_output_index) {
                                                // We should assert that funding_spend_confirmed is_some() here, but we
                                                // have some unit tests which violate HTLC transaction CSVs entirely and
                                                // would fail.
@@ -1557,17 +1558,17 @@ impl<Signer: Sign> ChannelMonitor<Signer> {
                                                // indicating we have spent this HTLC with a timeout, claiming it back
                                                // and awaiting confirmations on it.
                                                let htlc_update_confd = us.onchain_events_awaiting_threshold_conf.iter().any(|event| {
-                                                       if let OnchainEvent::HTLCUpdate { input_idx: Some(input_idx), .. } = event.event {
+                                                       if let OnchainEvent::HTLCUpdate { commitment_tx_output_idx: Some(commitment_tx_output_idx), .. } = event.event {
                                                                // If the HTLC was timed out, we wait for ANTI_REORG_DELAY blocks
                                                                // before considering it "no longer pending" - this matches when we
                                                                // provide the ChannelManager an HTLC failure event.
-                                                               Some(input_idx) == htlc.transaction_output_index &&
+                                                               Some(commitment_tx_output_idx) == htlc.transaction_output_index &&
                                                                        us.best_block.height() >= event.height + ANTI_REORG_DELAY - 1
-                                                       } else if let OnchainEvent::HTLCSpendConfirmation { input_idx, .. } = event.event {
+                                                       } else if let OnchainEvent::HTLCSpendConfirmation { commitment_tx_output_idx, .. } = event.event {
                                                                // If the HTLC was fulfilled with a preimage, we consider the HTLC
                                                                // immediately non-pending, matching when we provide ChannelManager
                                                                // the preimage.
-                                                               Some(input_idx) == htlc.transaction_output_index
+                                                               Some(commitment_tx_output_idx) == htlc.transaction_output_index
                                                        } else { false }
                                                });
                                                if !htlc_update_confd {
@@ -1688,8 +1689,8 @@ macro_rules! fail_unbroadcast_htlcs {
                                                                event: OnchainEvent::HTLCUpdate {
                                                                        source: (**source).clone(),
                                                                        payment_hash: htlc.payment_hash.clone(),
-                                                                       onchain_value_satoshis: Some(htlc.amount_msat / 1000),
-                                                                       input_idx: None,
+                                                                       htlc_value_satoshis: Some(htlc.amount_msat / 1000),
+                                                                       commitment_tx_output_idx: None,
                                                                },
                                                        };
                                                        log_trace!($logger, "Failing HTLC with payment_hash {} from {} counterparty commitment tx due to broadcast of {} commitment transaction, waiting for confirmation (at height {})",
@@ -1761,7 +1762,7 @@ impl<Signer: Sign> ChannelMonitorImpl<Signer> {
                Ok(())
        }
 
-       pub(crate) fn provide_latest_counterparty_commitment_tx<L: Deref>(&mut self, txid: Txid, htlc_outputs: Vec<(HTLCOutputInCommitment, Option<Box<HTLCSource>>)>, commitment_number: u64, their_revocation_point: PublicKey, logger: &L) where L::Target: Logger {
+       pub(crate) fn provide_latest_counterparty_commitment_tx<L: Deref>(&mut self, txid: Txid, htlc_outputs: Vec<(HTLCOutputInCommitment, Option<Box<HTLCSource>>)>, commitment_number: u64, their_per_commitment_point: PublicKey, logger: &L) where L::Target: Logger {
                // TODO: Encrypt the htlc_outputs data with the single-hash of the commitment transaction
                // so that a remote monitor doesn't learn anything unless there is a malicious close.
                // (only maybe, sadly we cant do the same for local info, as we need to be aware of
@@ -1776,22 +1777,22 @@ impl<Signer: Sign> ChannelMonitorImpl<Signer> {
                self.counterparty_claimable_outpoints.insert(txid, htlc_outputs.clone());
                self.current_counterparty_commitment_number = commitment_number;
                //TODO: Merge this into the other per-counterparty-transaction output storage stuff
-               match self.their_cur_revocation_points {
+               match self.their_cur_per_commitment_points {
                        Some(old_points) => {
                                if old_points.0 == commitment_number + 1 {
-                                       self.their_cur_revocation_points = Some((old_points.0, old_points.1, Some(their_revocation_point)));
+                                       self.their_cur_per_commitment_points = Some((old_points.0, old_points.1, Some(their_per_commitment_point)));
                                } else if old_points.0 == commitment_number + 2 {
                                        if let Some(old_second_point) = old_points.2 {
-                                               self.their_cur_revocation_points = Some((old_points.0 - 1, old_second_point, Some(their_revocation_point)));
+                                               self.their_cur_per_commitment_points = Some((old_points.0 - 1, old_second_point, Some(their_per_commitment_point)));
                                        } else {
-                                               self.their_cur_revocation_points = Some((commitment_number, their_revocation_point, None));
+                                               self.their_cur_per_commitment_points = Some((commitment_number, their_per_commitment_point, None));
                                        }
                                } else {
-                                       self.their_cur_revocation_points = Some((commitment_number, their_revocation_point, None));
+                                       self.their_cur_per_commitment_points = Some((commitment_number, their_per_commitment_point, None));
                                }
                        },
                        None => {
-                               self.their_cur_revocation_points = Some((commitment_number, their_revocation_point, None));
+                               self.their_cur_per_commitment_points = Some((commitment_number, their_per_commitment_point, None));
                        }
                }
                let mut htlcs = Vec::with_capacity(htlc_outputs.len());
@@ -1929,9 +1930,9 @@ impl<Signer: Sign> ChannelMonitorImpl<Signer> {
                                                ret = Err(());
                                        }
                                }
-                               ChannelMonitorUpdateStep::LatestCounterpartyCommitmentTXInfo { commitment_txid, htlc_outputs, commitment_number, their_revocation_point } => {
+                               ChannelMonitorUpdateStep::LatestCounterpartyCommitmentTXInfo { commitment_txid, htlc_outputs, commitment_number, their_per_commitment_point } => {
                                        log_trace!(logger, "Updating ChannelMonitor with latest counterparty commitment transaction info");
-                                       self.provide_latest_counterparty_commitment_tx(*commitment_txid, htlc_outputs.clone(), *commitment_number, *their_revocation_point, logger)
+                                       self.provide_latest_counterparty_commitment_tx(*commitment_txid, htlc_outputs.clone(), *commitment_number, *their_per_commitment_point, logger)
                                },
                                ChannelMonitorUpdateStep::PaymentPreimage { payment_preimage } => {
                                        log_trace!(logger, "Updating ChannelMonitor with payment preimage");
@@ -2120,18 +2121,18 @@ impl<Signer: Sign> ChannelMonitorImpl<Signer> {
        fn get_counterparty_htlc_output_claim_reqs(&self, commitment_number: u64, commitment_txid: Txid, tx: Option<&Transaction>) -> Vec<PackageTemplate> {
                let mut claimable_outpoints = Vec::new();
                if let Some(htlc_outputs) = self.counterparty_claimable_outpoints.get(&commitment_txid) {
-                       if let Some(revocation_points) = self.their_cur_revocation_points {
-                               let revocation_point_option =
+                       if let Some(per_commitment_points) = self.their_cur_per_commitment_points {
+                               let per_commitment_point_option =
                                        // If the counterparty commitment tx is the latest valid state, use their latest
                                        // per-commitment point
-                                       if revocation_points.0 == commitment_number { Some(&revocation_points.1) }
-                                       else if let Some(point) = revocation_points.2.as_ref() {
+                                       if per_commitment_points.0 == commitment_number { Some(&per_commitment_points.1) }
+                                       else if let Some(point) = per_commitment_points.2.as_ref() {
                                                // If counterparty commitment tx is the state previous to the latest valid state, use
                                                // their previous per-commitment point (non-atomicity of revocation means it's valid for
                                                // them to temporarily have two valid commitment txns from our viewpoint)
-                                               if revocation_points.0 == commitment_number + 1 { Some(point) } else { None }
+                                               if per_commitment_points.0 == commitment_number + 1 { Some(point) } else { None }
                                        } else { None };
-                               if let Some(revocation_point) = revocation_point_option {
+                               if let Some(per_commitment_point) = per_commitment_point_option {
                                        for (_, &(ref htlc, _)) in htlc_outputs.iter().enumerate() {
                                                if let Some(transaction_output_index) = htlc.transaction_output_index {
                                                        if let Some(transaction) = tx {
@@ -2142,7 +2143,19 @@ impl<Signer: Sign> ChannelMonitorImpl<Signer> {
                                                        }
                                                        let preimage = if htlc.offered { if let Some(p) = self.payment_preimages.get(&htlc.payment_hash) { Some(*p) } else { None } } else { None };
                                                        if preimage.is_some() || !htlc.offered {
-                                                               let counterparty_htlc_outp = if htlc.offered { PackageSolvingData::CounterpartyOfferedHTLCOutput(CounterpartyOfferedHTLCOutput::build(*revocation_point, self.counterparty_commitment_params.counterparty_delayed_payment_base_key, self.counterparty_commitment_params.counterparty_htlc_base_key, preimage.unwrap(), htlc.clone())) } else { PackageSolvingData::CounterpartyReceivedHTLCOutput(CounterpartyReceivedHTLCOutput::build(*revocation_point, self.counterparty_commitment_params.counterparty_delayed_payment_base_key, self.counterparty_commitment_params.counterparty_htlc_base_key, htlc.clone())) };
+                                                               let counterparty_htlc_outp = if htlc.offered {
+                                                                       PackageSolvingData::CounterpartyOfferedHTLCOutput(
+                                                                               CounterpartyOfferedHTLCOutput::build(*per_commitment_point,
+                                                                                       self.counterparty_commitment_params.counterparty_delayed_payment_base_key,
+                                                                                       self.counterparty_commitment_params.counterparty_htlc_base_key,
+                                                                                       preimage.unwrap(), htlc.clone()))
+                                                               } else {
+                                                                       PackageSolvingData::CounterpartyReceivedHTLCOutput(
+                                                                               CounterpartyReceivedHTLCOutput::build(*per_commitment_point,
+                                                                                       self.counterparty_commitment_params.counterparty_delayed_payment_base_key,
+                                                                                       self.counterparty_commitment_params.counterparty_htlc_base_key,
+                                                                                       htlc.clone()))
+                                                               };
                                                                let aggregation = if !htlc.offered { false } else { true };
                                                                let counterparty_package = PackageTemplate::build_package(commitment_txid, transaction_output_index, counterparty_htlc_outp, htlc.cltv_expiry,aggregation, 0);
                                                                claimable_outpoints.push(counterparty_package);
@@ -2522,7 +2535,7 @@ impl<Signer: Sign> ChannelMonitorImpl<Signer> {
                // Produce actionable events from on-chain events having reached their threshold.
                for entry in onchain_events_reaching_threshold_conf.drain(..) {
                        match entry.event {
-                               OnchainEvent::HTLCUpdate { ref source, payment_hash, onchain_value_satoshis, input_idx } => {
+                               OnchainEvent::HTLCUpdate { ref source, payment_hash, htlc_value_satoshis, commitment_tx_output_idx } => {
                                        // Check for duplicate HTLC resolutions.
                                        #[cfg(debug_assertions)]
                                        {
@@ -2544,10 +2557,10 @@ impl<Signer: Sign> ChannelMonitorImpl<Signer> {
                                                payment_hash,
                                                payment_preimage: None,
                                                source: source.clone(),
-                                               onchain_value_satoshis,
+                                               htlc_value_satoshis,
                                        }));
-                                       if let Some(idx) = input_idx {
-                                               self.htlcs_resolved_on_chain.push(IrrevocablyResolvedHTLC { input_idx: idx, payment_preimage: None });
+                                       if let Some(idx) = commitment_tx_output_idx {
+                                               self.htlcs_resolved_on_chain.push(IrrevocablyResolvedHTLC { commitment_tx_output_idx: idx, payment_preimage: None });
                                        }
                                },
                                OnchainEvent::MaturingOutput { descriptor } => {
@@ -2556,8 +2569,8 @@ impl<Signer: Sign> ChannelMonitorImpl<Signer> {
                                                outputs: vec![descriptor]
                                        });
                                },
-                               OnchainEvent::HTLCSpendConfirmation { input_idx, preimage, .. } => {
-                                       self.htlcs_resolved_on_chain.push(IrrevocablyResolvedHTLC { input_idx, payment_preimage: preimage });
+                               OnchainEvent::HTLCSpendConfirmation { commitment_tx_output_idx, preimage, .. } => {
+                                       self.htlcs_resolved_on_chain.push(IrrevocablyResolvedHTLC { commitment_tx_output_idx, payment_preimage: preimage });
                                },
                                OnchainEvent::FundingSpendConfirmation { .. } => {
                                        self.funding_spend_confirmed = Some(entry.txid);
@@ -2826,7 +2839,7 @@ impl<Signer: Sign> ChannelMonitorImpl<Signer> {
                                                                        self.onchain_events_awaiting_threshold_conf.push(OnchainEventEntry {
                                                                                txid: tx.txid(), height,
                                                                                event: OnchainEvent::HTLCSpendConfirmation {
-                                                                                       input_idx: input.previous_output.vout,
+                                                                                       commitment_tx_output_idx: input.previous_output.vout,
                                                                                        preimage: if accepted_preimage_claim || offered_preimage_claim {
                                                                                                Some(payment_preimage) } else { None },
                                                                                        // If this is a payment to us (!outbound_htlc, above),
@@ -2877,7 +2890,7 @@ impl<Signer: Sign> ChannelMonitorImpl<Signer> {
                                                        txid: tx.txid(),
                                                        height,
                                                        event: OnchainEvent::HTLCSpendConfirmation {
-                                                               input_idx: input.previous_output.vout,
+                                                               commitment_tx_output_idx: input.previous_output.vout,
                                                                preimage: Some(payment_preimage),
                                                                on_to_local_output_csv: None,
                                                        },
@@ -2886,7 +2899,7 @@ impl<Signer: Sign> ChannelMonitorImpl<Signer> {
                                                        source,
                                                        payment_preimage: Some(payment_preimage),
                                                        payment_hash,
-                                                       onchain_value_satoshis: Some(amount_msat / 1000),
+                                                       htlc_value_satoshis: Some(amount_msat / 1000),
                                                }));
                                        }
                                } else if offered_preimage_claim {
@@ -2898,7 +2911,7 @@ impl<Signer: Sign> ChannelMonitorImpl<Signer> {
                                                        txid: tx.txid(),
                                                        height,
                                                        event: OnchainEvent::HTLCSpendConfirmation {
-                                                               input_idx: input.previous_output.vout,
+                                                               commitment_tx_output_idx: input.previous_output.vout,
                                                                preimage: Some(payment_preimage),
                                                                on_to_local_output_csv: None,
                                                        },
@@ -2907,7 +2920,7 @@ impl<Signer: Sign> ChannelMonitorImpl<Signer> {
                                                        source,
                                                        payment_preimage: Some(payment_preimage),
                                                        payment_hash,
-                                                       onchain_value_satoshis: Some(amount_msat / 1000),
+                                                       htlc_value_satoshis: Some(amount_msat / 1000),
                                                }));
                                        }
                                } else {
@@ -2925,8 +2938,8 @@ impl<Signer: Sign> ChannelMonitorImpl<Signer> {
                                                height,
                                                event: OnchainEvent::HTLCUpdate {
                                                        source, payment_hash,
-                                                       onchain_value_satoshis: Some(amount_msat / 1000),
-                                                       input_idx: Some(input.previous_output.vout),
+                                                       htlc_value_satoshis: Some(amount_msat / 1000),
+                                                       commitment_tx_output_idx: Some(input.previous_output.vout),
                                                },
                                        };
                                        log_info!(logger, "Failing HTLC with payment_hash {} timeout by a spend tx, waiting for confirmation (at height {})", log_bytes!(payment_hash.0), entry.confirmation_threshold());
@@ -3094,7 +3107,7 @@ impl<'a, Signer: Sign, K: KeysInterface<Signer = Signer>> ReadableArgs<&'a K>
                let funding_redeemscript = Readable::read(reader)?;
                let channel_value_satoshis = Readable::read(reader)?;
 
-               let their_cur_revocation_points = {
+               let their_cur_per_commitment_points = {
                        let first_idx = <U48 as Readable>::read(reader)?.0;
                        if first_idx == 0 {
                                None
@@ -3283,7 +3296,7 @@ impl<'a, Signer: Sign, K: KeysInterface<Signer = Signer>> ReadableArgs<&'a K>
                                counterparty_commitment_params,
                                funding_redeemscript,
                                channel_value_satoshis,
-                               their_cur_revocation_points,
+                               their_cur_per_commitment_points,
 
                                on_holder_tx_csv,
 
index 24eb09e7090b212ea9867aba9844ed6993c83735..5959508de3b728e29b13d326b76c066d665c48bd 100644 (file)
@@ -76,7 +76,7 @@ pub trait Access {
        /// Returns an error if `genesis_hash` is for a different chain or if such a transaction output
        /// is unknown.
        ///
-       /// [`short_channel_id`]: https://github.com/lightningnetwork/lightning-rfc/blob/master/07-routing-gossip.md#definition-of-short_channel_id
+       /// [`short_channel_id`]: https://github.com/lightning/bolts/blob/master/07-routing-gossip.md#definition-of-short_channel_id
        fn get_utxo(&self, genesis_hash: &BlockHash, short_channel_id: u64) -> Result<TxOut, AccessError>;
 }
 
@@ -302,7 +302,7 @@ pub trait Watch<ChannelSigner: Sign> {
        ///
        /// For details on asynchronous [`ChannelMonitor`] updating and returning
        /// [`MonitorEvent::UpdateCompleted`] here, see [`ChannelMonitorUpdateErr::TemporaryFailure`].
-       fn release_pending_monitor_events(&self) -> Vec<MonitorEvent>;
+       fn release_pending_monitor_events(&self) -> Vec<(OutPoint, Vec<MonitorEvent>)>;
 }
 
 /// The `Filter` trait defines behavior for indicating chain activity of interest pertaining to
index abdc10c577a4f476b70929dad51499b26f1b0cfe..3f88a208e9d00da1976fafd6e42fbbf6c6ace030 100644 (file)
 //! figure out how best to make networking happen/timers fire/things get written to disk/keys get
 //! generated/etc. This makes it a good candidate for tight integration into an existing wallet
 //! instead of having a rather-separate lightning appendage to a wallet.
+//! 
+//! `default` features are:
+//!
+//! * `std` - enables functionalities which require `std`, including `std::io` trait implementations and things which utilize time
+//! * `grind_signatures` - enables generation of [low-r bitcoin signatures](https://bitcoin.stackexchange.com/questions/111660/what-is-signature-grinding),
+//! which saves 1 byte per signature in 50% of the cases (see [bitcoin PR #13666](https://github.com/bitcoin/bitcoin/pull/13666))
+//!
+//! Available features are:
+//!
+//! * `std`
+//! * `grind_signatures`
+//! * `no-std ` - exposes write trait implementations from the `core2` crate (at least one of `no-std` or `std` are required)
+//! * Skip logging of messages at levels below the given log level:
+//!     * `max_level_off`
+//!     * `max_level_error`
+//!     * `max_level_warn`
+//!     * `max_level_info`
+//!     * `max_level_debug`
+//!     * `max_level_trace`
 
 #![cfg_attr(not(any(test, fuzzing, feature = "_test_utils")), deny(missing_docs))]
 #![cfg_attr(not(any(test, fuzzing, feature = "_test_utils")), forbid(unsafe_code))]
@@ -158,7 +177,7 @@ mod sync {
        #[cfg(test)]
        pub use debug_sync::*;
        #[cfg(not(test))]
-       pub use ::std::sync::{Arc, Mutex, Condvar, MutexGuard, RwLock, RwLockReadGuard};
+       pub use ::std::sync::{Arc, Mutex, Condvar, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
        #[cfg(not(test))]
        pub use crate::util::fairrwlock::FairRwLock;
 }
index 9e987c3deec005debbdb81fd7557a6be75e955f1..41d1eff856a71ed88cf8f43c14d8ffe234d27447 100644 (file)
@@ -139,7 +139,7 @@ pub fn build_closing_transaction(to_holder_value_sat: u64, to_counterparty_value
 }
 
 /// Implements the per-commitment secret storage scheme from
-/// [BOLT 3](https://github.com/lightningnetwork/lightning-rfc/blob/dcbf8583976df087c79c3ce0b535311212e6812d/03-transactions.md#efficient-per-commitment-secret-storage).
+/// [BOLT 3](https://github.com/lightning/bolts/blob/dcbf8583976df087c79c3ce0b535311212e6812d/03-transactions.md#efficient-per-commitment-secret-storage).
 ///
 /// Allows us to keep track of all of the revocation secrets of our counterparty in just 50*32 bytes
 /// or so.
index 5706870e6c1cbd9c74e88b89e6cb8fafbe8d6ea8..9eeec319b5ff0c07fbf55aa9fbc2dfc8cdc28d72 100644 (file)
@@ -224,7 +224,7 @@ fn do_test_simple_monitor_temporary_update_fail(disconnect: bool) {
        }
 
        // ...and make sure we can force-close a frozen channel
-       nodes[0].node.force_close_channel(&channel_id).unwrap();
+       nodes[0].node.force_close_channel(&channel_id, &nodes[1].node.get_our_node_id()).unwrap();
        check_added_monitors!(nodes[0], 1);
        check_closed_broadcast!(nodes[0], true);
 
@@ -1102,7 +1102,7 @@ fn test_monitor_update_fail_reestablish() {
        assert!(updates.update_fee.is_none());
        assert_eq!(updates.update_fulfill_htlcs.len(), 1);
        nodes[1].node.handle_update_fulfill_htlc(&nodes[2].node.get_our_node_id(), &updates.update_fulfill_htlcs[0]);
-       expect_payment_forwarded!(nodes[1], nodes[0], Some(1000), false);
+       expect_payment_forwarded!(nodes[1], nodes[0], nodes[2], Some(1000), false, false);
        check_added_monitors!(nodes[1], 1);
        assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty());
        commitment_signed_dance!(nodes[1], nodes[2], updates.commitment_signed, false);
@@ -1809,9 +1809,9 @@ fn do_during_funding_monitor_fail(confirm_a_first: bool, restore_b_before_conf:
        nodes[1].node.handle_open_channel(&nodes[0].node.get_our_node_id(), InitFeatures::known(), &get_event_msg!(nodes[0], MessageSendEvent::SendOpenChannel, nodes[1].node.get_our_node_id()));
        nodes[0].node.handle_accept_channel(&nodes[1].node.get_our_node_id(), InitFeatures::known(), &get_event_msg!(nodes[1], MessageSendEvent::SendAcceptChannel, nodes[0].node.get_our_node_id()));
 
-       let (temporary_channel_id, funding_tx, funding_output) = create_funding_transaction(&nodes[0], 100000, 43);
+       let (temporary_channel_id, funding_tx, funding_output) = create_funding_transaction(&nodes[0], &nodes[1].node.get_our_node_id(), 100000, 43);
 
-       nodes[0].node.funding_transaction_generated(&temporary_channel_id, funding_tx.clone()).unwrap();
+       nodes[0].node.funding_transaction_generated(&temporary_channel_id, &nodes[1].node.get_our_node_id(), funding_tx.clone()).unwrap();
        check_added_monitors!(nodes[0], 0);
 
        chanmon_cfgs[1].persister.set_update_ret(Err(ChannelMonitorUpdateErr::TemporaryFailure));
@@ -2087,7 +2087,7 @@ fn test_fail_htlc_on_broadcast_after_claim() {
        nodes[1].node.handle_update_fulfill_htlc(&nodes[2].node.get_our_node_id(), &cs_updates.update_fulfill_htlcs[0]);
        let bs_updates = get_htlc_update_msgs!(nodes[1], nodes[0].node.get_our_node_id());
        check_added_monitors!(nodes[1], 1);
-       expect_payment_forwarded!(nodes[1], nodes[0], Some(1000), false);
+       expect_payment_forwarded!(nodes[1], nodes[0], nodes[2], Some(1000), false, false);
 
        mine_transaction(&nodes[1], &bs_txn[0]);
        check_closed_event!(nodes[1], 1, ClosureReason::CommitmentTxConfirmed);
@@ -2468,7 +2468,7 @@ fn do_test_reconnect_dup_htlc_claims(htlc_status: HTLCStatusAtDupClaim, second_f
                assert_eq!(fulfill_msg, cs_updates.update_fulfill_htlcs[0]);
        }
        nodes[1].node.handle_update_fulfill_htlc(&nodes[2].node.get_our_node_id(), &fulfill_msg);
-       expect_payment_forwarded!(nodes[1], nodes[0], Some(1000), false);
+       expect_payment_forwarded!(nodes[1], nodes[0], nodes[2], Some(1000), false, false);
        check_added_monitors!(nodes[1], 1);
 
        let mut bs_updates = None;
@@ -2538,7 +2538,7 @@ fn test_temporary_error_during_shutdown() {
        chanmon_cfgs[0].persister.set_update_ret(Err(ChannelMonitorUpdateErr::TemporaryFailure));
        chanmon_cfgs[1].persister.set_update_ret(Err(ChannelMonitorUpdateErr::TemporaryFailure));
 
-       nodes[0].node.close_channel(&channel_id).unwrap();
+       nodes[0].node.close_channel(&channel_id, &nodes[1].node.get_our_node_id()).unwrap();
        nodes[1].node.handle_shutdown(&nodes[0].node.get_our_node_id(), &InitFeatures::known(), &get_event_msg!(nodes[0], MessageSendEvent::SendShutdown, nodes[1].node.get_our_node_id()));
        check_added_monitors!(nodes[1], 1);
 
@@ -2591,7 +2591,7 @@ fn test_permanent_error_during_sending_shutdown() {
        let channel_id = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known()).2;
        chanmon_cfgs[0].persister.set_update_ret(Err(ChannelMonitorUpdateErr::PermanentFailure));
 
-       assert!(nodes[0].node.close_channel(&channel_id).is_ok());
+       assert!(nodes[0].node.close_channel(&channel_id, &nodes[1].node.get_our_node_id()).is_ok());
        check_closed_broadcast!(nodes[0], true);
        check_added_monitors!(nodes[0], 2);
        check_closed_event!(nodes[0], 1, ClosureReason::ProcessingError { err: "ChannelMonitor storage failure".to_string() });
@@ -2612,7 +2612,7 @@ fn test_permanent_error_during_handling_shutdown() {
        let channel_id = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known()).2;
        chanmon_cfgs[1].persister.set_update_ret(Err(ChannelMonitorUpdateErr::PermanentFailure));
 
-       assert!(nodes[0].node.close_channel(&channel_id).is_ok());
+       assert!(nodes[0].node.close_channel(&channel_id, &nodes[1].node.get_our_node_id()).is_ok());
        let shutdown = get_event_msg!(nodes[0], MessageSendEvent::SendShutdown, nodes[1].node.get_our_node_id());
        nodes[1].node.handle_shutdown(&nodes[0].node.get_our_node_id(), &InitFeatures::known(), &shutdown);
        check_closed_broadcast!(nodes[1], true);
index 43032c51a3c09cdc67988e96050835fd7a17d9d7..863cae4624adf9b62db787c3653b1763ff39426b 100644 (file)
@@ -126,7 +126,7 @@ enum InboundHTLCState {
        /// signatures in a commitment_signed message.
        /// Implies AwaitingRemoteRevoke.
        ///
-       /// [BOLT #2]: https://github.com/lightningnetwork/lightning-rfc/blob/master/02-peer-protocol.md
+       /// [BOLT #2]: https://github.com/lightning/bolts/blob/master/02-peer-protocol.md
        AwaitingRemoteRevokeToAnnounce(PendingHTLCStatus),
        /// Included in a received commitment_signed message (implying we've revoke_and_ack'd it).
        /// We have also included this HTLC in our latest commitment_signed and are now just waiting
@@ -775,7 +775,7 @@ pub const MAX_CHAN_DUST_LIMIT_SATOSHIS: u64 = MAX_STD_OUTPUT_DUST_LIMIT_SATOSHIS
 /// In order to avoid having to concern ourselves with standardness during the closing process, we
 /// simply require our counterparty to use a dust limit which will leave any segwit output
 /// standard.
-/// See https://github.com/lightningnetwork/lightning-rfc/issues/905 for more details.
+/// See https://github.com/lightning/bolts/issues/905 for more details.
 pub const MIN_CHAN_DUST_LIMIT_SATOSHIS: u64 = 354;
 
 /// Used to return a simple Error back to ChannelManager. Will get converted to a
@@ -5400,7 +5400,7 @@ impl<Signer: Sign> Channel<Signer> {
                                commitment_txid: counterparty_commitment_txid,
                                htlc_outputs: htlcs.clone(),
                                commitment_number: self.cur_counterparty_commitment_transaction_number,
-                               their_revocation_point: self.counterparty_cur_commitment_point.unwrap()
+                               their_per_commitment_point: self.counterparty_cur_commitment_point.unwrap()
                        }]
                };
                self.channel_state |= ChannelState::AwaitingRemoteRevoke as u32;
index 066cddd42ccdb0245fbadbfba08eea89d6c498b4..528447be364b39fa692727d0a327856715855322 100644 (file)
@@ -1769,17 +1769,18 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                });
        }
 
-       fn close_channel_internal(&self, channel_id: &[u8; 32], target_feerate_sats_per_1000_weight: Option<u32>) -> Result<(), APIError> {
+       fn close_channel_internal(&self, channel_id: &[u8; 32], counterparty_node_id: &PublicKey, target_feerate_sats_per_1000_weight: Option<u32>) -> Result<(), APIError> {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
-               let counterparty_node_id;
                let mut failed_htlcs: Vec<(HTLCSource, PaymentHash)>;
                let result: Result<(), _> = loop {
                        let mut channel_state_lock = self.channel_state.lock().unwrap();
                        let channel_state = &mut *channel_state_lock;
                        match channel_state.by_id.entry(channel_id.clone()) {
                                hash_map::Entry::Occupied(mut chan_entry) => {
-                                       counterparty_node_id = chan_entry.get().get_counterparty_node_id();
+                                       if *counterparty_node_id != chan_entry.get().get_counterparty_node_id(){
+                                               return Err(APIError::APIMisuseError { err: "The passed counterparty_node_id doesn't match the channel's counterparty node_id".to_owned() });
+                                       }
                                        let per_peer_state = self.per_peer_state.read().unwrap();
                                        let (shutdown_msg, monitor_update, htlcs) = match per_peer_state.get(&counterparty_node_id) {
                                                Some(peer_state) => {
@@ -1804,7 +1805,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        }
 
                                        channel_state.pending_msg_events.push(events::MessageSendEvent::SendShutdown {
-                                               node_id: counterparty_node_id,
+                                               node_id: *counterparty_node_id,
                                                msg: shutdown_msg
                                        });
 
@@ -1827,7 +1828,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() });
                }
 
-               let _ = handle_error!(self, result, counterparty_node_id);
+               let _ = handle_error!(self, result, *counterparty_node_id);
                Ok(())
        }
 
@@ -1848,8 +1849,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        /// [`ChannelConfig::force_close_avoidance_max_fee_satoshis`]: crate::util::config::ChannelConfig::force_close_avoidance_max_fee_satoshis
        /// [`Background`]: crate::chain::chaininterface::ConfirmationTarget::Background
        /// [`Normal`]: crate::chain::chaininterface::ConfirmationTarget::Normal
-       pub fn close_channel(&self, channel_id: &[u8; 32]) -> Result<(), APIError> {
-               self.close_channel_internal(channel_id, None)
+       pub fn close_channel(&self, channel_id: &[u8; 32], counterparty_node_id: &PublicKey) -> Result<(), APIError> {
+               self.close_channel_internal(channel_id, counterparty_node_id, None)
        }
 
        /// Begins the process of closing a channel. After this call (plus some timeout), no new HTLCs
@@ -1871,8 +1872,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        /// [`ChannelConfig::force_close_avoidance_max_fee_satoshis`]: crate::util::config::ChannelConfig::force_close_avoidance_max_fee_satoshis
        /// [`Background`]: crate::chain::chaininterface::ConfirmationTarget::Background
        /// [`Normal`]: crate::chain::chaininterface::ConfirmationTarget::Normal
-       pub fn close_channel_with_target_feerate(&self, channel_id: &[u8; 32], target_feerate_sats_per_1000_weight: u32) -> Result<(), APIError> {
-               self.close_channel_internal(channel_id, Some(target_feerate_sats_per_1000_weight))
+       pub fn close_channel_with_target_feerate(&self, channel_id: &[u8; 32], counterparty_node_id: &PublicKey, target_feerate_sats_per_1000_weight: u32) -> Result<(), APIError> {
+               self.close_channel_internal(channel_id, counterparty_node_id, Some(target_feerate_sats_per_1000_weight))
        }
 
        #[inline]
@@ -1891,22 +1892,18 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                }
        }
 
-       /// `peer_node_id` should be set when we receive a message from a peer, but not set when the
+       /// `peer_msg` should be set when we receive a message from a peer, but not set when the
        /// user closes, which will be re-exposed as the `ChannelClosed` reason.
-       fn force_close_channel_with_peer(&self, channel_id: &[u8; 32], peer_node_id: Option<&PublicKey>, peer_msg: Option<&String>) -> Result<PublicKey, APIError> {
+       fn force_close_channel_with_peer(&self, channel_id: &[u8; 32], peer_node_id: &PublicKey, peer_msg: Option<&String>) -> Result<PublicKey, APIError> {
                let mut chan = {
                        let mut channel_state_lock = self.channel_state.lock().unwrap();
                        let channel_state = &mut *channel_state_lock;
                        if let hash_map::Entry::Occupied(chan) = channel_state.by_id.entry(channel_id.clone()) {
-                               if let Some(node_id) = peer_node_id {
-                                       if chan.get().get_counterparty_node_id() != *node_id {
-                                               return Err(APIError::ChannelUnavailable{err: "No such channel".to_owned()});
-                                       }
+                               if chan.get().get_counterparty_node_id() != *peer_node_id {
+                                       return Err(APIError::ChannelUnavailable{err: "No such channel".to_owned()});
                                }
-                               if peer_node_id.is_some() {
-                                       if let Some(peer_msg) = peer_msg {
-                                               self.issue_channel_close_events(chan.get(),ClosureReason::CounterpartyForceClosed { peer_msg: peer_msg.to_string() });
-                                       }
+                               if let Some(peer_msg) = peer_msg {
+                                       self.issue_channel_close_events(chan.get(),ClosureReason::CounterpartyForceClosed { peer_msg: peer_msg.to_string() });
                                } else {
                                        self.issue_channel_close_events(chan.get(),ClosureReason::HolderForceClosed);
                                }
@@ -1928,10 +1925,12 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        }
 
        /// Force closes a channel, immediately broadcasting the latest local commitment transaction to
-       /// the chain and rejecting new HTLCs on the given channel. Fails if channel_id is unknown to the manager.
-       pub fn force_close_channel(&self, channel_id: &[u8; 32]) -> Result<(), APIError> {
+       /// the chain and rejecting new HTLCs on the given channel. Fails if `channel_id` is unknown to
+       /// the manager, or if the `counterparty_node_id` isn't the counterparty of the corresponding
+       /// channel.
+       pub fn force_close_channel(&self, channel_id: &[u8; 32], counterparty_node_id: &PublicKey) -> Result<(), APIError> {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
-               match self.force_close_channel_with_peer(channel_id, None, None) {
+               match self.force_close_channel_with_peer(channel_id, counterparty_node_id, None) {
                        Ok(counterparty_node_id) => {
                                self.channel_state.lock().unwrap().pending_msg_events.push(
                                        events::MessageSendEvent::HandleError {
@@ -1951,7 +1950,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        /// for each to the chain and rejecting new HTLCs on each.
        pub fn force_close_all_channels(&self) {
                for chan in self.list_channels() {
-                       let _ = self.force_close_channel(&chan.channel_id);
+                       let _ = self.force_close_channel(&chan.channel_id, &chan.counterparty.node_id);
                }
        }
 
@@ -2259,7 +2258,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                        msg.cltv_expiry.write(&mut res).expect("Writes cannot fail");
                                                }
                                                else if code == 0x1000 | 20 {
-                                                       // TODO: underspecified, follow https://github.com/lightningnetwork/lightning-rfc/issues/791
+                                                       // TODO: underspecified, follow https://github.com/lightning/bolts/issues/791
                                                        0u16.write(&mut res).expect("Writes cannot fail");
                                                }
                                                (chan_update.serialized_length() as u16 + 2).write(&mut res).expect("Writes cannot fail");
@@ -2687,8 +2686,9 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
        /// Handles the generation of a funding transaction, optionally (for tests) with a function
        /// which checks the correctness of the funding transaction given the associated channel.
-       fn funding_transaction_generated_intern<FundingOutput: Fn(&Channel<Signer>, &Transaction) -> Result<OutPoint, APIError>>
-                       (&self, temporary_channel_id: &[u8; 32], funding_transaction: Transaction, find_funding_output: FundingOutput) -> Result<(), APIError> {
+       fn funding_transaction_generated_intern<FundingOutput: Fn(&Channel<Signer>, &Transaction) -> Result<OutPoint, APIError>>(
+               &self, temporary_channel_id: &[u8; 32], _counterparty_node_id: &PublicKey, funding_transaction: Transaction, find_funding_output: FundingOutput
+       ) -> Result<(), APIError> {
                let (chan, msg) = {
                        let (res, chan) = match self.channel_state.lock().unwrap().by_id.remove(temporary_channel_id) {
                                Some(mut chan) => {
@@ -2729,8 +2729,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        }
 
        #[cfg(test)]
-       pub(crate) fn funding_transaction_generated_unchecked(&self, temporary_channel_id: &[u8; 32], funding_transaction: Transaction, output_index: u16) -> Result<(), APIError> {
-               self.funding_transaction_generated_intern(temporary_channel_id, funding_transaction, |_, tx| {
+       pub(crate) fn funding_transaction_generated_unchecked(&self, temporary_channel_id: &[u8; 32], counterparty_node_id: &PublicKey, funding_transaction: Transaction, output_index: u16) -> Result<(), APIError> {
+               self.funding_transaction_generated_intern(temporary_channel_id, counterparty_node_id, funding_transaction, |_, tx| {
                        Ok(OutPoint { txid: tx.txid(), index: output_index })
                })
        }
@@ -2757,7 +2757,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        ///
        /// [`Event::FundingGenerationReady`]: crate::util::events::Event::FundingGenerationReady
        /// [`Event::ChannelClosed`]: crate::util::events::Event::ChannelClosed
-       pub fn funding_transaction_generated(&self, temporary_channel_id: &[u8; 32], funding_transaction: Transaction) -> Result<(), APIError> {
+       pub fn funding_transaction_generated(&self, temporary_channel_id: &[u8; 32], counterparty_node_id: &PublicKey, funding_transaction: Transaction) -> Result<(), APIError> {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
                for inp in funding_transaction.input.iter() {
@@ -2767,7 +2767,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                });
                        }
                }
-               self.funding_transaction_generated_intern(temporary_channel_id, funding_transaction, |chan, tx| {
+               self.funding_transaction_generated_intern(temporary_channel_id, counterparty_node_id, funding_transaction, |chan, tx| {
                        let mut output_index = None;
                        let expected_spk = chan.get_funding_redeemscript().to_v0_p2wsh();
                        for (idx, outp) in tx.output.iter().enumerate() {
@@ -3958,7 +3958,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                }
        }
 
-       fn claim_funds_internal(&self, mut channel_state_lock: MutexGuard<ChannelHolder<Signer>>, source: HTLCSource, payment_preimage: PaymentPreimage, forwarded_htlc_value_msat: Option<u64>, from_onchain: bool) {
+       fn claim_funds_internal(&self, mut channel_state_lock: MutexGuard<ChannelHolder<Signer>>, source: HTLCSource, payment_preimage: PaymentPreimage, forwarded_htlc_value_msat: Option<u64>, from_onchain: bool, next_channel_id: [u8; 32]) {
                match source {
                        HTLCSource::OutboundRoute { session_priv, payment_id, path, .. } => {
                                mem::drop(channel_state_lock);
@@ -4049,12 +4049,14 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                } else { None };
 
                                                let mut pending_events = self.pending_events.lock().unwrap();
+                                               let prev_channel_id = Some(prev_outpoint.to_channel_id());
+                                               let next_channel_id = Some(next_channel_id);
 
-                                               let source_channel_id = Some(prev_outpoint.to_channel_id());
                                                pending_events.push(events::Event::PaymentForwarded {
-                                                       source_channel_id,
                                                        fee_earned_msat,
                                                        claim_from_onchain_tx: from_onchain,
+                                                       prev_channel_id,
+                                                       next_channel_id,
                                                });
                                        }
                                }
@@ -4112,7 +4114,9 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        /// Called to accept a request to open a channel after [`Event::OpenChannelRequest`] has been
        /// triggered.
        ///
-       /// The `temporary_channel_id` parameter indicates which inbound channel should be accepted.
+       /// The `temporary_channel_id` parameter indicates which inbound channel should be accepted,
+       /// and the `counterparty_node_id` parameter is the id of the peer which has requested to open
+       /// the channel.
        ///
        /// For inbound channels, the `user_channel_id` parameter will be provided back in
        /// [`Event::ChannelClosed::user_channel_id`] to allow tracking of which events correspond
@@ -4120,7 +4124,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        ///
        /// [`Event::OpenChannelRequest`]: events::Event::OpenChannelRequest
        /// [`Event::ChannelClosed::user_channel_id`]: events::Event::ChannelClosed::user_channel_id
-       pub fn accept_inbound_channel(&self, temporary_channel_id: &[u8; 32], user_channel_id: u64) -> Result<(), APIError> {
+       pub fn accept_inbound_channel(&self, temporary_channel_id: &[u8; 32], counterparty_node_id: &PublicKey, user_channel_id: u64) -> Result<(), APIError> {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
                let mut channel_state_lock = self.channel_state.lock().unwrap();
@@ -4130,6 +4134,9 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                if !channel.get().inbound_is_awaiting_accept() {
                                        return Err(APIError::APIMisuseError { err: "The channel isn't currently awaiting to be accepted.".to_owned() });
                                }
+                               if *counterparty_node_id != channel.get().get_counterparty_node_id() {
+                                       return Err(APIError::APIMisuseError { err: "The passed counterparty_node_id doesn't match the channel's counterparty node_id".to_owned() });
+                               }
                                channel_state.pending_msg_events.push(events::MessageSendEvent::SendAcceptChannel {
                                        node_id: channel.get().get_counterparty_node_id(),
                                        msg: channel.get_mut().accept_inbound_channel(user_channel_id),
@@ -4212,6 +4219,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                let mut pending_events = self.pending_events.lock().unwrap();
                pending_events.push(events::Event::FundingGenerationReady {
                        temporary_channel_id: msg.temporary_channel_id,
+                       counterparty_node_id: *counterparty_node_id,
                        channel_value_satoshis: value,
                        output_script,
                        user_channel_id: user_id,
@@ -4507,7 +4515,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                hash_map::Entry::Vacant(_) => return Err(MsgHandleErrInternal::send_err_msg_no_close("Failed to find corresponding channel".to_owned(), msg.channel_id))
                        }
                };
-               self.claim_funds_internal(channel_lock, htlc_source, msg.payment_preimage.clone(), Some(forwarded_htlc_value), false);
+               self.claim_funds_internal(channel_lock, htlc_source, msg.payment_preimage.clone(), Some(forwarded_htlc_value), false, msg.channel_id);
                Ok(())
        }
 
@@ -4827,48 +4835,50 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                let mut failed_channels = Vec::new();
                let mut pending_monitor_events = self.chain_monitor.release_pending_monitor_events();
                let has_pending_monitor_events = !pending_monitor_events.is_empty();
-               for monitor_event in pending_monitor_events.drain(..) {
-                       match monitor_event {
-                               MonitorEvent::HTLCEvent(htlc_update) => {
-                                       if let Some(preimage) = htlc_update.payment_preimage {
-                                               log_trace!(self.logger, "Claiming HTLC with preimage {} from our monitor", log_bytes!(preimage.0));
-                                               self.claim_funds_internal(self.channel_state.lock().unwrap(), htlc_update.source, preimage, htlc_update.onchain_value_satoshis.map(|v| v * 1000), true);
-                                       } else {
-                                               log_trace!(self.logger, "Failing HTLC with hash {} from our monitor", log_bytes!(htlc_update.payment_hash.0));
-                                               self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_update.source, &htlc_update.payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() });
-                                       }
-                               },
-                               MonitorEvent::CommitmentTxConfirmed(funding_outpoint) |
-                               MonitorEvent::UpdateFailed(funding_outpoint) => {
-                                       let mut channel_lock = self.channel_state.lock().unwrap();
-                                       let channel_state = &mut *channel_lock;
-                                       let by_id = &mut channel_state.by_id;
-                                       let pending_msg_events = &mut channel_state.pending_msg_events;
-                                       if let hash_map::Entry::Occupied(chan_entry) = by_id.entry(funding_outpoint.to_channel_id()) {
-                                               let mut chan = remove_channel!(self, channel_state, chan_entry);
-                                               failed_channels.push(chan.force_shutdown(false));
-                                               if let Ok(update) = self.get_channel_update_for_broadcast(&chan) {
-                                                       pending_msg_events.push(events::MessageSendEvent::BroadcastChannelUpdate {
-                                                               msg: update
+               for (funding_outpoint, mut monitor_events) in pending_monitor_events.drain(..) {
+                       for monitor_event in monitor_events.drain(..) {
+                               match monitor_event {
+                                       MonitorEvent::HTLCEvent(htlc_update) => {
+                                               if let Some(preimage) = htlc_update.payment_preimage {
+                                                       log_trace!(self.logger, "Claiming HTLC with preimage {} from our monitor", log_bytes!(preimage.0));
+                                                       self.claim_funds_internal(self.channel_state.lock().unwrap(), htlc_update.source, preimage, htlc_update.htlc_value_satoshis.map(|v| v * 1000), true, funding_outpoint.to_channel_id());
+                                               } else {
+                                                       log_trace!(self.logger, "Failing HTLC with hash {} from our monitor", log_bytes!(htlc_update.payment_hash.0));
+                                                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_update.source, &htlc_update.payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() });
+                                               }
+                                       },
+                                       MonitorEvent::CommitmentTxConfirmed(funding_outpoint) |
+                                       MonitorEvent::UpdateFailed(funding_outpoint) => {
+                                               let mut channel_lock = self.channel_state.lock().unwrap();
+                                               let channel_state = &mut *channel_lock;
+                                               let by_id = &mut channel_state.by_id;
+                                               let pending_msg_events = &mut channel_state.pending_msg_events;
+                                               if let hash_map::Entry::Occupied(chan_entry) = by_id.entry(funding_outpoint.to_channel_id()) {
+                                                       let mut chan = remove_channel!(self, channel_state, chan_entry);
+                                                       failed_channels.push(chan.force_shutdown(false));
+                                                       if let Ok(update) = self.get_channel_update_for_broadcast(&chan) {
+                                                               pending_msg_events.push(events::MessageSendEvent::BroadcastChannelUpdate {
+                                                                       msg: update
+                                                               });
+                                                       }
+                                                       let reason = if let MonitorEvent::UpdateFailed(_) = monitor_event {
+                                                               ClosureReason::ProcessingError { err: "Failed to persist ChannelMonitor update during chain sync".to_string() }
+                                                       } else {
+                                                               ClosureReason::CommitmentTxConfirmed
+                                                       };
+                                                       self.issue_channel_close_events(&chan, reason);
+                                                       pending_msg_events.push(events::MessageSendEvent::HandleError {
+                                                               node_id: chan.get_counterparty_node_id(),
+                                                               action: msgs::ErrorAction::SendErrorMessage {
+                                                                       msg: msgs::ErrorMessage { channel_id: chan.channel_id(), data: "Channel force-closed".to_owned() }
+                                                               },
                                                        });
                                                }
-                                               let reason = if let MonitorEvent::UpdateFailed(_) = monitor_event {
-                                                       ClosureReason::ProcessingError { err: "Failed to persist ChannelMonitor update during chain sync".to_string() }
-                                               } else {
-                                                       ClosureReason::CommitmentTxConfirmed
-                                               };
-                                               self.issue_channel_close_events(&chan, reason);
-                                               pending_msg_events.push(events::MessageSendEvent::HandleError {
-                                                       node_id: chan.get_counterparty_node_id(),
-                                                       action: msgs::ErrorAction::SendErrorMessage {
-                                                               msg: msgs::ErrorMessage { channel_id: chan.channel_id(), data: "Channel force-closed".to_owned() }
-                                                       },
-                                               });
-                                       }
-                               },
-                               MonitorEvent::UpdateCompleted { funding_txo, monitor_update_id } => {
-                                       self.channel_monitor_updated(&funding_txo, monitor_update_id);
-                               },
+                                       },
+                                       MonitorEvent::UpdateCompleted { funding_txo, monitor_update_id } => {
+                                               self.channel_monitor_updated(&funding_txo, monitor_update_id);
+                                       },
+                               }
                        }
                }
 
@@ -5796,7 +5806,7 @@ impl<Signer: Sign, M: Deref , T: Deref , K: Deref , F: Deref , L: Deref >
                        for chan in self.list_channels() {
                                if chan.counterparty.node_id == *counterparty_node_id {
                                        // Untrusted messages from peer, we throw away the error if id points to a non-existent channel
-                                       let _ = self.force_close_channel_with_peer(&chan.channel_id, Some(counterparty_node_id), Some(&msg.data));
+                                       let _ = self.force_close_channel_with_peer(&chan.channel_id, counterparty_node_id, Some(&msg.data));
                                }
                        }
                } else {
@@ -5818,7 +5828,7 @@ impl<Signer: Sign, M: Deref , T: Deref , K: Deref , F: Deref , L: Deref >
                        }
 
                        // Untrusted messages from peer, we throw away the error if id points to a non-existent channel
-                       let _ = self.force_close_channel_with_peer(&msg.channel_id, Some(counterparty_node_id), Some(&msg.data));
+                       let _ = self.force_close_channel_with_peer(&msg.channel_id, counterparty_node_id, Some(&msg.data));
                }
        }
 }
@@ -7406,7 +7416,7 @@ pub mod bench {
                        tx = Transaction { version: 2, lock_time: 0, input: Vec::new(), output: vec![TxOut {
                                value: 8_000_000, script_pubkey: output_script,
                        }]};
-                       node_a.funding_transaction_generated(&temporary_channel_id, tx.clone()).unwrap();
+                       node_a.funding_transaction_generated(&temporary_channel_id, &node_b.get_our_node_id(), tx.clone()).unwrap();
                } else { panic!(); }
 
                node_b.handle_funding_created(&node_a.get_our_node_id(), &get_event_msg!(node_a_holder, MessageSendEvent::SendFundingCreated, node_b.get_our_node_id()));
index 2580874640e438e611bea9b88ddda0384dc9db5d..5c4a94dabfd59c367a91012719d5b5c82de609f8 100644 (file)
@@ -19,7 +19,7 @@
 //! supports a feature if it advertises the feature (as either required or optional) to its peers.
 //! And the implementation can interpret a feature if the feature is known to it.
 //!
-//! [BOLT #9]: https://github.com/lightningnetwork/lightning-rfc/blob/master/09-features.md
+//! [BOLT #9]: https://github.com/lightning/bolts/blob/master/09-features.md
 //! [messages]: crate::ln::msgs
 
 use {io, io_extras};
@@ -244,7 +244,7 @@ mod sealed {
                        ///
                        /// See [BOLT #9] for details.
                        ///
-                       /// [BOLT #9]: https://github.com/lightningnetwork/lightning-rfc/blob/master/09-features.md
+                       /// [BOLT #9]: https://github.com/lightning/bolts/blob/master/09-features.md
                        pub trait $feature: Context {
                                /// The bit used to signify that the feature is required.
                                const EVEN_BIT: usize = $odd_bit - 1;
index d59ae68d26191cf3a731192888decc524fc4e7b0..316ef7774ce125d364d6048376270deb1992c9ed 100644 (file)
@@ -599,13 +599,14 @@ macro_rules! check_added_monitors {
        }
 }
 
-pub fn create_funding_transaction<'a, 'b, 'c>(node: &Node<'a, 'b, 'c>, expected_chan_value: u64, expected_user_chan_id: u64) -> ([u8; 32], Transaction, OutPoint) {
+pub fn create_funding_transaction<'a, 'b, 'c>(node: &Node<'a, 'b, 'c>, expected_counterparty_node_id: &PublicKey, expected_chan_value: u64, expected_user_chan_id: u64) -> ([u8; 32], Transaction, OutPoint) {
        let chan_id = *node.network_chan_count.borrow();
 
        let events = node.node.get_and_clear_pending_events();
        assert_eq!(events.len(), 1);
        match events[0] {
-               Event::FundingGenerationReady { ref temporary_channel_id, ref channel_value_satoshis, ref output_script, user_channel_id } => {
+               Event::FundingGenerationReady { ref temporary_channel_id, ref counterparty_node_id, ref channel_value_satoshis, ref output_script, user_channel_id } => {
+                       assert_eq!(counterparty_node_id, expected_counterparty_node_id);
                        assert_eq!(*channel_value_satoshis, expected_chan_value);
                        assert_eq!(user_channel_id, expected_user_chan_id);
 
@@ -619,10 +620,10 @@ pub fn create_funding_transaction<'a, 'b, 'c>(node: &Node<'a, 'b, 'c>, expected_
        }
 }
 pub fn sign_funding_transaction<'a, 'b, 'c>(node_a: &Node<'a, 'b, 'c>, node_b: &Node<'a, 'b, 'c>, channel_value: u64, expected_temporary_channel_id: [u8; 32]) -> Transaction {
-       let (temporary_channel_id, tx, funding_output) = create_funding_transaction(node_a, channel_value, 42);
+       let (temporary_channel_id, tx, funding_output) = create_funding_transaction(node_a, &node_b.node.get_our_node_id(), channel_value, 42);
        assert_eq!(temporary_channel_id, expected_temporary_channel_id);
 
-       assert!(node_a.node.funding_transaction_generated(&temporary_channel_id, tx.clone()).is_ok());
+       assert!(node_a.node.funding_transaction_generated(&temporary_channel_id, &node_b.node.get_our_node_id(), tx.clone()).is_ok());
        check_added_monitors!(node_a, 0);
 
        let funding_created_msg = get_event_msg!(node_a, MessageSendEvent::SendFundingCreated, node_b.node.get_our_node_id());
@@ -651,7 +652,7 @@ pub fn sign_funding_transaction<'a, 'b, 'c>(node_a: &Node<'a, 'b, 'c>, node_b: &
        node_a.tx_broadcaster.txn_broadcasted.lock().unwrap().clear();
 
        // Ensure that funding_transaction_generated is idempotent.
-       assert!(node_a.node.funding_transaction_generated(&temporary_channel_id, tx.clone()).is_err());
+       assert!(node_a.node.funding_transaction_generated(&temporary_channel_id, &node_b.node.get_our_node_id(), tx.clone()).is_err());
        assert!(node_a.node.get_and_clear_pending_msg_events().is_empty());
        check_added_monitors!(node_a, 0);
 
@@ -768,8 +769,8 @@ pub fn create_unannounced_chan_between_nodes_with_value<'a, 'b, 'c, 'd>(nodes: &
        let accept_channel = get_event_msg!(nodes[b], MessageSendEvent::SendAcceptChannel, nodes[a].node.get_our_node_id());
        nodes[a].node.handle_accept_channel(&nodes[b].node.get_our_node_id(), b_flags, &accept_channel);
 
-       let (temporary_channel_id, tx, _) = create_funding_transaction(&nodes[a], channel_value, 42);
-       nodes[a].node.funding_transaction_generated(&temporary_channel_id, tx.clone()).unwrap();
+       let (temporary_channel_id, tx, _) = create_funding_transaction(&nodes[a], &nodes[b].node.get_our_node_id(), channel_value, 42);
+       nodes[a].node.funding_transaction_generated(&temporary_channel_id, &nodes[b].node.get_our_node_id(), tx.clone()).unwrap();
        nodes[b].node.handle_funding_created(&nodes[a].node.get_our_node_id(), &get_event_msg!(nodes[a], MessageSendEvent::SendFundingCreated, nodes[b].node.get_our_node_id()));
        check_added_monitors!(nodes[b], 1);
 
@@ -1014,7 +1015,7 @@ pub fn close_channel<'a, 'b, 'c>(outbound_node: &Node<'a, 'b, 'c>, inbound_node:
        let (node_b, broadcaster_b, struct_b) = if close_inbound_first { (&outbound_node.node, &outbound_node.tx_broadcaster, outbound_node) } else { (&inbound_node.node, &inbound_node.tx_broadcaster, inbound_node) };
        let (tx_a, tx_b);
 
-       node_a.close_channel(channel_id).unwrap();
+       node_a.close_channel(channel_id, &node_b.get_our_node_id()).unwrap();
        node_b.handle_shutdown(&node_a.get_our_node_id(), &InitFeatures::known(), &get_event_msg!(struct_a, MessageSendEvent::SendShutdown, node_b.get_our_node_id()));
 
        let events_1 = node_b.get_and_clear_pending_msg_events();
@@ -1380,15 +1381,20 @@ macro_rules! expect_payment_path_successful {
 }
 
 macro_rules! expect_payment_forwarded {
-       ($node: expr, $source_node: expr, $expected_fee: expr, $upstream_force_closed: expr) => {
+       ($node: expr, $prev_node: expr, $next_node: expr, $expected_fee: expr, $upstream_force_closed: expr, $downstream_force_closed: expr) => {
                let events = $node.node.get_and_clear_pending_events();
                assert_eq!(events.len(), 1);
                match events[0] {
-                       Event::PaymentForwarded { fee_earned_msat, source_channel_id, claim_from_onchain_tx } => {
+                       Event::PaymentForwarded { fee_earned_msat, prev_channel_id, claim_from_onchain_tx, next_channel_id } => {
                                assert_eq!(fee_earned_msat, $expected_fee);
                                if fee_earned_msat.is_some() {
-                                       // Is the event channel_id in one of the channels between the two nodes?
-                                       assert!($node.node.list_channels().iter().any(|x| x.counterparty.node_id == $source_node.node.get_our_node_id() && x.channel_id == source_channel_id.unwrap()));
+                                       // Is the event prev_channel_id in one of the channels between the two nodes?
+                                       assert!($node.node.list_channels().iter().any(|x| x.counterparty.node_id == $prev_node.node.get_our_node_id() && x.channel_id == prev_channel_id.unwrap()));
+                               }
+                               // We check for force closures since a force closed channel is removed from the
+                               // node's channel list
+                               if !$downstream_force_closed {
+                                       assert!($node.node.list_channels().iter().any(|x| x.counterparty.node_id == $next_node.node.get_our_node_id() && x.channel_id == next_channel_id.unwrap()));
                                }
                                assert_eq!(claim_from_onchain_tx, $upstream_force_closed);
                        },
@@ -1632,7 +1638,7 @@ pub fn do_claim_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>,
                                {
                                        $node.node.handle_update_fulfill_htlc(&$prev_node.node.get_our_node_id(), &next_msgs.as_ref().unwrap().0);
                                        let fee = $node.node.channel_state.lock().unwrap().by_id.get(&next_msgs.as_ref().unwrap().0.channel_id).unwrap().config.forwarding_fee_base_msat;
-                                       expect_payment_forwarded!($node, $next_node, Some(fee as u64), false);
+                                       expect_payment_forwarded!($node, $next_node, $prev_node, Some(fee as u64), false, false);
                                        expected_total_fee_msat += fee as u64;
                                        check_added_monitors!($node, 1);
                                        let new_next_msgs = if $new_msgs {
index 2d52d56aaed46745ce30c5e0de5284b3b1a8efbd..2a095b90f2e62a3e4f7c0bc44ecce67763b7fbfb 100644 (file)
@@ -515,10 +515,10 @@ fn do_test_sanity_on_in_flight_opens(steps: u8) {
        if steps & 0x0f == 2 { return; }
        nodes[0].node.handle_accept_channel(&nodes[1].node.get_our_node_id(), InitFeatures::known(), &accept_channel);
 
-       let (temporary_channel_id, tx, funding_output) = create_funding_transaction(&nodes[0], 100000, 42);
+       let (temporary_channel_id, tx, funding_output) = create_funding_transaction(&nodes[0], &nodes[1].node.get_our_node_id(), 100000, 42);
 
        if steps & 0x0f == 3 { return; }
-       nodes[0].node.funding_transaction_generated(&temporary_channel_id, tx.clone()).unwrap();
+       nodes[0].node.funding_transaction_generated(&temporary_channel_id, &nodes[1].node.get_our_node_id(), tx.clone()).unwrap();
        check_added_monitors!(nodes[0], 0);
        let funding_created = get_event_msg!(nodes[0], MessageSendEvent::SendFundingCreated, nodes[1].node.get_our_node_id());
 
@@ -2196,7 +2196,7 @@ fn channel_monitor_network_test() {
        send_payment(&nodes[0], &vec!(&nodes[1], &nodes[2], &nodes[3], &nodes[4])[..], 8000000);
 
        // Simple case with no pending HTLCs:
-       nodes[1].node.force_close_channel(&chan_1.2).unwrap();
+       nodes[1].node.force_close_channel(&chan_1.2, &nodes[0].node.get_our_node_id()).unwrap();
        check_added_monitors!(nodes[1], 1);
        check_closed_broadcast!(nodes[1], true);
        {
@@ -2217,7 +2217,7 @@ fn channel_monitor_network_test() {
 
        // Simple case of one pending HTLC to HTLC-Timeout (note that the HTLC-Timeout is not
        // broadcasted until we reach the timelock time).
-       nodes[1].node.force_close_channel(&chan_2.2).unwrap();
+       nodes[1].node.force_close_channel(&chan_2.2, &nodes[2].node.get_our_node_id()).unwrap();
        check_closed_broadcast!(nodes[1], true);
        check_added_monitors!(nodes[1], 1);
        {
@@ -2256,7 +2256,7 @@ fn channel_monitor_network_test() {
 
        // nodes[3] gets the preimage, but nodes[2] already disconnected, resulting in a nodes[2]
        // HTLC-Timeout and a nodes[3] claim against it (+ its own announces)
-       nodes[2].node.force_close_channel(&chan_3.2).unwrap();
+       nodes[2].node.force_close_channel(&chan_3.2, &nodes[3].node.get_our_node_id()).unwrap();
        check_added_monitors!(nodes[2], 1);
        check_closed_broadcast!(nodes[2], true);
        let node2_commitment_txid;
@@ -2723,18 +2723,20 @@ fn test_htlc_on_chain_success() {
        }
        let chan_id = Some(chan_1.2);
        match forwarded_events[1] {
-               Event::PaymentForwarded { fee_earned_msat, source_channel_id, claim_from_onchain_tx } => {
+               Event::PaymentForwarded { fee_earned_msat, prev_channel_id, claim_from_onchain_tx, next_channel_id } => {
                        assert_eq!(fee_earned_msat, Some(1000));
-                       assert_eq!(source_channel_id, chan_id);
+                       assert_eq!(prev_channel_id, chan_id);
                        assert_eq!(claim_from_onchain_tx, true);
+                       assert_eq!(next_channel_id, Some(chan_2.2));
                },
                _ => panic!()
        }
        match forwarded_events[2] {
-               Event::PaymentForwarded { fee_earned_msat, source_channel_id, claim_from_onchain_tx } => {
+               Event::PaymentForwarded { fee_earned_msat, prev_channel_id, claim_from_onchain_tx, next_channel_id } => {
                        assert_eq!(fee_earned_msat, Some(1000));
-                       assert_eq!(source_channel_id, chan_id);
+                       assert_eq!(prev_channel_id, chan_id);
                        assert_eq!(claim_from_onchain_tx, true);
+                       assert_eq!(next_channel_id, Some(chan_2.2));
                },
                _ => panic!()
        }
@@ -3377,7 +3379,7 @@ fn test_htlc_ignore_latest_remote_commitment() {
        create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
 
        route_payment(&nodes[0], &[&nodes[1]], 10000000);
-       nodes[0].node.force_close_channel(&nodes[0].node.list_channels()[0].channel_id).unwrap();
+       nodes[0].node.force_close_channel(&nodes[0].node.list_channels()[0].channel_id, &nodes[1].node.get_our_node_id()).unwrap();
        connect_blocks(&nodes[0], TEST_FINAL_CLTV + LATENCY_GRACE_PERIOD_BLOCKS + 1);
        check_closed_broadcast!(nodes[0], true);
        check_added_monitors!(nodes[0], 1);
@@ -3440,7 +3442,7 @@ fn test_force_close_fail_back() {
        // state or updated nodes[1]' state. Now force-close and broadcast that commitment/HTLC
        // transaction and ensure nodes[1] doesn't fail-backwards (this was originally a bug!).
 
-       nodes[2].node.force_close_channel(&payment_event.commitment_msg.channel_id).unwrap();
+       nodes[2].node.force_close_channel(&payment_event.commitment_msg.channel_id, &nodes[1].node.get_our_node_id()).unwrap();
        check_closed_broadcast!(nodes[2], true);
        check_added_monitors!(nodes[2], 1);
        check_closed_event!(nodes[2], 1, ClosureReason::HolderForceClosed);
@@ -3521,10 +3523,10 @@ fn test_peer_disconnected_before_funding_broadcasted() {
        let accept_channel = get_event_msg!(nodes[1], MessageSendEvent::SendAcceptChannel, nodes[0].node.get_our_node_id());
        nodes[0].node.handle_accept_channel(&nodes[1].node.get_our_node_id(), InitFeatures::known(), &accept_channel);
 
-       let (temporary_channel_id, tx, _funding_output) = create_funding_transaction(&nodes[0], 1_000_000, 42);
+       let (temporary_channel_id, tx, _funding_output) = create_funding_transaction(&nodes[0], &nodes[1].node.get_our_node_id(), 1_000_000, 42);
        assert_eq!(temporary_channel_id, expected_temporary_channel_id);
 
-       assert!(nodes[0].node.funding_transaction_generated(&temporary_channel_id, tx.clone()).is_ok());
+       assert!(nodes[0].node.funding_transaction_generated(&temporary_channel_id, &nodes[1].node.get_our_node_id(), tx.clone()).is_ok());
 
        let funding_created_msg = get_event_msg!(nodes[0], MessageSendEvent::SendFundingCreated, nodes[1].node.get_our_node_id());
        assert_eq!(funding_created_msg.temporary_channel_id, expected_temporary_channel_id);
@@ -4441,9 +4443,9 @@ fn test_manager_serialize_deserialize_events() {
        node_b.node.handle_open_channel(&node_a.node.get_our_node_id(), a_flags, &get_event_msg!(node_a, MessageSendEvent::SendOpenChannel, node_b.node.get_our_node_id()));
        node_a.node.handle_accept_channel(&node_b.node.get_our_node_id(), b_flags, &get_event_msg!(node_b, MessageSendEvent::SendAcceptChannel, node_a.node.get_our_node_id()));
 
-       let (temporary_channel_id, tx, funding_output) = create_funding_transaction(&node_a, channel_value, 42);
+       let (temporary_channel_id, tx, funding_output) = create_funding_transaction(&node_a, &node_b.node.get_our_node_id(), channel_value, 42);
 
-       node_a.node.funding_transaction_generated(&temporary_channel_id, tx.clone()).unwrap();
+       node_a.node.funding_transaction_generated(&temporary_channel_id, &node_b.node.get_our_node_id(), tx.clone()).unwrap();
        check_added_monitors!(node_a, 0);
 
        node_b.node.handle_funding_created(&node_a.node.get_our_node_id(), &get_event_msg!(node_a, MessageSendEvent::SendFundingCreated, node_b.node.get_our_node_id()));
@@ -4764,7 +4766,7 @@ fn test_claim_sizeable_push_msat() {
        let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
 
        let chan = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 100_000, 98_000_000, InitFeatures::known(), InitFeatures::known());
-       nodes[1].node.force_close_channel(&chan.2).unwrap();
+       nodes[1].node.force_close_channel(&chan.2, &nodes[0].node.get_our_node_id()).unwrap();
        check_closed_broadcast!(nodes[1], true);
        check_added_monitors!(nodes[1], 1);
        check_closed_event!(nodes[1], 1, ClosureReason::HolderForceClosed);
@@ -4793,7 +4795,7 @@ fn test_claim_on_remote_sizeable_push_msat() {
        let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
 
        let chan = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 100_000, 98_000_000, InitFeatures::known(), InitFeatures::known());
-       nodes[0].node.force_close_channel(&chan.2).unwrap();
+       nodes[0].node.force_close_channel(&chan.2, &nodes[1].node.get_our_node_id()).unwrap();
        check_closed_broadcast!(nodes[0], true);
        check_added_monitors!(nodes[0], 1);
        check_closed_event!(nodes[0], 1, ClosureReason::HolderForceClosed);
@@ -5195,10 +5197,11 @@ fn test_onchain_to_onchain_claim() {
                _ => panic!("Unexpected event"),
        }
        match events[1] {
-               Event::PaymentForwarded { fee_earned_msat, source_channel_id, claim_from_onchain_tx } => {
+               Event::PaymentForwarded { fee_earned_msat, prev_channel_id, claim_from_onchain_tx, next_channel_id } => {
                        assert_eq!(fee_earned_msat, Some(1000));
-                       assert_eq!(source_channel_id, Some(chan_1.2));
+                       assert_eq!(prev_channel_id, Some(chan_1.2));
                        assert_eq!(claim_from_onchain_tx, true);
+                       assert_eq!(next_channel_id, Some(chan_2.2));
                },
                _ => panic!("Unexpected event"),
        }
@@ -5374,7 +5377,7 @@ fn test_duplicate_payment_hash_one_failure_one_success() {
        // Note that the fee paid is effectively double as the HTLC value (including the nodes[1] fee
        // and nodes[2] fee) is rounded down and then claimed in full.
        mine_transaction(&nodes[1], &htlc_success_txn[0]);
-       expect_payment_forwarded!(nodes[1], nodes[0], Some(196*2), true);
+       expect_payment_forwarded!(nodes[1], nodes[0], nodes[2], Some(196*2), true, true);
        let updates = get_htlc_update_msgs!(nodes[1], nodes[0].node.get_our_node_id());
        assert!(updates.update_add_htlcs.is_empty());
        assert!(updates.update_fail_htlcs.is_empty());
@@ -8318,7 +8321,7 @@ fn test_manually_accept_inbound_channel_request() {
        let events = nodes[1].node.get_and_clear_pending_events();
        match events[0] {
                Event::OpenChannelRequest { temporary_channel_id, .. } => {
-                       nodes[1].node.accept_inbound_channel(&temporary_channel_id, 23).unwrap();
+                       nodes[1].node.accept_inbound_channel(&temporary_channel_id, &nodes[0].node.get_our_node_id(), 23).unwrap();
                }
                _ => panic!("Unexpected event"),
        }
@@ -8333,7 +8336,7 @@ fn test_manually_accept_inbound_channel_request() {
                _ => panic!("Unexpected event"),
        }
 
-       nodes[1].node.force_close_channel(&temp_channel_id).unwrap();
+       nodes[1].node.force_close_channel(&temp_channel_id, &nodes[0].node.get_our_node_id()).unwrap();
 
        let close_msg_ev = nodes[1].node.get_and_clear_pending_msg_events();
        assert_eq!(close_msg_ev.len(), 1);
@@ -8368,7 +8371,7 @@ fn test_manually_reject_inbound_channel_request() {
        let events = nodes[1].node.get_and_clear_pending_events();
        match events[0] {
                Event::OpenChannelRequest { temporary_channel_id, .. } => {
-                       nodes[1].node.force_close_channel(&temporary_channel_id).unwrap();
+                       nodes[1].node.force_close_channel(&temporary_channel_id, &nodes[0].node.get_our_node_id()).unwrap();
                }
                _ => panic!("Unexpected event"),
        }
@@ -8423,9 +8426,9 @@ fn test_reject_funding_before_inbound_channel_accepted() {
                nodes[0].node.handle_accept_channel(&nodes[1].node.get_our_node_id(), InitFeatures::known(), &accept_chan_msg);
        }
 
-       let (temporary_channel_id, tx, _) = create_funding_transaction(&nodes[0], 100000, 42);
+       let (temporary_channel_id, tx, _) = create_funding_transaction(&nodes[0], &nodes[1].node.get_our_node_id(), 100000, 42);
 
-       nodes[0].node.funding_transaction_generated(&temporary_channel_id, tx.clone()).unwrap();
+       nodes[0].node.funding_transaction_generated(&temporary_channel_id, &nodes[1].node.get_our_node_id(), tx.clone()).unwrap();
        let funding_created_msg = get_event_msg!(nodes[0], MessageSendEvent::SendFundingCreated, nodes[1].node.get_our_node_id());
 
        // The `funding_created_msg` should be rejected by `nodes[1]` as it hasn't accepted the channel
@@ -8468,8 +8471,8 @@ fn test_can_not_accept_inbound_channel_twice() {
        let events = nodes[1].node.get_and_clear_pending_events();
        match events[0] {
                Event::OpenChannelRequest { temporary_channel_id, .. } => {
-                       nodes[1].node.accept_inbound_channel(&temporary_channel_id, 0).unwrap();
-                       let api_res = nodes[1].node.accept_inbound_channel(&temporary_channel_id, 0);
+                       nodes[1].node.accept_inbound_channel(&temporary_channel_id, &nodes[0].node.get_our_node_id(), 0).unwrap();
+                       let api_res = nodes[1].node.accept_inbound_channel(&temporary_channel_id, &nodes[0].node.get_our_node_id(), 0);
                        match api_res {
                                Err(APIError::APIMisuseError { err }) => {
                                        assert_eq!(err, "The channel isn't currently awaiting to be accepted.");
@@ -8495,13 +8498,13 @@ fn test_can_not_accept_inbound_channel_twice() {
 
 #[test]
 fn test_can_not_accept_unknown_inbound_channel() {
-       let chanmon_cfg = create_chanmon_cfgs(1);
-       let node_cfg = create_node_cfgs(1, &chanmon_cfg);
-       let node_chanmgr = create_node_chanmgrs(1, &node_cfg, &[None]);
-       let node = create_network(1, &node_cfg, &node_chanmgr)[0].node;
+       let chanmon_cfg = create_chanmon_cfgs(2);
+       let node_cfg = create_node_cfgs(2, &chanmon_cfg);
+       let node_chanmgr = create_node_chanmgrs(2, &node_cfg, &[None, None]);
+       let nodes = create_network(2, &node_cfg, &node_chanmgr);
 
        let unknown_channel_id = [0; 32];
-       let api_res = node.accept_inbound_channel(&unknown_channel_id, 0);
+       let api_res = nodes[0].node.accept_inbound_channel(&unknown_channel_id, &nodes[1].node.get_our_node_id(), 0);
        match api_res {
                Err(APIError::ChannelUnavailable { err }) => {
                        assert_eq!(err, "Can't accept a channel that doesn't exist");
@@ -8912,9 +8915,9 @@ fn test_pre_lockin_no_chan_closed_update() {
        nodes[0].node.handle_accept_channel(&nodes[1].node.get_our_node_id(), InitFeatures::known(), &accept_chan_msg);
 
        // Move the first channel through the funding flow...
-       let (temporary_channel_id, tx, _) = create_funding_transaction(&nodes[0], 100000, 42);
+       let (temporary_channel_id, tx, _) = create_funding_transaction(&nodes[0], &nodes[1].node.get_our_node_id(), 100000, 42);
 
-       nodes[0].node.funding_transaction_generated(&temporary_channel_id, tx.clone()).unwrap();
+       nodes[0].node.funding_transaction_generated(&temporary_channel_id, &nodes[1].node.get_our_node_id(), tx.clone()).unwrap();
        check_added_monitors!(nodes[0], 0);
 
        let funding_created_msg = get_event_msg!(nodes[0], MessageSendEvent::SendFundingCreated, nodes[1].node.get_our_node_id());
@@ -9012,8 +9015,14 @@ fn do_test_onchain_htlc_settlement_after_close(broadcast_alice: bool, go_onchain
        // If `go_onchain_before_fufill`, broadcast the relevant commitment transaction and check that Bob
        // responds by (1) broadcasting a channel update and (2) adding a new ChannelMonitor.
        let mut force_closing_node = 0; // Alice force-closes
-       if !broadcast_alice { force_closing_node = 1; } // Bob force-closes
-       nodes[force_closing_node].node.force_close_channel(&chan_ab.2).unwrap();
+       let mut counterparty_node = 1; // Bob if Alice force-closes
+
+       // Bob force-closes
+       if !broadcast_alice {
+               force_closing_node = 1;
+               counterparty_node = 0;
+       }
+       nodes[force_closing_node].node.force_close_channel(&chan_ab.2, &nodes[counterparty_node].node.get_our_node_id()).unwrap();
        check_closed_broadcast!(nodes[force_closing_node], true);
        check_added_monitors!(nodes[force_closing_node], 1);
        check_closed_event!(nodes[force_closing_node], 1, ClosureReason::HolderForceClosed);
@@ -9047,7 +9056,7 @@ fn do_test_onchain_htlc_settlement_after_close(broadcast_alice: bool, go_onchain
        assert_eq!(carol_updates.update_fulfill_htlcs.len(), 1);
 
        nodes[1].node.handle_update_fulfill_htlc(&nodes[2].node.get_our_node_id(), &carol_updates.update_fulfill_htlcs[0]);
-       expect_payment_forwarded!(nodes[1], nodes[0], if go_onchain_before_fulfill || force_closing_node == 1 { None } else { Some(1000) }, false);
+       expect_payment_forwarded!(nodes[1], nodes[0], nodes[2], if go_onchain_before_fulfill || force_closing_node == 1 { None } else { Some(1000) }, false, false);
        // If Alice broadcasted but Bob doesn't know yet, here he prepares to tell her about the preimage.
        if !go_onchain_before_fulfill && broadcast_alice {
                let events = nodes[1].node.get_and_clear_pending_msg_events();
@@ -9201,9 +9210,9 @@ fn test_duplicate_chan_id() {
        }
 
        // Move the first channel through the funding flow...
-       let (temporary_channel_id, tx, funding_output) = create_funding_transaction(&nodes[0], 100000, 42);
+       let (temporary_channel_id, tx, funding_output) = create_funding_transaction(&nodes[0], &nodes[1].node.get_our_node_id(), 100000, 42);
 
-       nodes[0].node.funding_transaction_generated(&temporary_channel_id, tx.clone()).unwrap();
+       nodes[0].node.funding_transaction_generated(&temporary_channel_id, &nodes[1].node.get_our_node_id(), tx.clone()).unwrap();
        check_added_monitors!(nodes[0], 0);
 
        let mut funding_created_msg = get_event_msg!(nodes[0], MessageSendEvent::SendFundingCreated, nodes[1].node.get_our_node_id());
@@ -9246,7 +9255,7 @@ fn test_duplicate_chan_id() {
        let open_chan_2_msg = get_event_msg!(nodes[0], MessageSendEvent::SendOpenChannel, nodes[1].node.get_our_node_id());
        nodes[1].node.handle_open_channel(&nodes[0].node.get_our_node_id(), InitFeatures::known(), &open_chan_2_msg);
        nodes[0].node.handle_accept_channel(&nodes[1].node.get_our_node_id(), InitFeatures::known(), &get_event_msg!(nodes[1], MessageSendEvent::SendAcceptChannel, nodes[0].node.get_our_node_id()));
-       create_funding_transaction(&nodes[0], 100000, 42); // Get and check the FundingGenerationReady event
+       create_funding_transaction(&nodes[0], &nodes[1].node.get_our_node_id(), 100000, 42); // Get and check the FundingGenerationReady event
 
        let funding_created = {
                let mut a_channel_lock = nodes[0].node.channel_state.lock().unwrap();
@@ -9379,13 +9388,13 @@ fn test_invalid_funding_tx() {
        nodes[1].node.handle_open_channel(&nodes[0].node.get_our_node_id(), InitFeatures::known(), &get_event_msg!(nodes[0], MessageSendEvent::SendOpenChannel, nodes[1].node.get_our_node_id()));
        nodes[0].node.handle_accept_channel(&nodes[1].node.get_our_node_id(), InitFeatures::known(), &get_event_msg!(nodes[1], MessageSendEvent::SendAcceptChannel, nodes[0].node.get_our_node_id()));
 
-       let (temporary_channel_id, mut tx, _) = create_funding_transaction(&nodes[0], 100_000, 42);
+       let (temporary_channel_id, mut tx, _) = create_funding_transaction(&nodes[0], &nodes[1].node.get_our_node_id(), 100_000, 42);
        for output in tx.output.iter_mut() {
                // Make the confirmed funding transaction have a bogus script_pubkey
                output.script_pubkey = bitcoin::Script::new();
        }
 
-       nodes[0].node.funding_transaction_generated_unchecked(&temporary_channel_id, tx.clone(), 0).unwrap();
+       nodes[0].node.funding_transaction_generated_unchecked(&temporary_channel_id, &nodes[1].node.get_our_node_id(), tx.clone(), 0).unwrap();
        nodes[1].node.handle_funding_created(&nodes[0].node.get_our_node_id(), &get_event_msg!(nodes[0], MessageSendEvent::SendFundingCreated, nodes[1].node.get_our_node_id()));
        check_added_monitors!(nodes[1], 1);
 
@@ -9443,7 +9452,7 @@ fn do_test_tx_confirmed_skipping_blocks_immediate_broadcast(test_height_before_t
        nodes[1].node.peer_disconnected(&nodes[2].node.get_our_node_id(), false);
        nodes[2].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
 
-       nodes[1].node.force_close_channel(&channel_id).unwrap();
+       nodes[1].node.force_close_channel(&channel_id, &nodes[2].node.get_our_node_id()).unwrap();
        check_closed_broadcast!(nodes[1], true);
        check_closed_event!(nodes[1], 1, ClosureReason::HolderForceClosed);
        check_added_monitors!(nodes[1], 1);
@@ -9915,7 +9924,7 @@ fn do_test_max_dust_htlc_exposure(dust_outbound_balance: bool, exposure_breach_e
 
        let opt_anchors = false;
 
-       let (temporary_channel_id, tx, _) = create_funding_transaction(&nodes[0], 1_000_000, 42);
+       let (temporary_channel_id, tx, _) = create_funding_transaction(&nodes[0], &nodes[1].node.get_our_node_id(), 1_000_000, 42);
 
        if on_holder_tx {
                if let Some(mut chan) = nodes[0].node.channel_state.lock().unwrap().by_id.get_mut(&temporary_channel_id) {
@@ -9923,7 +9932,7 @@ fn do_test_max_dust_htlc_exposure(dust_outbound_balance: bool, exposure_breach_e
                }
        }
 
-       nodes[0].node.funding_transaction_generated(&temporary_channel_id, tx.clone()).unwrap();
+       nodes[0].node.funding_transaction_generated(&temporary_channel_id, &nodes[1].node.get_our_node_id(), tx.clone()).unwrap();
        nodes[1].node.handle_funding_created(&nodes[0].node.get_our_node_id(), &get_event_msg!(nodes[0], MessageSendEvent::SendFundingCreated, nodes[1].node.get_our_node_id()));
        check_added_monitors!(nodes[1], 1);
 
index 6fa0aab9d311fb2ed5f9dff09ac3b487b2298051..7e1eec96b1ebd6e53d1aefc9c714d111694c1c9e 100644 (file)
@@ -20,6 +20,7 @@ use util::events::{Event, MessageSendEvent, MessageSendEventsProvider, ClosureRe
 use bitcoin::blockdata::script::Builder;
 use bitcoin::blockdata::opcodes;
 use bitcoin::secp256k1::Secp256k1;
+use bitcoin::Transaction;
 
 use prelude::*;
 
@@ -82,6 +83,17 @@ fn chanmon_fail_from_stale_commitment() {
        expect_payment_failed_with_update!(nodes[0], payment_hash, false, update_a.contents.short_channel_id, true);
 }
 
+fn test_spendable_output<'a, 'b, 'c, 'd>(node: &'a Node<'b, 'c, 'd>, spendable_tx: &Transaction) {
+       let mut spendable = node.chain_monitor.chain_monitor.get_and_clear_pending_events();
+       assert_eq!(spendable.len(), 1);
+       if let Event::SpendableOutputs { outputs } = spendable.pop().unwrap() {
+               assert_eq!(outputs.len(), 1);
+               let spend_tx = node.keys_manager.backing.spend_spendable_outputs(&[&outputs[0]], Vec::new(),
+                       Builder::new().push_opcode(opcodes::all::OP_RETURN).into_script(), 253, &Secp256k1::new()).unwrap();
+               check_spends!(spend_tx, spendable_tx);
+       } else { panic!(); }
+}
+
 #[test]
 fn chanmon_claim_value_coop_close() {
        // Tests `get_claimable_balances` returns the correct values across a simple cooperative claim.
@@ -108,7 +120,7 @@ fn chanmon_claim_value_coop_close() {
        assert_eq!(vec![Balance::ClaimableOnChannelClose { claimable_amount_satoshis: 1_000, }],
                nodes[1].chain_monitor.chain_monitor.get_monitor(funding_outpoint).unwrap().get_claimable_balances());
 
-       nodes[0].node.close_channel(&chan_id).unwrap();
+       nodes[0].node.close_channel(&chan_id, &nodes[1].node.get_our_node_id()).unwrap();
        let node_0_shutdown = get_event_msg!(nodes[0], MessageSendEvent::SendShutdown, nodes[1].node.get_our_node_id());
        nodes[1].node.handle_shutdown(&nodes[0].node.get_our_node_id(), &InitFeatures::known(), &node_0_shutdown);
        let node_1_shutdown = get_event_msg!(nodes[1], MessageSendEvent::SendShutdown, nodes[0].node.get_our_node_id());
@@ -155,23 +167,9 @@ fn chanmon_claim_value_coop_close() {
        assert_eq!(Vec::<Balance>::new(),
                nodes[1].chain_monitor.chain_monitor.get_monitor(funding_outpoint).unwrap().get_claimable_balances());
 
-       let mut node_a_spendable = nodes[0].chain_monitor.chain_monitor.get_and_clear_pending_events();
-       assert_eq!(node_a_spendable.len(), 1);
-       if let Event::SpendableOutputs { outputs } = node_a_spendable.pop().unwrap() {
-               assert_eq!(outputs.len(), 1);
-               let spend_tx = nodes[0].keys_manager.backing.spend_spendable_outputs(&[&outputs[0]], Vec::new(),
-                       Builder::new().push_opcode(opcodes::all::OP_RETURN).into_script(), 253, &Secp256k1::new()).unwrap();
-               check_spends!(spend_tx, shutdown_tx[0]);
-       }
+       test_spendable_output(&nodes[0], &shutdown_tx[0]);
+       test_spendable_output(&nodes[1], &shutdown_tx[0]);
 
-       let mut node_b_spendable = nodes[1].chain_monitor.chain_monitor.get_and_clear_pending_events();
-       assert_eq!(node_b_spendable.len(), 1);
-       if let Event::SpendableOutputs { outputs } = node_b_spendable.pop().unwrap() {
-               assert_eq!(outputs.len(), 1);
-               let spend_tx = nodes[1].keys_manager.backing.spend_spendable_outputs(&[&outputs[0]], Vec::new(),
-                       Builder::new().push_opcode(opcodes::all::OP_RETURN).into_script(), 253, &Secp256k1::new()).unwrap();
-               check_spends!(spend_tx, shutdown_tx[0]);
-       }
        check_closed_event!(nodes[0], 1, ClosureReason::CooperativeClosure);
        check_closed_event!(nodes[1], 1, ClosureReason::CooperativeClosure);
 }
@@ -384,15 +382,7 @@ fn do_test_claim_value_force_close(prev_commitment_tx: bool) {
                }]),
                sorted_vec(nodes[1].chain_monitor.chain_monitor.get_monitor(funding_outpoint).unwrap().get_claimable_balances()));
 
-       let mut node_a_spendable = nodes[0].chain_monitor.chain_monitor.get_and_clear_pending_events();
-       assert_eq!(node_a_spendable.len(), 1);
-       if let Event::SpendableOutputs { outputs } = node_a_spendable.pop().unwrap() {
-               assert_eq!(outputs.len(), 1);
-               let spend_tx = nodes[0].keys_manager.backing.spend_spendable_outputs(&[&outputs[0]], Vec::new(),
-                       Builder::new().push_opcode(opcodes::all::OP_RETURN).into_script(), 253, &Secp256k1::new()).unwrap();
-               check_spends!(spend_tx, remote_txn[0]);
-       }
-
+       test_spendable_output(&nodes[0], &remote_txn[0]);
        assert!(nodes[1].chain_monitor.chain_monitor.get_and_clear_pending_events().is_empty());
 
        // After broadcasting the HTLC claim transaction, node A will still consider the HTLC
@@ -449,14 +439,7 @@ fn do_test_claim_value_force_close(prev_commitment_tx: bool) {
                nodes[0].chain_monitor.chain_monitor.get_monitor(funding_outpoint).unwrap().get_claimable_balances());
        expect_payment_failed!(nodes[0], timeout_payment_hash, true);
 
-       let mut node_a_spendable = nodes[0].chain_monitor.chain_monitor.get_and_clear_pending_events();
-       assert_eq!(node_a_spendable.len(), 1);
-       if let Event::SpendableOutputs { outputs } = node_a_spendable.pop().unwrap() {
-               assert_eq!(outputs.len(), 1);
-               let spend_tx = nodes[0].keys_manager.backing.spend_spendable_outputs(&[&outputs[0]], Vec::new(),
-                       Builder::new().push_opcode(opcodes::all::OP_RETURN).into_script(), 253, &Secp256k1::new()).unwrap();
-               check_spends!(spend_tx, a_broadcast_txn[2]);
-       } else { panic!(); }
+       test_spendable_output(&nodes[0], &a_broadcast_txn[2]);
 
        // Node B will no longer consider the HTLC "contentious" after the HTLC claim transaction
        // confirms, and consider it simply "awaiting confirmations". Note that it has to wait for the
@@ -479,15 +462,7 @@ fn do_test_claim_value_force_close(prev_commitment_tx: bool) {
        // After reaching the commitment output CSV, we'll get a SpendableOutputs event for it and have
        // only the HTLCs claimable on node B.
        connect_blocks(&nodes[1], node_b_commitment_claimable - nodes[1].best_block_info().1);
-
-       let mut node_b_spendable = nodes[1].chain_monitor.chain_monitor.get_and_clear_pending_events();
-       assert_eq!(node_b_spendable.len(), 1);
-       if let Event::SpendableOutputs { outputs } = node_b_spendable.pop().unwrap() {
-               assert_eq!(outputs.len(), 1);
-               let spend_tx = nodes[1].keys_manager.backing.spend_spendable_outputs(&[&outputs[0]], Vec::new(),
-                       Builder::new().push_opcode(opcodes::all::OP_RETURN).into_script(), 253, &Secp256k1::new()).unwrap();
-               check_spends!(spend_tx, remote_txn[0]);
-       }
+       test_spendable_output(&nodes[1], &remote_txn[0]);
 
        assert_eq!(sorted_vec(vec![Balance::ClaimableAwaitingConfirmations {
                        claimable_amount_satoshis: 3_000,
@@ -501,15 +476,7 @@ fn do_test_claim_value_force_close(prev_commitment_tx: bool) {
        // After reaching the claimed HTLC output CSV, we'll get a SpendableOutptus event for it and
        // have only one HTLC output left spendable.
        connect_blocks(&nodes[1], node_b_htlc_claimable - nodes[1].best_block_info().1);
-
-       let mut node_b_spendable = nodes[1].chain_monitor.chain_monitor.get_and_clear_pending_events();
-       assert_eq!(node_b_spendable.len(), 1);
-       if let Event::SpendableOutputs { outputs } = node_b_spendable.pop().unwrap() {
-               assert_eq!(outputs.len(), 1);
-               let spend_tx = nodes[1].keys_manager.backing.spend_spendable_outputs(&[&outputs[0]], Vec::new(),
-                       Builder::new().push_opcode(opcodes::all::OP_RETURN).into_script(), 253, &Secp256k1::new()).unwrap();
-               check_spends!(spend_tx, b_broadcast_txn[0]);
-       } else { panic!(); }
+       test_spendable_output(&nodes[1], &b_broadcast_txn[0]);
 
        assert_eq!(vec![Balance::ContentiousClaimable {
                        claimable_amount_satoshis: 4_000,
@@ -704,25 +671,11 @@ fn test_balances_on_local_commitment_htlcs() {
                        confirmation_height: node_a_htlc_claimable,
                }],
                nodes[0].chain_monitor.chain_monitor.get_monitor(funding_outpoint).unwrap().get_claimable_balances());
-       let mut node_a_spendable = nodes[0].chain_monitor.chain_monitor.get_and_clear_pending_events();
-       assert_eq!(node_a_spendable.len(), 1);
-       if let Event::SpendableOutputs { outputs } = node_a_spendable.pop().unwrap() {
-               assert_eq!(outputs.len(), 1);
-               let spend_tx = nodes[0].keys_manager.backing.spend_spendable_outputs(&[&outputs[0]], Vec::new(),
-                       Builder::new().push_opcode(opcodes::all::OP_RETURN).into_script(), 253, &Secp256k1::new()).unwrap();
-               check_spends!(spend_tx, as_txn[0]);
-       }
+       test_spendable_output(&nodes[0], &as_txn[0]);
 
        // Connect blocks until the HTLC-Timeout's CSV expires, providing us the relevant
        // `SpendableOutputs` event and removing the claimable balance entry.
        connect_blocks(&nodes[0], node_a_htlc_claimable - nodes[0].best_block_info().1);
        assert!(nodes[0].chain_monitor.chain_monitor.get_monitor(funding_outpoint).unwrap().get_claimable_balances().is_empty());
-       let mut node_a_spendable = nodes[0].chain_monitor.chain_monitor.get_and_clear_pending_events();
-       assert_eq!(node_a_spendable.len(), 1);
-       if let Event::SpendableOutputs { outputs } = node_a_spendable.pop().unwrap() {
-               assert_eq!(outputs.len(), 1);
-               let spend_tx = nodes[0].keys_manager.backing.spend_spendable_outputs(&[&outputs[0]], Vec::new(),
-                       Builder::new().push_opcode(opcodes::all::OP_RETURN).into_script(), 253, &Secp256k1::new()).unwrap();
-               check_spends!(spend_tx, as_txn[1]);
-       }
+       test_spendable_output(&nodes[0], &as_txn[1]);
 }
index 281a2a8e977123158f5daeba45ff8666c96ec403..5aa48c32bf07686aba4adad53b9615a6cdd91bb1 100644 (file)
@@ -630,7 +630,10 @@ pub struct UnsignedChannelUpdate {
        pub fee_base_msat: u32,
        /// The amount to fee multiplier, in micro-satoshi
        pub fee_proportional_millionths: u32,
-       pub(crate) excess_data: Vec<u8>,
+       /// Excess data which was signed as a part of the message which we do not (yet) understand how
+       /// to decode. This is stored to ensure forward-compatibility as new fields are added to the
+       /// lightning gossip
+       pub excess_data: Vec<u8>,
 }
 /// A channel_update message to be sent or received from a peer
 #[derive(Clone, Debug, PartialEq)]
index 088ad855683327d1c89bd6aa963aa8417d807a41..8f5c9d6ae6a21a4d1ffb2d9f86302225b3dfa498 100644 (file)
@@ -495,7 +495,7 @@ fn do_retry_with_no_persist(confirm_before_reload: bool) {
        let bs_htlc_claim_txn = nodes[1].tx_broadcaster.txn_broadcasted.lock().unwrap().split_off(0);
        assert_eq!(bs_htlc_claim_txn.len(), 1);
        check_spends!(bs_htlc_claim_txn[0], as_commitment_tx);
-       expect_payment_forwarded!(nodes[1], nodes[0], None, false);
+       expect_payment_forwarded!(nodes[1], nodes[0], nodes[2], None, false, false);
 
        if !confirm_before_reload {
                mine_transaction(&nodes[0], &as_commitment_tx);
@@ -570,7 +570,7 @@ fn do_test_dup_htlc_onchain_fails_on_reload(persist_manager_post_event: bool, co
        // Route a payment, but force-close the channel before the HTLC fulfill message arrives at
        // nodes[0].
        let (payment_preimage, payment_hash, _) = route_payment(&nodes[0], &[&nodes[1]], 10000000);
-       nodes[0].node.force_close_channel(&nodes[0].node.list_channels()[0].channel_id).unwrap();
+       nodes[0].node.force_close_channel(&nodes[0].node.list_channels()[0].channel_id, &nodes[1].node.get_our_node_id()).unwrap();
        check_closed_broadcast!(nodes[0], true);
        check_added_monitors!(nodes[0], 1);
        check_closed_event!(nodes[0], 1, ClosureReason::HolderForceClosed);
index da6dc67aba5c4f52ec9e9dcd819aebf131d33084..29fa84505fcefc55a255b842f05fd5eca84bed65 100644 (file)
@@ -25,8 +25,8 @@ use util::crypto::hkdf_extract_expand_twice;
 use bitcoin::hashes::hex::ToHex;
 
 /// Maximum Lightning message data length according to
-/// [BOLT-8](https://github.com/lightningnetwork/lightning-rfc/blob/v1.0/08-transport.md#lightning-message-specification)
-/// and [BOLT-1](https://github.com/lightningnetwork/lightning-rfc/blob/master/01-messaging.md#lightning-message-format):
+/// [BOLT-8](https://github.com/lightning/bolts/blob/v1.0/08-transport.md#lightning-message-specification)
+/// and [BOLT-1](https://github.com/lightning/bolts/blob/master/01-messaging.md#lightning-message-format):
 pub const LN_MAX_MSG_LEN: usize = ::core::u16::MAX as usize; // Must be equal to 65535
 
 // Sha256("Noise_XK_secp256k1_ChaChaPoly_SHA256")
@@ -80,7 +80,6 @@ enum NoiseState {
 }
 
 pub struct PeerChannelEncryptor {
-       secp_ctx: Secp256k1<secp256k1::SignOnly>,
        their_node_id: Option<PublicKey>, // filled in for outbound, or inbound after noise_state is Finished
 
        noise_state: NoiseState,
@@ -88,8 +87,6 @@ pub struct PeerChannelEncryptor {
 
 impl PeerChannelEncryptor {
        pub fn new_outbound(their_node_id: PublicKey, ephemeral_key: SecretKey) -> PeerChannelEncryptor {
-               let secp_ctx = Secp256k1::signing_only();
-
                let mut sha = Sha256::engine();
                sha.input(&NOISE_H);
                sha.input(&their_node_id.serialize()[..]);
@@ -97,7 +94,6 @@ impl PeerChannelEncryptor {
 
                PeerChannelEncryptor {
                        their_node_id: Some(their_node_id),
-                       secp_ctx,
                        noise_state: NoiseState::InProgress {
                                state: NoiseStep::PreActOne,
                                directional_state: DirectionalNoiseState::Outbound {
@@ -111,9 +107,7 @@ impl PeerChannelEncryptor {
                }
        }
 
-       pub fn new_inbound(our_node_secret: &SecretKey) -> PeerChannelEncryptor {
-               let secp_ctx = Secp256k1::signing_only();
-
+       pub fn new_inbound<C: secp256k1::Signing>(our_node_secret: &SecretKey, secp_ctx: &Secp256k1<C>) -> PeerChannelEncryptor {
                let mut sha = Sha256::engine();
                sha.input(&NOISE_H);
                let our_node_id = PublicKey::from_secret_key(&secp_ctx, our_node_secret);
@@ -122,7 +116,6 @@ impl PeerChannelEncryptor {
 
                PeerChannelEncryptor {
                        their_node_id: None,
-                       secp_ctx,
                        noise_state: NoiseState::InProgress {
                                state: NoiseStep::PreActOne,
                                directional_state: DirectionalNoiseState::Inbound {
@@ -224,7 +217,7 @@ impl PeerChannelEncryptor {
                Ok((their_pub, temp_k))
        }
 
-       pub fn get_act_one(&mut self) -> [u8; 50] {
+       pub fn get_act_one<C: secp256k1::Signing>(&mut self, secp_ctx: &Secp256k1<C>) -> [u8; 50] {
                match self.noise_state {
                        NoiseState::InProgress { ref mut state, ref directional_state, ref mut bidirectional_state } =>
                                match directional_state {
@@ -233,7 +226,7 @@ impl PeerChannelEncryptor {
                                                        panic!("Requested act at wrong step");
                                                }
 
-                                               let (res, _) = PeerChannelEncryptor::outbound_noise_act(&self.secp_ctx, bidirectional_state, &ie, &self.their_node_id.unwrap());
+                                               let (res, _) = PeerChannelEncryptor::outbound_noise_act(secp_ctx, bidirectional_state, &ie, &self.their_node_id.unwrap());
                                                *state = NoiseStep::PostActOne;
                                                res
                                        },
@@ -243,7 +236,9 @@ impl PeerChannelEncryptor {
                }
        }
 
-       pub fn process_act_one_with_keys(&mut self, act_one: &[u8], our_node_secret: &SecretKey, our_ephemeral: SecretKey) -> Result<[u8; 50], LightningError> {
+       pub fn process_act_one_with_keys<C: secp256k1::Signing>(
+               &mut self, act_one: &[u8], our_node_secret: &SecretKey, our_ephemeral: SecretKey, secp_ctx: &Secp256k1<C>)
+       -> Result<[u8; 50], LightningError> {
                assert_eq!(act_one.len(), 50);
 
                match self.noise_state {
@@ -259,7 +254,8 @@ impl PeerChannelEncryptor {
 
                                                re.get_or_insert(our_ephemeral);
 
-                                               let (res, temp_k) = PeerChannelEncryptor::outbound_noise_act(&self.secp_ctx, bidirectional_state, &re.unwrap(), &ie.unwrap());
+                                               let (res, temp_k) =
+                                                       PeerChannelEncryptor::outbound_noise_act(secp_ctx, bidirectional_state, &re.unwrap(), &ie.unwrap());
                                                *temp_k2 = Some(temp_k);
                                                *state = NoiseStep::PostActTwo;
                                                Ok(res)
@@ -270,7 +266,9 @@ impl PeerChannelEncryptor {
                }
        }
 
-       pub fn process_act_two(&mut self, act_two: &[u8], our_node_secret: &SecretKey) -> Result<([u8; 66], PublicKey), LightningError> {
+       pub fn process_act_two<C: secp256k1::Signing>(
+               &mut self, act_two: &[u8], our_node_secret: &SecretKey, secp_ctx: &Secp256k1<C>)
+       -> Result<([u8; 66], PublicKey), LightningError> {
                assert_eq!(act_two.len(), 50);
 
                let final_hkdf;
@@ -286,7 +284,7 @@ impl PeerChannelEncryptor {
                                                let (re, temp_k2) = PeerChannelEncryptor::inbound_noise_act(bidirectional_state, act_two, &ie)?;
 
                                                let mut res = [0; 66];
-                                               let our_node_id = PublicKey::from_secret_key(&self.secp_ctx, &our_node_secret);
+                                               let our_node_id = PublicKey::from_secret_key(secp_ctx, &our_node_secret);
 
                                                PeerChannelEncryptor::encrypt_with_ad(&mut res[1..50], 1, &temp_k2, &bidirectional_state.h, &our_node_id.serialize()[..]);
 
@@ -474,6 +472,7 @@ mod tests {
        use super::LN_MAX_MSG_LEN;
 
        use bitcoin::secp256k1::{PublicKey,SecretKey};
+       use bitcoin::secp256k1::Secp256k1;
 
        use hex;
 
@@ -481,9 +480,10 @@ mod tests {
 
        fn get_outbound_peer_for_initiator_test_vectors() -> PeerChannelEncryptor {
                let their_node_id = PublicKey::from_slice(&hex::decode("028d7500dd4c12685d1f568b4c2b5048e8534b873319f3a8daa612b469132ec7f7").unwrap()[..]).unwrap();
+               let secp_ctx = Secp256k1::signing_only();
 
                let mut outbound_peer = PeerChannelEncryptor::new_outbound(their_node_id, SecretKey::from_slice(&hex::decode("1212121212121212121212121212121212121212121212121212121212121212").unwrap()[..]).unwrap());
-               assert_eq!(outbound_peer.get_act_one()[..], hex::decode("00036360e856310ce5d294e8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df6086551151f58b8afe6c195782c6a").unwrap()[..]);
+               assert_eq!(outbound_peer.get_act_one(&secp_ctx)[..], hex::decode("00036360e856310ce5d294e8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df6086551151f58b8afe6c195782c6a").unwrap()[..]);
                outbound_peer
        }
 
@@ -491,11 +491,12 @@ mod tests {
                // transport-responder successful handshake
                let our_node_id = SecretKey::from_slice(&hex::decode("2121212121212121212121212121212121212121212121212121212121212121").unwrap()[..]).unwrap();
                let our_ephemeral = SecretKey::from_slice(&hex::decode("2222222222222222222222222222222222222222222222222222222222222222").unwrap()[..]).unwrap();
+               let secp_ctx = Secp256k1::signing_only();
 
-               let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id);
+               let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id, &secp_ctx);
 
                let act_one = hex::decode("00036360e856310ce5d294e8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df6086551151f58b8afe6c195782c6a").unwrap().to_vec();
-               assert_eq!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone()).unwrap()[..], hex::decode("0002466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730ae").unwrap()[..]);
+               assert_eq!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone(), &secp_ctx).unwrap()[..], hex::decode("0002466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730ae").unwrap()[..]);
 
                let act_three = hex::decode("00b9e3a702e93e3a9948c2ed6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8fc28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba").unwrap().to_vec();
                // test vector doesn't specify the initiator static key, but it's the same as the one
@@ -520,13 +521,14 @@ mod tests {
        #[test]
        fn noise_initiator_test_vectors() {
                let our_node_id = SecretKey::from_slice(&hex::decode("1111111111111111111111111111111111111111111111111111111111111111").unwrap()[..]).unwrap();
+               let secp_ctx = Secp256k1::signing_only();
 
                {
                        // transport-initiator successful handshake
                        let mut outbound_peer = get_outbound_peer_for_initiator_test_vectors();
 
                        let act_two = hex::decode("0002466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730ae").unwrap().to_vec();
-                       assert_eq!(outbound_peer.process_act_two(&act_two[..], &our_node_id).unwrap().0[..], hex::decode("00b9e3a702e93e3a9948c2ed6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8fc28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba").unwrap()[..]);
+                       assert_eq!(outbound_peer.process_act_two(&act_two[..], &our_node_id, &secp_ctx).unwrap().0[..], hex::decode("00b9e3a702e93e3a9948c2ed6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8fc28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba").unwrap()[..]);
 
                        match outbound_peer.noise_state {
                                NoiseState::Finished { sk, sn, sck, rk, rn, rck } => {
@@ -549,7 +551,7 @@ mod tests {
                        let mut outbound_peer = get_outbound_peer_for_initiator_test_vectors();
 
                        let act_two = hex::decode("0102466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730ae").unwrap().to_vec();
-                       assert!(outbound_peer.process_act_two(&act_two[..], &our_node_id).is_err());
+                       assert!(outbound_peer.process_act_two(&act_two[..], &our_node_id, &secp_ctx).is_err());
                }
 
                {
@@ -557,7 +559,7 @@ mod tests {
                        let mut outbound_peer = get_outbound_peer_for_initiator_test_vectors();
 
                        let act_two = hex::decode("0004466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730ae").unwrap().to_vec();
-                       assert!(outbound_peer.process_act_two(&act_two[..], &our_node_id).is_err());
+                       assert!(outbound_peer.process_act_two(&act_two[..], &our_node_id, &secp_ctx).is_err());
                }
 
                {
@@ -565,7 +567,7 @@ mod tests {
                        let mut outbound_peer = get_outbound_peer_for_initiator_test_vectors();
 
                        let act_two = hex::decode("0002466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730af").unwrap().to_vec();
-                       assert!(outbound_peer.process_act_two(&act_two[..], &our_node_id).is_err());
+                       assert!(outbound_peer.process_act_two(&act_two[..], &our_node_id, &secp_ctx).is_err());
                }
        }
 
@@ -573,6 +575,7 @@ mod tests {
        fn noise_responder_test_vectors() {
                let our_node_id = SecretKey::from_slice(&hex::decode("2121212121212121212121212121212121212121212121212121212121212121").unwrap()[..]).unwrap();
                let our_ephemeral = SecretKey::from_slice(&hex::decode("2222222222222222222222222222222222222222222222222222222222222222").unwrap()[..]).unwrap();
+               let secp_ctx = Secp256k1::signing_only();
 
                {
                        let _ = get_inbound_peer_for_test_vectors();
@@ -583,31 +586,31 @@ mod tests {
                }
                {
                        // transport-responder act1 bad version test
-                       let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id);
+                       let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id, &secp_ctx);
 
                        let act_one = hex::decode("01036360e856310ce5d294e8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df6086551151f58b8afe6c195782c6a").unwrap().to_vec();
-                       assert!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone()).is_err());
+                       assert!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone(), &secp_ctx).is_err());
                }
                {
                        // transport-responder act1 bad key serialization test
-                       let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id);
+                       let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id, &secp_ctx);
 
                        let act_one =hex::decode("00046360e856310ce5d294e8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df6086551151f58b8afe6c195782c6a").unwrap().to_vec();
-                       assert!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone()).is_err());
+                       assert!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone(), &secp_ctx).is_err());
                }
                {
                        // transport-responder act1 bad MAC test
-                       let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id);
+                       let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id, &secp_ctx);
 
                        let act_one = hex::decode("00036360e856310ce5d294e8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df6086551151f58b8afe6c195782c6b").unwrap().to_vec();
-                       assert!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone()).is_err());
+                       assert!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone(), &secp_ctx).is_err());
                }
                {
                        // transport-responder act3 bad version test
-                       let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id);
+                       let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id, &secp_ctx);
 
                        let act_one = hex::decode("00036360e856310ce5d294e8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df6086551151f58b8afe6c195782c6a").unwrap().to_vec();
-                       assert_eq!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone()).unwrap()[..], hex::decode("0002466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730ae").unwrap()[..]);
+                       assert_eq!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone(), &secp_ctx).unwrap()[..], hex::decode("0002466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730ae").unwrap()[..]);
 
                        let act_three = hex::decode("01b9e3a702e93e3a9948c2ed6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8fc28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba").unwrap().to_vec();
                        assert!(inbound_peer.process_act_three(&act_three[..]).is_err());
@@ -618,30 +621,30 @@ mod tests {
                }
                {
                        // transport-responder act3 bad MAC for ciphertext test
-                       let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id);
+                       let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id, &secp_ctx);
 
                        let act_one = hex::decode("00036360e856310ce5d294e8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df6086551151f58b8afe6c195782c6a").unwrap().to_vec();
-                       assert_eq!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone()).unwrap()[..], hex::decode("0002466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730ae").unwrap()[..]);
+                       assert_eq!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone(), &secp_ctx).unwrap()[..], hex::decode("0002466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730ae").unwrap()[..]);
 
                        let act_three = hex::decode("00c9e3a702e93e3a9948c2ed6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8fc28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba").unwrap().to_vec();
                        assert!(inbound_peer.process_act_three(&act_three[..]).is_err());
                }
                {
                        // transport-responder act3 bad rs test
-                       let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id);
+                       let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id, &secp_ctx);
 
                        let act_one = hex::decode("00036360e856310ce5d294e8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df6086551151f58b8afe6c195782c6a").unwrap().to_vec();
-                       assert_eq!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone()).unwrap()[..], hex::decode("0002466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730ae").unwrap()[..]);
+                       assert_eq!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone(), &secp_ctx).unwrap()[..], hex::decode("0002466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730ae").unwrap()[..]);
 
                        let act_three = hex::decode("00bfe3a702e93e3a9948c2ed6e5fd7590a6e1c3a0344cfc9d5b57357049aa2235536ad09a8ee351870c2bb7f78b754a26c6cef79a98d25139c856d7efd252c2ae73c").unwrap().to_vec();
                        assert!(inbound_peer.process_act_three(&act_three[..]).is_err());
                }
                {
                        // transport-responder act3 bad MAC test
-                       let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id);
+                       let mut inbound_peer = PeerChannelEncryptor::new_inbound(&our_node_id, &secp_ctx);
 
                        let act_one = hex::decode("00036360e856310ce5d294e8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df6086551151f58b8afe6c195782c6a").unwrap().to_vec();
-                       assert_eq!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone()).unwrap()[..], hex::decode("0002466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730ae").unwrap()[..]);
+                       assert_eq!(inbound_peer.process_act_one_with_keys(&act_one[..], &our_node_id, our_ephemeral.clone(), &secp_ctx).unwrap()[..], hex::decode("0002466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730ae").unwrap()[..]);
 
                        let act_three = hex::decode("00b9e3a702e93e3a9948c2ed6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8fc28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139bb").unwrap().to_vec();
                        assert!(inbound_peer.process_act_three(&act_three[..]).is_err());
@@ -654,12 +657,13 @@ mod tests {
                // We use the same keys as the initiator and responder test vectors, so we copy those tests
                // here and use them to encrypt.
                let mut outbound_peer = get_outbound_peer_for_initiator_test_vectors();
+               let secp_ctx = Secp256k1::signing_only();
 
                {
                        let our_node_id = SecretKey::from_slice(&hex::decode("1111111111111111111111111111111111111111111111111111111111111111").unwrap()[..]).unwrap();
 
                        let act_two = hex::decode("0002466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac583c9ef6eafca3f730ae").unwrap().to_vec();
-                       assert_eq!(outbound_peer.process_act_two(&act_two[..], &our_node_id).unwrap().0[..], hex::decode("00b9e3a702e93e3a9948c2ed6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8fc28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba").unwrap()[..]);
+                       assert_eq!(outbound_peer.process_act_two(&act_two[..], &our_node_id, &secp_ctx).unwrap().0[..], hex::decode("00b9e3a702e93e3a9948c2ed6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8fc28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba").unwrap()[..]);
 
                        match outbound_peer.noise_state {
                                NoiseState::Finished { sk, sn, sck, rk, rn, rck } => {
index 4b777b0b99446e0d4781e6029592b60ee0e24ff7..f128f6c5259802466cf99f27496b554d7b88f9ea 100644 (file)
@@ -15,7 +15,7 @@
 //! call into the provided message handlers (probably a ChannelManager and NetGraphmsgHandler) with messages
 //! they should handle, and encoding/sending response messages.
 
-use bitcoin::secp256k1::{SecretKey,PublicKey};
+use bitcoin::secp256k1::{self, Secp256k1, SecretKey, PublicKey};
 
 use ln::features::InitFeatures;
 use ln::msgs;
@@ -455,6 +455,7 @@ pub struct PeerManager<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: De
        peer_counter: AtomicCounter,
 
        logger: L,
+       secp_ctx: Secp256k1<secp256k1::SignOnly>
 }
 
 enum MessageHandlingError {
@@ -573,6 +574,10 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                let mut ephemeral_key_midstate = Sha256::engine();
                ephemeral_key_midstate.input(ephemeral_random_data);
 
+               let mut secp_ctx = Secp256k1::signing_only();
+               let ephemeral_hash = Sha256::from_engine(ephemeral_key_midstate.clone()).into_inner();
+               secp_ctx.seeded_randomize(&ephemeral_hash);
+
                PeerManager {
                        message_handler,
                        peers: FairRwLock::new(HashMap::new()),
@@ -584,6 +589,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                        peer_counter: AtomicCounter::new(),
                        logger,
                        custom_message_handler,
+                       secp_ctx,
                }
        }
 
@@ -628,7 +634,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
        /// [`socket_disconnected()`]: PeerManager::socket_disconnected
        pub fn new_outbound_connection(&self, their_node_id: PublicKey, descriptor: Descriptor, remote_network_address: Option<NetAddress>) -> Result<Vec<u8>, PeerHandleError> {
                let mut peer_encryptor = PeerChannelEncryptor::new_outbound(their_node_id.clone(), self.get_ephemeral_key());
-               let res = peer_encryptor.get_act_one().to_vec();
+               let res = peer_encryptor.get_act_one(&self.secp_ctx).to_vec();
                let pending_read_buffer = [0; 50].to_vec(); // Noise act two is 50 bytes
 
                let mut peers = self.peers.write().unwrap();
@@ -675,7 +681,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
        ///
        /// [`socket_disconnected()`]: PeerManager::socket_disconnected
        pub fn new_inbound_connection(&self, descriptor: Descriptor, remote_network_address: Option<NetAddress>) -> Result<(), PeerHandleError> {
-               let peer_encryptor = PeerChannelEncryptor::new_inbound(&self.our_node_secret);
+               let peer_encryptor = PeerChannelEncryptor::new_inbound(&self.our_node_secret, &self.secp_ctx);
                let pending_read_buffer = [0; 50].to_vec(); // Noise act one is 50 bytes
 
                let mut peers = self.peers.write().unwrap();
@@ -940,14 +946,16 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                                                let next_step = peer.channel_encryptor.get_noise_step();
                                                match next_step {
                                                        NextNoiseStep::ActOne => {
-                                                               let act_two = try_potential_handleerror!(peer,
-                                                                       peer.channel_encryptor.process_act_one_with_keys(&peer.pending_read_buffer[..], &self.our_node_secret, self.get_ephemeral_key())).to_vec();
+                                                               let act_two = try_potential_handleerror!(peer, peer.channel_encryptor
+                                                                       .process_act_one_with_keys(&peer.pending_read_buffer[..],
+                                                                               &self.our_node_secret, self.get_ephemeral_key(), &self.secp_ctx)).to_vec();
                                                                peer.pending_outbound_buffer.push_back(act_two);
                                                                peer.pending_read_buffer = [0; 66].to_vec(); // act three is 66 bytes long
                                                        },
                                                        NextNoiseStep::ActTwo => {
                                                                let (act_three, their_node_id) = try_potential_handleerror!(peer,
-                                                                       peer.channel_encryptor.process_act_two(&peer.pending_read_buffer[..], &self.our_node_secret));
+                                                                       peer.channel_encryptor.process_act_two(&peer.pending_read_buffer[..],
+                                                                               &self.our_node_secret, &self.secp_ctx));
                                                                peer.pending_outbound_buffer.push_back(act_three.to_vec());
                                                                peer.pending_read_buffer = [0; 18].to_vec(); // Message length header is 18 bytes
                                                                peer.pending_read_is_header = true;
index d7b940eb90e09dfce8e6557656327440fb987159..47e2fb33e3c7321103350865f97b5f5cc5897d14 100644 (file)
@@ -389,8 +389,8 @@ fn test_inbound_scid_privacy() {
        let accept_channel = get_event_msg!(nodes[2], MessageSendEvent::SendAcceptChannel, nodes[1].node.get_our_node_id());
        nodes[1].node.handle_accept_channel(&nodes[2].node.get_our_node_id(), InitFeatures::known(), &accept_channel);
 
-       let (temporary_channel_id, tx, _) = create_funding_transaction(&nodes[1], 100_000, 42);
-       nodes[1].node.funding_transaction_generated(&temporary_channel_id, tx.clone()).unwrap();
+       let (temporary_channel_id, tx, _) = create_funding_transaction(&nodes[1], &nodes[2].node.get_our_node_id(), 100_000, 42);
+       nodes[1].node.funding_transaction_generated(&temporary_channel_id, &nodes[2].node.get_our_node_id(), tx.clone()).unwrap();
        nodes[2].node.handle_funding_created(&nodes[1].node.get_our_node_id(), &get_event_msg!(nodes[1], MessageSendEvent::SendFundingCreated, nodes[2].node.get_our_node_id()));
        check_added_monitors!(nodes[2], 1);
 
index 616feb1eea1f6a14dddc7f876246976f49eaea01..518cca3a65ec52a90a56ada0af8db86bd2be6e71 100644 (file)
@@ -138,7 +138,7 @@ fn do_test_onchain_htlc_reorg(local_commitment: bool, claim: bool) {
                // ChannelManager only polls chain::Watch::release_pending_monitor_events when we
                // probe it for events, so we probe non-message events here (which should just be the
                // PaymentForwarded event).
-               expect_payment_forwarded!(nodes[1], nodes[0], Some(1000), true);
+               expect_payment_forwarded!(nodes[1], nodes[0], nodes[2], Some(1000), true, true);
        } else {
                // Confirm the timeout tx and check that we fail the HTLC backwards
                let block = Block {
index 7abe3060fa7b666f22fc83048ca330a194e3fdfb..0e25f46d472141fb2f1e7a2f58ada2bbabc25e15 100644 (file)
@@ -16,7 +16,7 @@ use io;
 
 /// A script pubkey for shutting down a channel as defined by [BOLT #2].
 ///
-/// [BOLT #2]: https://github.com/lightningnetwork/lightning-rfc/blob/master/02-peer-protocol.md
+/// [BOLT #2]: https://github.com/lightning/bolts/blob/master/02-peer-protocol.md
 #[derive(Clone, PartialEq)]
 pub struct ShutdownScript(ShutdownScriptImpl);
 
@@ -25,7 +25,7 @@ pub struct ShutdownScript(ShutdownScriptImpl);
 pub struct InvalidShutdownScript {
        /// The script that did not meet the requirements from [BOLT #2].
        ///
-       /// [BOLT #2]: https://github.com/lightningnetwork/lightning-rfc/blob/master/02-peer-protocol.md
+       /// [BOLT #2]: https://github.com/lightning/bolts/blob/master/02-peer-protocol.md
        pub script: Script
 }
 
index 59004547b1489afe21f0a88ae5b71788b1f90fe1..063fa01608679e71011dbb39d89afbdfff1a4774 100644 (file)
@@ -46,7 +46,7 @@ fn pre_funding_lock_shutdown_test() {
        mine_transaction(&nodes[0], &tx);
        mine_transaction(&nodes[1], &tx);
 
-       nodes[0].node.close_channel(&OutPoint { txid: tx.txid(), index: 0 }.to_channel_id()).unwrap();
+       nodes[0].node.close_channel(&OutPoint { txid: tx.txid(), index: 0 }.to_channel_id(), &nodes[1].node.get_our_node_id()).unwrap();
        let node_0_shutdown = get_event_msg!(nodes[0], MessageSendEvent::SendShutdown, nodes[1].node.get_our_node_id());
        nodes[1].node.handle_shutdown(&nodes[0].node.get_our_node_id(), &InitFeatures::known(), &node_0_shutdown);
        let node_1_shutdown = get_event_msg!(nodes[1], MessageSendEvent::SendShutdown, nodes[0].node.get_our_node_id());
@@ -83,7 +83,7 @@ fn updates_shutdown_wait() {
 
        let (payment_preimage, _, _) = route_payment(&nodes[0], &[&nodes[1], &nodes[2]], 100000);
 
-       nodes[0].node.close_channel(&chan_1.2).unwrap();
+       nodes[0].node.close_channel(&chan_1.2, &nodes[1].node.get_our_node_id()).unwrap();
        let node_0_shutdown = get_event_msg!(nodes[0], MessageSendEvent::SendShutdown, nodes[1].node.get_our_node_id());
        nodes[1].node.handle_shutdown(&nodes[0].node.get_our_node_id(), &InitFeatures::known(), &node_0_shutdown);
        let node_1_shutdown = get_event_msg!(nodes[1], MessageSendEvent::SendShutdown, nodes[0].node.get_our_node_id());
@@ -110,7 +110,7 @@ fn updates_shutdown_wait() {
        assert!(updates.update_fee.is_none());
        assert_eq!(updates.update_fulfill_htlcs.len(), 1);
        nodes[1].node.handle_update_fulfill_htlc(&nodes[2].node.get_our_node_id(), &updates.update_fulfill_htlcs[0]);
-       expect_payment_forwarded!(nodes[1], nodes[0], Some(1000), false);
+       expect_payment_forwarded!(nodes[1], nodes[0], nodes[2], Some(1000), false, false);
        check_added_monitors!(nodes[1], 1);
        let updates_2 = get_htlc_update_msgs!(nodes[1], nodes[0].node.get_our_node_id());
        commitment_signed_dance!(nodes[1], nodes[2], updates.commitment_signed, false);
@@ -166,7 +166,7 @@ fn htlc_fail_async_shutdown() {
        assert!(updates.update_fail_malformed_htlcs.is_empty());
        assert!(updates.update_fee.is_none());
 
-       nodes[1].node.close_channel(&chan_1.2).unwrap();
+       nodes[1].node.close_channel(&chan_1.2, &nodes[0].node.get_our_node_id()).unwrap();
        let node_1_shutdown = get_event_msg!(nodes[1], MessageSendEvent::SendShutdown, nodes[0].node.get_our_node_id());
        nodes[0].node.handle_shutdown(&nodes[1].node.get_our_node_id(), &InitFeatures::known(), &node_1_shutdown);
        let node_0_shutdown = get_event_msg!(nodes[0], MessageSendEvent::SendShutdown, nodes[1].node.get_our_node_id());
@@ -232,7 +232,7 @@ fn do_test_shutdown_rebroadcast(recv_count: u8) {
 
        let (payment_preimage, _, _) = route_payment(&nodes[0], &[&nodes[1], &nodes[2]], 100000);
 
-       nodes[1].node.close_channel(&chan_1.2).unwrap();
+       nodes[1].node.close_channel(&chan_1.2, &nodes[0].node.get_our_node_id()).unwrap();
        let node_1_shutdown = get_event_msg!(nodes[1], MessageSendEvent::SendShutdown, nodes[0].node.get_our_node_id());
        if recv_count > 0 {
                nodes[0].node.handle_shutdown(&nodes[1].node.get_our_node_id(), &InitFeatures::known(), &node_1_shutdown);
@@ -279,7 +279,7 @@ fn do_test_shutdown_rebroadcast(recv_count: u8) {
        assert!(updates.update_fee.is_none());
        assert_eq!(updates.update_fulfill_htlcs.len(), 1);
        nodes[1].node.handle_update_fulfill_htlc(&nodes[2].node.get_our_node_id(), &updates.update_fulfill_htlcs[0]);
-       expect_payment_forwarded!(nodes[1], nodes[0], Some(1000), false);
+       expect_payment_forwarded!(nodes[1], nodes[0], nodes[2], Some(1000), false, false);
        check_added_monitors!(nodes[1], 1);
        let updates_2 = get_htlc_update_msgs!(nodes[1], nodes[0].node.get_our_node_id());
        commitment_signed_dance!(nodes[1], nodes[2], updates.commitment_signed, false);
@@ -417,7 +417,7 @@ fn test_upfront_shutdown_script() {
        // We test that in case of peer committing upfront to a script, if it changes at closing, we refuse to sign
        let flags = InitFeatures::known();
        let chan = create_announced_chan_between_nodes_with_value(&nodes, 0, 2, 1000000, 1000000, flags.clone(), flags.clone());
-       nodes[0].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id()).unwrap();
+       nodes[0].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id(), &nodes[2].node.get_our_node_id()).unwrap();
        let node_0_orig_shutdown = get_event_msg!(nodes[0], MessageSendEvent::SendShutdown, nodes[2].node.get_our_node_id());
        let mut node_0_shutdown = node_0_orig_shutdown.clone();
        node_0_shutdown.scriptpubkey = Builder::new().push_opcode(opcodes::all::OP_RETURN).into_script().to_p2sh();
@@ -432,7 +432,7 @@ fn test_upfront_shutdown_script() {
 
        // We test that in case of peer committing upfront to a script, if it doesn't change at closing, we sign
        let chan = create_announced_chan_between_nodes_with_value(&nodes, 0, 2, 1000000, 1000000, flags.clone(), flags.clone());
-       nodes[0].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id()).unwrap();
+       nodes[0].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id(), &nodes[2].node.get_our_node_id()).unwrap();
        let node_0_shutdown = get_event_msg!(nodes[0], MessageSendEvent::SendShutdown, nodes[2].node.get_our_node_id());
        // We test that in case of peer committing upfront to a script, if it oesn't change at closing, we sign
        nodes[2].node.handle_shutdown(&nodes[0].node.get_our_node_id(), &InitFeatures::known(), &node_0_shutdown);
@@ -446,7 +446,7 @@ fn test_upfront_shutdown_script() {
        // We test that if case of peer non-signaling we don't enforce committed script at channel opening
        let flags_no = InitFeatures::known().clear_upfront_shutdown_script();
        let chan = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 1000000, 1000000, flags_no, flags.clone());
-       nodes[0].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id()).unwrap();
+       nodes[0].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id(), &nodes[1].node.get_our_node_id()).unwrap();
        let node_1_shutdown = get_event_msg!(nodes[0], MessageSendEvent::SendShutdown, nodes[1].node.get_our_node_id());
        nodes[1].node.handle_shutdown(&nodes[0].node.get_our_node_id(), &InitFeatures::known(), &node_1_shutdown);
        check_added_monitors!(nodes[1], 1);
@@ -460,7 +460,7 @@ fn test_upfront_shutdown_script() {
        // We test that if user opt-out, we provide a zero-length script at channel opening and we are able to close
        // channel smoothly, opt-out is from channel initiator here
        let chan = create_announced_chan_between_nodes_with_value(&nodes, 1, 0, 1000000, 1000000, flags.clone(), flags.clone());
-       nodes[1].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id()).unwrap();
+       nodes[1].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id(), &nodes[0].node.get_our_node_id()).unwrap();
        check_added_monitors!(nodes[1], 1);
        let node_0_shutdown = get_event_msg!(nodes[1], MessageSendEvent::SendShutdown, nodes[0].node.get_our_node_id());
        nodes[0].node.handle_shutdown(&nodes[1].node.get_our_node_id(), &InitFeatures::known(), &node_0_shutdown);
@@ -474,7 +474,7 @@ fn test_upfront_shutdown_script() {
        //// We test that if user opt-out, we provide a zero-length script at channel opening and we are able to close
        //// channel smoothly
        let chan = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 1000000, 1000000, flags.clone(), flags.clone());
-       nodes[1].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id()).unwrap();
+       nodes[1].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id(), &nodes[0].node.get_our_node_id()).unwrap();
        check_added_monitors!(nodes[1], 1);
        let node_0_shutdown = get_event_msg!(nodes[1], MessageSendEvent::SendShutdown, nodes[0].node.get_our_node_id());
        nodes[0].node.handle_shutdown(&nodes[1].node.get_our_node_id(), &InitFeatures::known(), &node_0_shutdown);
@@ -580,7 +580,7 @@ fn test_segwit_v0_shutdown_script() {
        let nodes = create_network(3, &node_cfgs, &node_chanmgrs);
 
        let chan = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
-       nodes[1].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id()).unwrap();
+       nodes[1].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id(), &nodes[0].node.get_our_node_id()).unwrap();
        check_added_monitors!(nodes[1], 1);
 
        // Use a segwit v0 script supported even without option_shutdown_anysegwit
@@ -615,7 +615,7 @@ fn test_anysegwit_shutdown_script() {
        let nodes = create_network(3, &node_cfgs, &node_chanmgrs);
 
        let chan = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
-       nodes[1].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id()).unwrap();
+       nodes[1].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id(), &nodes[0].node.get_our_node_id()).unwrap();
        check_added_monitors!(nodes[1], 1);
 
        // Use a non-v0 segwit script supported by option_shutdown_anysegwit
@@ -660,14 +660,14 @@ fn test_unsupported_anysegwit_shutdown_script() {
                .expect(OnGetShutdownScriptpubkey { returns: supported_shutdown_script });
 
        let chan = create_announced_chan_between_nodes(&nodes, 0, 1, node_cfgs[0].features.clone(), node_cfgs[1].features.clone());
-       match nodes[1].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id()) {
+       match nodes[1].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id(), &nodes[0].node.get_our_node_id()) {
                Err(APIError::IncompatibleShutdownScript { script }) => {
                        assert_eq!(script.into_inner(), unsupported_shutdown_script.clone().into_inner());
                },
                Err(e) => panic!("Unexpected error: {:?}", e),
                Ok(_) => panic!("Expected error"),
        }
-       nodes[1].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id()).unwrap();
+       nodes[1].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id(), &nodes[0].node.get_our_node_id()).unwrap();
        check_added_monitors!(nodes[1], 1);
 
        // Use a non-v0 segwit script unsupported without option_shutdown_anysegwit
@@ -692,7 +692,7 @@ fn test_invalid_shutdown_script() {
        let nodes = create_network(3, &node_cfgs, &node_chanmgrs);
 
        let chan = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
-       nodes[1].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id()).unwrap();
+       nodes[1].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id(), &nodes[0].node.get_our_node_id()).unwrap();
        check_added_monitors!(nodes[1], 1);
 
        // Use a segwit v0 script with an unsupported witness program
@@ -730,7 +730,7 @@ fn do_test_closing_signed_reinit_timeout(timeout_step: TimeoutStep) {
 
        send_payment(&nodes[0], &[&nodes[1]], 8_000_000);
 
-       nodes[0].node.close_channel(&chan_id).unwrap();
+       nodes[0].node.close_channel(&chan_id, &nodes[1].node.get_our_node_id()).unwrap();
        let node_0_shutdown = get_event_msg!(nodes[0], MessageSendEvent::SendShutdown, nodes[1].node.get_our_node_id());
        nodes[1].node.handle_shutdown(&nodes[0].node.get_our_node_id(), &InitFeatures::known(), &node_0_shutdown);
        let node_1_shutdown = get_event_msg!(nodes[1], MessageSendEvent::SendShutdown, nodes[0].node.get_our_node_id());
@@ -831,7 +831,7 @@ fn do_simple_legacy_shutdown_test(high_initiator_fee: bool) {
                *feerate_lock *= 10;
        }
 
-       nodes[0].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id()).unwrap();
+       nodes[0].node.close_channel(&OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id(), &nodes[1].node.get_our_node_id()).unwrap();
        let node_0_shutdown = get_event_msg!(nodes[0], MessageSendEvent::SendShutdown, nodes[1].node.get_our_node_id());
        nodes[1].node.handle_shutdown(&nodes[0].node.get_our_node_id(), &InitFeatures::known(), &node_0_shutdown);
        let node_1_shutdown = get_event_msg!(nodes[1], MessageSendEvent::SendShutdown, nodes[0].node.get_our_node_id());
@@ -873,9 +873,9 @@ fn simple_target_feerate_shutdown() {
        let chan = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
        let chan_id = OutPoint { txid: chan.3.txid(), index: 0 }.to_channel_id();
 
-       nodes[0].node.close_channel_with_target_feerate(&chan_id, 253 * 10).unwrap();
+       nodes[0].node.close_channel_with_target_feerate(&chan_id, &nodes[1].node.get_our_node_id(), 253 * 10).unwrap();
        let node_0_shutdown = get_event_msg!(nodes[0], MessageSendEvent::SendShutdown, nodes[1].node.get_our_node_id());
-       nodes[1].node.close_channel_with_target_feerate(&chan_id, 253 * 5).unwrap();
+       nodes[1].node.close_channel_with_target_feerate(&chan_id, &nodes[0].node.get_our_node_id(), 253 * 5).unwrap();
        let node_1_shutdown = get_event_msg!(nodes[1], MessageSendEvent::SendShutdown, nodes[0].node.get_our_node_id());
 
        nodes[1].node.handle_shutdown(&nodes[0].node.get_our_node_id(), &InitFeatures::known(), &node_0_shutdown);
index 8fd5c16f36261ce3782d72bb81846b3e5b503376..a0b549452830ecd85f05ba5fa76ab543ae29aa81 100644 (file)
@@ -10,7 +10,7 @@
 //! Wire encoding/decoding for Lightning messages according to [BOLT #1], and for
 //! custom message through the [`CustomMessageReader`] trait.
 //! 
-//! [BOLT #1]: https://github.com/lightningnetwork/lightning-rfc/blob/master/01-messaging.md
+//! [BOLT #1]: https://github.com/lightning/bolts/blob/master/01-messaging.md
 
 use io;
 use ln::msgs;
index 2e0679eba79f6cbbb1473ebdf89a0e113cf8fd5e..26719cc279061f180f5617f204846922271f3f3b 100644 (file)
@@ -67,7 +67,7 @@ impl NodeId {
        pub fn from_pubkey(pubkey: &PublicKey) -> Self {
                NodeId(pubkey.serialize())
        }
-       
+
        /// Get the public key slice from this NodeId
        pub fn as_slice(&self) -> &[u8] {
                &self.0
@@ -150,7 +150,7 @@ pub struct ReadOnlyNetworkGraph<'a> {
 /// Update to the [`NetworkGraph`] based on payment failure information conveyed via the Onion
 /// return packet by a node along the route. See [BOLT #4] for details.
 ///
-/// [BOLT #4]: https://github.com/lightningnetwork/lightning-rfc/blob/master/04-onion-routing.md
+/// [BOLT #4]: https://github.com/lightning/bolts/blob/master/04-onion-routing.md
 #[derive(Clone, Debug, PartialEq)]
 pub enum NetworkUpdate {
        /// An error indicating a `channel_update` messages should be applied via
@@ -695,7 +695,7 @@ impl ChannelInfo {
                                return None;
                        }
                };
-               Some((DirectedChannelInfo { channel: self, direction }, source))
+               Some((DirectedChannelInfo::new(self, direction), source))
        }
 
        /// Returns a [`DirectedChannelInfo`] for the channel directed from the given `source` to a
@@ -710,7 +710,17 @@ impl ChannelInfo {
                                return None;
                        }
                };
-               Some((DirectedChannelInfo { channel: self, direction }, target))
+               Some((DirectedChannelInfo::new(self, direction), target))
+       }
+
+       /// Returns a [`ChannelUpdateInfo`] based on the direction implied by the channel_flag.
+       pub fn get_directional_info(&self, channel_flags: u8) -> Option<&ChannelUpdateInfo> {
+               let direction = channel_flags & 1u8;
+               if direction == 0 {
+                       self.one_to_two.as_ref()
+               } else {
+                       self.two_to_one.as_ref()
+               }
        }
 }
 
@@ -739,35 +749,53 @@ impl_writeable_tlv_based!(ChannelInfo, {
 pub struct DirectedChannelInfo<'a> {
        channel: &'a ChannelInfo,
        direction: Option<&'a ChannelUpdateInfo>,
+       htlc_maximum_msat: u64,
+       effective_capacity: EffectiveCapacity,
 }
 
 impl<'a> DirectedChannelInfo<'a> {
+       #[inline]
+       fn new(channel: &'a ChannelInfo, direction: Option<&'a ChannelUpdateInfo>) -> Self {
+               let htlc_maximum_msat = direction.and_then(|direction| direction.htlc_maximum_msat);
+               let capacity_msat = channel.capacity_sats.map(|capacity_sats| capacity_sats * 1000);
+
+               let (htlc_maximum_msat, effective_capacity) = match (htlc_maximum_msat, capacity_msat) {
+                       (Some(amount_msat), Some(capacity_msat)) => {
+                               let htlc_maximum_msat = cmp::min(amount_msat, capacity_msat);
+                               (htlc_maximum_msat, EffectiveCapacity::Total { capacity_msat })
+                       },
+                       (Some(amount_msat), None) => {
+                               (amount_msat, EffectiveCapacity::MaximumHTLC { amount_msat })
+                       },
+                       (None, Some(capacity_msat)) => {
+                               (capacity_msat, EffectiveCapacity::Total { capacity_msat })
+                       },
+                       (None, None) => (EffectiveCapacity::Unknown.as_msat(), EffectiveCapacity::Unknown),
+               };
+
+               Self {
+                       channel, direction, htlc_maximum_msat, effective_capacity
+               }
+       }
+
        /// Returns information for the channel.
        pub fn channel(&self) -> &'a ChannelInfo { self.channel }
 
        /// Returns information for the direction.
        pub fn direction(&self) -> Option<&'a ChannelUpdateInfo> { self.direction }
 
+       /// Returns the maximum HTLC amount allowed over the channel in the direction.
+       pub fn htlc_maximum_msat(&self) -> u64 {
+               self.htlc_maximum_msat
+       }
+
        /// Returns the [`EffectiveCapacity`] of the channel in the direction.
        ///
        /// This is either the total capacity from the funding transaction, if known, or the
        /// `htlc_maximum_msat` for the direction as advertised by the gossip network, if known,
-       /// whichever is smaller.
+       /// otherwise.
        pub fn effective_capacity(&self) -> EffectiveCapacity {
-               let capacity_msat = self.channel.capacity_sats.map(|capacity_sats| capacity_sats * 1000);
-               self.direction
-                       .and_then(|direction| direction.htlc_maximum_msat)
-                       .map(|max_htlc_msat| {
-                               let capacity_msat = capacity_msat.unwrap_or(u64::max_value());
-                               if max_htlc_msat < capacity_msat {
-                                       EffectiveCapacity::MaximumHTLC { amount_msat: max_htlc_msat }
-                               } else {
-                                       EffectiveCapacity::Total { capacity_msat }
-                               }
-                       })
-                       .or_else(|| capacity_msat.map(|capacity_msat|
-                                       EffectiveCapacity::Total { capacity_msat }))
-                       .unwrap_or(EffectiveCapacity::Unknown)
+               self.effective_capacity
        }
 
        /// Returns `Some` if [`ChannelUpdateInfo`] is available in the direction.
@@ -805,6 +833,10 @@ impl<'a> DirectedChannelInfoWithUpdate<'a> {
        /// Returns the [`EffectiveCapacity`] of the channel in the direction.
        #[inline]
        pub(super) fn effective_capacity(&self) -> EffectiveCapacity { self.inner.effective_capacity() }
+
+       /// Returns the maximum HTLC amount allowed over the channel in the direction.
+       #[inline]
+       pub(super) fn htlc_maximum_msat(&self) -> u64 { self.inner.htlc_maximum_msat() }
 }
 
 impl<'a> fmt::Debug for DirectedChannelInfoWithUpdate<'a> {
@@ -817,6 +849,7 @@ impl<'a> fmt::Debug for DirectedChannelInfoWithUpdate<'a> {
 ///
 /// While this may be smaller than the actual channel capacity, amounts greater than
 /// [`Self::as_msat`] should not be routed through the channel.
+#[derive(Clone, Copy)]
 pub enum EffectiveCapacity {
        /// The available liquidity in the channel known from being a channel counterparty, and thus a
        /// direct hop.
@@ -1132,6 +1165,83 @@ impl NetworkGraph {
                self.update_channel_from_unsigned_announcement_intern(msg, None, chain_access)
        }
 
+       /// Update channel from partial announcement data received via rapid gossip sync
+       ///
+       /// `timestamp: u64`: Timestamp emulating the backdated original announcement receipt (by the
+       /// rapid gossip sync server)
+       ///
+       /// All other parameters as used in [`msgs::UnsignedChannelAnnouncement`] fields.
+       pub fn add_channel_from_partial_announcement(&self, short_channel_id: u64, timestamp: u64, features: ChannelFeatures, node_id_1: PublicKey, node_id_2: PublicKey) -> Result<(), LightningError> {
+               if node_id_1 == node_id_2 {
+                       return Err(LightningError{err: "Channel announcement node had a channel with itself".to_owned(), action: ErrorAction::IgnoreError});
+               };
+
+               let node_1 = NodeId::from_pubkey(&node_id_1);
+               let node_2 = NodeId::from_pubkey(&node_id_2);
+               let channel_info = ChannelInfo {
+                       features,
+                       node_one: node_1.clone(),
+                       one_to_two: None,
+                       node_two: node_2.clone(),
+                       two_to_one: None,
+                       capacity_sats: None,
+                       announcement_message: None,
+                       announcement_received_time: timestamp,
+               };
+
+               self.add_channel_between_nodes(short_channel_id, channel_info, None)
+       }
+
+       fn add_channel_between_nodes(&self, short_channel_id: u64, channel_info: ChannelInfo, utxo_value: Option<u64>) -> Result<(), LightningError> {
+               let mut channels = self.channels.write().unwrap();
+               let mut nodes = self.nodes.write().unwrap();
+
+               let node_id_a = channel_info.node_one.clone();
+               let node_id_b = channel_info.node_two.clone();
+
+               match channels.entry(short_channel_id) {
+                       BtreeEntry::Occupied(mut entry) => {
+                               //TODO: because asking the blockchain if short_channel_id is valid is only optional
+                               //in the blockchain API, we need to handle it smartly here, though it's unclear
+                               //exactly how...
+                               if utxo_value.is_some() {
+                                       // Either our UTXO provider is busted, there was a reorg, or the UTXO provider
+                                       // only sometimes returns results. In any case remove the previous entry. Note
+                                       // that the spec expects us to "blacklist" the node_ids involved, but we can't
+                                       // do that because
+                                       // a) we don't *require* a UTXO provider that always returns results.
+                                       // b) we don't track UTXOs of channels we know about and remove them if they
+                                       //    get reorg'd out.
+                                       // c) it's unclear how to do so without exposing ourselves to massive DoS risk.
+                                       Self::remove_channel_in_nodes(&mut nodes, &entry.get(), short_channel_id);
+                                       *entry.get_mut() = channel_info;
+                               } else {
+                                       return Err(LightningError{err: "Already have knowledge of channel".to_owned(), action: ErrorAction::IgnoreDuplicateGossip});
+                               }
+                       },
+                       BtreeEntry::Vacant(entry) => {
+                               entry.insert(channel_info);
+                       }
+               };
+
+               for current_node_id in [node_id_a, node_id_b].iter() {
+                       match nodes.entry(current_node_id.clone()) {
+                               BtreeEntry::Occupied(node_entry) => {
+                                       node_entry.into_mut().channels.push(short_channel_id);
+                               },
+                               BtreeEntry::Vacant(node_entry) => {
+                                       node_entry.insert(NodeInfo {
+                                               channels: vec!(short_channel_id),
+                                               lowest_inbound_channel_fees: None,
+                                               announcement_info: None,
+                                       });
+                               }
+                       };
+               };
+
+               Ok(())
+       }
+
        fn update_channel_from_unsigned_announcement_intern<C: Deref>(
                &self, msg: &msgs::UnsignedChannelAnnouncement, full_msg: Option<&msgs::ChannelAnnouncement>, chain_access: &Option<C>
        ) -> Result<(), LightningError>
@@ -1180,65 +1290,18 @@ impl NetworkGraph {
                }
 
                let chan_info = ChannelInfo {
-                               features: msg.features.clone(),
-                               node_one: NodeId::from_pubkey(&msg.node_id_1),
-                               one_to_two: None,
-                               node_two: NodeId::from_pubkey(&msg.node_id_2),
-                               two_to_one: None,
-                               capacity_sats: utxo_value,
-                               announcement_message: if msg.excess_data.len() <= MAX_EXCESS_BYTES_FOR_RELAY
-                                       { full_msg.cloned() } else { None },
-                               announcement_received_time,
-                       };
-
-               let mut channels = self.channels.write().unwrap();
-               let mut nodes = self.nodes.write().unwrap();
-               match channels.entry(msg.short_channel_id) {
-                       BtreeEntry::Occupied(mut entry) => {
-                               //TODO: because asking the blockchain if short_channel_id is valid is only optional
-                               //in the blockchain API, we need to handle it smartly here, though it's unclear
-                               //exactly how...
-                               if utxo_value.is_some() {
-                                       // Either our UTXO provider is busted, there was a reorg, or the UTXO provider
-                                       // only sometimes returns results. In any case remove the previous entry. Note
-                                       // that the spec expects us to "blacklist" the node_ids involved, but we can't
-                                       // do that because
-                                       // a) we don't *require* a UTXO provider that always returns results.
-                                       // b) we don't track UTXOs of channels we know about and remove them if they
-                                       //    get reorg'd out.
-                                       // c) it's unclear how to do so without exposing ourselves to massive DoS risk.
-                                       Self::remove_channel_in_nodes(&mut nodes, &entry.get(), msg.short_channel_id);
-                                       *entry.get_mut() = chan_info;
-                               } else {
-                                       return Err(LightningError{err: "Already have knowledge of channel".to_owned(), action: ErrorAction::IgnoreDuplicateGossip});
-                               }
-                       },
-                       BtreeEntry::Vacant(entry) => {
-                               entry.insert(chan_info);
-                       }
+                       features: msg.features.clone(),
+                       node_one: NodeId::from_pubkey(&msg.node_id_1),
+                       one_to_two: None,
+                       node_two: NodeId::from_pubkey(&msg.node_id_2),
+                       two_to_one: None,
+                       capacity_sats: utxo_value,
+                       announcement_message: if msg.excess_data.len() <= MAX_EXCESS_BYTES_FOR_RELAY
+                               { full_msg.cloned() } else { None },
+                       announcement_received_time,
                };
 
-               macro_rules! add_channel_to_node {
-                       ( $node_id: expr ) => {
-                               match nodes.entry($node_id) {
-                                       BtreeEntry::Occupied(node_entry) => {
-                                               node_entry.into_mut().channels.push(msg.short_channel_id);
-                                       },
-                                       BtreeEntry::Vacant(node_entry) => {
-                                               node_entry.insert(NodeInfo {
-                                                       channels: vec!(msg.short_channel_id),
-                                                       lowest_inbound_channel_fees: None,
-                                                       announcement_info: None,
-                                               });
-                                       }
-                               }
-                       };
-               }
-
-               add_channel_to_node!(NodeId::from_pubkey(&msg.node_id_1));
-               add_channel_to_node!(NodeId::from_pubkey(&msg.node_id_2));
-
-               Ok(())
+               self.add_channel_between_nodes(msg.short_channel_id, chan_info, utxo_value)
        }
 
        /// Close a channel if a corresponding HTLC fail was sent.
@@ -1558,7 +1621,7 @@ mod tests {
        use ln::features::{ChannelFeatures, InitFeatures, NodeFeatures};
        use routing::network_graph::{NetGraphMsgHandler, NetworkGraph, NetworkUpdate, MAX_EXCESS_BYTES_FOR_RELAY};
        use ln::msgs::{Init, OptionalField, RoutingMessageHandler, UnsignedNodeAnnouncement, NodeAnnouncement,
-               UnsignedChannelAnnouncement, ChannelAnnouncement, UnsignedChannelUpdate, ChannelUpdate, 
+               UnsignedChannelAnnouncement, ChannelAnnouncement, UnsignedChannelUpdate, ChannelUpdate,
                ReplyChannelRange, QueryChannelRange, QueryShortChannelIds, MAX_VALUE_MSAT};
        use util::test_utils;
        use util::logger::Logger;
index dc094b9b6e0309a8fc51cd01a75a1ad480c3d432..5bb81a033eb697193cde37776c8f9a9dadde04c8 100644 (file)
@@ -17,9 +17,9 @@ use bitcoin::secp256k1::PublicKey;
 use ln::channelmanager::ChannelDetails;
 use ln::features::{ChannelFeatures, InvoiceFeatures, NodeFeatures};
 use ln::msgs::{DecodeError, ErrorAction, LightningError, MAX_VALUE_MSAT};
-use routing::scoring::Score;
+use routing::scoring::{ChannelUsage, Score};
 use routing::network_graph::{DirectedChannelInfoWithUpdate, EffectiveCapacity, NetworkGraph, ReadOnlyNetworkGraph, NodeId, RoutingFees};
-use util::ser::{Writeable, Readable};
+use util::ser::{Writeable, Readable, Writer};
 use util::logger::{Level, Logger};
 use util::chacha20::ChaCha20;
 
@@ -65,11 +65,10 @@ impl_writeable_tlv_based!(RouteHop, {
 #[derive(Clone, Hash, PartialEq, Eq)]
 pub struct Route {
        /// The list of routes taken for a single (potentially-)multi-part payment. The pubkey of the
-       /// last RouteHop in each path must be the same.
-       /// Each entry represents a list of hops, NOT INCLUDING our own, where the last hop is the
-       /// destination. Thus, this must always be at least length one. While the maximum length of any
-       /// given path is variable, keeping the length of any path to less than 20 should currently
-       /// ensure it is viable.
+       /// last RouteHop in each path must be the same. Each entry represents a list of hops, NOT
+       /// INCLUDING our own, where the last hop is the destination. Thus, this must always be at
+       /// least length one. While the maximum length of any given path is variable, keeping the length
+       /// of any path less or equal to 19 should currently ensure it is viable.
        pub paths: Vec<Vec<RouteHop>>,
        /// The `payment_params` parameter passed to [`find_route`].
        /// This is used by `ChannelManager` to track information which may be required for retries,
@@ -152,8 +151,8 @@ impl Readable for Route {
 
 /// Parameters needed to find a [`Route`].
 ///
-/// Passed to [`find_route`] and also provided in [`Event::PaymentPathFailed`] for retrying a failed
-/// payment path.
+/// Passed to [`find_route`] and [`build_route_from_hops`], but also provided in
+/// [`Event::PaymentPathFailed`] for retrying a failed payment path.
 ///
 /// [`Event::PaymentPathFailed`]: crate::util::events::Event::PaymentPathFailed
 #[derive(Clone, Debug)]
@@ -177,9 +176,23 @@ impl_writeable_tlv_based!(RouteParameters, {
 /// Maximum total CTLV difference we allow for a full payment path.
 pub const DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA: u32 = 1008;
 
-/// The median hop CLTV expiry delta currently seen in the network.
+// The median hop CLTV expiry delta currently seen in the network.
 const MEDIAN_HOP_CLTV_EXPIRY_DELTA: u32 = 40;
 
+// During routing, we only consider paths shorter than our maximum length estimate.
+// In the legacy onion format, the maximum number of hops used to be a fixed value of 20.
+// However, in the TLV onion format, there is no fixed maximum length, but the `hop_payloads`
+// field is always 1300 bytes. As the `tlv_payload` for each hop may vary in length, we have to
+// estimate how many hops the route may have so that it actually fits the `hop_payloads` field.
+//
+// We estimate 3+32 (payload length and HMAC) + 2+8 (amt_to_forward) + 2+4 (outgoing_cltv_value) +
+// 2+8 (short_channel_id) = 61 bytes for each intermediate hop and 3+32
+// (payload length and HMAC) + 2+8 (amt_to_forward) + 2+4 (outgoing_cltv_value) + 2+32+8
+// (payment_secret and total_msat) = 93 bytes for the final hop.
+// Since the length of the potentially included `payment_metadata` is unknown to us, we round
+// down from (1300-93) / 61 = 19.78... to arrive at a conservative estimate of 19.
+const MAX_PATH_LENGTH_ESTIMATE: u8 = 19;
+
 /// The recipient of a payment.
 #[derive(Clone, Debug, Hash, PartialEq, Eq)]
 pub struct PaymentParameters {
@@ -327,6 +340,8 @@ struct RouteGraphNode {
        /// All penalties incurred from this hop on the way to the destination, as calculated using
        /// channel scoring.
        path_penalty_msat: u64,
+       /// The number of hops walked up to this node.
+       path_length_to_node: u8,
 }
 
 impl cmp::Ord for RouteGraphNode {
@@ -399,6 +414,16 @@ impl<'a> CandidateRouteHop<'a> {
                }
        }
 
+       fn htlc_maximum_msat(&self) -> u64 {
+               match self {
+                       CandidateRouteHop::FirstHop { details } => details.next_outbound_htlc_limit_msat,
+                       CandidateRouteHop::PublicHop { info, .. } => info.htlc_maximum_msat(),
+                       CandidateRouteHop::PrivateHop { hint } => {
+                               hint.htlc_maximum_msat.unwrap_or(u64::max_value())
+                       },
+               }
+       }
+
        fn fees(&self) -> RoutingFees {
                match self {
                        CandidateRouteHop::FirstHop { .. } => RoutingFees {
@@ -466,7 +491,8 @@ struct PathBuildingHop<'a> {
 
 impl<'a> core::fmt::Debug for PathBuildingHop<'a> {
        fn fmt(&self, f: &mut core::fmt::Formatter) -> Result<(), core::fmt::Error> {
-               f.debug_struct("PathBuildingHop")
+               let mut debug_struct = f.debug_struct("PathBuildingHop");
+               debug_struct
                        .field("node_id", &self.node_id)
                        .field("short_channel_id", &self.candidate.short_channel_id())
                        .field("total_fee_msat", &self.total_fee_msat)
@@ -475,8 +501,11 @@ impl<'a> core::fmt::Debug for PathBuildingHop<'a> {
                        .field("total_fee_msat - (next_hops_fee_msat + hop_use_fee_msat)", &(&self.total_fee_msat - (&self.next_hops_fee_msat + &self.hop_use_fee_msat)))
                        .field("path_penalty_msat", &self.path_penalty_msat)
                        .field("path_htlc_minimum_msat", &self.path_htlc_minimum_msat)
-                       .field("cltv_expiry_delta", &self.candidate.cltv_expiry_delta())
-                       .finish()
+                       .field("cltv_expiry_delta", &self.candidate.cltv_expiry_delta());
+               #[cfg(all(not(feature = "_bench_unstable"), any(test, fuzzing)))]
+               let debug_struct = debug_struct
+                       .field("value_contribution_msat", &self.value_contribution_msat);
+               debug_struct.finish()
        }
 }
 
@@ -647,16 +676,11 @@ pub fn find_route<L: Deref, S: Score>(
 ) -> Result<Route, LightningError>
 where L::Target: Logger {
        let network_graph = network.read_only();
-       match get_route(
-               our_node_pubkey, &route_params.payment_params, &network_graph, first_hops, route_params.final_value_msat,
-               route_params.final_cltv_expiry_delta, logger, scorer, random_seed_bytes
-       ) {
-               Ok(mut route) => {
-                       add_random_cltv_offset(&mut route, &route_params.payment_params, &network_graph, random_seed_bytes);
-                       Ok(route)
-               },
-               Err(err) => Err(err),
-       }
+       let mut route = get_route(our_node_pubkey, &route_params.payment_params, &network_graph, first_hops,
+               route_params.final_value_msat, route_params.final_cltv_expiry_delta, logger, scorer,
+               random_seed_bytes)?;
+       add_random_cltv_offset(&mut route, &route_params.payment_params, &network_graph, random_seed_bytes);
+       Ok(route)
 }
 
 pub(crate) fn get_route<L: Deref, S: Score>(
@@ -796,11 +820,11 @@ where L::Target: Logger {
        // The main heap containing all candidate next-hops sorted by their score (max(A* fee,
        // htlc_minimum)). Ideally this would be a heap which allowed cheap score reduction instead of
        // adding duplicate entries when we find a better path to a given node.
-       let mut targets = BinaryHeap::new();
+       let mut targets: BinaryHeap<RouteGraphNode> = BinaryHeap::new();
 
        // Map from node_id to information about the best current path to that node, including feerate
        // information.
-       let mut dist = HashMap::with_capacity(network_nodes.len());
+       let mut dist: HashMap<NodeId, PathBuildingHop> = HashMap::with_capacity(network_nodes.len());
 
        // During routing, if we ignore a path due to an htlc_minimum_msat limit, we set this,
        // indicating that we may wish to try again with a higher value, potentially paying to meet an
@@ -815,12 +839,12 @@ where L::Target: Logger {
        let recommended_value_msat = final_value_msat * ROUTE_CAPACITY_PROVISION_FACTOR as u64;
        let mut path_value_msat = final_value_msat;
 
-       // We don't want multiple paths (as per MPP) share liquidity of the same channels.
-       // This map allows paths to be aware of the channel use by other paths in the same call.
-       // This would help to make a better path finding decisions and not "overbook" channels.
-       // It is unaware of the directions (except for `next_outbound_htlc_limit_msat` in
-       // `first_hops`).
-       let mut bookkept_channels_liquidity_available_msat = HashMap::with_capacity(network_nodes.len());
+       // Keep track of how much liquidity has been used in selected channels. Used to determine
+       // if the channel can be used by additional MPP paths or to inform path finding decisions. It is
+       // aware of direction *only* to ensure that the correct htlc_maximum_msat value is used. Hence,
+       // liquidity used in one direction will not offset any used in the opposite direction.
+       let mut used_channel_liquidities: HashMap<(u64, bool), u64> =
+               HashMap::with_capacity(network_nodes.len());
 
        // Keeping track of how much value we already collected across other paths. Helps to decide:
        // - how much a new path should be transferring (upper bound);
@@ -860,8 +884,8 @@ where L::Target: Logger {
                // since that value has to be transferred over this channel.
                // Returns whether this channel caused an update to `targets`.
                ( $candidate: expr, $src_node_id: expr, $dest_node_id: expr, $next_hops_fee_msat: expr,
-                  $next_hops_value_contribution: expr, $next_hops_path_htlc_minimum_msat: expr,
-                  $next_hops_path_penalty_msat: expr, $next_hops_cltv_delta: expr ) => { {
+                       $next_hops_value_contribution: expr, $next_hops_path_htlc_minimum_msat: expr,
+                       $next_hops_path_penalty_msat: expr, $next_hops_cltv_delta: expr, $next_hops_path_length: expr ) => { {
                        // We "return" whether we updated the path at the end, via this:
                        let mut did_add_update_path_to_src_node = false;
                        // Channels to self should not be used. This is more of belt-and-suspenders, because in
@@ -870,9 +894,7 @@ where L::Target: Logger {
                        // - for first and last hops early in get_route
                        if $src_node_id != $dest_node_id {
                                let short_channel_id = $candidate.short_channel_id();
-                               let available_liquidity_msat = bookkept_channels_liquidity_available_msat
-                                       .entry(short_channel_id)
-                                       .or_insert_with(|| $candidate.effective_capacity().as_msat());
+                               let htlc_maximum_msat = $candidate.htlc_maximum_msat();
 
                                // It is tricky to subtract $next_hops_fee_msat from available liquidity here.
                                // It may be misleading because we might later choose to reduce the value transferred
@@ -881,7 +903,14 @@ where L::Target: Logger {
                                // fees caused by one expensive channel, but then this channel could have been used
                                // if the amount being transferred over this path is lower.
                                // We do this for now, but this is a subject for removal.
-                               if let Some(available_value_contribution_msat) = available_liquidity_msat.checked_sub($next_hops_fee_msat) {
+                               if let Some(mut available_value_contribution_msat) = htlc_maximum_msat.checked_sub($next_hops_fee_msat) {
+                                       let used_liquidity_msat = used_channel_liquidities
+                                               .get(&(short_channel_id, $src_node_id < $dest_node_id))
+                                               .map_or(0, |used_liquidity_msat| {
+                                                       available_value_contribution_msat = available_value_contribution_msat
+                                                               .saturating_sub(*used_liquidity_msat);
+                                                       *used_liquidity_msat
+                                               });
 
                                        // Routing Fragmentation Mitigation heuristic:
                                        //
@@ -905,6 +934,9 @@ where L::Target: Logger {
                                        };
                                        // Verify the liquidity offered by this channel complies to the minimal contribution.
                                        let contributes_sufficient_value = available_value_contribution_msat >= minimal_value_contribution_msat;
+                                       // Do not consider candidate hops that would exceed the maximum path length.
+                                       let path_length_to_node = $next_hops_path_length + 1;
+                                       let doesnt_exceed_max_path_length = path_length_to_node <= MAX_PATH_LENGTH_ESTIMATE;
 
                                        // Do not consider candidates that exceed the maximum total cltv expiry limit.
                                        // In order to already account for some of the privacy enhancing random CLTV
@@ -939,9 +971,11 @@ where L::Target: Logger {
                                        // bother considering this channel. If retrying with recommended_value_msat may
                                        // allow us to hit the HTLC minimum limit, set htlc_minimum_limit so that we go
                                        // around again with a higher amount.
-                                       if contributes_sufficient_value && doesnt_exceed_cltv_delta_limit && may_overpay_to_meet_path_minimum_msat {
+                                       if contributes_sufficient_value && doesnt_exceed_max_path_length &&
+                                               doesnt_exceed_cltv_delta_limit && may_overpay_to_meet_path_minimum_msat {
                                                hit_minimum_limit = true;
-                                       } else if contributes_sufficient_value && doesnt_exceed_cltv_delta_limit && over_path_minimum_msat {
+                                       } else if contributes_sufficient_value && doesnt_exceed_max_path_length &&
+                                               doesnt_exceed_cltv_delta_limit && over_path_minimum_msat {
                                                // Note that low contribution here (limited by available_liquidity_msat)
                                                // might violate htlc_minimum_msat on the hops which are next along the
                                                // payment path (upstream to the payee). To avoid that, we recompute
@@ -1027,9 +1061,16 @@ where L::Target: Logger {
                                                                }
                                                        }
 
-                                                       let path_penalty_msat = $next_hops_path_penalty_msat.saturating_add(
-                                                               scorer.channel_penalty_msat(short_channel_id, amount_to_transfer_over_msat,
-                                                                       *available_liquidity_msat, &$src_node_id, &$dest_node_id));
+                                                       let channel_usage = ChannelUsage {
+                                                               amount_msat: amount_to_transfer_over_msat,
+                                                               inflight_htlc_msat: used_liquidity_msat,
+                                                               effective_capacity: $candidate.effective_capacity(),
+                                                       };
+                                                       let channel_penalty_msat = scorer.channel_penalty_msat(
+                                                               short_channel_id, &$src_node_id, &$dest_node_id, channel_usage
+                                                       );
+                                                       let path_penalty_msat = $next_hops_path_penalty_msat
+                                                               .saturating_add(channel_penalty_msat);
                                                        let new_graph_node = RouteGraphNode {
                                                                node_id: $src_node_id,
                                                                lowest_fee_to_peer_through_node: total_fee_msat,
@@ -1038,6 +1079,7 @@ where L::Target: Logger {
                                                                value_contribution_msat: value_contribution_msat,
                                                                path_htlc_minimum_msat,
                                                                path_penalty_msat,
+                                                               path_length_to_node,
                                                        };
 
                                                        // Update the way of reaching $src_node_id with the given short_channel_id (from $dest_node_id),
@@ -1121,7 +1163,9 @@ where L::Target: Logger {
        // meaning how much will be paid in fees after this node (to the best of our knowledge).
        // This data can later be helpful to optimize routing (pay lower fees).
        macro_rules! add_entries_to_cheapest_to_target_node {
-               ( $node: expr, $node_id: expr, $fee_to_target_msat: expr, $next_hops_value_contribution: expr, $next_hops_path_htlc_minimum_msat: expr, $next_hops_path_penalty_msat: expr, $next_hops_cltv_delta: expr ) => {
+               ( $node: expr, $node_id: expr, $fee_to_target_msat: expr, $next_hops_value_contribution: expr,
+                 $next_hops_path_htlc_minimum_msat: expr, $next_hops_path_penalty_msat: expr,
+                 $next_hops_cltv_delta: expr, $next_hops_path_length: expr ) => {
                        let skip_node = if let Some(elem) = dist.get_mut(&$node_id) {
                                let was_processed = elem.was_processed;
                                elem.was_processed = true;
@@ -1138,7 +1182,10 @@ where L::Target: Logger {
                                if let Some(first_channels) = first_hop_targets.get(&$node_id) {
                                        for details in first_channels {
                                                let candidate = CandidateRouteHop::FirstHop { details };
-                                               add_entry!(candidate, our_node_id, $node_id, $fee_to_target_msat, $next_hops_value_contribution, $next_hops_path_htlc_minimum_msat, $next_hops_path_penalty_msat, $next_hops_cltv_delta);
+                                               add_entry!(candidate, our_node_id, $node_id, $fee_to_target_msat,
+                                                       $next_hops_value_contribution,
+                                                       $next_hops_path_htlc_minimum_msat, $next_hops_path_penalty_msat,
+                                                       $next_hops_cltv_delta, $next_hops_path_length);
                                        }
                                }
 
@@ -1161,7 +1208,12 @@ where L::Target: Logger {
                                                                                        info: directed_channel.with_update().unwrap(),
                                                                                        short_channel_id: *chan_id,
                                                                                };
-                                                                               add_entry!(candidate, *source, $node_id, $fee_to_target_msat, $next_hops_value_contribution, $next_hops_path_htlc_minimum_msat, $next_hops_path_penalty_msat, $next_hops_cltv_delta);
+                                                                               add_entry!(candidate, *source, $node_id,
+                                                                                       $fee_to_target_msat,
+                                                                                       $next_hops_value_contribution,
+                                                                                       $next_hops_path_htlc_minimum_msat,
+                                                                                       $next_hops_path_penalty_msat,
+                                                                                       $next_hops_cltv_delta, $next_hops_path_length);
                                                                        }
                                                                }
                                                        }
@@ -1176,9 +1228,8 @@ where L::Target: Logger {
 
        // TODO: diversify by nodes (so that all paths aren't doomed if one node is offline).
        'paths_collection: loop {
-               // For every new path, start from scratch, except
-               // bookkept_channels_liquidity_available_msat, which will improve
-               // the further iterations of path finding. Also don't erase first_hop_targets.
+               // For every new path, start from scratch, except for used_channel_liquidities, which
+               // helps to avoid reusing previously selected paths in future iterations.
                targets.clear();
                dist.clear();
                hit_minimum_limit = false;
@@ -1188,8 +1239,10 @@ where L::Target: Logger {
                if let Some(first_channels) = first_hop_targets.get(&payee_node_id) {
                        for details in first_channels {
                                let candidate = CandidateRouteHop::FirstHop { details };
-                               let added = add_entry!(candidate, our_node_id, payee_node_id, 0, path_value_msat, 0, 0u64, 0);
-                               log_trace!(logger, "{} direct route to payee via SCID {}", if added { "Added" } else { "Skipped" }, candidate.short_channel_id());
+                               let added = add_entry!(candidate, our_node_id, payee_node_id, 0, path_value_msat,
+                                                                       0, 0u64, 0, 0);
+                               log_trace!(logger, "{} direct route to payee via SCID {}",
+                                               if added { "Added" } else { "Skipped" }, candidate.short_channel_id());
                        }
                }
 
@@ -1202,7 +1255,7 @@ where L::Target: Logger {
                        // If not, targets.pop() will not even let us enter the loop in step 2.
                        None => {},
                        Some(node) => {
-                               add_entries_to_cheapest_to_target_node!(node, payee_node_id, 0, path_value_msat, 0, 0u64, 0);
+                               add_entries_to_cheapest_to_target_node!(node, payee_node_id, 0, path_value_msat, 0, 0u64, 0, 0);
                        },
                }
 
@@ -1229,6 +1282,7 @@ where L::Target: Logger {
                                let mut aggregate_next_hops_path_htlc_minimum_msat: u64 = 0;
                                let mut aggregate_next_hops_path_penalty_msat: u64 = 0;
                                let mut aggregate_next_hops_cltv_delta: u32 = 0;
+                               let mut aggregate_next_hops_path_length: u8 = 0;
 
                                for (idx, (hop, prev_hop_id)) in hop_iter.zip(prev_hop_iter).enumerate() {
                                        let source = NodeId::from_pubkey(&hop.src_node_id);
@@ -1242,25 +1296,45 @@ where L::Target: Logger {
                                                        short_channel_id: hop.short_channel_id,
                                                })
                                                .unwrap_or_else(|| CandidateRouteHop::PrivateHop { hint: hop });
-                                       let capacity_msat = candidate.effective_capacity().as_msat();
+
+                                       if !add_entry!(candidate, source, target, aggregate_next_hops_fee_msat,
+                                                               path_value_msat, aggregate_next_hops_path_htlc_minimum_msat,
+                                                               aggregate_next_hops_path_penalty_msat,
+                                                               aggregate_next_hops_cltv_delta, aggregate_next_hops_path_length) {
+                                               // If this hop was not used then there is no use checking the preceding
+                                               // hops in the RouteHint. We can break by just searching for a direct
+                                               // channel between last checked hop and first_hop_targets.
+                                               hop_used = false;
+                                       }
+
+                                       let used_liquidity_msat = used_channel_liquidities
+                                               .get(&(hop.short_channel_id, source < target)).copied().unwrap_or(0);
+                                       let channel_usage = ChannelUsage {
+                                               amount_msat: final_value_msat + aggregate_next_hops_fee_msat,
+                                               inflight_htlc_msat: used_liquidity_msat,
+                                               effective_capacity: candidate.effective_capacity(),
+                                       };
+                                       let channel_penalty_msat = scorer.channel_penalty_msat(
+                                               hop.short_channel_id, &source, &target, channel_usage
+                                       );
                                        aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
-                                               .saturating_add(scorer.channel_penalty_msat(hop.short_channel_id, final_value_msat, capacity_msat, &source, &target));
+                                               .saturating_add(channel_penalty_msat);
 
                                        aggregate_next_hops_cltv_delta = aggregate_next_hops_cltv_delta
                                                .saturating_add(hop.cltv_expiry_delta as u32);
 
-                                       if !add_entry!(candidate, source, target, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat, aggregate_next_hops_cltv_delta) {
-                                               // If this hop was not used then there is no use checking the preceding hops
-                                               // in the RouteHint. We can break by just searching for a direct channel between
-                                               // last checked hop and first_hop_targets
-                                               hop_used = false;
-                                       }
+                                       aggregate_next_hops_path_length = aggregate_next_hops_path_length
+                                               .saturating_add(1);
 
                                        // Searching for a direct channel between last checked hop and first_hop_targets
                                        if let Some(first_channels) = first_hop_targets.get(&NodeId::from_pubkey(&prev_hop_id)) {
                                                for details in first_channels {
                                                        let candidate = CandidateRouteHop::FirstHop { details };
-                                                       add_entry!(candidate, our_node_id, NodeId::from_pubkey(&prev_hop_id), aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat, aggregate_next_hops_cltv_delta);
+                                                       add_entry!(candidate, our_node_id, NodeId::from_pubkey(&prev_hop_id),
+                                                               aggregate_next_hops_fee_msat, path_value_msat,
+                                                               aggregate_next_hops_path_htlc_minimum_msat,
+                                                               aggregate_next_hops_path_penalty_msat, aggregate_next_hops_cltv_delta,
+                                                               aggregate_next_hops_path_length);
                                                }
                                        }
 
@@ -1271,7 +1345,7 @@ where L::Target: Logger {
                                        // In the next values of the iterator, the aggregate fees already reflects
                                        // the sum of value sent from payer (final_value_msat) and routing fees
                                        // for the last node in the RouteHint. We need to just add the fees to
-                                       // route through the current node so that the preceeding node (next iteration)
+                                       // route through the current node so that the preceding node (next iteration)
                                        // can use it.
                                        let hops_fee = compute_fees(aggregate_next_hops_fee_msat + final_value_msat, hop.fees)
                                                .map_or(None, |inc| inc.checked_add(aggregate_next_hops_fee_msat));
@@ -1296,7 +1370,13 @@ where L::Target: Logger {
                                                if let Some(first_channels) = first_hop_targets.get(&NodeId::from_pubkey(&hop.src_node_id)) {
                                                        for details in first_channels {
                                                                let candidate = CandidateRouteHop::FirstHop { details };
-                                                               add_entry!(candidate, our_node_id, NodeId::from_pubkey(&hop.src_node_id), aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat, aggregate_next_hops_cltv_delta);
+                                                               add_entry!(candidate, our_node_id,
+                                                                       NodeId::from_pubkey(&hop.src_node_id),
+                                                                       aggregate_next_hops_fee_msat, path_value_msat,
+                                                                       aggregate_next_hops_path_htlc_minimum_msat,
+                                                                       aggregate_next_hops_path_penalty_msat,
+                                                                       aggregate_next_hops_cltv_delta,
+                                                                       aggregate_next_hops_path_length);
                                                        }
                                                }
                                        }
@@ -1319,13 +1399,13 @@ where L::Target: Logger {
                // Both these cases (and other cases except reaching recommended_value_msat) mean that
                // paths_collection will be stopped because found_new_path==false.
                // This is not necessarily a routing failure.
-               'path_construction: while let Some(RouteGraphNode { node_id, lowest_fee_to_node, total_cltv_delta, value_contribution_msat, path_htlc_minimum_msat, path_penalty_msat, .. }) = targets.pop() {
+               'path_construction: while let Some(RouteGraphNode { node_id, lowest_fee_to_node, total_cltv_delta, value_contribution_msat, path_htlc_minimum_msat, path_penalty_msat, path_length_to_node, .. }) = targets.pop() {
 
                        // Since we're going payee-to-payer, hitting our node as a target means we should stop
                        // traversing the graph and arrange the path out of what we found.
                        if node_id == our_node_id {
                                let mut new_entry = dist.remove(&our_node_id).unwrap();
-                               let mut ordered_hops = vec!((new_entry.clone(), default_node_features.clone()));
+                               let mut ordered_hops: Vec<(PathBuildingHop, NodeFeatures)> = vec!((new_entry.clone(), default_node_features.clone()));
 
                                'path_walk: loop {
                                        let mut features_set = false;
@@ -1397,26 +1477,30 @@ where L::Target: Logger {
                                // Remember that we used these channels so that we don't rely
                                // on the same liquidity in future paths.
                                let mut prevented_redundant_path_selection = false;
-                               for (payment_hop, _) in payment_path.hops.iter() {
-                                       let channel_liquidity_available_msat = bookkept_channels_liquidity_available_msat.get_mut(&payment_hop.candidate.short_channel_id()).unwrap();
-                                       let mut spent_on_hop_msat = value_contribution_msat;
-                                       let next_hops_fee_msat = payment_hop.next_hops_fee_msat;
-                                       spent_on_hop_msat += next_hops_fee_msat;
-                                       if spent_on_hop_msat == *channel_liquidity_available_msat {
+                               let prev_hop_iter = core::iter::once(&our_node_id)
+                                       .chain(payment_path.hops.iter().map(|(hop, _)| &hop.node_id));
+                               for (prev_hop, (hop, _)) in prev_hop_iter.zip(payment_path.hops.iter()) {
+                                       let spent_on_hop_msat = value_contribution_msat + hop.next_hops_fee_msat;
+                                       let used_liquidity_msat = used_channel_liquidities
+                                               .entry((hop.candidate.short_channel_id(), *prev_hop < hop.node_id))
+                                               .and_modify(|used_liquidity_msat| *used_liquidity_msat += spent_on_hop_msat)
+                                               .or_insert(spent_on_hop_msat);
+                                       if *used_liquidity_msat == hop.candidate.htlc_maximum_msat() {
                                                // If this path used all of this channel's available liquidity, we know
                                                // this path will not be selected again in the next loop iteration.
                                                prevented_redundant_path_selection = true;
                                        }
-                                       *channel_liquidity_available_msat -= spent_on_hop_msat;
+                                       debug_assert!(*used_liquidity_msat <= hop.candidate.htlc_maximum_msat());
                                }
                                if !prevented_redundant_path_selection {
                                        // If we weren't capped by hitting a liquidity limit on a channel in the path,
                                        // we'll probably end up picking the same path again on the next iteration.
                                        // Decrease the available liquidity of a hop in the middle of the path.
                                        let victim_scid = payment_path.hops[(payment_path.hops.len()) / 2].0.candidate.short_channel_id();
+                                       let exhausted = u64::max_value();
                                        log_trace!(logger, "Disabling channel {} for future path building iterations to avoid duplicates.", victim_scid);
-                                       let victim_liquidity = bookkept_channels_liquidity_available_msat.get_mut(&victim_scid).unwrap();
-                                       *victim_liquidity = 0;
+                                       *used_channel_liquidities.entry((victim_scid, false)).or_default() = exhausted;
+                                       *used_channel_liquidities.entry((victim_scid, true)).or_default() = exhausted;
                                }
 
                                // Track the total amount all our collected paths allow to send so that we:
@@ -1441,7 +1525,9 @@ where L::Target: Logger {
                        match network_nodes.get(&node_id) {
                                None => {},
                                Some(node) => {
-                                       add_entries_to_cheapest_to_target_node!(node, node_id, lowest_fee_to_node, value_contribution_msat, path_htlc_minimum_msat, path_penalty_msat, total_cltv_delta);
+                                       add_entries_to_cheapest_to_target_node!(node, node_id, lowest_fee_to_node,
+                                               value_contribution_msat, path_htlc_minimum_msat, path_penalty_msat,
+                                               total_cltv_delta, path_length_to_node);
                                },
                        }
                }
@@ -1612,7 +1698,9 @@ where L::Target: Logger {
 // destination, if the remaining CLTV expiry delta exactly matches a feasible path in the network
 // graph. In order to improve privacy, this method obfuscates the CLTV expiry deltas along the
 // payment path by adding a randomized 'shadow route' offset to the final hop.
-fn add_random_cltv_offset(route: &mut Route, payment_params: &PaymentParameters, network_graph: &ReadOnlyNetworkGraph, random_seed_bytes: &[u8; 32]) {
+fn add_random_cltv_offset(route: &mut Route, payment_params: &PaymentParameters,
+       network_graph: &ReadOnlyNetworkGraph, random_seed_bytes: &[u8; 32]
+) {
        let network_channels = network_graph.channels();
        let network_nodes = network_graph.nodes();
 
@@ -1694,16 +1782,92 @@ fn add_random_cltv_offset(route: &mut Route, payment_params: &PaymentParameters,
        }
 }
 
+/// Construct a route from us (payer) to the target node (payee) via the given hops (which should
+/// exclude the payer, but include the payee). This may be useful, e.g., for probing the chosen path.
+///
+/// Re-uses logic from `find_route`, so the restrictions described there also apply here.
+pub fn build_route_from_hops<L: Deref>(
+       our_node_pubkey: &PublicKey, hops: &[PublicKey], route_params: &RouteParameters, network: &NetworkGraph,
+       logger: L, random_seed_bytes: &[u8; 32]
+) -> Result<Route, LightningError>
+where L::Target: Logger {
+       let network_graph = network.read_only();
+       let mut route = build_route_from_hops_internal(
+               our_node_pubkey, hops, &route_params.payment_params, &network_graph,
+               route_params.final_value_msat, route_params.final_cltv_expiry_delta, logger, random_seed_bytes)?;
+       add_random_cltv_offset(&mut route, &route_params.payment_params, &network_graph, random_seed_bytes);
+       Ok(route)
+}
+
+fn build_route_from_hops_internal<L: Deref>(
+       our_node_pubkey: &PublicKey, hops: &[PublicKey], payment_params: &PaymentParameters,
+       network_graph: &ReadOnlyNetworkGraph, final_value_msat: u64, final_cltv_expiry_delta: u32,
+       logger: L, random_seed_bytes: &[u8; 32]
+) -> Result<Route, LightningError> where L::Target: Logger {
+
+       struct HopScorer {
+               our_node_id: NodeId,
+               hop_ids: [Option<NodeId>; MAX_PATH_LENGTH_ESTIMATE as usize],
+       }
+
+       impl Score for HopScorer {
+               fn channel_penalty_msat(&self, _short_channel_id: u64, source: &NodeId, target: &NodeId,
+                       _usage: ChannelUsage) -> u64
+               {
+                       let mut cur_id = self.our_node_id;
+                       for i in 0..self.hop_ids.len() {
+                               if let Some(next_id) = self.hop_ids[i] {
+                                       if cur_id == *source && next_id == *target {
+                                               return 0;
+                                       }
+                                       cur_id = next_id;
+                               } else {
+                                       break;
+                               }
+                       }
+                       u64::max_value()
+               }
+
+               fn payment_path_failed(&mut self, _path: &[&RouteHop], _short_channel_id: u64) {}
+
+               fn payment_path_successful(&mut self, _path: &[&RouteHop]) {}
+       }
+
+       impl<'a> Writeable for HopScorer {
+               #[inline]
+               fn write<W: Writer>(&self, _w: &mut W) -> Result<(), io::Error> {
+                       unreachable!();
+               }
+       }
+
+       if hops.len() > MAX_PATH_LENGTH_ESTIMATE.into() {
+               return Err(LightningError{err: "Cannot build a route exceeding the maximum path length.".to_owned(), action: ErrorAction::IgnoreError});
+       }
+
+       let our_node_id = NodeId::from_pubkey(our_node_pubkey);
+       let mut hop_ids = [None; MAX_PATH_LENGTH_ESTIMATE as usize];
+       for i in 0..hops.len() {
+               hop_ids[i] = Some(NodeId::from_pubkey(&hops[i]));
+       }
+
+       let scorer = HopScorer { our_node_id, hop_ids };
+
+       get_route(our_node_pubkey, payment_params, network_graph, None, final_value_msat,
+               final_cltv_expiry_delta, logger, &scorer, random_seed_bytes)
+}
+
 #[cfg(test)]
 mod tests {
        use routing::network_graph::{NetworkGraph, NetGraphMsgHandler, NodeId};
-       use routing::router::{get_route, add_random_cltv_offset, default_node_features, PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees, DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA};
-       use routing::scoring::Score;
+       use routing::router::{get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features,
+               PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees,
+               DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, MAX_PATH_LENGTH_ESTIMATE};
+       use routing::scoring::{ChannelUsage, Score};
        use chain::transaction::OutPoint;
        use chain::keysinterface::KeysInterface;
        use ln::features::{ChannelFeatures, InitFeatures, InvoiceFeatures, NodeFeatures};
        use ln::msgs::{ErrorAction, LightningError, OptionalField, UnsignedChannelAnnouncement, ChannelAnnouncement, RoutingMessageHandler,
-          NodeAnnouncement, UnsignedNodeAnnouncement, ChannelUpdate, UnsignedChannelUpdate};
+               NodeAnnouncement, UnsignedNodeAnnouncement, ChannelUpdate, UnsignedChannelUpdate};
        use ln::channelmanager;
        use util::test_utils;
        use util::chacha20::ChaCha20;
@@ -1836,7 +2000,7 @@ mod tests {
        }
 
        fn get_nodes(secp_ctx: &Secp256k1<All>) -> (SecretKey, PublicKey, Vec<SecretKey>, Vec<PublicKey>) {
-               let privkeys: Vec<SecretKey> = (2..10).map(|i| {
+               let privkeys: Vec<SecretKey> = (2..22).map(|i| {
                        SecretKey::from_slice(&hex::decode(format!("{:02x}", i).repeat(32)).unwrap()[..]).unwrap()
                }).collect();
 
@@ -1863,6 +2027,57 @@ mod tests {
                }
        }
 
+       fn build_line_graph() -> (
+               Secp256k1<All>, sync::Arc<NetworkGraph>, NetGraphMsgHandler<sync::Arc<NetworkGraph>,
+               sync::Arc<test_utils::TestChainSource>, sync::Arc<crate::util::test_utils::TestLogger>>,
+               sync::Arc<test_utils::TestChainSource>, sync::Arc<test_utils::TestLogger>,
+       ) {
+               let secp_ctx = Secp256k1::new();
+               let logger = Arc::new(test_utils::TestLogger::new());
+               let chain_monitor = Arc::new(test_utils::TestChainSource::new(Network::Testnet));
+               let network_graph = Arc::new(NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash()));
+               let net_graph_msg_handler = NetGraphMsgHandler::new(Arc::clone(&network_graph), None, Arc::clone(&logger));
+
+               // Build network from our_id to node 19:
+               // our_id -1(1)2- node0 -1(2)2- node1 - ... - node19
+               let (our_privkey, _, privkeys, _) = get_nodes(&secp_ctx);
+
+               for (idx, (cur_privkey, next_privkey)) in core::iter::once(&our_privkey)
+                       .chain(privkeys.iter()).zip(privkeys.iter()).enumerate() {
+                       let cur_short_channel_id = (idx as u64) + 1;
+                       add_channel(&net_graph_msg_handler, &secp_ctx, &cur_privkey, &next_privkey,
+                               ChannelFeatures::from_le_bytes(id_to_feature_flags(1)), cur_short_channel_id);
+                       update_channel(&net_graph_msg_handler, &secp_ctx, &cur_privkey, UnsignedChannelUpdate {
+                               chain_hash: genesis_block(Network::Testnet).header.block_hash(),
+                               short_channel_id: cur_short_channel_id,
+                               timestamp: idx as u32,
+                               flags: 0,
+                               cltv_expiry_delta: 0,
+                               htlc_minimum_msat: 0,
+                               htlc_maximum_msat: OptionalField::Absent,
+                               fee_base_msat: 0,
+                               fee_proportional_millionths: 0,
+                               excess_data: Vec::new()
+                       });
+                       update_channel(&net_graph_msg_handler, &secp_ctx, &next_privkey, UnsignedChannelUpdate {
+                               chain_hash: genesis_block(Network::Testnet).header.block_hash(),
+                               short_channel_id: cur_short_channel_id,
+                               timestamp: (idx as u32)+1,
+                               flags: 1,
+                               cltv_expiry_delta: 0,
+                               htlc_minimum_msat: 0,
+                               htlc_maximum_msat: OptionalField::Absent,
+                               fee_base_msat: 0,
+                               fee_proportional_millionths: 0,
+                               excess_data: Vec::new()
+                       });
+                       add_or_update_node(&net_graph_msg_handler, &secp_ctx, next_privkey,
+                               NodeFeatures::from_le_bytes(id_to_feature_flags(1)), 0);
+               }
+
+               (secp_ctx, network_graph, net_graph_msg_handler, chain_monitor, logger)
+       }
+
        fn build_graph() -> (
                Secp256k1<All>,
                sync::Arc<NetworkGraph>,
@@ -5039,7 +5254,7 @@ mod tests {
                fn write<W: Writer>(&self, _w: &mut W) -> Result<(), ::io::Error> { unimplemented!() }
        }
        impl Score for BadChannelScorer {
-               fn channel_penalty_msat(&self, short_channel_id: u64, _send_amt: u64, _capacity_msat: u64, _source: &NodeId, _target: &NodeId) -> u64 {
+               fn channel_penalty_msat(&self, short_channel_id: u64, _: &NodeId, _: &NodeId, _: ChannelUsage) -> u64 {
                        if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
                }
 
@@ -5057,7 +5272,7 @@ mod tests {
        }
 
        impl Score for BadNodeScorer {
-               fn channel_penalty_msat(&self, _short_channel_id: u64, _send_amt: u64, _capacity_msat: u64, _source: &NodeId, target: &NodeId) -> u64 {
+               fn channel_penalty_msat(&self, _: u64, _: &NodeId, target: &NodeId, _: ChannelUsage) -> u64 {
                        if *target == self.node_id { u64::max_value() } else { 0 }
                }
 
@@ -5213,6 +5428,35 @@ mod tests {
                }
        }
 
+       #[test]
+       fn limits_path_length() {
+               let (secp_ctx, network, _, _, logger) = build_line_graph();
+               let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
+               let network_graph = network.read_only();
+
+               let scorer = test_utils::TestScorer::with_penalty(0);
+               let keys_manager = test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
+               let random_seed_bytes = keys_manager.get_secure_random_bytes();
+
+               // First check we can actually create a long route on this graph.
+               let feasible_payment_params = PaymentParameters::from_node_id(nodes[18]);
+               let route = get_route(&our_id, &feasible_payment_params, &network_graph, None, 100, 0,
+                       Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap();
+               let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();
+               assert!(path.len() == MAX_PATH_LENGTH_ESTIMATE.into());
+
+               // But we can't create a path surpassing the MAX_PATH_LENGTH_ESTIMATE limit.
+               let fail_payment_params = PaymentParameters::from_node_id(nodes[19]);
+               match get_route(&our_id, &fail_payment_params, &network_graph, None, 100, 0,
+                       Arc::clone(&logger), &scorer, &random_seed_bytes)
+               {
+                       Err(LightningError { err, .. } ) => {
+                               assert_eq!(err, "Failed to find a path to the given destination");
+                       },
+                       Ok(_) => panic!("Expected error"),
+               }
+       }
+
        #[test]
        fn adds_and_limits_cltv_offset() {
                let (secp_ctx, network_graph, _, _, logger) = build_graph();
@@ -5313,6 +5557,26 @@ mod tests {
                assert!(path_plausibility.iter().all(|x| *x));
        }
 
+       #[test]
+       fn builds_correct_path_from_hops() {
+               let (secp_ctx, network, _, _, logger) = build_graph();
+               let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
+               let network_graph = network.read_only();
+
+               let keys_manager = test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
+               let random_seed_bytes = keys_manager.get_secure_random_bytes();
+
+               let payment_params = PaymentParameters::from_node_id(nodes[3]);
+               let hops = [nodes[1], nodes[2], nodes[4], nodes[3]];
+               let route = build_route_from_hops_internal(&our_id, &hops, &payment_params,
+                        &network_graph, 100, 0, Arc::clone(&logger), &random_seed_bytes).unwrap();
+               let route_hop_pubkeys = route.paths[0].iter().map(|hop| hop.pubkey).collect::<Vec<_>>();
+               assert_eq!(hops.len(), route.paths[0].len());
+               for (idx, hop_pubkey) in hops.iter().enumerate() {
+                       assert!(*hop_pubkey == route_hop_pubkeys[idx]);
+               }
+       }
+
        #[cfg(not(feature = "no-std"))]
        pub(super) fn random_init_seed() -> u64 {
                // Because the default HashMap in std pulls OS randomness, we can use it as a (bad) RNG.
index 4c47aac47b64c2e78df2b39d0e740083023ecadc..38e3c838425f5957fddb0a525dc15ca8cbc24eca 100644 (file)
 //! [`find_route`]: crate::routing::router::find_route
 
 use ln::msgs::DecodeError;
-use routing::network_graph::{NetworkGraph, NodeId};
+use routing::network_graph::{EffectiveCapacity, NetworkGraph, NodeId};
 use routing::router::RouteHop;
 use util::ser::{Readable, ReadableArgs, Writeable, Writer};
 use util::logger::Logger;
+use util::time::Time;
 
 use prelude::*;
 use core::fmt;
@@ -92,7 +93,9 @@ pub trait Score $(: $supertrait)* {
        /// such as a chain data, network gossip, or invoice hints. For invoice hints, a capacity near
        /// [`u64::max_value`] is given to indicate sufficient capacity for the invoice's full amount.
        /// Thus, implementations should be overflow-safe.
-       fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, capacity_msat: u64, source: &NodeId, target: &NodeId) -> u64;
+       fn channel_penalty_msat(
+               &self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage
+       ) -> u64;
 
        /// Handles updating channel penalties after failing to route through a channel.
        fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64);
@@ -102,8 +105,10 @@ pub trait Score $(: $supertrait)* {
 }
 
 impl<S: Score, T: DerefMut<Target=S> $(+ $supertrait)*> Score for T {
-       fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, capacity_msat: u64, source: &NodeId, target: &NodeId) -> u64 {
-               self.deref().channel_penalty_msat(short_channel_id, send_amt_msat, capacity_msat, source, target)
+       fn channel_penalty_msat(
+               &self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage
+       ) -> u64 {
+               self.deref().channel_penalty_msat(short_channel_id, source, target, usage)
        }
 
        fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) {
@@ -201,6 +206,20 @@ impl<'a, S: Writeable> Writeable for MutexGuard<'a, S> {
        }
 }
 
+/// Proposed use of a channel passed as a parameter to [`Score::channel_penalty_msat`].
+#[derive(Clone, Copy)]
+pub struct ChannelUsage {
+       /// The amount to send through the channel, denominated in millisatoshis.
+       pub amount_msat: u64,
+
+       /// Total amount, denominated in millisatoshis, already allocated to send through the channel
+       /// as part of a multi-path payment.
+       pub inflight_htlc_msat: u64,
+
+       /// The effective capacity of the channel.
+       pub effective_capacity: EffectiveCapacity,
+}
+
 #[derive(Clone)]
 /// [`Score`] implementation that uses a fixed penalty.
 pub struct FixedPenaltyScorer {
@@ -215,7 +234,7 @@ impl FixedPenaltyScorer {
 }
 
 impl Score for FixedPenaltyScorer {
-       fn channel_penalty_msat(&self, _: u64, _: u64, _: u64, _: &NodeId, _: &NodeId) -> u64 {
+       fn channel_penalty_msat(&self, _: u64, _: &NodeId, _: &NodeId, _: ChannelUsage) -> u64 {
                self.penalty_msat
        }
 
@@ -262,7 +281,9 @@ pub type Scorer = ScorerUsingTime::<ConfiguredTime>;
 #[cfg(not(feature = "no-std"))]
 type ConfiguredTime = std::time::Instant;
 #[cfg(feature = "no-std")]
-type ConfiguredTime = time::Eternity;
+use util::time::Eternity;
+#[cfg(feature = "no-std")]
+type ConfiguredTime = Eternity;
 
 // Note that ideally we'd hide ScorerUsingTime from public view by sealing it as well, but rustdoc
 // doesn't handle this well - instead exposing a `Scorer` which has no trait implementation(s) or
@@ -404,13 +425,16 @@ impl Default for ScoringParameters {
 
 impl<T: Time> Score for ScorerUsingTime<T> {
        fn channel_penalty_msat(
-               &self, short_channel_id: u64, send_amt_msat: u64, capacity_msat: u64, _source: &NodeId, _target: &NodeId
+               &self, short_channel_id: u64, _source: &NodeId, _target: &NodeId, usage: ChannelUsage
        ) -> u64 {
                let failure_penalty_msat = self.channel_failures
                        .get(&short_channel_id)
                        .map_or(0, |value| value.decayed_penalty_msat(self.params.failure_penalty_half_life));
 
                let mut penalty_msat = self.params.base_penalty_msat + failure_penalty_msat;
+               let send_amt_msat = usage.amount_msat;
+               let capacity_msat = usage.effective_capacity.as_msat()
+                       .saturating_sub(usage.inflight_htlc_msat);
                let send_1024ths = send_amt_msat.checked_mul(1024).unwrap_or(u64::max_value()) / capacity_msat;
                if send_1024ths > self.params.overuse_penalty_start_1024th as u64 {
                        penalty_msat = penalty_msat.checked_add(
@@ -877,10 +901,20 @@ impl<L: DerefMut<Target = u64>, T: Time, U: DerefMut<Target = T>> DirectedChanne
 
 impl<G: Deref<Target = NetworkGraph>, L: Deref, T: Time> Score for ProbabilisticScorerUsingTime<G, L, T> where L::Target: Logger {
        fn channel_penalty_msat(
-               &self, short_channel_id: u64, amount_msat: u64, capacity_msat: u64, source: &NodeId,
-               target: &NodeId
+               &self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage
        ) -> u64 {
+               if let EffectiveCapacity::ExactLiquidity { liquidity_msat } = usage.effective_capacity {
+                       if usage.amount_msat > liquidity_msat {
+                               return u64::max_value();
+                       } else {
+                               return self.params.base_penalty_msat;
+                       };
+               }
+
                let liquidity_offset_half_life = self.params.liquidity_offset_half_life;
+               let amount_msat = usage.amount_msat;
+               let capacity_msat = usage.effective_capacity.as_msat()
+                       .saturating_sub(usage.inflight_htlc_msat);
                self.channel_liquidities
                        .get(&short_channel_id)
                        .unwrap_or(&ChannelLiquidity::new())
@@ -1327,88 +1361,16 @@ impl<T: Time> Readable for ChannelLiquidity<T> {
        }
 }
 
-pub(crate) mod time {
-       use core::ops::Sub;
-       use core::time::Duration;
-       /// A measurement of time.
-       pub trait Time: Copy + Sub<Duration, Output = Self> where Self: Sized {
-               /// Returns an instance corresponding to the current moment.
-               fn now() -> Self;
-
-               /// Returns the amount of time elapsed since `self` was created.
-               fn elapsed(&self) -> Duration;
-
-               /// Returns the amount of time passed between `earlier` and `self`.
-               fn duration_since(&self, earlier: Self) -> Duration;
-
-               /// Returns the amount of time passed since the beginning of [`Time`].
-               ///
-               /// Used during (de-)serialization.
-               fn duration_since_epoch() -> Duration;
-       }
-
-       /// A state in which time has no meaning.
-       #[derive(Clone, Copy, Debug, PartialEq, Eq)]
-       pub struct Eternity;
-
-       #[cfg(not(feature = "no-std"))]
-       impl Time for std::time::Instant {
-               fn now() -> Self {
-                       std::time::Instant::now()
-               }
-
-               fn duration_since(&self, earlier: Self) -> Duration {
-                       self.duration_since(earlier)
-               }
-
-               fn duration_since_epoch() -> Duration {
-                       use std::time::SystemTime;
-                       SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap()
-               }
-
-               fn elapsed(&self) -> Duration {
-                       std::time::Instant::elapsed(self)
-               }
-       }
-
-       impl Time for Eternity {
-               fn now() -> Self {
-                       Self
-               }
-
-               fn duration_since(&self, _earlier: Self) -> Duration {
-                       Duration::from_secs(0)
-               }
-
-               fn duration_since_epoch() -> Duration {
-                       Duration::from_secs(0)
-               }
-
-               fn elapsed(&self) -> Duration {
-                       Duration::from_secs(0)
-               }
-       }
-
-       impl Sub<Duration> for Eternity {
-               type Output = Self;
-
-               fn sub(self, _other: Duration) -> Self {
-                       self
-               }
-       }
-}
-
-pub(crate) use self::time::Time;
-
 #[cfg(test)]
 mod tests {
-       use super::{ChannelLiquidity, ProbabilisticScoringParameters, ProbabilisticScorerUsingTime, ScoringParameters, ScorerUsingTime, Time};
-       use super::time::Eternity;
+       use super::{ChannelLiquidity, ProbabilisticScoringParameters, ProbabilisticScorerUsingTime, ScoringParameters, ScorerUsingTime};
+       use util::time::Time;
+       use util::time::tests::SinceEpoch;
 
        use ln::features::{ChannelFeatures, NodeFeatures};
        use ln::msgs::{ChannelAnnouncement, ChannelUpdate, OptionalField, UnsignedChannelAnnouncement, UnsignedChannelUpdate};
-       use routing::scoring::Score;
-       use routing::network_graph::{NetworkGraph, NodeId};
+       use routing::scoring::{ChannelUsage, Score};
+       use routing::network_graph::{EffectiveCapacity, NetworkGraph, NodeId};
        use routing::router::RouteHop;
        use util::ser::{Readable, ReadableArgs, Writeable};
        use util::test_utils::TestLogger;
@@ -1418,80 +1380,9 @@ mod tests {
        use bitcoin::hashes::sha256d::Hash as Sha256dHash;
        use bitcoin::network::constants::Network;
        use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
-       use core::cell::Cell;
-       use core::ops::Sub;
        use core::time::Duration;
        use io;
 
-       // `Time` tests
-
-       /// Time that can be advanced manually in tests.
-       #[derive(Clone, Copy, Debug, PartialEq, Eq)]
-       struct SinceEpoch(Duration);
-
-       impl SinceEpoch {
-               thread_local! {
-                       static ELAPSED: Cell<Duration> = core::cell::Cell::new(Duration::from_secs(0));
-               }
-
-               fn advance(duration: Duration) {
-                       Self::ELAPSED.with(|elapsed| elapsed.set(elapsed.get() + duration))
-               }
-       }
-
-       impl Time for SinceEpoch {
-               fn now() -> Self {
-                       Self(Self::duration_since_epoch())
-               }
-
-               fn duration_since(&self, earlier: Self) -> Duration {
-                       self.0 - earlier.0
-               }
-
-               fn duration_since_epoch() -> Duration {
-                       Self::ELAPSED.with(|elapsed| elapsed.get())
-               }
-
-               fn elapsed(&self) -> Duration {
-                       Self::duration_since_epoch() - self.0
-               }
-       }
-
-       impl Sub<Duration> for SinceEpoch {
-               type Output = Self;
-
-               fn sub(self, other: Duration) -> Self {
-                       Self(self.0 - other)
-               }
-       }
-
-       #[test]
-       fn time_passes_when_advanced() {
-               let now = SinceEpoch::now();
-               assert_eq!(now.elapsed(), Duration::from_secs(0));
-
-               SinceEpoch::advance(Duration::from_secs(1));
-               SinceEpoch::advance(Duration::from_secs(1));
-
-               let elapsed = now.elapsed();
-               let later = SinceEpoch::now();
-
-               assert_eq!(elapsed, Duration::from_secs(2));
-               assert_eq!(later - elapsed, now);
-       }
-
-       #[test]
-       fn time_never_passes_in_an_eternity() {
-               let now = Eternity::now();
-               let elapsed = now.elapsed();
-               let later = Eternity::now();
-
-               assert_eq!(now.elapsed(), Duration::from_secs(0));
-               assert_eq!(later - elapsed, now);
-       }
-
-       // `Scorer` tests
-
        /// A scorer for testing with time that can be manually advanced.
        type Scorer = ScorerUsingTime::<SinceEpoch>;
 
@@ -1532,10 +1423,13 @@ mod tests {
                });
                let source = source_node_id();
                let target = target_node_id();
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
+               let usage = ChannelUsage {
+                       amount_msat: 1, inflight_htlc_msat: 0, effective_capacity: EffectiveCapacity::Unknown
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_000);
 
                SinceEpoch::advance(Duration::from_secs(1));
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_000);
        }
 
        #[test]
@@ -1549,16 +1443,19 @@ mod tests {
                });
                let source = source_node_id();
                let target = target_node_id();
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
+               let usage = ChannelUsage {
+                       amount_msat: 1, inflight_htlc_msat: 0, effective_capacity: EffectiveCapacity::Unknown
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_000);
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_064);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_064);
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_128);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_128);
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_192);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_192);
        }
 
        #[test]
@@ -1572,25 +1469,28 @@ mod tests {
                });
                let source = source_node_id();
                let target = target_node_id();
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
+               let usage = ChannelUsage {
+                       amount_msat: 1, inflight_htlc_msat: 0, effective_capacity: EffectiveCapacity::Unknown
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_000);
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_512);
 
                SinceEpoch::advance(Duration::from_secs(9));
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_512);
 
                SinceEpoch::advance(Duration::from_secs(1));
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_256);
 
                SinceEpoch::advance(Duration::from_secs(10 * 8));
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_001);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_001);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_000);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_000);
        }
 
        #[test]
@@ -1604,18 +1504,21 @@ mod tests {
                });
                let source = source_node_id();
                let target = target_node_id();
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
+               let usage = ChannelUsage {
+                       amount_msat: 1, inflight_htlc_msat: 0, effective_capacity: EffectiveCapacity::Unknown
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_000);
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_512);
 
                // An unchecked right shift 64 bits or more in ChannelFailure::decayed_penalty_msat would
                // cause an overflow.
                SinceEpoch::advance(Duration::from_secs(10 * 64));
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_000);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_000);
        }
 
        #[test]
@@ -1629,19 +1532,22 @@ mod tests {
                });
                let source = source_node_id();
                let target = target_node_id();
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
+               let usage = ChannelUsage {
+                       amount_msat: 1, inflight_htlc_msat: 0, effective_capacity: EffectiveCapacity::Unknown
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_000);
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_512);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_256);
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_768);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_768);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_384);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_384);
        }
 
        #[test]
@@ -1655,13 +1561,16 @@ mod tests {
                });
                let source = source_node_id();
                let target = target_node_id();
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
+               let usage = ChannelUsage {
+                       amount_msat: 1, inflight_htlc_msat: 0, effective_capacity: EffectiveCapacity::Unknown
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_000);
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_512);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_256);
 
                let hop = RouteHop {
                        pubkey: PublicKey::from_slice(target.as_slice()).unwrap(),
@@ -1672,10 +1581,10 @@ mod tests {
                        cltv_expiry_delta: 18,
                };
                scorer.payment_path_successful(&[&hop]);
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_128);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_128);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_064);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_064);
        }
 
        #[test]
@@ -1689,22 +1598,25 @@ mod tests {
                });
                let source = source_node_id();
                let target = target_node_id();
+               let usage = ChannelUsage {
+                       amount_msat: 1, inflight_htlc_msat: 0, effective_capacity: EffectiveCapacity::Unknown
+               };
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_512);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_256);
 
                scorer.payment_path_failed(&[], 43);
-               assert_eq!(scorer.channel_penalty_msat(43, 1, 1, &source, &target), 1_512);
+               assert_eq!(scorer.channel_penalty_msat(43, &source, &target, usage), 1_512);
 
                let mut serialized_scorer = Vec::new();
                scorer.write(&mut serialized_scorer).unwrap();
 
                let deserialized_scorer = <Scorer>::read(&mut io::Cursor::new(&serialized_scorer)).unwrap();
-               assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
-               assert_eq!(deserialized_scorer.channel_penalty_msat(43, 1, 1, &source, &target), 1_512);
+               assert_eq!(deserialized_scorer.channel_penalty_msat(42, &source, &target, usage), 1_256);
+               assert_eq!(deserialized_scorer.channel_penalty_msat(43, &source, &target, usage), 1_512);
        }
 
        #[test]
@@ -1718,9 +1630,12 @@ mod tests {
                });
                let source = source_node_id();
                let target = target_node_id();
+               let usage = ChannelUsage {
+                       amount_msat: 1, inflight_htlc_msat: 0, effective_capacity: EffectiveCapacity::Unknown
+               };
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_512);
 
                let mut serialized_scorer = Vec::new();
                scorer.write(&mut serialized_scorer).unwrap();
@@ -1728,10 +1643,10 @@ mod tests {
                SinceEpoch::advance(Duration::from_secs(10));
 
                let deserialized_scorer = <Scorer>::read(&mut io::Cursor::new(&serialized_scorer)).unwrap();
-               assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
+               assert_eq!(deserialized_scorer.channel_penalty_msat(42, &source, &target, usage), 1_256);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_128);
+               assert_eq!(deserialized_scorer.channel_penalty_msat(42, &source, &target, usage), 1_128);
        }
 
        #[test]
@@ -1746,11 +1661,24 @@ mod tests {
                let source = source_node_id();
                let target = target_node_id();
 
-               assert_eq!(scorer.channel_penalty_msat(42, 1_000, 1_024_000, &source, &target), 0);
-               assert_eq!(scorer.channel_penalty_msat(42, 256_999, 1_024_000, &source, &target), 0);
-               assert_eq!(scorer.channel_penalty_msat(42, 257_000, 1_024_000, &source, &target), 100);
-               assert_eq!(scorer.channel_penalty_msat(42, 258_000, 1_024_000, &source, &target), 200);
-               assert_eq!(scorer.channel_penalty_msat(42, 512_000, 1_024_000, &source, &target), 256 * 100);
+               let usage = ChannelUsage {
+                       amount_msat: 1_000,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_024_000 },
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+
+               let usage = ChannelUsage { amount_msat: 256_999, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+
+               let usage = ChannelUsage { amount_msat: 257_000, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 100);
+
+               let usage = ChannelUsage { amount_msat: 258_000, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 200);
+
+               let usage = ChannelUsage { amount_msat: 512_000, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 256 * 100);
        }
 
        // `ProbabilisticScorer` tests
@@ -2083,18 +2011,37 @@ mod tests {
                let source = source_node_id();
                let target = target_node_id();
 
-               assert_eq!(scorer.channel_penalty_msat(42, 1_024, 1_024_000, &source, &target), 0);
-               assert_eq!(scorer.channel_penalty_msat(42, 10_240, 1_024_000, &source, &target), 0);
-               assert_eq!(scorer.channel_penalty_msat(42, 102_400, 1_024_000, &source, &target), 47);
-               assert_eq!(scorer.channel_penalty_msat(42, 1_024_000, 1_024_000, &source, &target), 2_000);
-
-               assert_eq!(scorer.channel_penalty_msat(42, 128, 1_024, &source, &target), 58);
-               assert_eq!(scorer.channel_penalty_msat(42, 256, 1_024, &source, &target), 125);
-               assert_eq!(scorer.channel_penalty_msat(42, 374, 1_024, &source, &target), 198);
-               assert_eq!(scorer.channel_penalty_msat(42, 512, 1_024, &source, &target), 300);
-               assert_eq!(scorer.channel_penalty_msat(42, 640, 1_024, &source, &target), 425);
-               assert_eq!(scorer.channel_penalty_msat(42, 768, 1_024, &source, &target), 602);
-               assert_eq!(scorer.channel_penalty_msat(42, 896, 1_024, &source, &target), 902);
+               let usage = ChannelUsage {
+                       amount_msat: 1_024,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_024_000 },
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+               let usage = ChannelUsage { amount_msat: 10_240, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+               let usage = ChannelUsage { amount_msat: 102_400, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 47);
+               let usage = ChannelUsage { amount_msat: 1_024_000, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 2_000);
+
+               let usage = ChannelUsage {
+                       amount_msat: 128,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_024 },
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 58);
+               let usage = ChannelUsage { amount_msat: 256, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 125);
+               let usage = ChannelUsage { amount_msat: 374, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 198);
+               let usage = ChannelUsage { amount_msat: 512, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 300);
+               let usage = ChannelUsage { amount_msat: 640, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 425);
+               let usage = ChannelUsage { amount_msat: 768, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 602);
+               let usage = ChannelUsage { amount_msat: 896, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 902);
        }
 
        #[test]
@@ -2114,10 +2061,17 @@ mod tests {
                let source = source_node_id();
                let target = target_node_id();
 
-               assert_eq!(scorer.channel_penalty_msat(42, 39, 100, &source, &target), 0);
-               assert_ne!(scorer.channel_penalty_msat(42, 50, 100, &source, &target), 0);
-               assert_ne!(scorer.channel_penalty_msat(42, 50, 100, &source, &target), u64::max_value());
-               assert_eq!(scorer.channel_penalty_msat(42, 61, 100, &source, &target), u64::max_value());
+               let usage = ChannelUsage {
+                       amount_msat: 39,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 100 },
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+               let usage = ChannelUsage { amount_msat: 50, ..usage };
+               assert_ne!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+               assert_ne!(scorer.channel_penalty_msat(42, &source, &target, usage), u64::max_value());
+               let usage = ChannelUsage { amount_msat: 61, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), u64::max_value());
        }
 
        #[test]
@@ -2131,16 +2085,21 @@ mod tests {
                let mut scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
                let sender = sender_node_id();
                let source = source_node_id();
+               let usage = ChannelUsage {
+                       amount_msat: 500,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_000 },
+               };
                let failed_path = payment_path_for_amount(500);
                let successful_path = payment_path_for_amount(200);
 
-               assert_eq!(scorer.channel_penalty_msat(41, 500, 1_000, &sender, &source), 301);
+               assert_eq!(scorer.channel_penalty_msat(41, &sender, &source, usage), 301);
 
                scorer.payment_path_failed(&failed_path.iter().collect::<Vec<_>>(), 41);
-               assert_eq!(scorer.channel_penalty_msat(41, 500, 1_000, &sender, &source), 301);
+               assert_eq!(scorer.channel_penalty_msat(41, &sender, &source, usage), 301);
 
                scorer.payment_path_successful(&successful_path.iter().collect::<Vec<_>>());
-               assert_eq!(scorer.channel_penalty_msat(41, 500, 1_000, &sender, &source), 301);
+               assert_eq!(scorer.channel_penalty_msat(41, &sender, &source, usage), 301);
        }
 
        #[test]
@@ -2156,15 +2115,25 @@ mod tests {
                let target = target_node_id();
                let path = payment_path_for_amount(500);
 
-               assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 128);
-               assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 301);
-               assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 602);
+               let usage = ChannelUsage {
+                       amount_msat: 250,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_000 },
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 128);
+               let usage = ChannelUsage { amount_msat: 500, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 301);
+               let usage = ChannelUsage { amount_msat: 750, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 602);
 
                scorer.payment_path_failed(&path.iter().collect::<Vec<_>>(), 43);
 
-               assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 0);
-               assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 0);
-               assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 300);
+               let usage = ChannelUsage { amount_msat: 250, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+               let usage = ChannelUsage { amount_msat: 500, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+               let usage = ChannelUsage { amount_msat: 750, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 300);
        }
 
        #[test]
@@ -2180,15 +2149,25 @@ mod tests {
                let target = target_node_id();
                let path = payment_path_for_amount(500);
 
-               assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 128);
-               assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 301);
-               assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 602);
+               let usage = ChannelUsage {
+                       amount_msat: 250,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_000 },
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 128);
+               let usage = ChannelUsage { amount_msat: 500, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 301);
+               let usage = ChannelUsage { amount_msat: 750, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 602);
 
                scorer.payment_path_failed(&path.iter().collect::<Vec<_>>(), 42);
 
-               assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 300);
-               assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), u64::max_value());
-               assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), u64::max_value());
+               let usage = ChannelUsage { amount_msat: 250, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 300);
+               let usage = ChannelUsage { amount_msat: 500, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), u64::max_value());
+               let usage = ChannelUsage { amount_msat: 750, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), u64::max_value());
        }
 
        #[test]
@@ -2204,17 +2183,22 @@ mod tests {
                let source = source_node_id();
                let target = target_node_id();
                let recipient = recipient_node_id();
+               let usage = ChannelUsage {
+                       amount_msat: 250,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_000 },
+               };
                let path = payment_path_for_amount(500);
 
-               assert_eq!(scorer.channel_penalty_msat(41, 250, 1_000, &sender, &source), 128);
-               assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 128);
-               assert_eq!(scorer.channel_penalty_msat(43, 250, 1_000, &target, &recipient), 128);
+               assert_eq!(scorer.channel_penalty_msat(41, &sender, &source, usage), 128);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 128);
+               assert_eq!(scorer.channel_penalty_msat(43, &target, &recipient, usage), 128);
 
                scorer.payment_path_successful(&path.iter().collect::<Vec<_>>());
 
-               assert_eq!(scorer.channel_penalty_msat(41, 250, 1_000, &sender, &source), 128);
-               assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 300);
-               assert_eq!(scorer.channel_penalty_msat(43, 250, 1_000, &target, &recipient), 300);
+               assert_eq!(scorer.channel_penalty_msat(41, &sender, &source, usage), 128);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 300);
+               assert_eq!(scorer.channel_penalty_msat(43, &target, &recipient, usage), 300);
        }
 
        #[test]
@@ -2230,44 +2214,70 @@ mod tests {
                let source = source_node_id();
                let target = target_node_id();
 
-               assert_eq!(scorer.channel_penalty_msat(42, 0, 1_024, &source, &target), 0);
-               assert_eq!(scorer.channel_penalty_msat(42, 1_024, 1_024, &source, &target), 2_000);
+               let usage = ChannelUsage {
+                       amount_msat: 0,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_024 },
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+               let usage = ChannelUsage { amount_msat: 1_024, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 2_000);
 
                scorer.payment_path_failed(&payment_path_for_amount(768).iter().collect::<Vec<_>>(), 42);
                scorer.payment_path_failed(&payment_path_for_amount(128).iter().collect::<Vec<_>>(), 43);
 
-               assert_eq!(scorer.channel_penalty_msat(42, 128, 1_024, &source, &target), 0);
-               assert_eq!(scorer.channel_penalty_msat(42, 256, 1_024, &source, &target), 93);
-               assert_eq!(scorer.channel_penalty_msat(42, 768, 1_024, &source, &target), 1_479);
-               assert_eq!(scorer.channel_penalty_msat(42, 896, 1_024, &source, &target), u64::max_value());
+               let usage = ChannelUsage { amount_msat: 128, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+               let usage = ChannelUsage { amount_msat: 256, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 93);
+               let usage = ChannelUsage { amount_msat: 768, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_479);
+               let usage = ChannelUsage { amount_msat: 896, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), u64::max_value());
 
                SinceEpoch::advance(Duration::from_secs(9));
-               assert_eq!(scorer.channel_penalty_msat(42, 128, 1_024, &source, &target), 0);
-               assert_eq!(scorer.channel_penalty_msat(42, 256, 1_024, &source, &target), 93);
-               assert_eq!(scorer.channel_penalty_msat(42, 768, 1_024, &source, &target), 1_479);
-               assert_eq!(scorer.channel_penalty_msat(42, 896, 1_024, &source, &target), u64::max_value());
+               let usage = ChannelUsage { amount_msat: 128, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+               let usage = ChannelUsage { amount_msat: 256, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 93);
+               let usage = ChannelUsage { amount_msat: 768, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_479);
+               let usage = ChannelUsage { amount_msat: 896, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), u64::max_value());
 
                SinceEpoch::advance(Duration::from_secs(1));
-               assert_eq!(scorer.channel_penalty_msat(42, 64, 1_024, &source, &target), 0);
-               assert_eq!(scorer.channel_penalty_msat(42, 128, 1_024, &source, &target), 34);
-               assert_eq!(scorer.channel_penalty_msat(42, 896, 1_024, &source, &target), 1_970);
-               assert_eq!(scorer.channel_penalty_msat(42, 960, 1_024, &source, &target), u64::max_value());
+               let usage = ChannelUsage { amount_msat: 64, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+               let usage = ChannelUsage { amount_msat: 128, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 34);
+               let usage = ChannelUsage { amount_msat: 896, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1_970);
+               let usage = ChannelUsage { amount_msat: 960, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), u64::max_value());
 
                // Fully decay liquidity lower bound.
                SinceEpoch::advance(Duration::from_secs(10 * 7));
-               assert_eq!(scorer.channel_penalty_msat(42, 0, 1_024, &source, &target), 0);
-               assert_eq!(scorer.channel_penalty_msat(42, 1, 1_024, &source, &target), 0);
-               assert_eq!(scorer.channel_penalty_msat(42, 1_023, 1_024, &source, &target), 2_000);
-               assert_eq!(scorer.channel_penalty_msat(42, 1_024, 1_024, &source, &target), 2_000);
+               let usage = ChannelUsage { amount_msat: 0, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+               let usage = ChannelUsage { amount_msat: 1, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+               let usage = ChannelUsage { amount_msat: 1_023, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 2_000);
+               let usage = ChannelUsage { amount_msat: 1_024, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 2_000);
 
                // Fully decay liquidity upper bound.
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, 0, 1_024, &source, &target), 0);
-               assert_eq!(scorer.channel_penalty_msat(42, 1_024, 1_024, &source, &target), 2_000);
+               let usage = ChannelUsage { amount_msat: 0, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+               let usage = ChannelUsage { amount_msat: 1_024, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 2_000);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, 0, 1_024, &source, &target), 0);
-               assert_eq!(scorer.channel_penalty_msat(42, 1_024, 1_024, &source, &target), 2_000);
+               let usage = ChannelUsage { amount_msat: 0, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 0);
+               let usage = ChannelUsage { amount_msat: 1_024, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 2_000);
        }
 
        #[test]
@@ -2282,18 +2292,23 @@ mod tests {
                let mut scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
                let source = source_node_id();
                let target = target_node_id();
-               assert_eq!(scorer.channel_penalty_msat(42, 256, 1_024, &source, &target), 125);
+               let usage = ChannelUsage {
+                       amount_msat: 256,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_024 },
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 125);
 
                scorer.payment_path_failed(&payment_path_for_amount(512).iter().collect::<Vec<_>>(), 42);
-               assert_eq!(scorer.channel_penalty_msat(42, 256, 1_024, &source, &target), 281);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 281);
 
                // An unchecked right shift 64 bits or more in DirectedChannelLiquidity::decayed_offset_msat
                // would cause an overflow.
                SinceEpoch::advance(Duration::from_secs(10 * 64));
-               assert_eq!(scorer.channel_penalty_msat(42, 256, 1_024, &source, &target), 125);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 125);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, 256, 1_024, &source, &target), 125);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 125);
        }
 
        #[test]
@@ -2308,31 +2323,36 @@ mod tests {
                let mut scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
                let source = source_node_id();
                let target = target_node_id();
+               let usage = ChannelUsage {
+                       amount_msat: 512,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_024 },
+               };
 
-               assert_eq!(scorer.channel_penalty_msat(42, 512, 1_024, &source, &target), 300);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 300);
 
                // More knowledge gives higher confidence (256, 768), meaning a lower penalty.
                scorer.payment_path_failed(&payment_path_for_amount(768).iter().collect::<Vec<_>>(), 42);
                scorer.payment_path_failed(&payment_path_for_amount(256).iter().collect::<Vec<_>>(), 43);
-               assert_eq!(scorer.channel_penalty_msat(42, 512, 1_024, &source, &target), 281);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 281);
 
                // Decaying knowledge gives less confidence (128, 896), meaning a higher penalty.
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, 512, 1_024, &source, &target), 291);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 291);
 
                // Reducing the upper bound gives more confidence (128, 832) that the payment amount (512)
                // is closer to the upper bound, meaning a higher penalty.
                scorer.payment_path_successful(&payment_path_for_amount(64).iter().collect::<Vec<_>>());
-               assert_eq!(scorer.channel_penalty_msat(42, 512, 1_024, &source, &target), 331);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 331);
 
                // Increasing the lower bound gives more confidence (256, 832) that the payment amount (512)
                // is closer to the lower bound, meaning a lower penalty.
                scorer.payment_path_failed(&payment_path_for_amount(256).iter().collect::<Vec<_>>(), 43);
-               assert_eq!(scorer.channel_penalty_msat(42, 512, 1_024, &source, &target), 245);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 245);
 
                // Further decaying affects the lower bound more than the upper bound (128, 928).
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, 512, 1_024, &source, &target), 280);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 280);
        }
 
        #[test]
@@ -2347,15 +2367,20 @@ mod tests {
                let mut scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
                let source = source_node_id();
                let target = target_node_id();
+               let usage = ChannelUsage {
+                       amount_msat: 500,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_000 },
+               };
 
                scorer.payment_path_failed(&payment_path_for_amount(500).iter().collect::<Vec<_>>(), 42);
-               assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), u64::max_value());
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), u64::max_value());
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 473);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 473);
 
                scorer.payment_path_failed(&payment_path_for_amount(250).iter().collect::<Vec<_>>(), 43);
-               assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 300);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 300);
 
                let mut serialized_scorer = Vec::new();
                scorer.write(&mut serialized_scorer).unwrap();
@@ -2363,7 +2388,7 @@ mod tests {
                let mut serialized_scorer = io::Cursor::new(&serialized_scorer);
                let deserialized_scorer =
                        <ProbabilisticScorer>::read(&mut serialized_scorer, (params, &network_graph, &logger)).unwrap();
-               assert_eq!(deserialized_scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 300);
+               assert_eq!(deserialized_scorer.channel_penalty_msat(42, &source, &target, usage), 300);
        }
 
        #[test]
@@ -2378,9 +2403,14 @@ mod tests {
                let mut scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
                let source = source_node_id();
                let target = target_node_id();
+               let usage = ChannelUsage {
+                       amount_msat: 500,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_000 },
+               };
 
                scorer.payment_path_failed(&payment_path_for_amount(500).iter().collect::<Vec<_>>(), 42);
-               assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), u64::max_value());
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), u64::max_value());
 
                let mut serialized_scorer = Vec::new();
                scorer.write(&mut serialized_scorer).unwrap();
@@ -2390,13 +2420,13 @@ mod tests {
                let mut serialized_scorer = io::Cursor::new(&serialized_scorer);
                let deserialized_scorer =
                        <ProbabilisticScorer>::read(&mut serialized_scorer, (params, &network_graph, &logger)).unwrap();
-               assert_eq!(deserialized_scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 473);
+               assert_eq!(deserialized_scorer.channel_penalty_msat(42, &source, &target, usage), 473);
 
                scorer.payment_path_failed(&payment_path_for_amount(250).iter().collect::<Vec<_>>(), 43);
-               assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 300);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 300);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(deserialized_scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 365);
+               assert_eq!(deserialized_scorer.channel_penalty_msat(42, &source, &target, usage), 365);
        }
 
        #[test]
@@ -2410,17 +2440,52 @@ mod tests {
                let source = source_node_id();
                let target = target_node_id();
 
-               assert_eq!(scorer.channel_penalty_msat(42, 100_000_000, 950_000_000, &source, &target), 3613);
-               assert_eq!(scorer.channel_penalty_msat(42, 100_000_000, 1_950_000_000, &source, &target), 1977);
-               assert_eq!(scorer.channel_penalty_msat(42, 100_000_000, 2_950_000_000, &source, &target), 1474);
-               assert_eq!(scorer.channel_penalty_msat(42, 100_000_000, 3_950_000_000, &source, &target), 1223);
-               assert_eq!(scorer.channel_penalty_msat(42, 100_000_000, 4_950_000_000, &source, &target), 877);
-               assert_eq!(scorer.channel_penalty_msat(42, 100_000_000, 5_950_000_000, &source, &target), 845);
-               assert_eq!(scorer.channel_penalty_msat(42, 100_000_000, 6_950_000_000, &source, &target), 500);
-               assert_eq!(scorer.channel_penalty_msat(42, 100_000_000, 7_450_000_000, &source, &target), 500);
-               assert_eq!(scorer.channel_penalty_msat(42, 100_000_000, 7_950_000_000, &source, &target), 500);
-               assert_eq!(scorer.channel_penalty_msat(42, 100_000_000, 8_950_000_000, &source, &target), 500);
-               assert_eq!(scorer.channel_penalty_msat(42, 100_000_000, 9_950_000_000, &source, &target), 500);
+               let usage = ChannelUsage {
+                       amount_msat: 100_000_000,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 950_000_000 },
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 3613);
+               let usage = ChannelUsage {
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_950_000_000 }, ..usage
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1977);
+               let usage = ChannelUsage {
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 2_950_000_000 }, ..usage
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1474);
+               let usage = ChannelUsage {
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 3_950_000_000 }, ..usage
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 1223);
+               let usage = ChannelUsage {
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 4_950_000_000 }, ..usage
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 877);
+               let usage = ChannelUsage {
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 5_950_000_000 }, ..usage
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 845);
+               let usage = ChannelUsage {
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 6_950_000_000 }, ..usage
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 500);
+               let usage = ChannelUsage {
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 7_450_000_000 }, ..usage
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 500);
+               let usage = ChannelUsage {
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 7_950_000_000 }, ..usage
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 500);
+               let usage = ChannelUsage {
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 8_950_000_000 }, ..usage
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 500);
+               let usage = ChannelUsage {
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 9_950_000_000 }, ..usage
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 500);
        }
 
        #[test]
@@ -2429,19 +2494,24 @@ mod tests {
                let network_graph = network_graph();
                let source = source_node_id();
                let target = target_node_id();
+               let usage = ChannelUsage {
+                       amount_msat: 128,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_024 },
+               };
 
                let params = ProbabilisticScoringParameters {
                        liquidity_penalty_multiplier_msat: 1_000,
                        ..ProbabilisticScoringParameters::zero_penalty()
                };
                let scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
-               assert_eq!(scorer.channel_penalty_msat(42, 128, 1_024, &source, &target), 58);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 58);
 
                let params = ProbabilisticScoringParameters {
                        base_penalty_msat: 500, liquidity_penalty_multiplier_msat: 1_000, ..Default::default()
                };
                let scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
-               assert_eq!(scorer.channel_penalty_msat(42, 128, 1_024, &source, &target), 558);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 558);
        }
 
        #[test]
@@ -2450,6 +2520,11 @@ mod tests {
                let network_graph = network_graph();
                let source = source_node_id();
                let target = target_node_id();
+               let usage = ChannelUsage {
+                       amount_msat: 512_000,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_024_000 },
+               };
 
                let params = ProbabilisticScoringParameters {
                        liquidity_penalty_multiplier_msat: 1_000,
@@ -2457,7 +2532,7 @@ mod tests {
                        ..ProbabilisticScoringParameters::zero_penalty()
                };
                let scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
-               assert_eq!(scorer.channel_penalty_msat(42, 512_000, 1_024_000, &source, &target), 300);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 300);
 
                let params = ProbabilisticScoringParameters {
                        liquidity_penalty_multiplier_msat: 1_000,
@@ -2465,7 +2540,7 @@ mod tests {
                        ..ProbabilisticScoringParameters::zero_penalty()
                };
                let scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
-               assert_eq!(scorer.channel_penalty_msat(42, 512_000, 1_024_000, &source, &target), 337);
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 337);
        }
 
        #[test]
@@ -2474,15 +2549,61 @@ mod tests {
                let network_graph = network_graph();
                let source = source_node_id();
                let target = target_node_id();
+               let usage = ChannelUsage {
+                       amount_msat: u64::max_value(),
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Infinite,
+               };
 
                let params = ProbabilisticScoringParameters {
                        liquidity_penalty_multiplier_msat: 40_000,
                        ..ProbabilisticScoringParameters::zero_penalty()
                };
                let scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
-               assert_eq!(
-                       scorer.channel_penalty_msat(42, u64::max_value(), u64::max_value(), &source, &target),
-                       80_000,
-               );
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 80_000);
+       }
+
+       #[test]
+       fn accounts_for_inflight_htlc_usage() {
+               let network_graph = network_graph();
+               let logger = TestLogger::new();
+               let params = ProbabilisticScoringParameters::default();
+               let scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
+               let source = source_node_id();
+               let target = target_node_id();
+
+               let usage = ChannelUsage {
+                       amount_msat: 750,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_000 },
+               };
+               assert_ne!(scorer.channel_penalty_msat(42, &source, &target, usage), u64::max_value());
+
+               let usage = ChannelUsage { inflight_htlc_msat: 251, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), u64::max_value());
+       }
+
+       #[test]
+       fn removes_uncertainity_when_exact_liquidity_known() {
+               let network_graph = network_graph();
+               let logger = TestLogger::new();
+               let params = ProbabilisticScoringParameters::default();
+               let scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
+               let source = source_node_id();
+               let target = target_node_id();
+
+               let base_penalty_msat = params.base_penalty_msat;
+               let usage = ChannelUsage {
+                       amount_msat: 750,
+                       inflight_htlc_msat: 0,
+                       effective_capacity: EffectiveCapacity::ExactLiquidity { liquidity_msat: 1_000 },
+               };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), base_penalty_msat);
+
+               let usage = ChannelUsage { amount_msat: 1_000, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), base_penalty_msat);
+
+               let usage = ChannelUsage { amount_msat: 1_001, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), u64::max_value());
        }
 }
index 6fff0cf055c165f05d5df2b2cfeef814ac5ef1a2..ef0c619b02e9f7a1b03cd6dd14b439e7799264dc 100644 (file)
@@ -161,8 +161,15 @@ pub enum Event {
        /// [`ChannelManager::funding_transaction_generated`]: crate::ln::channelmanager::ChannelManager::funding_transaction_generated
        FundingGenerationReady {
                /// The random channel_id we picked which you'll need to pass into
-               /// ChannelManager::funding_transaction_generated.
+               /// [`ChannelManager::funding_transaction_generated`].
+               ///
+               /// [`ChannelManager::funding_transaction_generated`]: crate::ln::channelmanager::ChannelManager::funding_transaction_generated
                temporary_channel_id: [u8; 32],
+               /// The counterparty's node_id, which you'll need to pass back into
+               /// [`ChannelManager::funding_transaction_generated`].
+               ///
+               /// [`ChannelManager::funding_transaction_generated`]: crate::ln::channelmanager::ChannelManager::funding_transaction_generated
+               counterparty_node_id: PublicKey,
                /// The value, in satoshis, that the output should have.
                channel_value_satoshis: u64,
                /// The script which should be used in the transaction output.
@@ -362,9 +369,12 @@ pub enum Event {
        /// This event is generated when a payment has been successfully forwarded through us and a
        /// forwarding fee earned.
        PaymentForwarded {
-               /// The channel between the source node and us. Optional because versions prior to 0.0.107
-               /// do not serialize this field.
-               source_channel_id: Option<[u8; 32]>,
+               /// The incoming channel between the previous node and us. This is only `None` for events
+               /// generated or serialized by versions prior to 0.0.107.
+               prev_channel_id: Option<[u8; 32]>,
+               /// The outgoing channel between the next node and us. This is only `None` for events
+               /// generated or serialized by versions prior to 0.0.107.
+               next_channel_id: Option<[u8; 32]>,
                /// The fee, in milli-satoshis, which was earned as a result of the payment.
                ///
                /// Note that if we force-closed the channel over which we forwarded an HTLC while the HTLC
@@ -424,13 +434,21 @@ pub enum Event {
                /// The temporary channel ID of the channel requested to be opened.
                ///
                /// When responding to the request, the `temporary_channel_id` should be passed
-               /// back to the ChannelManager with [`ChannelManager::accept_inbound_channel`] to accept,
-               /// or to [`ChannelManager::force_close_channel`] to reject.
+               /// back to the ChannelManager through [`ChannelManager::accept_inbound_channel`] to accept,
+               /// or through [`ChannelManager::force_close_channel`] to reject.
                ///
                /// [`ChannelManager::accept_inbound_channel`]: crate::ln::channelmanager::ChannelManager::accept_inbound_channel
                /// [`ChannelManager::force_close_channel`]: crate::ln::channelmanager::ChannelManager::force_close_channel
                temporary_channel_id: [u8; 32],
                /// The node_id of the counterparty requesting to open the channel.
+               ///
+               /// When responding to the request, the `counterparty_node_id` should be passed
+               /// back to the `ChannelManager` through [`ChannelManager::accept_inbound_channel`] to
+               /// accept the request, or through [`ChannelManager::force_close_channel`] to reject the
+               /// request.
+               ///
+               /// [`ChannelManager::accept_inbound_channel`]: crate::ln::channelmanager::ChannelManager::accept_inbound_channel
+               /// [`ChannelManager::force_close_channel`]: crate::ln::channelmanager::ChannelManager::force_close_channel
                counterparty_node_id: PublicKey,
                /// The channel value of the requested channel.
                funding_satoshis: u64,
@@ -522,12 +540,13 @@ impl Writeable for Event {
                                        (0, VecWriteWrapper(outputs), required),
                                });
                        },
-                       &Event::PaymentForwarded { fee_earned_msat, source_channel_id, claim_from_onchain_tx } => {
+                       &Event::PaymentForwarded { fee_earned_msat, prev_channel_id, claim_from_onchain_tx, next_channel_id } => {
                                7u8.write(writer)?;
                                write_tlv_fields!(writer, {
                                        (0, fee_earned_msat, option),
-                                       (1, source_channel_id, option),
+                                       (1, prev_channel_id, option),
                                        (2, claim_from_onchain_tx, required),
+                                       (3, next_channel_id, option),
                                });
                        },
                        &Event::ChannelClosed { ref channel_id, ref user_channel_id, ref reason } => {
@@ -687,14 +706,16 @@ impl MaybeReadable for Event {
                        7u8 => {
                                let f = || {
                                        let mut fee_earned_msat = None;
-                                       let mut source_channel_id = None;
+                                       let mut prev_channel_id = None;
                                        let mut claim_from_onchain_tx = false;
+                                       let mut next_channel_id = None;
                                        read_tlv_fields!(reader, {
                                                (0, fee_earned_msat, option),
-                                               (1, source_channel_id, option),
+                                               (1, prev_channel_id, option),
                                                (2, claim_from_onchain_tx, required),
+                                               (3, next_channel_id, option),
                                        });
-                                       Ok(Some(Event::PaymentForwarded { fee_earned_msat, source_channel_id, claim_from_onchain_tx }))
+                                       Ok(Some(Event::PaymentForwarded { fee_earned_msat, prev_channel_id, claim_from_onchain_tx, next_channel_id }))
                                };
                                f()
                        },
index b7ee02d2c1f5724ba0b0b6309ce2eb35ad321c1d..c6181ab269a57a19a8d5ee8c9f225358d6695947 100644 (file)
@@ -36,6 +36,7 @@ pub(crate) mod poly1305;
 pub(crate) mod chacha20poly1305rfc;
 pub(crate) mod transaction_utils;
 pub(crate) mod scid_utils;
+pub(crate) mod time;
 
 /// Logging macro utilities.
 #[macro_use]
index 69fd14640ae605bc2d481f40760ee48d3a502eb6..428adbc5e6638b5a0489eeb855d80f7c42f5ce86 100644 (file)
@@ -301,7 +301,7 @@ impl Readable for U48 {
 /// encoded in several different ways, which we must check for at deserialization-time. Thus, if
 /// you're looking for an example of a variable-length integer to use for your own project, move
 /// along, this is a rather poor design.
-pub(crate) struct BigSize(pub u64);
+pub struct BigSize(pub u64);
 impl Writeable for BigSize {
        #[inline]
        fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
index ac1baaf0d423ac2987757e0a951475536e62a6f6..32ef172b19a609be60c4670eb94281b6b3aeab3b 100644 (file)
@@ -164,7 +164,7 @@ impl<'a> chain::Watch<EnforcingSigner> for TestChainMonitor<'a> {
                update_res
        }
 
-       fn release_pending_monitor_events(&self) -> Vec<MonitorEvent> {
+       fn release_pending_monitor_events(&self) -> Vec<(OutPoint, Vec<MonitorEvent>)> {
                return self.chain_monitor.release_pending_monitor_events();
        }
 }
diff --git a/lightning/src/util/time.rs b/lightning/src/util/time.rs
new file mode 100644 (file)
index 0000000..d3768aa
--- /dev/null
@@ -0,0 +1,152 @@
+// This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
+// or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
+// You may not use this file except in accordance with one or both of these
+// licenses.
+
+//! [`Time`] trait and different implementations. Currently, it's mainly used in tests so we can
+//! manually advance time.
+//! Other crates may symlink this file to use it while [`Time`] trait is sealed here.
+
+use core::ops::Sub;
+use core::time::Duration;
+
+/// A measurement of time.
+pub trait Time: Copy + Sub<Duration, Output = Self> where Self: Sized {
+       /// Returns an instance corresponding to the current moment.
+       fn now() -> Self;
+
+       /// Returns the amount of time elapsed since `self` was created.
+       fn elapsed(&self) -> Duration;
+
+       /// Returns the amount of time passed between `earlier` and `self`.
+       fn duration_since(&self, earlier: Self) -> Duration;
+
+       /// Returns the amount of time passed since the beginning of [`Time`].
+       ///
+       /// Used during (de-)serialization.
+       fn duration_since_epoch() -> Duration;
+}
+
+/// A state in which time has no meaning.
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub struct Eternity;
+
+impl Time for Eternity {
+       fn now() -> Self {
+               Self
+       }
+
+       fn duration_since(&self, _earlier: Self) -> Duration {
+               Duration::from_secs(0)
+       }
+
+       fn duration_since_epoch() -> Duration {
+               Duration::from_secs(0)
+       }
+
+       fn elapsed(&self) -> Duration {
+               Duration::from_secs(0)
+       }
+}
+
+impl Sub<Duration> for Eternity {
+       type Output = Self;
+
+       fn sub(self, _other: Duration) -> Self {
+               self
+       }
+}
+
+#[cfg(not(feature = "no-std"))]
+impl Time for std::time::Instant {
+       fn now() -> Self {
+               std::time::Instant::now()
+       }
+
+       fn duration_since(&self, earlier: Self) -> Duration {
+               self.duration_since(earlier)
+       }
+
+       fn duration_since_epoch() -> Duration {
+               use std::time::SystemTime;
+               SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap()
+       }
+       fn elapsed(&self) -> Duration {
+               std::time::Instant::elapsed(self)
+       }
+}
+
+#[cfg(test)]
+pub mod tests {
+       use super::{Time, Eternity};
+
+       use core::time::Duration;
+       use core::ops::Sub;
+       use core::cell::Cell;
+
+       /// Time that can be advanced manually in tests.
+       #[derive(Clone, Copy, Debug, PartialEq, Eq)]
+       pub struct SinceEpoch(Duration);
+
+       impl SinceEpoch {
+               thread_local! {
+                       static ELAPSED: Cell<Duration> = core::cell::Cell::new(Duration::from_secs(0));
+               }
+
+               pub fn advance(duration: Duration) {
+                       Self::ELAPSED.with(|elapsed| elapsed.set(elapsed.get() + duration))
+               }
+       }
+
+       impl Time for SinceEpoch {
+               fn now() -> Self {
+                       Self(Self::duration_since_epoch())
+               }
+
+               fn duration_since(&self, earlier: Self) -> Duration {
+                       self.0 - earlier.0
+               }
+
+               fn duration_since_epoch() -> Duration {
+                       Self::ELAPSED.with(|elapsed| elapsed.get())
+               }
+
+               fn elapsed(&self) -> Duration {
+                       Self::duration_since_epoch() - self.0
+               }
+       }
+
+       impl Sub<Duration> for SinceEpoch {
+               type Output = Self;
+
+               fn sub(self, other: Duration) -> Self {
+                       Self(self.0 - other)
+               }
+       }
+
+       #[test]
+       fn time_passes_when_advanced() {
+               let now = SinceEpoch::now();
+               assert_eq!(now.elapsed(), Duration::from_secs(0));
+
+               SinceEpoch::advance(Duration::from_secs(1));
+               SinceEpoch::advance(Duration::from_secs(1));
+
+               let elapsed = now.elapsed();
+               let later = SinceEpoch::now();
+
+               assert_eq!(elapsed, Duration::from_secs(2));
+               assert_eq!(later - elapsed, now);
+       }
+
+       #[test]
+       fn time_never_passes_in_an_eternity() {
+               let now = Eternity::now();
+               let elapsed = now.elapsed();
+               let later = Eternity::now();
+
+               assert_eq!(now.elapsed(), Duration::from_secs(0));
+               assert_eq!(later - elapsed, now);
+       }
+}