Merge pull request #2067 from benthecarman/fix-typos
authorMatt Corallo <649246+TheBlueMatt@users.noreply.github.com>
Fri, 3 Mar 2023 20:00:12 +0000 (20:00 +0000)
committerGitHub <noreply@github.com>
Fri, 3 Mar 2023 20:00:12 +0000 (20:00 +0000)
Fix typos in lightning-transaction-sync

32 files changed:
.github/workflows/build.yml
fuzz/src/full_stack.rs
fuzz/src/process_network_graph.rs
fuzz/src/router.rs
lightning-background-processor/src/lib.rs
lightning-block-sync/src/rpc.rs
lightning-invoice/src/payment.rs
lightning-invoice/src/utils.rs
lightning-rapid-gossip-sync/src/lib.rs
lightning-rapid-gossip-sync/src/processing.rs
lightning/src/chain/channelmonitor.rs
lightning/src/ln/chanmon_update_fail_tests.rs
lightning/src/ln/channel.rs
lightning/src/ln/channelmanager.rs
lightning/src/ln/functional_test_utils.rs
lightning/src/ln/functional_tests.rs
lightning/src/ln/onion_route_tests.rs
lightning/src/ln/outbound_payment.rs
lightning/src/ln/payment_tests.rs
lightning/src/ln/peer_handler.rs
lightning/src/routing/router.rs
lightning/src/routing/utxo.rs
lightning/src/sync/debug_sync.rs
lightning/src/sync/fairrwlock.rs
lightning/src/sync/mod.rs
lightning/src/sync/nostd_sync.rs
lightning/src/sync/test_lockorder_checks.rs
lightning/src/util/indexed_map.rs
lightning/src/util/macro_logger.rs
lightning/src/util/ser.rs
lightning/src/util/ser_macros.rs
lightning/src/util/wakers.rs

index 9fb37f1b4a27571fd098c2f66f55a0d0d3ecae63..2bb21c2cdbb39b87c5d9f08f8d4fe73d9d4707aa 100644 (file)
@@ -141,8 +141,9 @@ jobs:
           cargo test --verbose --color always --features esplora-async
       - name: Test backtrace-debug builds on Rust ${{ matrix.toolchain }}
         if: "matrix.toolchain == 'stable'"
+        shell: bash # Default on Winblows is powershell
         run: |
-          cd lightning && cargo test --verbose --color always --features backtrace
+          cd lightning && RUST_BACKTRACE=1 cargo test --verbose --color always --features backtrace
       - name: Test on Rust ${{ matrix.toolchain }} with net-tokio
         if: "matrix.build-net-tokio && !matrix.coverage"
         run: cargo test --verbose --color always
index e7d10a341629db7d5e190945e51f8e7959fc02a2..b96de0b9782443f9c4fd72da7b861a2c6727a420 100644 (file)
@@ -513,7 +513,6 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
                                let params = RouteParameters {
                                        payment_params,
                                        final_value_msat,
-                                       final_cltv_expiry_delta: 42,
                                };
                                let random_seed_bytes: [u8; 32] = keys_manager.get_secure_random_bytes();
                                let route = match find_route(&our_id, &params, &network_graph, None, Arc::clone(&logger), &scorer, &random_seed_bytes) {
@@ -537,7 +536,6 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
                                let params = RouteParameters {
                                        payment_params,
                                        final_value_msat,
-                                       final_cltv_expiry_delta: 42,
                                };
                                let random_seed_bytes: [u8; 32] = keys_manager.get_secure_random_bytes();
                                let mut route = match find_route(&our_id, &params, &network_graph, None, Arc::clone(&logger), &scorer, &random_seed_bytes) {
index c900a7d38d5ac0529b0912e05717f9d0f0e4c693..b4c6a29e8a99f47744fdb9fd0caf2468c7604e57 100644 (file)
@@ -7,7 +7,7 @@ use crate::utils::test_logger;
 fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
        let logger = test_logger::TestLogger::new("".to_owned(), out);
        let network_graph = lightning::routing::gossip::NetworkGraph::new(bitcoin::Network::Bitcoin, &logger);
-       let rapid_sync = RapidGossipSync::new(&network_graph);
+       let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
        let _ = rapid_sync.update_network_graph(data);
 }
 
index 93de35c5d094783d8bdc4e9b6c6a935b39088f34..4c228cc731bc2d0981ad77c9a0e2ff444f243302 100644 (file)
@@ -301,7 +301,6 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
                                                payment_params: PaymentParameters::from_node_id(*target, final_cltv_expiry_delta)
                                                        .with_route_hints(last_hops.clone()),
                                                final_value_msat,
-                                               final_cltv_expiry_delta,
                                        };
                                        let _ = find_route(&our_pubkey, &route_params, &net_graph,
                                                first_hops.map(|c| c.iter().collect::<Vec<_>>()).as_ref().map(|a| a.as_slice()),
index 8711a4aeb5898e393f173886ca8e20de4b9b0d6f..f7f1296b98e66e737c6c81283a8a7e459a7b8160 100644 (file)
@@ -953,7 +953,7 @@ mod tests {
                        let params = ChainParameters { network, best_block };
                        let manager = Arc::new(ChannelManager::new(fee_estimator.clone(), chain_monitor.clone(), tx_broadcaster.clone(), router.clone(), logger.clone(), keys_manager.clone(), keys_manager.clone(), keys_manager.clone(), UserConfig::default(), params));
                        let p2p_gossip_sync = Arc::new(P2PGossipSync::new(network_graph.clone(), Some(chain_source.clone()), logger.clone()));
-                       let rapid_gossip_sync = Arc::new(RapidGossipSync::new(network_graph.clone()));
+                       let rapid_gossip_sync = Arc::new(RapidGossipSync::new(network_graph.clone(), logger.clone()));
                        let msg_handler = MessageHandler { chan_handler: Arc::new(test_utils::TestChannelMessageHandler::new()), route_handler: Arc::new(test_utils::TestRoutingMessageHandler::new()), onion_message_handler: IgnoringMessageHandler{}};
                        let peer_manager = Arc::new(PeerManager::new(msg_handler, 0, &seed, logger.clone(), IgnoringMessageHandler{}, keys_manager.clone()));
                        let node = Node { node: manager, p2p_gossip_sync, rapid_gossip_sync, peer_manager, chain_monitor, persister, tx_broadcaster, network_graph, logger, best_block, scorer };
index f04769560246f8537e1e022efa22c8b7a815eab4..6b4397a6b0fbe87d19a130fa18002a27042c9864 100644 (file)
@@ -13,8 +13,27 @@ use serde_json;
 
 use std::convert::TryFrom;
 use std::convert::TryInto;
+use std::error::Error;
+use std::fmt;
 use std::sync::atomic::{AtomicUsize, Ordering};
 
+/// An error returned by the RPC server.
+#[derive(Debug)]
+pub struct RpcError {
+       /// The error code.
+       pub code: i64,
+       /// The error message.
+       pub message: String,
+}
+
+impl fmt::Display for RpcError {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "RPC error {}: {}", self.code, self.message)
+    }
+}
+
+impl Error for RpcError {}
+
 /// A simple RPC client for calling methods using HTTP `POST`.
 pub struct RpcClient {
        basic_auth: String,
@@ -69,8 +88,11 @@ impl RpcClient {
                let error = &response["error"];
                if !error.is_null() {
                        // TODO: Examine error code for a more precise std::io::ErrorKind.
-                       let message = error["message"].as_str().unwrap_or("unknown error");
-                       return Err(std::io::Error::new(std::io::ErrorKind::Other, message));
+                       let rpc_error = RpcError { 
+                               code: error["code"].as_i64().unwrap_or(-1), 
+                               message: error["message"].as_str().unwrap_or("unknown error").to_string() 
+                       };
+                       return Err(std::io::Error::new(std::io::ErrorKind::Other, rpc_error));
                }
 
                let result = &mut response["result"];
@@ -163,7 +185,9 @@ mod tests {
                match client.call_method::<u64>("getblock", &[invalid_block_hash]).await {
                        Err(e) => {
                                assert_eq!(e.kind(), std::io::ErrorKind::Other);
-                               assert_eq!(e.get_ref().unwrap().to_string(), "invalid parameter");
+                               let rpc_error: Box<RpcError> = e.into_inner().unwrap().downcast().unwrap();
+                               assert_eq!(rpc_error.code, -8);
+                               assert_eq!(rpc_error.message, "invalid parameter");
                        },
                        Ok(_) => panic!("Expected error"),
                }
index db47f9954800a2296785bdabd622aa31b7bd41c4..981fd2e6ba49225c0e41fb2923ee56e480aea0b7 100644 (file)
@@ -156,7 +156,6 @@ fn pay_invoice_using_amount<P: Deref>(
        let route_params = RouteParameters {
                payment_params,
                final_value_msat: amount_msats,
-               final_cltv_expiry_delta: invoice.min_final_cltv_expiry_delta() as u32,
        };
 
        payer.send_payment(payment_hash, &payment_secret, payment_id, route_params, retry_strategy)
index 868f0b297b1162da0551c8923b4ced23ab880ccb..11a779b1028a8ddceb825cc833cceb57856d52f0 100644 (file)
@@ -686,7 +686,6 @@ mod test {
                let route_params = RouteParameters {
                        payment_params,
                        final_value_msat: invoice.amount_milli_satoshis().unwrap(),
-                       final_cltv_expiry_delta: invoice.min_final_cltv_expiry_delta() as u32,
                };
                let first_hops = nodes[0].node.list_usable_channels();
                let network_graph = &node_cfgs[0].network_graph;
@@ -1050,7 +1049,6 @@ mod test {
                let params = RouteParameters {
                        payment_params,
                        final_value_msat: invoice.amount_milli_satoshis().unwrap(),
-                       final_cltv_expiry_delta: invoice.min_final_cltv_expiry_delta() as u32,
                };
                let first_hops = nodes[0].node.list_usable_channels();
                let network_graph = &node_cfgs[0].network_graph;
index 3bceb2e28e9239c0a7e8b34790275d4f51f064e2..af235b1c4224d990f2a1d71881ab23e659d61e28 100644 (file)
@@ -54,7 +54,7 @@
 //! # let logger = FakeLogger {};
 //!
 //! let network_graph = NetworkGraph::new(Network::Bitcoin, &logger);
-//! let rapid_sync = RapidGossipSync::new(&network_graph);
+//! let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
 //! let snapshot_contents: &[u8] = &[0; 0];
 //! let new_last_sync_timestamp_result = rapid_sync.update_network_graph(snapshot_contents);
 //! ```
@@ -94,14 +94,16 @@ mod processing;
 pub struct RapidGossipSync<NG: Deref<Target=NetworkGraph<L>>, L: Deref>
 where L::Target: Logger {
        network_graph: NG,
+       logger: L,
        is_initial_sync_complete: AtomicBool
 }
 
 impl<NG: Deref<Target=NetworkGraph<L>>, L: Deref> RapidGossipSync<NG, L> where L::Target: Logger {
        /// Instantiate a new [`RapidGossipSync`] instance.
-       pub fn new(network_graph: NG) -> Self {
+       pub fn new(network_graph: NG, logger: L) -> Self {
                Self {
                        network_graph,
+                       logger,
                        is_initial_sync_complete: AtomicBool::new(false)
                }
        }
@@ -228,7 +230,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let sync_result = rapid_sync.sync_network_graph_with_file_path(&graph_sync_test_file);
 
                if sync_result.is_err() {
@@ -260,7 +262,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let start = std::time::Instant::now();
                let sync_result = rapid_sync
                        .sync_network_graph_with_file_path("./res/full_graph.lngossip");
@@ -299,7 +301,7 @@ pub mod bench {
                let logger = TestLogger::new();
                b.iter(|| {
                        let network_graph = NetworkGraph::new(Network::Bitcoin, &logger);
-                       let rapid_sync = RapidGossipSync::new(&network_graph);
+                       let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                        let sync_result = rapid_sync.sync_network_graph_with_file_path("./res/full_graph.lngossip");
                        if let Err(crate::error::GraphSyncError::DecodeError(DecodeError::Io(io_error))) = &sync_result {
                                let error_string = format!("Input file lightning-rapid-gossip-sync/res/full_graph.lngossip is missing! Download it from https://bitcoin.ninja/ldk-compressed_graph-bc08df7542-2022-05-05.bin\n\n{:?}", io_error);
index 4b6de04c6556a5ef302d4dced78f1b833f5e5380..8a52d234388052a38b540bfa30a6621387007ec5 100644 (file)
@@ -10,6 +10,7 @@ use lightning::ln::msgs::{
 };
 use lightning::routing::gossip::NetworkGraph;
 use lightning::util::logger::Logger;
+use lightning::{log_warn, log_trace, log_given_level};
 use lightning::util::ser::{BigSize, Readable};
 use lightning::io;
 
@@ -120,6 +121,7 @@ impl<NG: Deref<Target=NetworkGraph<L>>, L: Deref> RapidGossipSync<NG, L> where L
                                if let ErrorAction::IgnoreDuplicateGossip = lightning_error.action {
                                        // everything is fine, just a duplicate channel announcement
                                } else {
+                                       log_warn!(self.logger, "Failed to process channel announcement: {:?}", lightning_error);
                                        return Err(lightning_error.into());
                                }
                        }
@@ -169,24 +171,19 @@ impl<NG: Deref<Target=NetworkGraph<L>>, L: Deref> RapidGossipSync<NG, L> where L
                        if (channel_flags & 0b_1000_0000) != 0 {
                                // incremental update, field flags will indicate mutated values
                                let read_only_network_graph = network_graph.read_only();
-                               if let Some(channel) = read_only_network_graph
-                                       .channels()
-                                       .get(&short_channel_id) {
-
-                                       let directional_info = channel
-                                               .get_directional_info(channel_flags)
-                                               .ok_or(LightningError {
-                                                       err: "Couldn't find previous directional data for update".to_owned(),
-                                                       action: ErrorAction::IgnoreError,
-                                               })?;
-
+                               if let Some(directional_info) =
+                                       read_only_network_graph.channels().get(&short_channel_id)
+                                       .and_then(|channel| channel.get_directional_info(channel_flags))
+                               {
                                        synthetic_update.cltv_expiry_delta = directional_info.cltv_expiry_delta;
                                        synthetic_update.htlc_minimum_msat = directional_info.htlc_minimum_msat;
                                        synthetic_update.htlc_maximum_msat = directional_info.htlc_maximum_msat;
                                        synthetic_update.fee_base_msat = directional_info.fees.base_msat;
                                        synthetic_update.fee_proportional_millionths = directional_info.fees.proportional_millionths;
-
                                } else {
+                                       log_trace!(self.logger,
+                                               "Skipping application of channel update for chan {} with flags {} as original data is missing.",
+                                               short_channel_id, channel_flags);
                                        skip_update_for_unknown_channel = true;
                                }
                        };
@@ -223,7 +220,9 @@ impl<NG: Deref<Target=NetworkGraph<L>>, L: Deref> RapidGossipSync<NG, L> where L
                        match network_graph.update_channel_unsigned(&synthetic_update) {
                                Ok(_) => {},
                                Err(LightningError { action: ErrorAction::IgnoreDuplicateGossip, .. }) => {},
-                               Err(LightningError { action: ErrorAction::IgnoreAndLog(_), .. }) => {},
+                               Err(LightningError { action: ErrorAction::IgnoreAndLog(level), err }) => {
+                                       log_given_level!(self.logger, level, "Failed to apply channel update: {:?}", err);
+                               },
                                Err(LightningError { action: ErrorAction::IgnoreError, .. }) => {},
                                Err(e) => return Err(e.into()),
                        }
@@ -287,7 +286,7 @@ mod tests {
                        0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 58, 85, 116, 216, 255, 2, 68, 226, 0, 6, 11, 0, 1, 24, 0,
                        0, 3, 232, 0, 0, 0,
                ];
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let update_result = rapid_sync.update_network_graph(&example_input[..]);
                assert!(update_result.is_err());
                if let Err(GraphSyncError::DecodeError(DecodeError::ShortRead)) = update_result {
@@ -312,7 +311,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let update_result = rapid_sync.update_network_graph(&incremental_update_input[..]);
                assert!(update_result.is_ok());
        }
@@ -340,17 +339,8 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
-               let update_result = rapid_sync.update_network_graph(&announced_update_input[..]);
-               assert!(update_result.is_err());
-               if let Err(GraphSyncError::LightningError(lightning_error)) = update_result {
-                       assert_eq!(
-                               lightning_error.err,
-                               "Couldn't find previous directional data for update"
-                       );
-               } else {
-                       panic!("Unexpected update result: {:?}", update_result)
-               }
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
+               rapid_sync.update_network_graph(&announced_update_input[..]).unwrap();
        }
 
        #[test]
@@ -376,7 +366,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let initialization_result = rapid_sync.update_network_graph(&initialization_input[..]);
                if initialization_result.is_err() {
                        panic!(
@@ -405,16 +395,7 @@ mod tests {
                        0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 136, 0, 0, 0, 221, 255, 2,
                        68, 226, 0, 6, 11, 0, 1, 128,
                ];
-               let update_result = rapid_sync.update_network_graph(&opposite_direction_incremental_update_input[..]);
-               assert!(update_result.is_err());
-               if let Err(GraphSyncError::LightningError(lightning_error)) = update_result {
-                       assert_eq!(
-                               lightning_error.err,
-                               "Couldn't find previous directional data for update"
-                       );
-               } else {
-                       panic!("Unexpected update result: {:?}", update_result)
-               }
+               rapid_sync.update_network_graph(&opposite_direction_incremental_update_input[..]).unwrap();
        }
 
        #[test]
@@ -442,7 +423,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let initialization_result = rapid_sync.update_network_graph(&initialization_input[..]);
                assert!(initialization_result.is_ok());
 
@@ -501,7 +482,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let initialization_result = rapid_sync.update_network_graph(&initialization_input[..]);
                assert!(initialization_result.is_ok());
 
@@ -526,7 +507,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let update_result = rapid_sync.update_network_graph(&VALID_RGS_BINARY);
                if update_result.is_err() {
                        panic!("Unexpected update result: {:?}", update_result)
@@ -557,7 +538,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                // this is mostly for checking uint underflow issues before the fuzzer does
                let update_result = rapid_sync.update_network_graph_no_std(&VALID_RGS_BINARY, Some(0));
                assert!(update_result.is_ok());
@@ -576,7 +557,7 @@ mod tests {
                        let network_graph = NetworkGraph::new(Network::Bitcoin, &logger);
                        assert_eq!(network_graph.read_only().channels().len(), 0);
 
-                       let rapid_sync = RapidGossipSync::new(&network_graph);
+                       let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                        let update_result = rapid_sync.update_network_graph_no_std(&VALID_RGS_BINARY, Some(latest_succeeding_time));
                        assert!(update_result.is_ok());
                        assert_eq!(network_graph.read_only().channels().len(), 2);
@@ -586,7 +567,7 @@ mod tests {
                        let network_graph = NetworkGraph::new(Network::Bitcoin, &logger);
                        assert_eq!(network_graph.read_only().channels().len(), 0);
 
-                       let rapid_sync = RapidGossipSync::new(&network_graph);
+                       let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                        let update_result = rapid_sync.update_network_graph_no_std(&VALID_RGS_BINARY, Some(earliest_failing_time));
                        assert!(update_result.is_err());
                        if let Err(GraphSyncError::LightningError(lightning_error)) = update_result {
@@ -622,7 +603,7 @@ mod tests {
 
                let logger = TestLogger::new();
                let network_graph = NetworkGraph::new(Network::Bitcoin, &logger);
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let update_result = rapid_sync.update_network_graph(&unknown_version_input[..]);
 
                assert!(update_result.is_err());
index 8b281db6030dba2f681720a91eba819918d20d3f..cd18ffad12d385747da26ee6e7222f9dc457b5f3 100644 (file)
@@ -37,7 +37,7 @@ use crate::ln::{PaymentHash, PaymentPreimage};
 use crate::ln::msgs::DecodeError;
 use crate::ln::chan_utils;
 use crate::ln::chan_utils::{CounterpartyCommitmentSecrets, HTLCOutputInCommitment, HTLCClaim, ChannelTransactionParameters, HolderCommitmentTransaction};
-use crate::ln::channelmanager::HTLCSource;
+use crate::ln::channelmanager::{HTLCSource, SentHTLCId};
 use crate::chain;
 use crate::chain::{BestBlock, WatchedOutput};
 use crate::chain::chaininterface::{BroadcasterInterface, FeeEstimator, LowerBoundedFeeEstimator};
@@ -60,7 +60,7 @@ use core::{cmp, mem};
 use crate::io::{self, Error};
 use core::convert::TryInto;
 use core::ops::Deref;
-use crate::sync::Mutex;
+use crate::sync::{Mutex, LockTestExt};
 
 /// An update generated by the underlying channel itself which contains some new information the
 /// [`ChannelMonitor`] should be made aware of.
@@ -494,6 +494,7 @@ pub(crate) enum ChannelMonitorUpdateStep {
        LatestHolderCommitmentTXInfo {
                commitment_tx: HolderCommitmentTransaction,
                htlc_outputs: Vec<(HTLCOutputInCommitment, Option<Signature>, Option<HTLCSource>)>,
+               claimed_htlcs: Vec<(SentHTLCId, PaymentPreimage)>,
        },
        LatestCounterpartyCommitmentTXInfo {
                commitment_txid: Txid,
@@ -536,6 +537,7 @@ impl ChannelMonitorUpdateStep {
 impl_writeable_tlv_based_enum_upgradable!(ChannelMonitorUpdateStep,
        (0, LatestHolderCommitmentTXInfo) => {
                (0, commitment_tx, required),
+               (1, claimed_htlcs, vec_type),
                (2, htlc_outputs, vec_type),
        },
        (1, LatestCounterpartyCommitmentTXInfo) => {
@@ -750,6 +752,8 @@ pub(crate) struct ChannelMonitorImpl<Signer: WriteableEcdsaChannelSigner> {
        /// Serialized to disk but should generally not be sent to Watchtowers.
        counterparty_hash_commitment_number: HashMap<PaymentHash, u64>,
 
+       counterparty_fulfilled_htlcs: HashMap<SentHTLCId, PaymentPreimage>,
+
        // We store two holder commitment transactions to avoid any race conditions where we may update
        // some monitors (potentially on watchtowers) but then fail to update others, resulting in the
        // various monitors for one channel being out of sync, and us broadcasting a holder
@@ -851,9 +855,13 @@ pub type TransactionOutputs = (Txid, Vec<(u32, TxOut)>);
 
 impl<Signer: WriteableEcdsaChannelSigner> PartialEq for ChannelMonitor<Signer> where Signer: PartialEq {
        fn eq(&self, other: &Self) -> bool {
-               let inner = self.inner.lock().unwrap();
-               let other = other.inner.lock().unwrap();
-               inner.eq(&other)
+               // We need some kind of total lockorder. Absent a better idea, we sort by position in
+               // memory and take locks in that order (assuming that we can't move within memory while a
+               // lock is held).
+               let ord = ((self as *const _) as usize) < ((other as *const _) as usize);
+               let a = if ord { self.inner.unsafe_well_ordered_double_lock_self() } else { other.inner.unsafe_well_ordered_double_lock_self() };
+               let b = if ord { other.inner.unsafe_well_ordered_double_lock_self() } else { self.inner.unsafe_well_ordered_double_lock_self() };
+               a.eq(&b)
        }
 }
 
@@ -1029,6 +1037,7 @@ impl<Signer: WriteableEcdsaChannelSigner> Writeable for ChannelMonitorImpl<Signe
                        (9, self.counterparty_node_id, option),
                        (11, self.confirmed_commitment_tx_counterparty_output, option),
                        (13, self.spendable_txids_confirmed, vec_type),
+                       (15, self.counterparty_fulfilled_htlcs, required),
                });
 
                Ok(())
@@ -1116,6 +1125,7 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitor<Signer> {
                        counterparty_claimable_outpoints: HashMap::new(),
                        counterparty_commitment_txn_on_chain: HashMap::new(),
                        counterparty_hash_commitment_number: HashMap::new(),
+                       counterparty_fulfilled_htlcs: HashMap::new(),
 
                        prev_holder_signed_commitment_tx: None,
                        current_holder_commitment_tx: holder_commitment_tx,
@@ -1170,7 +1180,7 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitor<Signer> {
                &self, holder_commitment_tx: HolderCommitmentTransaction,
                htlc_outputs: Vec<(HTLCOutputInCommitment, Option<Signature>, Option<HTLCSource>)>,
        ) -> Result<(), ()> {
-               self.inner.lock().unwrap().provide_latest_holder_commitment_tx(holder_commitment_tx, htlc_outputs).map_err(|_| ())
+               self.inner.lock().unwrap().provide_latest_holder_commitment_tx(holder_commitment_tx, htlc_outputs, &Vec::new()).map_err(|_| ())
        }
 
        /// This is used to provide payment preimage(s) out-of-band during startup without updating the
@@ -1806,9 +1816,10 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitor<Signer> {
        /// `ChannelMonitor`. This is used to determine if an HTLC was removed from the channel prior
        /// to the `ChannelManager` having been persisted.
        ///
-       /// This is similar to [`Self::get_pending_outbound_htlcs`] except it includes HTLCs which were
-       /// resolved by this `ChannelMonitor`.
-       pub(crate) fn get_all_current_outbound_htlcs(&self) -> HashMap<HTLCSource, HTLCOutputInCommitment> {
+       /// This is similar to [`Self::get_pending_or_resolved_outbound_htlcs`] except it includes
+       /// HTLCs which were resolved on-chain (i.e. where the final HTLC resolution was done by an
+       /// event from this `ChannelMonitor`).
+       pub(crate) fn get_all_current_outbound_htlcs(&self) -> HashMap<HTLCSource, (HTLCOutputInCommitment, Option<PaymentPreimage>)> {
                let mut res = HashMap::new();
                // Just examine the available counterparty commitment transactions. See docs on
                // `fail_unbroadcast_htlcs`, below, for justification.
@@ -1818,7 +1829,8 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitor<Signer> {
                                if let Some(ref latest_outpoints) = us.counterparty_claimable_outpoints.get($txid) {
                                        for &(ref htlc, ref source_option) in latest_outpoints.iter() {
                                                if let &Some(ref source) = source_option {
-                                                       res.insert((**source).clone(), htlc.clone());
+                                                       res.insert((**source).clone(), (htlc.clone(),
+                                                               us.counterparty_fulfilled_htlcs.get(&SentHTLCId::from_source(source)).cloned()));
                                                }
                                        }
                                }
@@ -1833,9 +1845,14 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitor<Signer> {
                res
        }
 
-       /// Gets the set of outbound HTLCs which are pending resolution in this channel.
+       /// Gets the set of outbound HTLCs which are pending resolution in this channel or which were
+       /// resolved with a preimage from our counterparty.
+       ///
        /// This is used to reconstruct pending outbound payments on restart in the ChannelManager.
-       pub(crate) fn get_pending_outbound_htlcs(&self) -> HashMap<HTLCSource, HTLCOutputInCommitment> {
+       ///
+       /// Currently, the preimage is unused, however if it is present in the relevant internal state
+       /// an HTLC is always included even if it has been resolved.
+       pub(crate) fn get_pending_or_resolved_outbound_htlcs(&self) -> HashMap<HTLCSource, (HTLCOutputInCommitment, Option<PaymentPreimage>)> {
                let us = self.inner.lock().unwrap();
                // We're only concerned with the confirmation count of HTLC transactions, and don't
                // actually care how many confirmations a commitment transaction may or may not have. Thus,
@@ -1883,8 +1900,10 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitor<Signer> {
                                                                Some(commitment_tx_output_idx) == htlc.transaction_output_index
                                                        } else { false }
                                                });
-                                               if !htlc_update_confd {
-                                                       res.insert(source.clone(), htlc.clone());
+                                               let counterparty_resolved_preimage_opt =
+                                                       us.counterparty_fulfilled_htlcs.get(&SentHTLCId::from_source(source)).cloned();
+                                               if !htlc_update_confd || counterparty_resolved_preimage_opt.is_some() {
+                                                       res.insert(source.clone(), (htlc.clone(), counterparty_resolved_preimage_opt));
                                                }
                                        }
                                }
@@ -1966,6 +1985,9 @@ macro_rules! fail_unbroadcast_htlcs {
                                                                }
                                                        }
                                                        if matched_htlc { continue; }
+                                                       if $self.counterparty_fulfilled_htlcs.get(&SentHTLCId::from_source(source)).is_some() {
+                                                               continue;
+                                                       }
                                                        $self.onchain_events_awaiting_threshold_conf.retain(|ref entry| {
                                                                if entry.height != $commitment_tx_conf_height { return true; }
                                                                match entry.event {
@@ -2037,8 +2059,23 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitorImpl<Signer> {
                // Prune HTLCs from the previous counterparty commitment tx so we don't generate failure/fulfill
                // events for now-revoked/fulfilled HTLCs.
                if let Some(txid) = self.prev_counterparty_commitment_txid.take() {
-                       for &mut (_, ref mut source) in self.counterparty_claimable_outpoints.get_mut(&txid).unwrap() {
-                               *source = None;
+                       if self.current_counterparty_commitment_txid.unwrap() != txid {
+                               let cur_claimables = self.counterparty_claimable_outpoints.get(
+                                       &self.current_counterparty_commitment_txid.unwrap()).unwrap();
+                               for (_, ref source_opt) in self.counterparty_claimable_outpoints.get(&txid).unwrap() {
+                                       if let Some(source) = source_opt {
+                                               if !cur_claimables.iter()
+                                                       .any(|(_, cur_source_opt)| cur_source_opt == source_opt)
+                                               {
+                                                       self.counterparty_fulfilled_htlcs.remove(&SentHTLCId::from_source(source));
+                                               }
+                                       }
+                               }
+                               for &mut (_, ref mut source_opt) in self.counterparty_claimable_outpoints.get_mut(&txid).unwrap() {
+                                       *source_opt = None;
+                               }
+                       } else {
+                               assert!(cfg!(fuzzing), "Commitment txids are unique outside of fuzzing, where hashes can collide");
                        }
                }
 
@@ -2123,28 +2160,37 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitorImpl<Signer> {
        /// is important that any clones of this channel monitor (including remote clones) by kept
        /// up-to-date as our holder commitment transaction is updated.
        /// Panics if set_on_holder_tx_csv has never been called.
-       fn provide_latest_holder_commitment_tx(&mut self, holder_commitment_tx: HolderCommitmentTransaction, htlc_outputs: Vec<(HTLCOutputInCommitment, Option<Signature>, Option<HTLCSource>)>) -> Result<(), &'static str> {
-               // block for Rust 1.34 compat
-               let mut new_holder_commitment_tx = {
-                       let trusted_tx = holder_commitment_tx.trust();
-                       let txid = trusted_tx.txid();
-                       let tx_keys = trusted_tx.keys();
-                       self.current_holder_commitment_number = trusted_tx.commitment_number();
-                       HolderSignedTx {
-                               txid,
-                               revocation_key: tx_keys.revocation_key,
-                               a_htlc_key: tx_keys.broadcaster_htlc_key,
-                               b_htlc_key: tx_keys.countersignatory_htlc_key,
-                               delayed_payment_key: tx_keys.broadcaster_delayed_payment_key,
-                               per_commitment_point: tx_keys.per_commitment_point,
-                               htlc_outputs,
-                               to_self_value_sat: holder_commitment_tx.to_broadcaster_value_sat(),
-                               feerate_per_kw: trusted_tx.feerate_per_kw(),
-                       }
+       fn provide_latest_holder_commitment_tx(&mut self, holder_commitment_tx: HolderCommitmentTransaction, htlc_outputs: Vec<(HTLCOutputInCommitment, Option<Signature>, Option<HTLCSource>)>, claimed_htlcs: &[(SentHTLCId, PaymentPreimage)]) -> Result<(), &'static str> {
+               let trusted_tx = holder_commitment_tx.trust();
+               let txid = trusted_tx.txid();
+               let tx_keys = trusted_tx.keys();
+               self.current_holder_commitment_number = trusted_tx.commitment_number();
+               let mut new_holder_commitment_tx = HolderSignedTx {
+                       txid,
+                       revocation_key: tx_keys.revocation_key,
+                       a_htlc_key: tx_keys.broadcaster_htlc_key,
+                       b_htlc_key: tx_keys.countersignatory_htlc_key,
+                       delayed_payment_key: tx_keys.broadcaster_delayed_payment_key,
+                       per_commitment_point: tx_keys.per_commitment_point,
+                       htlc_outputs,
+                       to_self_value_sat: holder_commitment_tx.to_broadcaster_value_sat(),
+                       feerate_per_kw: trusted_tx.feerate_per_kw(),
                };
                self.onchain_tx_handler.provide_latest_holder_tx(holder_commitment_tx);
                mem::swap(&mut new_holder_commitment_tx, &mut self.current_holder_commitment_tx);
                self.prev_holder_signed_commitment_tx = Some(new_holder_commitment_tx);
+               for (claimed_htlc_id, claimed_preimage) in claimed_htlcs {
+                       #[cfg(debug_assertions)] {
+                               let cur_counterparty_htlcs = self.counterparty_claimable_outpoints.get(
+                                               &self.current_counterparty_commitment_txid.unwrap()).unwrap();
+                               assert!(cur_counterparty_htlcs.iter().any(|(_, source_opt)| {
+                                       if let Some(source) = source_opt {
+                                               SentHTLCId::from_source(source) == *claimed_htlc_id
+                                       } else { false }
+                               }));
+                       }
+                       self.counterparty_fulfilled_htlcs.insert(*claimed_htlc_id, *claimed_preimage);
+               }
                if self.holder_tx_signed {
                        return Err("Latest holder commitment signed has already been signed, update is rejected");
                }
@@ -2239,10 +2285,10 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitorImpl<Signer> {
                let bounded_fee_estimator = LowerBoundedFeeEstimator::new(&*fee_estimator);
                for update in updates.updates.iter() {
                        match update {
-                               ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo { commitment_tx, htlc_outputs } => {
+                               ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo { commitment_tx, htlc_outputs, claimed_htlcs } => {
                                        log_trace!(logger, "Updating ChannelMonitor with latest holder commitment transaction info");
                                        if self.lockdown_from_offchain { panic!(); }
-                                       if let Err(e) = self.provide_latest_holder_commitment_tx(commitment_tx.clone(), htlc_outputs.clone()) {
+                                       if let Err(e) = self.provide_latest_holder_commitment_tx(commitment_tx.clone(), htlc_outputs.clone(), &claimed_htlcs) {
                                                log_error!(logger, "Providing latest holder commitment transaction failed/was refused:");
                                                log_error!(logger, "    {}", e);
                                                ret = Err(());
@@ -3864,6 +3910,7 @@ impl<'a, 'b, ES: EntropySource, SP: SignerProvider> ReadableArgs<(&'a ES, &'b SP
                let mut counterparty_node_id = None;
                let mut confirmed_commitment_tx_counterparty_output = None;
                let mut spendable_txids_confirmed = Some(Vec::new());
+               let mut counterparty_fulfilled_htlcs = Some(HashMap::new());
                read_tlv_fields!(reader, {
                        (1, funding_spend_confirmed, option),
                        (3, htlcs_resolved_on_chain, vec_type),
@@ -3872,6 +3919,7 @@ impl<'a, 'b, ES: EntropySource, SP: SignerProvider> ReadableArgs<(&'a ES, &'b SP
                        (9, counterparty_node_id, option),
                        (11, confirmed_commitment_tx_counterparty_output, option),
                        (13, spendable_txids_confirmed, vec_type),
+                       (15, counterparty_fulfilled_htlcs, option),
                });
 
                Ok((best_block.block_hash(), ChannelMonitor::from_impl(ChannelMonitorImpl {
@@ -3900,6 +3948,7 @@ impl<'a, 'b, ES: EntropySource, SP: SignerProvider> ReadableArgs<(&'a ES, &'b SP
                        counterparty_claimable_outpoints,
                        counterparty_commitment_txn_on_chain,
                        counterparty_hash_commitment_number,
+                       counterparty_fulfilled_htlcs: counterparty_fulfilled_htlcs.unwrap(),
 
                        prev_holder_signed_commitment_tx,
                        current_holder_commitment_tx,
@@ -4066,11 +4115,13 @@ mod tests {
        fn test_prune_preimages() {
                let secp_ctx = Secp256k1::new();
                let logger = Arc::new(TestLogger::new());
-               let broadcaster = Arc::new(TestBroadcaster{txn_broadcasted: Mutex::new(Vec::new()), blocks: Arc::new(Mutex::new(Vec::new()))});
+               let broadcaster = Arc::new(TestBroadcaster {
+                       txn_broadcasted: Mutex::new(Vec::new()),
+                       blocks: Arc::new(Mutex::new(Vec::new()))
+               });
                let fee_estimator = TestFeeEstimator { sat_per_kw: Mutex::new(253) };
 
                let dummy_key = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&[42; 32]).unwrap());
-               let dummy_tx = Transaction { version: 0, lock_time: PackedLockTime::ZERO, input: Vec::new(), output: Vec::new() };
 
                let mut preimages = Vec::new();
                {
@@ -4160,11 +4211,10 @@ mod tests {
                                                  HolderCommitmentTransaction::dummy(), best_block, dummy_key);
 
                monitor.provide_latest_holder_commitment_tx(HolderCommitmentTransaction::dummy(), preimages_to_holder_htlcs!(preimages[0..10])).unwrap();
-               let dummy_txid = dummy_tx.txid();
-               monitor.provide_latest_counterparty_commitment_tx(dummy_txid, preimages_slice_to_htlc_outputs!(preimages[5..15]), 281474976710655, dummy_key, &logger);
-               monitor.provide_latest_counterparty_commitment_tx(dummy_txid, preimages_slice_to_htlc_outputs!(preimages[15..20]), 281474976710654, dummy_key, &logger);
-               monitor.provide_latest_counterparty_commitment_tx(dummy_txid, preimages_slice_to_htlc_outputs!(preimages[17..20]), 281474976710653, dummy_key, &logger);
-               monitor.provide_latest_counterparty_commitment_tx(dummy_txid, preimages_slice_to_htlc_outputs!(preimages[18..20]), 281474976710652, dummy_key, &logger);
+               monitor.provide_latest_counterparty_commitment_tx(Txid::from_inner(Sha256::hash(b"1").into_inner()),
+                       preimages_slice_to_htlc_outputs!(preimages[5..15]), 281474976710655, dummy_key, &logger);
+               monitor.provide_latest_counterparty_commitment_tx(Txid::from_inner(Sha256::hash(b"2").into_inner()),
+                       preimages_slice_to_htlc_outputs!(preimages[15..20]), 281474976710654, dummy_key, &logger);
                for &(ref preimage, ref hash) in preimages.iter() {
                        let bounded_fee_estimator = LowerBoundedFeeEstimator::new(&fee_estimator);
                        monitor.provide_payment_preimage(hash, preimage, &broadcaster, &bounded_fee_estimator, &logger);
@@ -4178,6 +4228,9 @@ mod tests {
                test_preimages_exist!(&preimages[0..10], monitor);
                test_preimages_exist!(&preimages[15..20], monitor);
 
+               monitor.provide_latest_counterparty_commitment_tx(Txid::from_inner(Sha256::hash(b"3").into_inner()),
+                       preimages_slice_to_htlc_outputs!(preimages[17..20]), 281474976710653, dummy_key, &logger);
+
                // Now provide a further secret, pruning preimages 15-17
                secret[0..32].clone_from_slice(&hex::decode("c7518c8ae4660ed02894df8976fa1a3659c1a8b4b5bec0c4b872abeba4cb8964").unwrap());
                monitor.provide_secret(281474976710654, secret.clone()).unwrap();
@@ -4185,6 +4238,9 @@ mod tests {
                test_preimages_exist!(&preimages[0..10], monitor);
                test_preimages_exist!(&preimages[17..20], monitor);
 
+               monitor.provide_latest_counterparty_commitment_tx(Txid::from_inner(Sha256::hash(b"4").into_inner()),
+                       preimages_slice_to_htlc_outputs!(preimages[18..20]), 281474976710652, dummy_key, &logger);
+
                // Now update holder commitment tx info, pruning only element 18 as we still care about the
                // previous commitment tx's preimages too
                monitor.provide_latest_holder_commitment_tx(HolderCommitmentTransaction::dummy(), preimages_to_holder_htlcs!(preimages[0..5])).unwrap();
index 1532884fa605c5ba363720aa6701b602c83d92a2..18e461f61a3c0da0ca74c8579542e7571bc39bf2 100644 (file)
@@ -108,12 +108,13 @@ fn test_monitor_and_persister_update_fail() {
                blocks: Arc::new(Mutex::new(vec![(genesis_block(Network::Testnet), 200); 200])),
        };
        let chain_mon = {
-               let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
-               let mut w = test_utils::TestVecWriter(Vec::new());
-               monitor.write(&mut w).unwrap();
-               let new_monitor = <(BlockHash, ChannelMonitor<EnforcingSigner>)>::read(
-                       &mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
-               assert!(new_monitor == *monitor);
+               let new_monitor = {
+                       let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
+                       let new_monitor = <(BlockHash, ChannelMonitor<EnforcingSigner>)>::read(
+                               &mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
+                       assert!(new_monitor == *monitor);
+                       new_monitor
+               };
                let chain_mon = test_utils::TestChainMonitor::new(Some(&chain_source), &tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager);
                assert_eq!(chain_mon.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed);
                chain_mon
@@ -1426,9 +1427,11 @@ fn monitor_failed_no_reestablish_response() {
        {
                let mut node_0_per_peer_lock;
                let mut node_0_peer_state_lock;
+               get_channel_ref!(nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, channel_id).announcement_sigs_state = AnnouncementSigsState::PeerReceived;
+       }
+       {
                let mut node_1_per_peer_lock;
                let mut node_1_peer_state_lock;
-               get_channel_ref!(nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, channel_id).announcement_sigs_state = AnnouncementSigsState::PeerReceived;
                get_channel_ref!(nodes[1], nodes[0], node_1_per_peer_lock, node_1_peer_state_lock, channel_id).announcement_sigs_state = AnnouncementSigsState::PeerReceived;
        }
 
index 2b9920aa2bae50879764b632f6f071588a14f89d..842c91180868c3d21a95feff20b303aba7c5288d 100644 (file)
@@ -27,7 +27,7 @@ use crate::ln::features::{ChannelTypeFeatures, InitFeatures};
 use crate::ln::msgs;
 use crate::ln::msgs::{DecodeError, OptionalField, DataLossProtect};
 use crate::ln::script::{self, ShutdownScript};
-use crate::ln::channelmanager::{self, CounterpartyForwardingInfo, PendingHTLCStatus, HTLCSource, HTLCFailureMsg, PendingHTLCInfo, RAACommitmentOrder, BREAKDOWN_TIMEOUT, MIN_CLTV_EXPIRY_DELTA, MAX_LOCAL_BREAKDOWN_TIMEOUT};
+use crate::ln::channelmanager::{self, CounterpartyForwardingInfo, PendingHTLCStatus, HTLCSource, SentHTLCId, HTLCFailureMsg, PendingHTLCInfo, RAACommitmentOrder, BREAKDOWN_TIMEOUT, MIN_CLTV_EXPIRY_DELTA, MAX_LOCAL_BREAKDOWN_TIMEOUT};
 use crate::ln::chan_utils::{CounterpartyCommitmentSecrets, TxCreationKeys, HTLCOutputInCommitment, htlc_success_tx_weight, htlc_timeout_tx_weight, make_funding_redeemscript, ChannelPublicKeys, CommitmentTransaction, HolderCommitmentTransaction, ChannelTransactionParameters, CounterpartyChannelTransactionParameters, MAX_HTLCS, get_commitment_transaction_number_obscure_factor, ClosingTransaction};
 use crate::ln::chan_utils;
 use crate::ln::onion_utils::HTLCFailReason;
@@ -192,6 +192,7 @@ enum OutboundHTLCState {
 
 #[derive(Clone)]
 enum OutboundHTLCOutcome {
+       /// LDK version 0.0.105+ will always fill in the preimage here.
        Success(Option<PaymentPreimage>),
        Failure(HTLCFailReason),
 }
@@ -3159,15 +3160,6 @@ impl<Signer: WriteableEcdsaChannelSigner> Channel<Signer> {
                        }
                }
 
-               self.latest_monitor_update_id += 1;
-               let mut monitor_update = ChannelMonitorUpdate {
-                       update_id: self.latest_monitor_update_id,
-                       updates: vec![ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo {
-                               commitment_tx: holder_commitment_tx,
-                               htlc_outputs: htlcs_and_sigs
-                       }]
-               };
-
                for htlc in self.pending_inbound_htlcs.iter_mut() {
                        let new_forward = if let &InboundHTLCState::RemoteAnnounced(ref forward_info) = &htlc.state {
                                Some(forward_info.clone())
@@ -3179,6 +3171,7 @@ impl<Signer: WriteableEcdsaChannelSigner> Channel<Signer> {
                                need_commitment = true;
                        }
                }
+               let mut claimed_htlcs = Vec::new();
                for htlc in self.pending_outbound_htlcs.iter_mut() {
                        if let &mut OutboundHTLCState::RemoteRemoved(ref mut outcome) = &mut htlc.state {
                                log_trace!(logger, "Updating HTLC {} to AwaitingRemoteRevokeToRemove due to commitment_signed in channel {}.",
@@ -3186,11 +3179,30 @@ impl<Signer: WriteableEcdsaChannelSigner> Channel<Signer> {
                                // Grab the preimage, if it exists, instead of cloning
                                let mut reason = OutboundHTLCOutcome::Success(None);
                                mem::swap(outcome, &mut reason);
+                               if let OutboundHTLCOutcome::Success(Some(preimage)) = reason {
+                                       // If a user (a) receives an HTLC claim using LDK 0.0.104 or before, then (b)
+                                       // upgrades to LDK 0.0.114 or later before the HTLC is fully resolved, we could
+                                       // have a `Success(None)` reason. In this case we could forget some HTLC
+                                       // claims, but such an upgrade is unlikely and including claimed HTLCs here
+                                       // fixes a bug which the user was exposed to on 0.0.104 when they started the
+                                       // claim anyway.
+                                       claimed_htlcs.push((SentHTLCId::from_source(&htlc.source), preimage));
+                               }
                                htlc.state = OutboundHTLCState::AwaitingRemoteRevokeToRemove(reason);
                                need_commitment = true;
                        }
                }
 
+               self.latest_monitor_update_id += 1;
+               let mut monitor_update = ChannelMonitorUpdate {
+                       update_id: self.latest_monitor_update_id,
+                       updates: vec![ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo {
+                               commitment_tx: holder_commitment_tx,
+                               htlc_outputs: htlcs_and_sigs,
+                               claimed_htlcs,
+                       }]
+               };
+
                self.cur_holder_commitment_transaction_number -= 1;
                // Note that if we need_commitment & !AwaitingRemoteRevoke we'll call
                // build_commitment_no_status_check() next which will reset this to RAAFirst.
index f07dc86f34d438b184f4eef2ad2aa7d932f8ff59..82b572cda84c77e5610ce4ad7a90dfb8a2d05de7 100644 (file)
@@ -234,6 +234,36 @@ impl Readable for InterceptId {
                Ok(InterceptId(buf))
        }
 }
+
+#[derive(Clone, Copy, PartialEq, Eq, Hash)]
+/// Uniquely describes an HTLC by its source. Just the guaranteed-unique subset of [`HTLCSource`].
+pub(crate) enum SentHTLCId {
+       PreviousHopData { short_channel_id: u64, htlc_id: u64 },
+       OutboundRoute { session_priv: SecretKey },
+}
+impl SentHTLCId {
+       pub(crate) fn from_source(source: &HTLCSource) -> Self {
+               match source {
+                       HTLCSource::PreviousHopData(hop_data) => Self::PreviousHopData {
+                               short_channel_id: hop_data.short_channel_id,
+                               htlc_id: hop_data.htlc_id,
+                       },
+                       HTLCSource::OutboundRoute { session_priv, .. } =>
+                               Self::OutboundRoute { session_priv: *session_priv },
+               }
+       }
+}
+impl_writeable_tlv_based_enum!(SentHTLCId,
+       (0, PreviousHopData) => {
+               (0, short_channel_id, required),
+               (2, htlc_id, required),
+       },
+       (2, OutboundRoute) => {
+               (0, session_priv, required),
+       };
+);
+
+
 /// Tracks the inbound corresponding to an outbound HTLC
 #[allow(clippy::derive_hash_xor_eq)] // Our Hash is faithful to the data, we just don't have SecretKey::hash
 #[derive(Clone, PartialEq, Eq)]
@@ -1417,7 +1447,7 @@ macro_rules! emit_channel_ready_event {
 }
 
 macro_rules! handle_monitor_update_completion {
-       ($self: ident, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $chan: expr) => { {
+       ($self: ident, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $per_peer_state_lock: expr, $chan: expr) => { {
                let mut updates = $chan.monitor_updating_restored(&$self.logger,
                        &$self.node_signer, $self.genesis_hash, &$self.default_configuration,
                        $self.best_block.read().unwrap().height());
@@ -1450,6 +1480,7 @@ macro_rules! handle_monitor_update_completion {
 
                let channel_id = $chan.channel_id();
                core::mem::drop($peer_state_lock);
+               core::mem::drop($per_peer_state_lock);
 
                $self.handle_monitor_update_completion_actions(update_actions);
 
@@ -1465,7 +1496,7 @@ macro_rules! handle_monitor_update_completion {
 }
 
 macro_rules! handle_new_monitor_update {
-       ($self: ident, $update_res: expr, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $chan: expr, MANUALLY_REMOVING, $remove: expr) => { {
+       ($self: ident, $update_res: expr, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $per_peer_state_lock: expr, $chan: expr, MANUALLY_REMOVING, $remove: expr) => { {
                // update_maps_on_chan_removal needs to be able to take id_to_peer, so make sure we can in
                // any case so that it won't deadlock.
                debug_assert!($self.id_to_peer.try_lock().is_ok());
@@ -1492,14 +1523,14 @@ macro_rules! handle_new_monitor_update {
                                        .update_id == $update_id) &&
                                        $chan.get_latest_monitor_update_id() == $update_id
                                {
-                                       handle_monitor_update_completion!($self, $update_id, $peer_state_lock, $peer_state, $chan);
+                                       handle_monitor_update_completion!($self, $update_id, $peer_state_lock, $peer_state, $per_peer_state_lock, $chan);
                                }
                                Ok(())
                        },
                }
        } };
-       ($self: ident, $update_res: expr, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $chan_entry: expr) => {
-               handle_new_monitor_update!($self, $update_res, $update_id, $peer_state_lock, $peer_state, $chan_entry.get_mut(), MANUALLY_REMOVING, $chan_entry.remove_entry())
+       ($self: ident, $update_res: expr, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $per_peer_state_lock: expr, $chan_entry: expr) => {
+               handle_new_monitor_update!($self, $update_res, $update_id, $peer_state_lock, $peer_state, $per_peer_state_lock, $chan_entry.get_mut(), MANUALLY_REMOVING, $chan_entry.remove_entry())
        }
 }
 
@@ -1835,7 +1866,7 @@ where
                                        if let Some(monitor_update) = monitor_update_opt.take() {
                                                let update_id = monitor_update.update_id;
                                                let update_res = self.chain_monitor.update_channel(funding_txo_opt.unwrap(), monitor_update);
-                                               break handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, chan_entry);
+                                               break handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, per_peer_state, chan_entry);
                                        }
 
                                        if chan_entry.get().is_shutdown() {
@@ -2413,8 +2444,16 @@ where
                })
        }
 
-       // Only public for testing, this should otherwise never be called direcly
-       pub(crate) fn send_payment_along_path(&self, path: &Vec<RouteHop>, payment_params: &Option<PaymentParameters>, payment_hash: &PaymentHash, payment_secret: &Option<PaymentSecret>, total_value: u64, cur_height: u32, payment_id: PaymentId, keysend_preimage: &Option<PaymentPreimage>, session_priv_bytes: [u8; 32]) -> Result<(), APIError> {
+       #[cfg(test)]
+       pub(crate) fn test_send_payment_along_path(&self, path: &Vec<RouteHop>, payment_params: &Option<PaymentParameters>, payment_hash: &PaymentHash, payment_secret: &Option<PaymentSecret>, total_value: u64, cur_height: u32, payment_id: PaymentId, keysend_preimage: &Option<PaymentPreimage>, session_priv_bytes: [u8; 32]) -> Result<(), APIError> {
+               let _lck = self.total_consistency_lock.read().unwrap();
+               self.send_payment_along_path(path, payment_params, payment_hash, payment_secret, total_value, cur_height, payment_id, keysend_preimage, session_priv_bytes)
+       }
+
+       fn send_payment_along_path(&self, path: &Vec<RouteHop>, payment_params: &Option<PaymentParameters>, payment_hash: &PaymentHash, payment_secret: &Option<PaymentSecret>, total_value: u64, cur_height: u32, payment_id: PaymentId, keysend_preimage: &Option<PaymentPreimage>, session_priv_bytes: [u8; 32]) -> Result<(), APIError> {
+               // The top-level caller should hold the total_consistency_lock read lock.
+               debug_assert!(self.total_consistency_lock.try_write().is_err());
+
                log_trace!(self.logger, "Attempting to send payment for path with next hop {}", path.first().unwrap().short_channel_id);
                let prng_seed = self.entropy_source.get_secure_random_bytes();
                let session_priv = SecretKey::from_slice(&session_priv_bytes[..]).expect("RNG is busted");
@@ -2427,8 +2466,6 @@ where
                }
                let onion_packet = onion_utils::construct_onion_packet(onion_payloads, onion_keys, prng_seed, payment_hash);
 
-               let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
-
                let err: Result<(), _> = loop {
                        let (counterparty_node_id, id) = match self.short_to_chan_info.read().unwrap().get(&path.first().unwrap().short_channel_id) {
                                None => return Err(APIError::ChannelUnavailable{err: "No channel available with first hop!".to_owned()}),
@@ -2458,7 +2495,7 @@ where
                                        Some(monitor_update) => {
                                                let update_id = monitor_update.update_id;
                                                let update_res = self.chain_monitor.update_channel(funding_txo, monitor_update);
-                                               if let Err(e) = handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, chan) {
+                                               if let Err(e) = handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, per_peer_state, chan) {
                                                        break Err(e);
                                                }
                                                if update_res == ChannelMonitorUpdateStatus::InProgress {
@@ -2555,6 +2592,7 @@ where
        /// [`ChannelMonitorUpdateStatus::InProgress`]: crate::chain::ChannelMonitorUpdateStatus::InProgress
        pub fn send_payment(&self, route: &Route, payment_hash: PaymentHash, payment_secret: &Option<PaymentSecret>, payment_id: PaymentId) -> Result<(), PaymentSendFailure> {
                let best_block_height = self.best_block.read().unwrap().height();
+               let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
                self.pending_outbound_payments
                        .send_payment_with_route(route, payment_hash, payment_secret, payment_id, &self.entropy_source, &self.node_signer, best_block_height,
                                |path, payment_params, payment_hash, payment_secret, total_value, cur_height, payment_id, keysend_preimage, session_priv|
@@ -2565,6 +2603,7 @@ where
        /// `route_params` and retry failed payment paths based on `retry_strategy`.
        pub fn send_payment_with_retry(&self, payment_hash: PaymentHash, payment_secret: &Option<PaymentSecret>, payment_id: PaymentId, route_params: RouteParameters, retry_strategy: Retry) -> Result<(), RetryableSendFailure> {
                let best_block_height = self.best_block.read().unwrap().height();
+               let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
                self.pending_outbound_payments
                        .send_payment(payment_hash, payment_secret, payment_id, retry_strategy, route_params,
                                &self.router, self.list_usable_channels(), || self.compute_inflight_htlcs(),
@@ -2577,6 +2616,7 @@ where
        #[cfg(test)]
        fn test_send_payment_internal(&self, route: &Route, payment_hash: PaymentHash, payment_secret: &Option<PaymentSecret>, keysend_preimage: Option<PaymentPreimage>, payment_id: PaymentId, recv_value_msat: Option<u64>, onion_session_privs: Vec<[u8; 32]>) -> Result<(), PaymentSendFailure> {
                let best_block_height = self.best_block.read().unwrap().height();
+               let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
                self.pending_outbound_payments.test_send_payment_internal(route, payment_hash, payment_secret, keysend_preimage, payment_id, recv_value_msat, onion_session_privs, &self.node_signer, best_block_height,
                        |path, payment_params, payment_hash, payment_secret, total_value, cur_height, payment_id, keysend_preimage, session_priv|
                        self.send_payment_along_path(path, payment_params, payment_hash, payment_secret, total_value, cur_height, payment_id, keysend_preimage, session_priv))
@@ -2627,6 +2667,7 @@ where
        /// [`send_payment`]: Self::send_payment
        pub fn send_spontaneous_payment(&self, route: &Route, payment_preimage: Option<PaymentPreimage>, payment_id: PaymentId) -> Result<PaymentHash, PaymentSendFailure> {
                let best_block_height = self.best_block.read().unwrap().height();
+               let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
                self.pending_outbound_payments.send_spontaneous_payment_with_route(
                        route, payment_preimage, payment_id, &self.entropy_source, &self.node_signer,
                        best_block_height,
@@ -2643,6 +2684,7 @@ where
        /// [`PaymentParameters::for_keysend`]: crate::routing::router::PaymentParameters::for_keysend
        pub fn send_spontaneous_payment_with_retry(&self, payment_preimage: Option<PaymentPreimage>, payment_id: PaymentId, route_params: RouteParameters, retry_strategy: Retry) -> Result<PaymentHash, RetryableSendFailure> {
                let best_block_height = self.best_block.read().unwrap().height();
+               let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
                self.pending_outbound_payments.send_spontaneous_payment(payment_preimage, payment_id,
                        retry_strategy, route_params, &self.router, self.list_usable_channels(),
                        || self.compute_inflight_htlcs(),  &self.entropy_source, &self.node_signer, best_block_height,
@@ -2656,6 +2698,7 @@ where
        /// us to easily discern them from real payments.
        pub fn send_probe(&self, hops: Vec<RouteHop>) -> Result<(PaymentHash, PaymentId), PaymentSendFailure> {
                let best_block_height = self.best_block.read().unwrap().height();
+               let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
                self.pending_outbound_payments.send_probe(hops, self.probing_cookie_secret, &self.entropy_source, &self.node_signer, best_block_height,
                        |path, payment_params, payment_hash, payment_secret, total_value, cur_height, payment_id, keysend_preimage, session_priv|
                        self.send_payment_along_path(path, payment_params, payment_hash, payment_secret, total_value, cur_height, payment_id, keysend_preimage, session_priv))
@@ -3640,14 +3683,14 @@ where
        /// [`events::Event::PaymentClaimed`] events even for payments you intend to fail, especially on
        /// startup during which time claims that were in-progress at shutdown may be replayed.
        pub fn fail_htlc_backwards(&self, payment_hash: &PaymentHash) {
-               self.fail_htlc_backwards_with_reason(payment_hash, &FailureCode::IncorrectOrUnknownPaymentDetails);
+               self.fail_htlc_backwards_with_reason(payment_hash, FailureCode::IncorrectOrUnknownPaymentDetails);
        }
 
        /// This is a variant of [`ChannelManager::fail_htlc_backwards`] that allows you to specify the
        /// reason for the failure.
        ///
        /// See [`FailureCode`] for valid failure codes.
-       pub fn fail_htlc_backwards_with_reason(&self, payment_hash: &PaymentHash, failure_code: &FailureCode) {
+       pub fn fail_htlc_backwards_with_reason(&self, payment_hash: &PaymentHash, failure_code: FailureCode) {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
                let removed_source = self.claimable_payments.lock().unwrap().claimable_htlcs.remove(payment_hash);
@@ -3662,14 +3705,14 @@ where
        }
 
        /// Gets error data to form an [`HTLCFailReason`] given a [`FailureCode`] and [`ClaimableHTLC`].
-       fn get_htlc_fail_reason_from_failure_code(&self, failure_code: &FailureCode, htlc: &ClaimableHTLC) -> HTLCFailReason {
+       fn get_htlc_fail_reason_from_failure_code(&self, failure_code: FailureCode, htlc: &ClaimableHTLC) -> HTLCFailReason {
                match failure_code {
-                       FailureCode::TemporaryNodeFailure => HTLCFailReason::from_failure_code(*failure_code as u16),
-                       FailureCode::RequiredNodeFeatureMissing => HTLCFailReason::from_failure_code(*failure_code as u16),
+                       FailureCode::TemporaryNodeFailure => HTLCFailReason::from_failure_code(failure_code as u16),
+                       FailureCode::RequiredNodeFeatureMissing => HTLCFailReason::from_failure_code(failure_code as u16),
                        FailureCode::IncorrectOrUnknownPaymentDetails => {
                                let mut htlc_msat_height_data = htlc.value.to_be_bytes().to_vec();
                                htlc_msat_height_data.extend_from_slice(&self.best_block.read().unwrap().height().to_be_bytes());
-                               HTLCFailReason::reason(*failure_code as u16, htlc_msat_height_data)
+                               HTLCFailReason::reason(failure_code as u16, htlc_msat_height_data)
                        }
                }
        }
@@ -3979,7 +4022,8 @@ where
                        )
                ).unwrap_or(None);
 
-               if let Some(mut peer_state_lock) = peer_state_opt.take() {
+               if peer_state_opt.is_some() {
+                       let mut peer_state_lock = peer_state_opt.unwrap();
                        let peer_state = &mut *peer_state_lock;
                        if let hash_map::Entry::Occupied(mut chan) = peer_state.channel_by_id.entry(chan_id) {
                                let counterparty_node_id = chan.get().get_counterparty_node_id();
@@ -3994,7 +4038,7 @@ where
                                        let update_id = monitor_update.update_id;
                                        let update_res = self.chain_monitor.update_channel(prev_hop.outpoint, monitor_update);
                                        let res = handle_new_monitor_update!(self, update_res, update_id, peer_state_lock,
-                                               peer_state, chan);
+                                               peer_state, per_peer_state, chan);
                                        if let Err(e) = res {
                                                // TODO: This is a *critical* error - we probably updated the outbound edge
                                                // of the HTLC's monitor with a preimage. We should retry this monitor
@@ -4164,7 +4208,7 @@ where
        }
 
        fn channel_monitor_updated(&self, funding_txo: &OutPoint, highest_applied_update_id: u64, counterparty_node_id: Option<&PublicKey>) {
-               let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
+               debug_assert!(self.total_consistency_lock.try_write().is_err()); // Caller holds read lock
 
                let counterparty_node_id = match counterparty_node_id {
                        Some(cp_id) => cp_id.clone(),
@@ -4195,7 +4239,7 @@ where
                if !channel.get().is_awaiting_monitor_update() || channel.get().get_latest_monitor_update_id() != highest_applied_update_id {
                        return;
                }
-               handle_monitor_update_completion!(self, highest_applied_update_id, peer_state_lock, peer_state, channel.get_mut());
+               handle_monitor_update_completion!(self, highest_applied_update_id, peer_state_lock, peer_state, per_peer_state, channel.get_mut());
        }
 
        /// Accepts a request to open a channel after a [`Event::OpenChannelRequest`].
@@ -4501,7 +4545,8 @@ where
                                let monitor_res = self.chain_monitor.watch_channel(monitor.get_funding_txo().0, monitor);
 
                                let chan = e.insert(chan);
-                               let mut res = handle_new_monitor_update!(self, monitor_res, 0, peer_state_lock, peer_state, chan, MANUALLY_REMOVING, { peer_state.channel_by_id.remove(&new_channel_id) });
+                               let mut res = handle_new_monitor_update!(self, monitor_res, 0, peer_state_lock, peer_state,
+                                       per_peer_state, chan, MANUALLY_REMOVING, { peer_state.channel_by_id.remove(&new_channel_id) });
 
                                // Note that we reply with the new channel_id in error messages if we gave up on the
                                // channel, not the temporary_channel_id. This is compatible with ourselves, but the
@@ -4534,7 +4579,7 @@ where
                                let monitor = try_chan_entry!(self,
                                        chan.get_mut().funding_signed(&msg, best_block, &self.signer_provider, &self.logger), chan);
                                let update_res = self.chain_monitor.watch_channel(chan.get().get_funding_txo().unwrap(), monitor);
-                               let mut res = handle_new_monitor_update!(self, update_res, 0, peer_state_lock, peer_state, chan);
+                               let mut res = handle_new_monitor_update!(self, update_res, 0, peer_state_lock, peer_state, per_peer_state, chan);
                                if let Err(MsgHandleErrInternal { ref mut shutdown_finish, .. }) = res {
                                        // We weren't able to watch the channel to begin with, so no updates should be made on
                                        // it. Previously, full_stack_target found an (unreachable) panic when the
@@ -4630,7 +4675,7 @@ where
                                        if let Some(monitor_update) = monitor_update_opt {
                                                let update_id = monitor_update.update_id;
                                                let update_res = self.chain_monitor.update_channel(funding_txo_opt.unwrap(), monitor_update);
-                                               break handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, chan_entry);
+                                               break handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, per_peer_state, chan_entry);
                                        }
                                        break Ok(());
                                },
@@ -4822,7 +4867,7 @@ where
                                let update_res = self.chain_monitor.update_channel(funding_txo.unwrap(), monitor_update);
                                let update_id = monitor_update.update_id;
                                handle_new_monitor_update!(self, update_res, update_id, peer_state_lock,
-                                       peer_state, chan)
+                                       peer_state, per_peer_state, chan)
                        },
                        hash_map::Entry::Vacant(_) => return Err(MsgHandleErrInternal::send_err_msg_no_close(format!("Got a message for a channel from the wrong node! No such channel for the passed counterparty_node_id {}", counterparty_node_id), msg.channel_id))
                }
@@ -4928,12 +4973,11 @@ where
        fn internal_revoke_and_ack(&self, counterparty_node_id: &PublicKey, msg: &msgs::RevokeAndACK) -> Result<(), MsgHandleErrInternal> {
                let (htlcs_to_fail, res) = {
                        let per_peer_state = self.per_peer_state.read().unwrap();
-                       let peer_state_mutex = per_peer_state.get(counterparty_node_id)
+                       let mut peer_state_lock = per_peer_state.get(counterparty_node_id)
                                .ok_or_else(|| {
                                        debug_assert!(false);
                                        MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.channel_id)
-                               })?;
-                       let mut peer_state_lock = peer_state_mutex.lock().unwrap();
+                               }).map(|mtx| mtx.lock().unwrap())?;
                        let peer_state = &mut *peer_state_lock;
                        match peer_state.channel_by_id.entry(msg.channel_id) {
                                hash_map::Entry::Occupied(mut chan) => {
@@ -4941,8 +4985,8 @@ where
                                        let (htlcs_to_fail, monitor_update) = try_chan_entry!(self, chan.get_mut().revoke_and_ack(&msg, &self.logger), chan);
                                        let update_res = self.chain_monitor.update_channel(funding_txo.unwrap(), monitor_update);
                                        let update_id = monitor_update.update_id;
-                                       let res = handle_new_monitor_update!(self, update_res, update_id, peer_state_lock,
-                                               peer_state, chan);
+                                       let res = handle_new_monitor_update!(self, update_res, update_id,
+                                               peer_state_lock, peer_state, per_peer_state, chan);
                                        (htlcs_to_fail, res)
                                },
                                hash_map::Entry::Vacant(_) => return Err(MsgHandleErrInternal::send_err_msg_no_close(format!("Got a message for a channel from the wrong node! No such channel for the passed counterparty_node_id {}", counterparty_node_id), msg.channel_id))
@@ -5104,6 +5148,8 @@ where
 
        /// Process pending events from the `chain::Watch`, returning whether any events were processed.
        fn process_pending_monitor_events(&self) -> bool {
+               debug_assert!(self.total_consistency_lock.try_write().is_err()); // Caller holds read lock
+
                let mut failed_channels = Vec::new();
                let mut pending_monitor_events = self.chain_monitor.release_pending_monitor_events();
                let has_pending_monitor_events = !pending_monitor_events.is_empty();
@@ -5181,7 +5227,13 @@ where
        /// update events as a separate process method here.
        #[cfg(fuzzing)]
        pub fn process_monitor_events(&self) {
-               self.process_pending_monitor_events();
+               PersistenceNotifierGuard::optionally_notify(&self.total_consistency_lock, &self.persistence_notifier, || {
+                       if self.process_pending_monitor_events() {
+                               NotifyOption::DoPersist
+                       } else {
+                               NotifyOption::SkipPersist
+                       }
+               });
        }
 
        /// Check the holding cell in each channel and free any pending HTLCs in them if possible.
@@ -5191,38 +5243,45 @@ where
                let mut has_monitor_update = false;
                let mut failed_htlcs = Vec::new();
                let mut handle_errors = Vec::new();
-               let per_peer_state = self.per_peer_state.read().unwrap();
-
-               for (_cp_id, peer_state_mutex) in per_peer_state.iter() {
-                       'chan_loop: loop {
-                               let mut peer_state_lock = peer_state_mutex.lock().unwrap();
-                               let peer_state: &mut PeerState<_> = &mut *peer_state_lock;
-                               for (channel_id, chan) in peer_state.channel_by_id.iter_mut() {
-                                       let counterparty_node_id = chan.get_counterparty_node_id();
-                                       let funding_txo = chan.get_funding_txo();
-                                       let (monitor_opt, holding_cell_failed_htlcs) =
-                                               chan.maybe_free_holding_cell_htlcs(&self.logger);
-                                       if !holding_cell_failed_htlcs.is_empty() {
-                                               failed_htlcs.push((holding_cell_failed_htlcs, *channel_id, counterparty_node_id));
-                                       }
-                                       if let Some(monitor_update) = monitor_opt {
-                                               has_monitor_update = true;
 
-                                               let update_res = self.chain_monitor.update_channel(
-                                                       funding_txo.expect("channel is live"), monitor_update);
-                                               let update_id = monitor_update.update_id;
-                                               let channel_id: [u8; 32] = *channel_id;
-                                               let res = handle_new_monitor_update!(self, update_res, update_id,
-                                                       peer_state_lock, peer_state, chan, MANUALLY_REMOVING,
-                                                       peer_state.channel_by_id.remove(&channel_id));
-                                               if res.is_err() {
-                                                       handle_errors.push((counterparty_node_id, res));
+               // Walk our list of channels and find any that need to update. Note that when we do find an
+               // update, if it includes actions that must be taken afterwards, we have to drop the
+               // per-peer state lock as well as the top level per_peer_state lock. Thus, we loop until we
+               // manage to go through all our peers without finding a single channel to update.
+               'peer_loop: loop {
+                       let per_peer_state = self.per_peer_state.read().unwrap();
+                       for (_cp_id, peer_state_mutex) in per_peer_state.iter() {
+                               'chan_loop: loop {
+                                       let mut peer_state_lock = peer_state_mutex.lock().unwrap();
+                                       let peer_state: &mut PeerState<_> = &mut *peer_state_lock;
+                                       for (channel_id, chan) in peer_state.channel_by_id.iter_mut() {
+                                               let counterparty_node_id = chan.get_counterparty_node_id();
+                                               let funding_txo = chan.get_funding_txo();
+                                               let (monitor_opt, holding_cell_failed_htlcs) =
+                                                       chan.maybe_free_holding_cell_htlcs(&self.logger);
+                                               if !holding_cell_failed_htlcs.is_empty() {
+                                                       failed_htlcs.push((holding_cell_failed_htlcs, *channel_id, counterparty_node_id));
+                                               }
+                                               if let Some(monitor_update) = monitor_opt {
+                                                       has_monitor_update = true;
+
+                                                       let update_res = self.chain_monitor.update_channel(
+                                                               funding_txo.expect("channel is live"), monitor_update);
+                                                       let update_id = monitor_update.update_id;
+                                                       let channel_id: [u8; 32] = *channel_id;
+                                                       let res = handle_new_monitor_update!(self, update_res, update_id,
+                                                               peer_state_lock, peer_state, per_peer_state, chan, MANUALLY_REMOVING,
+                                                               peer_state.channel_by_id.remove(&channel_id));
+                                                       if res.is_err() {
+                                                               handle_errors.push((counterparty_node_id, res));
+                                                       }
+                                                       continue 'peer_loop;
                                                }
-                                               continue 'chan_loop;
                                        }
+                                       break 'chan_loop;
                                }
-                               break 'chan_loop;
                        }
+                       break 'peer_loop;
                }
 
                let has_update = has_monitor_update || !failed_htlcs.is_empty() || !handle_errors.is_empty();
@@ -6745,27 +6804,36 @@ impl Readable for HTLCSource {
                        0 => {
                                let mut session_priv: crate::util::ser::RequiredWrapper<SecretKey> = crate::util::ser::RequiredWrapper(None);
                                let mut first_hop_htlc_msat: u64 = 0;
-                               let mut path = Some(Vec::new());
+                               let mut path: Option<Vec<RouteHop>> = Some(Vec::new());
                                let mut payment_id = None;
                                let mut payment_secret = None;
-                               let mut payment_params = None;
+                               let mut payment_params: Option<PaymentParameters> = None;
                                read_tlv_fields!(reader, {
                                        (0, session_priv, required),
                                        (1, payment_id, option),
                                        (2, first_hop_htlc_msat, required),
                                        (3, payment_secret, option),
                                        (4, path, vec_type),
-                                       (5, payment_params, option),
+                                       (5, payment_params, (option: ReadableArgs, 0)),
                                });
                                if payment_id.is_none() {
                                        // For backwards compat, if there was no payment_id written, use the session_priv bytes
                                        // instead.
                                        payment_id = Some(PaymentId(*session_priv.0.unwrap().as_ref()));
                                }
+                               if path.is_none() || path.as_ref().unwrap().is_empty() {
+                                       return Err(DecodeError::InvalidValue);
+                               }
+                               let path = path.unwrap();
+                               if let Some(params) = payment_params.as_mut() {
+                                       if params.final_cltv_expiry_delta == 0 {
+                                               params.final_cltv_expiry_delta = path.last().unwrap().cltv_expiry_delta;
+                                       }
+                               }
                                Ok(HTLCSource::OutboundRoute {
                                        session_priv: session_priv.0.unwrap(),
                                        first_hop_htlc_msat,
-                                       path: path.unwrap(),
+                                       path,
                                        payment_id: payment_id.unwrap(),
                                        payment_secret,
                                        payment_params,
@@ -6912,7 +6980,10 @@ where
                let mut monitor_update_blocked_actions_per_peer = None;
                let mut peer_states = Vec::new();
                for (_, peer_state_mutex) in per_peer_state.iter() {
-                       peer_states.push(peer_state_mutex.lock().unwrap());
+                       // Because we're holding the owning `per_peer_state` write lock here there's no chance
+                       // of a lockorder violation deadlock - no other thread can be holding any
+                       // per_peer_state lock at all.
+                       peer_states.push(peer_state_mutex.unsafe_well_ordered_double_lock_self());
                }
 
                (serializable_peer_count).write(writer)?;
@@ -7404,6 +7475,10 @@ where
                        probing_cookie_secret = Some(args.entropy_source.get_secure_random_bytes());
                }
 
+               if !channel_closures.is_empty() {
+                       pending_events_read.append(&mut channel_closures);
+               }
+
                if pending_outbound_payments.is_none() && pending_outbound_payments_no_retry.is_none() {
                        pending_outbound_payments = Some(pending_outbound_payments_compat);
                } else if pending_outbound_payments.is_none() {
@@ -7412,7 +7487,13 @@ where
                                outbounds.insert(id, PendingOutboundPayment::Legacy { session_privs });
                        }
                        pending_outbound_payments = Some(outbounds);
-               } else {
+               }
+               let pending_outbounds = OutboundPayments {
+                       pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()),
+                       retry_lock: Mutex::new(())
+               };
+
+               {
                        // If we're tracking pending payments, ensure we haven't lost any by looking at the
                        // ChannelMonitor data for any channels for which we do not have authorative state
                        // (i.e. those for which we just force-closed above or we otherwise don't have a
@@ -7423,16 +7504,17 @@ where
                        // 0.0.102+
                        for (_, monitor) in args.channel_monitors.iter() {
                                if id_to_peer.get(&monitor.get_funding_txo().0.to_channel_id()).is_none() {
-                                       for (htlc_source, htlc) in monitor.get_pending_outbound_htlcs() {
+                                       for (htlc_source, (htlc, _)) in monitor.get_pending_or_resolved_outbound_htlcs() {
                                                if let HTLCSource::OutboundRoute { payment_id, session_priv, path, payment_secret, .. } = htlc_source {
                                                        if path.is_empty() {
                                                                log_error!(args.logger, "Got an empty path for a pending payment");
                                                                return Err(DecodeError::InvalidValue);
                                                        }
+
                                                        let path_amt = path.last().unwrap().fee_msat;
                                                        let mut session_priv_bytes = [0; 32];
                                                        session_priv_bytes[..].copy_from_slice(&session_priv[..]);
-                                                       match pending_outbound_payments.as_mut().unwrap().entry(payment_id) {
+                                                       match pending_outbounds.pending_outbound_payments.lock().unwrap().entry(payment_id) {
                                                                hash_map::Entry::Occupied(mut entry) => {
                                                                        let newly_added = entry.get_mut().insert(session_priv_bytes, &path);
                                                                        log_info!(args.logger, "{} a pending payment path for {} msat for session priv {} on an existing pending payment with payment hash {}",
@@ -7459,48 +7541,64 @@ where
                                                        }
                                                }
                                        }
-                                       for (htlc_source, htlc) in monitor.get_all_current_outbound_htlcs() {
-                                               if let HTLCSource::PreviousHopData(prev_hop_data) = htlc_source {
-                                                       let pending_forward_matches_htlc = |info: &PendingAddHTLCInfo| {
-                                                               info.prev_funding_outpoint == prev_hop_data.outpoint &&
-                                                                       info.prev_htlc_id == prev_hop_data.htlc_id
-                                                       };
-                                                       // The ChannelMonitor is now responsible for this HTLC's
-                                                       // failure/success and will let us know what its outcome is. If we
-                                                       // still have an entry for this HTLC in `forward_htlcs` or
-                                                       // `pending_intercepted_htlcs`, we were apparently not persisted after
-                                                       // the monitor was when forwarding the payment.
-                                                       forward_htlcs.retain(|_, forwards| {
-                                                               forwards.retain(|forward| {
-                                                                       if let HTLCForwardInfo::AddHTLC(htlc_info) = forward {
-                                                                               if pending_forward_matches_htlc(&htlc_info) {
-                                                                                       log_info!(args.logger, "Removing pending to-forward HTLC with hash {} as it was forwarded to the closed channel {}",
-                                                                                               log_bytes!(htlc.payment_hash.0), log_bytes!(monitor.get_funding_txo().0.to_channel_id()));
-                                                                                       false
+                                       for (htlc_source, (htlc, preimage_opt)) in monitor.get_all_current_outbound_htlcs() {
+                                               match htlc_source {
+                                                       HTLCSource::PreviousHopData(prev_hop_data) => {
+                                                               let pending_forward_matches_htlc = |info: &PendingAddHTLCInfo| {
+                                                                       info.prev_funding_outpoint == prev_hop_data.outpoint &&
+                                                                               info.prev_htlc_id == prev_hop_data.htlc_id
+                                                               };
+                                                               // The ChannelMonitor is now responsible for this HTLC's
+                                                               // failure/success and will let us know what its outcome is. If we
+                                                               // still have an entry for this HTLC in `forward_htlcs` or
+                                                               // `pending_intercepted_htlcs`, we were apparently not persisted after
+                                                               // the monitor was when forwarding the payment.
+                                                               forward_htlcs.retain(|_, forwards| {
+                                                                       forwards.retain(|forward| {
+                                                                               if let HTLCForwardInfo::AddHTLC(htlc_info) = forward {
+                                                                                       if pending_forward_matches_htlc(&htlc_info) {
+                                                                                               log_info!(args.logger, "Removing pending to-forward HTLC with hash {} as it was forwarded to the closed channel {}",
+                                                                                                       log_bytes!(htlc.payment_hash.0), log_bytes!(monitor.get_funding_txo().0.to_channel_id()));
+                                                                                               false
+                                                                                       } else { true }
                                                                                } else { true }
+                                                                       });
+                                                                       !forwards.is_empty()
+                                                               });
+                                                               pending_intercepted_htlcs.as_mut().unwrap().retain(|intercepted_id, htlc_info| {
+                                                                       if pending_forward_matches_htlc(&htlc_info) {
+                                                                               log_info!(args.logger, "Removing pending intercepted HTLC with hash {} as it was forwarded to the closed channel {}",
+                                                                                       log_bytes!(htlc.payment_hash.0), log_bytes!(monitor.get_funding_txo().0.to_channel_id()));
+                                                                               pending_events_read.retain(|event| {
+                                                                                       if let Event::HTLCIntercepted { intercept_id: ev_id, .. } = event {
+                                                                                               intercepted_id != ev_id
+                                                                                       } else { true }
+                                                                               });
+                                                                               false
                                                                        } else { true }
                                                                });
-                                                               !forwards.is_empty()
-                                                       });
-                                                       pending_intercepted_htlcs.as_mut().unwrap().retain(|intercepted_id, htlc_info| {
-                                                               if pending_forward_matches_htlc(&htlc_info) {
-                                                                       log_info!(args.logger, "Removing pending intercepted HTLC with hash {} as it was forwarded to the closed channel {}",
-                                                                               log_bytes!(htlc.payment_hash.0), log_bytes!(monitor.get_funding_txo().0.to_channel_id()));
-                                                                       pending_events_read.retain(|event| {
-                                                                               if let Event::HTLCIntercepted { intercept_id: ev_id, .. } = event {
-                                                                                       intercepted_id != ev_id
-                                                                               } else { true }
-                                                                       });
-                                                                       false
-                                                               } else { true }
-                                                       });
+                                                       },
+                                                       HTLCSource::OutboundRoute { payment_id, session_priv, path, .. } => {
+                                                               if let Some(preimage) = preimage_opt {
+                                                                       let pending_events = Mutex::new(pending_events_read);
+                                                                       // Note that we set `from_onchain` to "false" here,
+                                                                       // deliberately keeping the pending payment around forever.
+                                                                       // Given it should only occur when we have a channel we're
+                                                                       // force-closing for being stale that's okay.
+                                                                       // The alternative would be to wipe the state when claiming,
+                                                                       // generating a `PaymentPathSuccessful` event but regenerating
+                                                                       // it and the `PaymentSent` on every restart until the
+                                                                       // `ChannelMonitor` is removed.
+                                                                       pending_outbounds.claim_htlc(payment_id, preimage, session_priv, path, false, &pending_events, &args.logger);
+                                                                       pending_events_read = pending_events.into_inner().unwrap();
+                                                               }
+                                                       },
                                                }
                                        }
                                }
                        }
                }
 
-               let pending_outbounds = OutboundPayments { pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()), retry_lock: Mutex::new(()) };
                if !forward_htlcs.is_empty() || pending_outbounds.needs_abandon() {
                        // If we have pending HTLCs to forward, assume we either dropped a
                        // `PendingHTLCsForwardable` or the user received it but never processed it as they
@@ -7558,10 +7656,6 @@ where
                let mut secp_ctx = Secp256k1::new();
                secp_ctx.seeded_randomize(&args.entropy_source.get_secure_random_bytes());
 
-               if !channel_closures.is_empty() {
-                       pending_events_read.append(&mut channel_closures);
-               }
-
                let our_network_pubkey = match args.node_signer.get_node_id(Recipient::Node) {
                        Ok(key) => key,
                        Err(()) => return Err(DecodeError::InvalidValue)
@@ -7841,7 +7935,7 @@ mod tests {
                // indicates there are more HTLCs coming.
                let cur_height = CHAN_CONFIRM_DEPTH + 1; // route_payment calls send_payment, which adds 1 to the current height. So we do the same here to match.
                let session_privs = nodes[0].node.test_add_new_pending_payment(our_payment_hash, Some(payment_secret), payment_id, &mpp_route).unwrap();
-               nodes[0].node.send_payment_along_path(&mpp_route.paths[0], &route.payment_params, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None, session_privs[0]).unwrap();
+               nodes[0].node.test_send_payment_along_path(&mpp_route.paths[0], &route.payment_params, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None, session_privs[0]).unwrap();
                check_added_monitors!(nodes[0], 1);
                let mut events = nodes[0].node.get_and_clear_pending_msg_events();
                assert_eq!(events.len(), 1);
@@ -7871,7 +7965,7 @@ mod tests {
                expect_payment_failed!(nodes[0], our_payment_hash, true);
 
                // Send the second half of the original MPP payment.
-               nodes[0].node.send_payment_along_path(&mpp_route.paths[1], &route.payment_params, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None, session_privs[1]).unwrap();
+               nodes[0].node.test_send_payment_along_path(&mpp_route.paths[1], &route.payment_params, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None, session_privs[1]).unwrap();
                check_added_monitors!(nodes[0], 1);
                let mut events = nodes[0].node.get_and_clear_pending_msg_events();
                assert_eq!(events.len(), 1);
@@ -7963,7 +8057,6 @@ mod tests {
                let route_params = RouteParameters {
                        payment_params: PaymentParameters::for_keysend(expected_route.last().unwrap().node.get_our_node_id(), TEST_FINAL_CLTV),
                        final_value_msat: 100_000,
-                       final_cltv_expiry_delta: TEST_FINAL_CLTV,
                };
                let route = find_route(
                        &nodes[0].node.get_our_node_id(), &route_params, &nodes[0].network_graph,
@@ -8054,7 +8147,6 @@ mod tests {
                let route_params = RouteParameters {
                        payment_params: PaymentParameters::for_keysend(payee_pubkey, 40),
                        final_value_msat: 10_000,
-                       final_cltv_expiry_delta: 40,
                };
                let network_graph = nodes[0].network_graph.clone();
                let first_hops = nodes[0].node.list_usable_channels();
@@ -8097,7 +8189,6 @@ mod tests {
                let route_params = RouteParameters {
                        payment_params: PaymentParameters::for_keysend(payee_pubkey, 40),
                        final_value_msat: 10_000,
-                       final_cltv_expiry_delta: 40,
                };
                let network_graph = nodes[0].network_graph.clone();
                let first_hops = nodes[0].node.list_usable_channels();
@@ -8251,10 +8342,10 @@ mod tests {
                        let nodes_0_lock = nodes[0].node.id_to_peer.lock().unwrap();
                        assert_eq!(nodes_0_lock.len(), 1);
                        assert!(nodes_0_lock.contains_key(channel_id));
-
-                       assert_eq!(nodes[1].node.id_to_peer.lock().unwrap().len(), 0);
                }
 
+               assert_eq!(nodes[1].node.id_to_peer.lock().unwrap().len(), 0);
+
                let funding_created_msg = get_event_msg!(nodes[0], MessageSendEvent::SendFundingCreated, nodes[1].node.get_our_node_id());
 
                nodes[1].node.handle_funding_created(&nodes[0].node.get_our_node_id(), &funding_created_msg);
@@ -8262,7 +8353,9 @@ mod tests {
                        let nodes_0_lock = nodes[0].node.id_to_peer.lock().unwrap();
                        assert_eq!(nodes_0_lock.len(), 1);
                        assert!(nodes_0_lock.contains_key(channel_id));
+               }
 
+               {
                        // Assert that `nodes[1]`'s `id_to_peer` map is populated with the channel as soon as
                        // as it has the funding transaction.
                        let nodes_1_lock = nodes[1].node.id_to_peer.lock().unwrap();
@@ -8292,7 +8385,9 @@ mod tests {
                        let nodes_0_lock = nodes[0].node.id_to_peer.lock().unwrap();
                        assert_eq!(nodes_0_lock.len(), 1);
                        assert!(nodes_0_lock.contains_key(channel_id));
+               }
 
+               {
                        // At this stage, `nodes[1]` has proposed a fee for the closing transaction in the
                        // `handle_closing_signed` call above. As `nodes[1]` has not yet received the signature
                        // from `nodes[0]` for the closing transaction with the proposed fee, the channel is
index b424916991e1462a8c216c2711295976c4385411..96ce9312eed3acec30cb9c7e4db7a8114fa804eb 100644 (file)
@@ -44,7 +44,7 @@ use crate::io;
 use crate::prelude::*;
 use core::cell::RefCell;
 use alloc::rc::Rc;
-use crate::sync::{Arc, Mutex};
+use crate::sync::{Arc, Mutex, LockTestExt};
 use core::mem;
 use core::iter::repeat;
 use bitcoin::{PackedLockTime, TxMerkleNode};
@@ -466,8 +466,8 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
                                        panic!();
                                }
                        }
-                       assert_eq!(*chain_source.watched_txn.lock().unwrap(), *self.chain_source.watched_txn.lock().unwrap());
-                       assert_eq!(*chain_source.watched_outputs.lock().unwrap(), *self.chain_source.watched_outputs.lock().unwrap());
+                       assert_eq!(*chain_source.watched_txn.unsafe_well_ordered_double_lock_self(), *self.chain_source.watched_txn.unsafe_well_ordered_double_lock_self());
+                       assert_eq!(*chain_source.watched_outputs.unsafe_well_ordered_double_lock_self(), *self.chain_source.watched_outputs.unsafe_well_ordered_double_lock_self());
                }
        }
 }
@@ -2151,9 +2151,10 @@ pub fn route_over_limit<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_rou
                assert!(err.contains("Cannot send value that would put us over the max HTLC value in flight our peer will accept")));
 }
 
-pub fn send_payment<'a, 'b, 'c>(origin: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], recv_value: u64)  {
-       let our_payment_preimage = route_payment(&origin, expected_route, recv_value).0;
-       claim_payment(&origin, expected_route, our_payment_preimage);
+pub fn send_payment<'a, 'b, 'c>(origin: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], recv_value: u64) -> (PaymentPreimage, PaymentHash, PaymentSecret) {
+       let res = route_payment(&origin, expected_route, recv_value);
+       claim_payment(&origin, expected_route, res.0);
+       res
 }
 
 pub fn fail_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_paths: &[&[&Node<'a, 'b, 'c>]], skip_last: bool, our_payment_hash: PaymentHash) {
index 538f51c3c591558298a8f8eef7871a3bfdd0ae43..454fbe2b7819889e3246314a90ac82eed17f40f9 100644 (file)
@@ -4083,7 +4083,7 @@ fn do_test_htlc_timeout(send_partial_mpp: bool) {
                let cur_height = CHAN_CONFIRM_DEPTH + 1; // route_payment calls send_payment, which adds 1 to the current height. So we do the same here to match.
                let payment_id = PaymentId([42; 32]);
                let session_privs = nodes[0].node.test_add_new_pending_payment(our_payment_hash, Some(payment_secret), payment_id, &route).unwrap();
-               nodes[0].node.send_payment_along_path(&route.paths[0], &route.payment_params, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None, session_privs[0]).unwrap();
+               nodes[0].node.test_send_payment_along_path(&route.paths[0], &route.payment_params, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None, session_privs[0]).unwrap();
                check_added_monitors!(nodes[0], 1);
                let mut events = nodes[0].node.get_and_clear_pending_msg_events();
                assert_eq!(events.len(), 1);
@@ -8150,12 +8150,13 @@ fn test_update_err_monitor_lockdown() {
        let logger = test_utils::TestLogger::with_id(format!("node {}", 0));
        let persister = test_utils::TestPersister::new();
        let watchtower = {
-               let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
-               let mut w = test_utils::TestVecWriter(Vec::new());
-               monitor.write(&mut w).unwrap();
-               let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
-                               &mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
-               assert!(new_monitor == *monitor);
+               let new_monitor = {
+                       let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
+                       let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
+                                       &mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
+                       assert!(new_monitor == *monitor);
+                       new_monitor
+               };
                let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager);
                assert_eq!(watchtower.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed);
                watchtower
@@ -8217,12 +8218,13 @@ fn test_concurrent_monitor_claim() {
        let logger = test_utils::TestLogger::with_id(format!("node {}", "Alice"));
        let persister = test_utils::TestPersister::new();
        let watchtower_alice = {
-               let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
-               let mut w = test_utils::TestVecWriter(Vec::new());
-               monitor.write(&mut w).unwrap();
-               let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
-                               &mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
-               assert!(new_monitor == *monitor);
+               let new_monitor = {
+                       let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
+                       let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
+                                       &mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
+                       assert!(new_monitor == *monitor);
+                       new_monitor
+               };
                let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager);
                assert_eq!(watchtower.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed);
                watchtower
@@ -8246,12 +8248,13 @@ fn test_concurrent_monitor_claim() {
        let logger = test_utils::TestLogger::with_id(format!("node {}", "Bob"));
        let persister = test_utils::TestPersister::new();
        let watchtower_bob = {
-               let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
-               let mut w = test_utils::TestVecWriter(Vec::new());
-               monitor.write(&mut w).unwrap();
-               let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
-                               &mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
-               assert!(new_monitor == *monitor);
+               let new_monitor = {
+                       let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
+                       let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
+                                       &mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
+                       assert!(new_monitor == *monitor);
+                       new_monitor
+               };
                let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager);
                assert_eq!(watchtower.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed);
                watchtower
@@ -9141,20 +9144,20 @@ fn test_inconsistent_mpp_params() {
                dup_route.paths.push(route.paths[1].clone());
                nodes[0].node.test_add_new_pending_payment(our_payment_hash, Some(our_payment_secret), payment_id, &dup_route).unwrap()
        };
-       {
-               nodes[0].node.send_payment_along_path(&route.paths[0], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 15_000_000, cur_height, payment_id, &None, session_privs[0]).unwrap();
-               check_added_monitors!(nodes[0], 1);
+       nodes[0].node.test_send_payment_along_path(&route.paths[0], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 15_000_000, cur_height, payment_id, &None, session_privs[0]).unwrap();
+       check_added_monitors!(nodes[0], 1);
 
+       {
                let mut events = nodes[0].node.get_and_clear_pending_msg_events();
                assert_eq!(events.len(), 1);
                pass_along_path(&nodes[0], &[&nodes[1], &nodes[3]], 15_000_000, our_payment_hash, Some(our_payment_secret), events.pop().unwrap(), false, None);
        }
        assert!(nodes[3].node.get_and_clear_pending_events().is_empty());
 
-       {
-               nodes[0].node.send_payment_along_path(&route.paths[1], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 14_000_000, cur_height, payment_id, &None, session_privs[1]).unwrap();
-               check_added_monitors!(nodes[0], 1);
+       nodes[0].node.test_send_payment_along_path(&route.paths[1], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 14_000_000, cur_height, payment_id, &None, session_privs[1]).unwrap();
+       check_added_monitors!(nodes[0], 1);
 
+       {
                let mut events = nodes[0].node.get_and_clear_pending_msg_events();
                assert_eq!(events.len(), 1);
                let payment_event = SendEvent::from_event(events.pop().unwrap());
@@ -9197,7 +9200,7 @@ fn test_inconsistent_mpp_params() {
 
        expect_payment_failed_conditions(&nodes[0], our_payment_hash, true, PaymentFailedConditions::new().mpp_parts_remain());
 
-       nodes[0].node.send_payment_along_path(&route.paths[1], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 15_000_000, cur_height, payment_id, &None, session_privs[2]).unwrap();
+       nodes[0].node.test_send_payment_along_path(&route.paths[1], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 15_000_000, cur_height, payment_id, &None, session_privs[2]).unwrap();
        check_added_monitors!(nodes[0], 1);
 
        let mut events = nodes[0].node.get_and_clear_pending_msg_events();
@@ -9241,7 +9244,6 @@ fn test_keysend_payments_to_public_node() {
        let route_params = RouteParameters {
                payment_params: PaymentParameters::for_keysend(payee_pubkey, 40),
                final_value_msat: 10000,
-               final_cltv_expiry_delta: 40,
        };
        let scorer = test_utils::TestScorer::new();
        let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
@@ -9272,7 +9274,6 @@ fn test_keysend_payments_to_private_node() {
        let route_params = RouteParameters {
                payment_params: PaymentParameters::for_keysend(payee_pubkey, 40),
                final_value_msat: 10000,
-               final_cltv_expiry_delta: 40,
        };
        let network_graph = nodes[0].network_graph.clone();
        let first_hops = nodes[0].node.list_usable_channels();
index 6ea8ab4ce1b28e837f69678ea037bbba86aaf7e4..36e9ddb067a87ad574fac2ddd9a604352e1da190 100644 (file)
@@ -848,7 +848,7 @@ fn do_test_fail_htlc_backwards_with_reason(failure_code: FailureCode) {
 
        expect_pending_htlcs_forwardable!(nodes[1]);
        expect_payment_claimable!(nodes[1], payment_hash, payment_secret, payment_amount);
-       nodes[1].node.fail_htlc_backwards_with_reason(&payment_hash, &failure_code);
+       nodes[1].node.fail_htlc_backwards_with_reason(&payment_hash, failure_code);
 
        expect_pending_htlcs_forwardable_and_htlc_handling_failed!(nodes[1], vec![HTLCDestination::FailedPayment { payment_hash: payment_hash }]);
        check_added_monitors!(nodes[1], 1);
index 15bba61dda408e208e6f9d9bf5d341ce116812be..33ccecb4ca858b17c45321b3a17c2303e6931587 100644 (file)
@@ -16,7 +16,6 @@ use bitcoin::secp256k1::{self, Secp256k1, SecretKey};
 use crate::chain::keysinterface::{EntropySource, NodeSigner, Recipient};
 use crate::ln::{PaymentHash, PaymentPreimage, PaymentSecret};
 use crate::ln::channelmanager::{ChannelDetails, HTLCSource, IDEMPOTENCY_TIMEOUT_TICKS, PaymentId};
-use crate::ln::channelmanager::MIN_FINAL_CLTV_EXPIRY_DELTA as LDK_DEFAULT_MIN_FINAL_CLTV_EXPIRY_DELTA;
 use crate::ln::onion_utils::HTLCFailReason;
 use crate::routing::router::{InFlightHtlcs, PaymentParameters, Route, RouteHop, RouteParameters, RoutePath, Router};
 use crate::util::errors::APIError;
@@ -25,6 +24,7 @@ use crate::util::logger::Logger;
 use crate::util::time::Time;
 #[cfg(all(not(feature = "no-std"), test))]
 use crate::util::time::tests::SinceEpoch;
+use crate::util::ser::ReadableArgs;
 
 use core::cmp;
 use core::fmt::{self, Display, Formatter};
@@ -528,12 +528,6 @@ impl OutboundPayments {
                                                if pending_amt_msat < total_msat {
                                                        retry_id_route_params = Some((*pmt_id, RouteParameters {
                                                                final_value_msat: *total_msat - *pending_amt_msat,
-                                                               final_cltv_expiry_delta:
-                                                                       if let Some(delta) = params.final_cltv_expiry_delta { delta }
-                                                                       else {
-                                                                               debug_assert!(false, "We always set the final_cltv_expiry_delta when a path fails");
-                                                                               LDK_DEFAULT_MIN_FINAL_CLTV_EXPIRY_DELTA.into()
-                                                                       },
                                                                payment_params: params.clone(),
                                                        }));
                                                        break
@@ -976,9 +970,6 @@ impl OutboundPayments {
                                                Some(RouteParameters {
                                                        payment_params: payment_params.clone(),
                                                        final_value_msat: pending_amt_unsent,
-                                                       final_cltv_expiry_delta:
-                                                               if let Some(delta) = payment_params.final_cltv_expiry_delta { delta }
-                                                               else { max_unsent_cltv_delta },
                                                })
                                        } else { None }
                                } else { None },
@@ -1179,23 +1170,14 @@ impl OutboundPayments {
                        // `payment_params`) back to the user.
                        let path_last_hop = path.last().expect("Outbound payments must have had a valid path");
                        if let Some(params) = payment.get_mut().payment_parameters() {
-                               if params.final_cltv_expiry_delta.is_none() {
-                                       // This should be rare, but a user could provide None for the payment data, and
-                                       // we need it when we go to retry the payment, so fill it in.
-                                       params.final_cltv_expiry_delta = Some(path_last_hop.cltv_expiry_delta);
-                               }
                                retry = Some(RouteParameters {
                                        payment_params: params.clone(),
                                        final_value_msat: path_last_hop.fee_msat,
-                                       final_cltv_expiry_delta: params.final_cltv_expiry_delta.unwrap(),
                                });
                        } else if let Some(params) = payment_params {
                                retry = Some(RouteParameters {
                                        payment_params: params.clone(),
                                        final_value_msat: path_last_hop.fee_msat,
-                                       final_cltv_expiry_delta:
-                                               if let Some(delta) = params.final_cltv_expiry_delta { delta }
-                                               else { path_last_hop.cltv_expiry_delta },
                                });
                        }
 
@@ -1330,7 +1312,9 @@ impl_writeable_tlv_based_enum_upgradable!(PendingOutboundPayment,
                (0, session_privs, required),
                (1, pending_fee_msat, option),
                (2, payment_hash, required),
-               (3, payment_params, option),
+               // Note that while we "default" payment_param's final CLTV expiry delta to 0 we should
+               // never see it - `payment_params` was added here after the field was added/required.
+               (3, payment_params, (option: ReadableArgs, 0)),
                (4, payment_secret, option),
                (5, keysend_preimage, option),
                (6, total_msat, required),
@@ -1386,7 +1370,6 @@ mod tests {
                let expired_route_params = RouteParameters {
                        payment_params,
                        final_value_msat: 0,
-                       final_cltv_expiry_delta: 0,
                };
                let pending_events = Mutex::new(Vec::new());
                if on_retry {
@@ -1428,7 +1411,6 @@ mod tests {
                let route_params = RouteParameters {
                        payment_params,
                        final_value_msat: 0,
-                       final_cltv_expiry_delta: 0,
                };
                router.expect_find_route(route_params.clone(),
                        Err(LightningError { err: String::new(), action: ErrorAction::IgnoreError }));
@@ -1471,7 +1453,6 @@ mod tests {
                let route_params = RouteParameters {
                        payment_params: payment_params.clone(),
                        final_value_msat: 0,
-                       final_cltv_expiry_delta: 0,
                };
                let failed_scid = 42;
                let route = Route {
index 4aa14dfa831fd32be9fabdd665dbc0033143e403..15361b98ad71fd9ee1193891aa71f4d45ef5d9a3 100644 (file)
@@ -98,7 +98,6 @@ fn mpp_retry() {
        let mut route_params = RouteParameters {
                payment_params: route.payment_params.clone().unwrap(),
                final_value_msat: amt_msat,
-               final_cltv_expiry_delta: TEST_FINAL_CLTV,
        };
 
        nodes[0].router.expect_find_route(route_params.clone(), Ok(route.clone()));
@@ -297,7 +296,6 @@ fn do_retry_with_no_persist(confirm_before_reload: bool) {
        let route_params = RouteParameters {
                payment_params: route.payment_params.clone().unwrap(),
                final_value_msat: amt_msat,
-               final_cltv_expiry_delta: TEST_FINAL_CLTV,
        };
        nodes[0].node.send_payment_with_retry(payment_hash, &Some(payment_secret), PaymentId(payment_hash.0), route_params, Retry::Attempts(1)).unwrap();
        check_added_monitors!(nodes[0], 1);
@@ -1194,33 +1192,31 @@ fn test_trivial_inflight_htlc_tracking(){
        let (_, _, chan_2_id, _) = create_announced_chan_between_nodes(&nodes, 1, 2);
 
        // Send and claim the payment. Inflight HTLCs should be empty.
-       let (route, payment_hash, payment_preimage, payment_secret) = get_route_and_payment_hash!(nodes[0], nodes[2], 500000);
-       nodes[0].node.send_payment(&route, payment_hash, &Some(payment_secret), PaymentId(payment_hash.0)).unwrap();
-       check_added_monitors!(nodes[0], 1);
-       pass_along_route(&nodes[0], &[&vec!(&nodes[1], &nodes[2])[..]], 500000, payment_hash, payment_secret);
-       claim_payment(&nodes[0], &vec!(&nodes[1], &nodes[2])[..], payment_preimage);
+       let payment_hash = send_payment(&nodes[0], &[&nodes[1], &nodes[2]], 500000).1;
+       let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
        {
-               let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
-
                let mut node_0_per_peer_lock;
                let mut node_0_peer_state_lock;
-               let mut node_1_per_peer_lock;
-               let mut node_1_peer_state_lock;
                let channel_1 =  get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1_id);
-               let channel_2 =  get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);
 
                let chan_1_used_liquidity = inflight_htlcs.used_liquidity_msat(
                        &NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) ,
                        &NodeId::from_pubkey(&nodes[1].node.get_our_node_id()),
                        channel_1.get_short_channel_id().unwrap()
                );
+               assert_eq!(chan_1_used_liquidity, None);
+       }
+       {
+               let mut node_1_per_peer_lock;
+               let mut node_1_peer_state_lock;
+               let channel_2 =  get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);
+
                let chan_2_used_liquidity = inflight_htlcs.used_liquidity_msat(
                        &NodeId::from_pubkey(&nodes[1].node.get_our_node_id()) ,
                        &NodeId::from_pubkey(&nodes[2].node.get_our_node_id()),
                        channel_2.get_short_channel_id().unwrap()
                );
 
-               assert_eq!(chan_1_used_liquidity, None);
                assert_eq!(chan_2_used_liquidity, None);
        }
        let pending_payments = nodes[0].node.list_recent_payments();
@@ -1233,30 +1229,32 @@ fn test_trivial_inflight_htlc_tracking(){
        }
 
        // Send the payment, but do not claim it. Our inflight HTLCs should contain the pending payment.
-       let (payment_preimage, payment_hash,  _) = route_payment(&nodes[0], &vec!(&nodes[1], &nodes[2])[..], 500000);
+       let (payment_preimage, payment_hash,  _) = route_payment(&nodes[0], &[&nodes[1], &nodes[2]], 500000);
+       let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
        {
-               let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
-
                let mut node_0_per_peer_lock;
                let mut node_0_peer_state_lock;
-               let mut node_1_per_peer_lock;
-               let mut node_1_peer_state_lock;
                let channel_1 =  get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1_id);
-               let channel_2 =  get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);
 
                let chan_1_used_liquidity = inflight_htlcs.used_liquidity_msat(
                        &NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) ,
                        &NodeId::from_pubkey(&nodes[1].node.get_our_node_id()),
                        channel_1.get_short_channel_id().unwrap()
                );
+               // First hop accounts for expected 1000 msat fee
+               assert_eq!(chan_1_used_liquidity, Some(501000));
+       }
+       {
+               let mut node_1_per_peer_lock;
+               let mut node_1_peer_state_lock;
+               let channel_2 =  get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);
+
                let chan_2_used_liquidity = inflight_htlcs.used_liquidity_msat(
                        &NodeId::from_pubkey(&nodes[1].node.get_our_node_id()) ,
                        &NodeId::from_pubkey(&nodes[2].node.get_our_node_id()),
                        channel_2.get_short_channel_id().unwrap()
                );
 
-               // First hop accounts for expected 1000 msat fee
-               assert_eq!(chan_1_used_liquidity, Some(501000));
                assert_eq!(chan_2_used_liquidity, Some(500000));
        }
        let pending_payments = nodes[0].node.list_recent_payments();
@@ -1271,28 +1269,29 @@ fn test_trivial_inflight_htlc_tracking(){
                nodes[0].node.timer_tick_occurred();
        }
 
+       let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
        {
-               let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
-
                let mut node_0_per_peer_lock;
                let mut node_0_peer_state_lock;
-               let mut node_1_per_peer_lock;
-               let mut node_1_peer_state_lock;
                let channel_1 =  get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1_id);
-               let channel_2 =  get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);
 
                let chan_1_used_liquidity = inflight_htlcs.used_liquidity_msat(
                        &NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) ,
                        &NodeId::from_pubkey(&nodes[1].node.get_our_node_id()),
                        channel_1.get_short_channel_id().unwrap()
                );
+               assert_eq!(chan_1_used_liquidity, None);
+       }
+       {
+               let mut node_1_per_peer_lock;
+               let mut node_1_peer_state_lock;
+               let channel_2 =  get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);
+
                let chan_2_used_liquidity = inflight_htlcs.used_liquidity_msat(
                        &NodeId::from_pubkey(&nodes[1].node.get_our_node_id()) ,
                        &NodeId::from_pubkey(&nodes[2].node.get_our_node_id()),
                        channel_2.get_short_channel_id().unwrap()
                );
-
-               assert_eq!(chan_1_used_liquidity, None);
                assert_eq!(chan_2_used_liquidity, None);
        }
 
@@ -1387,12 +1386,12 @@ fn do_test_intercepted_payment(test: InterceptTest) {
        let route_params = RouteParameters {
                payment_params,
                final_value_msat: amt_msat,
-               final_cltv_expiry_delta: TEST_FINAL_CLTV,
        };
        let route = get_route(
                &nodes[0].node.get_our_node_id(), &route_params.payment_params,
                &nodes[0].network_graph.read_only(), None, route_params.final_value_msat,
-               route_params.final_cltv_expiry_delta, nodes[0].logger, &scorer, &random_seed_bytes
+               route_params.payment_params.final_cltv_expiry_delta, nodes[0].logger, &scorer,
+               &random_seed_bytes,
        ).unwrap();
 
        let (payment_hash, payment_secret) = nodes[2].node.create_inbound_payment(Some(amt_msat), 60 * 60, None).unwrap();
@@ -1577,7 +1576,6 @@ fn do_automatic_retries(test: AutoRetry) {
        let route_params = RouteParameters {
                payment_params,
                final_value_msat: amt_msat,
-               final_cltv_expiry_delta: TEST_FINAL_CLTV,
        };
        let (_, payment_hash, payment_preimage, payment_secret) = get_route_and_payment_hash!(nodes[0], nodes[2], amt_msat);
 
@@ -1787,7 +1785,6 @@ fn auto_retry_partial_failure() {
        let route_params = RouteParameters {
                payment_params,
                final_value_msat: amt_msat,
-               final_cltv_expiry_delta: TEST_FINAL_CLTV,
        };
 
        // Ensure the first monitor update (for the initial send path1 over chan_1) succeeds, but the
@@ -1860,12 +1857,12 @@ fn auto_retry_partial_failure() {
        let mut payment_params = route_params.payment_params.clone();
        payment_params.previously_failed_channels.push(chan_2_id);
        nodes[0].router.expect_find_route(RouteParameters {
-                       payment_params, final_value_msat: amt_msat / 2, final_cltv_expiry_delta: TEST_FINAL_CLTV
+                       payment_params, final_value_msat: amt_msat / 2,
                }, Ok(retry_1_route));
        let mut payment_params = route_params.payment_params.clone();
        payment_params.previously_failed_channels.push(chan_3_id);
        nodes[0].router.expect_find_route(RouteParameters {
-                       payment_params, final_value_msat: amt_msat / 4, final_cltv_expiry_delta: TEST_FINAL_CLTV
+                       payment_params, final_value_msat: amt_msat / 4,
                }, Ok(retry_2_route));
 
        // Send a payment that will partially fail on send, then partially fail on retry, then succeed.
@@ -1999,7 +1996,6 @@ fn auto_retry_zero_attempts_send_error() {
        let route_params = RouteParameters {
                payment_params,
                final_value_msat: amt_msat,
-               final_cltv_expiry_delta: TEST_FINAL_CLTV,
        };
 
        chanmon_cfgs[0].persister.set_update_ret(ChannelMonitorUpdateStatus::PermanentFailure);
@@ -2039,7 +2035,6 @@ fn fails_paying_after_rejected_by_payee() {
        let route_params = RouteParameters {
                payment_params,
                final_value_msat: amt_msat,
-               final_cltv_expiry_delta: TEST_FINAL_CLTV,
        };
 
        nodes[0].node.send_payment_with_retry(payment_hash, &Some(payment_secret), PaymentId(payment_hash.0), route_params, Retry::Attempts(1)).unwrap();
@@ -2086,7 +2081,6 @@ fn retry_multi_path_single_failed_payment() {
        let route_params = RouteParameters {
                payment_params: payment_params.clone(),
                final_value_msat: amt_msat,
-               final_cltv_expiry_delta: TEST_FINAL_CLTV,
        };
 
        let chans = nodes[0].node.list_usable_channels();
@@ -2121,7 +2115,7 @@ fn retry_multi_path_single_failed_payment() {
                        payment_params: pay_params,
                        // Note that the second request here requests the amount we originally failed to send,
                        // not the amount remaining on the full payment, which should be changed.
-                       final_value_msat: 100_000_001, final_cltv_expiry_delta: TEST_FINAL_CLTV
+                       final_value_msat: 100_000_001,
                }, Ok(route.clone()));
 
        {
@@ -2180,7 +2174,6 @@ fn immediate_retry_on_failure() {
        let route_params = RouteParameters {
                payment_params,
                final_value_msat: amt_msat,
-               final_cltv_expiry_delta: TEST_FINAL_CLTV,
        };
 
        let chans = nodes[0].node.list_usable_channels();
@@ -2207,7 +2200,6 @@ fn immediate_retry_on_failure() {
        pay_params.previously_failed_channels.push(chans[0].short_channel_id.unwrap());
        nodes[0].router.expect_find_route(RouteParameters {
                        payment_params: pay_params, final_value_msat: amt_msat,
-                       final_cltv_expiry_delta: TEST_FINAL_CLTV
                }, Ok(route.clone()));
 
        nodes[0].node.send_payment_with_retry(payment_hash, &Some(payment_secret), PaymentId(payment_hash.0), route_params, Retry::Attempts(1)).unwrap();
@@ -2270,7 +2262,6 @@ fn no_extra_retries_on_back_to_back_fail() {
        let route_params = RouteParameters {
                payment_params,
                final_value_msat: amt_msat,
-               final_cltv_expiry_delta: TEST_FINAL_CLTV,
        };
 
        let mut route = Route {
@@ -2316,7 +2307,7 @@ fn no_extra_retries_on_back_to_back_fail() {
        route.paths[0][1].fee_msat = amt_msat;
        nodes[0].router.expect_find_route(RouteParameters {
                        payment_params: second_payment_params,
-                       final_value_msat: amt_msat, final_cltv_expiry_delta: TEST_FINAL_CLTV,
+                       final_value_msat: amt_msat,
                }, Ok(route.clone()));
 
        nodes[0].node.send_payment_with_retry(payment_hash, &Some(payment_secret), PaymentId(payment_hash.0), route_params, Retry::Attempts(1)).unwrap();
@@ -2471,7 +2462,6 @@ fn test_simple_partial_retry() {
        let route_params = RouteParameters {
                payment_params,
                final_value_msat: amt_msat,
-               final_cltv_expiry_delta: TEST_FINAL_CLTV,
        };
 
        let mut route = Route {
@@ -2516,7 +2506,7 @@ fn test_simple_partial_retry() {
        route.paths.remove(0);
        nodes[0].router.expect_find_route(RouteParameters {
                        payment_params: second_payment_params,
-                       final_value_msat: amt_msat / 2, final_cltv_expiry_delta: TEST_FINAL_CLTV,
+                       final_value_msat: amt_msat / 2,
                }, Ok(route.clone()));
 
        nodes[0].node.send_payment_with_retry(payment_hash, &Some(payment_secret), PaymentId(payment_hash.0), route_params, Retry::Attempts(1)).unwrap();
@@ -2637,7 +2627,6 @@ fn test_threaded_payment_retries() {
        let mut route_params = RouteParameters {
                payment_params,
                final_value_msat: amt_msat,
-               final_cltv_expiry_delta: TEST_FINAL_CLTV,
        };
 
        let mut route = Route {
@@ -2759,3 +2748,83 @@ fn test_threaded_payment_retries() {
                }
        }
 }
+
+fn do_no_missing_sent_on_midpoint_reload(persist_manager_with_payment: bool) {
+       // Test that if we reload in the middle of an HTLC claim commitment signed dance we'll still
+       // receive the PaymentSent event even if the ChannelManager had no idea about the payment when
+       // it was last persisted.
+       let chanmon_cfgs = create_chanmon_cfgs(2);
+       let node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
+       let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
+       let (persister_a, persister_b, persister_c);
+       let (chain_monitor_a, chain_monitor_b, chain_monitor_c);
+       let (nodes_0_deserialized, nodes_0_deserialized_b, nodes_0_deserialized_c);
+       let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs);
+
+       let chan_id = create_announced_chan_between_nodes(&nodes, 0, 1).2;
+
+       let mut nodes_0_serialized = Vec::new();
+       if !persist_manager_with_payment {
+               nodes_0_serialized = nodes[0].node.encode();
+       }
+
+       let (our_payment_preimage, our_payment_hash, _) = route_payment(&nodes[0], &[&nodes[1]], 1_000_000);
+
+       if persist_manager_with_payment {
+               nodes_0_serialized = nodes[0].node.encode();
+       }
+
+       nodes[1].node.claim_funds(our_payment_preimage);
+       check_added_monitors!(nodes[1], 1);
+       expect_payment_claimed!(nodes[1], our_payment_hash, 1_000_000);
+
+       let updates = get_htlc_update_msgs!(nodes[1], nodes[0].node.get_our_node_id());
+       nodes[0].node.handle_update_fulfill_htlc(&nodes[1].node.get_our_node_id(), &updates.update_fulfill_htlcs[0]);
+       nodes[0].node.handle_commitment_signed(&nodes[1].node.get_our_node_id(), &updates.commitment_signed);
+       check_added_monitors!(nodes[0], 1);
+
+       // The ChannelMonitor should always be the latest version, as we're required to persist it
+       // during the commitment signed handling.
+       let chan_0_monitor_serialized = get_monitor!(nodes[0], chan_id).encode();
+       reload_node!(nodes[0], test_default_channel_config(), &nodes_0_serialized, &[&chan_0_monitor_serialized], persister_a, chain_monitor_a, nodes_0_deserialized);
+
+       let events = nodes[0].node.get_and_clear_pending_events();
+       assert_eq!(events.len(), 2);
+       if let Event::ChannelClosed { reason: ClosureReason::OutdatedChannelManager, .. } = events[0] {} else { panic!(); }
+       if let Event::PaymentSent { payment_preimage, .. } = events[1] { assert_eq!(payment_preimage, our_payment_preimage); } else { panic!(); }
+       // Note that we don't get a PaymentPathSuccessful here as we leave the HTLC pending to avoid
+       // the double-claim that would otherwise appear at the end of this test.
+       let as_broadcasted_txn = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap().split_off(0);
+       assert_eq!(as_broadcasted_txn.len(), 1);
+
+       // Ensure that, even after some time, if we restart we still include *something* in the current
+       // `ChannelManager` which prevents a `PaymentFailed` when we restart even if pending resolved
+       // payments have since been timed out thanks to `IDEMPOTENCY_TIMEOUT_TICKS`.
+       // A naive implementation of the fix here would wipe the pending payments set, causing a
+       // failure event when we restart.
+       for _ in 0..(IDEMPOTENCY_TIMEOUT_TICKS * 2) { nodes[0].node.timer_tick_occurred(); }
+
+       let chan_0_monitor_serialized = get_monitor!(nodes[0], chan_id).encode();
+       reload_node!(nodes[0], test_default_channel_config(), &nodes[0].node.encode(), &[&chan_0_monitor_serialized], persister_b, chain_monitor_b, nodes_0_deserialized_b);
+       let events = nodes[0].node.get_and_clear_pending_events();
+       assert!(events.is_empty());
+
+       // Ensure that we don't generate any further events even after the channel-closing commitment
+       // transaction is confirmed on-chain.
+       confirm_transaction(&nodes[0], &as_broadcasted_txn[0]);
+       for _ in 0..(IDEMPOTENCY_TIMEOUT_TICKS * 2) { nodes[0].node.timer_tick_occurred(); }
+
+       let events = nodes[0].node.get_and_clear_pending_events();
+       assert!(events.is_empty());
+
+       let chan_0_monitor_serialized = get_monitor!(nodes[0], chan_id).encode();
+       reload_node!(nodes[0], test_default_channel_config(), &nodes[0].node.encode(), &[&chan_0_monitor_serialized], persister_c, chain_monitor_c, nodes_0_deserialized_c);
+       let events = nodes[0].node.get_and_clear_pending_events();
+       assert!(events.is_empty());
+}
+
+#[test]
+fn no_missing_sent_on_midpoint_reload() {
+       do_no_missing_sent_on_midpoint_reload(false);
+       do_no_missing_sent_on_midpoint_reload(true);
+}
index e9eaf33e8840b021afe7fbcfc60afec3e0fa5b3d..41bbf5b4908bba966f83d274dce8b2f3a2bb8ebf 100644 (file)
@@ -815,34 +815,40 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                let pending_read_buffer = [0; 50].to_vec(); // Noise act two is 50 bytes
 
                let mut peers = self.peers.write().unwrap();
-               if peers.insert(descriptor, Mutex::new(Peer {
-                       channel_encryptor: peer_encryptor,
-                       their_node_id: None,
-                       their_features: None,
-                       their_net_address: remote_network_address,
-
-                       pending_outbound_buffer: LinkedList::new(),
-                       pending_outbound_buffer_first_msg_offset: 0,
-                       gossip_broadcast_buffer: LinkedList::new(),
-                       awaiting_write_event: false,
-
-                       pending_read_buffer,
-                       pending_read_buffer_pos: 0,
-                       pending_read_is_header: false,
-
-                       sync_status: InitSyncTracker::NoSyncRequested,
-
-                       msgs_sent_since_pong: 0,
-                       awaiting_pong_timer_tick_intervals: 0,
-                       received_message_since_timer_tick: false,
-                       sent_gossip_timestamp_filter: false,
-
-                       received_channel_announce_since_backlogged: false,
-                       inbound_connection: false,
-               })).is_some() {
-                       panic!("PeerManager driver duplicated descriptors!");
-               };
-               Ok(res)
+               match peers.entry(descriptor) {
+                       hash_map::Entry::Occupied(_) => {
+                               debug_assert!(false, "PeerManager driver duplicated descriptors!");
+                               Err(PeerHandleError {})
+                       },
+                       hash_map::Entry::Vacant(e) => {
+                               e.insert(Mutex::new(Peer {
+                                       channel_encryptor: peer_encryptor,
+                                       their_node_id: None,
+                                       their_features: None,
+                                       their_net_address: remote_network_address,
+
+                                       pending_outbound_buffer: LinkedList::new(),
+                                       pending_outbound_buffer_first_msg_offset: 0,
+                                       gossip_broadcast_buffer: LinkedList::new(),
+                                       awaiting_write_event: false,
+
+                                       pending_read_buffer,
+                                       pending_read_buffer_pos: 0,
+                                       pending_read_is_header: false,
+
+                                       sync_status: InitSyncTracker::NoSyncRequested,
+
+                                       msgs_sent_since_pong: 0,
+                                       awaiting_pong_timer_tick_intervals: 0,
+                                       received_message_since_timer_tick: false,
+                                       sent_gossip_timestamp_filter: false,
+
+                                       received_channel_announce_since_backlogged: false,
+                                       inbound_connection: false,
+                               }));
+                               Ok(res)
+                       }
+               }
        }
 
        /// Indicates a new inbound connection has been established to a node with an optional remote
@@ -865,34 +871,40 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                let pending_read_buffer = [0; 50].to_vec(); // Noise act one is 50 bytes
 
                let mut peers = self.peers.write().unwrap();
-               if peers.insert(descriptor, Mutex::new(Peer {
-                       channel_encryptor: peer_encryptor,
-                       their_node_id: None,
-                       their_features: None,
-                       their_net_address: remote_network_address,
-
-                       pending_outbound_buffer: LinkedList::new(),
-                       pending_outbound_buffer_first_msg_offset: 0,
-                       gossip_broadcast_buffer: LinkedList::new(),
-                       awaiting_write_event: false,
-
-                       pending_read_buffer,
-                       pending_read_buffer_pos: 0,
-                       pending_read_is_header: false,
-
-                       sync_status: InitSyncTracker::NoSyncRequested,
-
-                       msgs_sent_since_pong: 0,
-                       awaiting_pong_timer_tick_intervals: 0,
-                       received_message_since_timer_tick: false,
-                       sent_gossip_timestamp_filter: false,
-
-                       received_channel_announce_since_backlogged: false,
-                       inbound_connection: true,
-               })).is_some() {
-                       panic!("PeerManager driver duplicated descriptors!");
-               };
-               Ok(())
+               match peers.entry(descriptor) {
+                       hash_map::Entry::Occupied(_) => {
+                               debug_assert!(false, "PeerManager driver duplicated descriptors!");
+                               Err(PeerHandleError {})
+                       },
+                       hash_map::Entry::Vacant(e) => {
+                               e.insert(Mutex::new(Peer {
+                                       channel_encryptor: peer_encryptor,
+                                       their_node_id: None,
+                                       their_features: None,
+                                       their_net_address: remote_network_address,
+
+                                       pending_outbound_buffer: LinkedList::new(),
+                                       pending_outbound_buffer_first_msg_offset: 0,
+                                       gossip_broadcast_buffer: LinkedList::new(),
+                                       awaiting_write_event: false,
+
+                                       pending_read_buffer,
+                                       pending_read_buffer_pos: 0,
+                                       pending_read_is_header: false,
+
+                                       sync_status: InitSyncTracker::NoSyncRequested,
+
+                                       msgs_sent_since_pong: 0,
+                                       awaiting_pong_timer_tick_intervals: 0,
+                                       received_message_since_timer_tick: false,
+                                       sent_gossip_timestamp_filter: false,
+
+                                       received_channel_announce_since_backlogged: false,
+                                       inbound_connection: true,
+                               }));
+                               Ok(())
+                       }
+               }
        }
 
        fn peer_should_read(&self, peer: &mut Peer) -> bool {
@@ -1141,9 +1153,13 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                macro_rules! insert_node_id {
                                                        () => {
                                                                match self.node_id_to_descriptor.lock().unwrap().entry(peer.their_node_id.unwrap().0) {
-                                                                       hash_map::Entry::Occupied(_) => {
+                                                                       hash_map::Entry::Occupied(e) => {
                                                                                log_trace!(self.logger, "Got second connection with {}, closing", log_pubkey!(peer.their_node_id.unwrap().0));
                                                                                peer.their_node_id = None; // Unset so that we don't generate a peer_disconnected event
+                                                                               // Check that the peers map is consistent with the
+                                                                               // node_id_to_descriptor map, as this has been broken
+                                                                               // before.
+                                                                               debug_assert!(peers.get(e.get()).is_some());
                                                                                return Err(PeerHandleError { })
                                                                        },
                                                                        hash_map::Entry::Vacant(entry) => {
@@ -1913,7 +1929,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                        self.do_attempt_write_data(&mut descriptor, &mut *peer, false);
                                                }
                                                self.do_disconnect(descriptor, &*peer, "DisconnectPeer HandleError");
-                                       }
+                                       } else { debug_assert!(false, "Missing connection for peer"); }
                                }
                        }
                }
@@ -1951,11 +1967,11 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                        },
                        Some(peer_lock) => {
                                let peer = peer_lock.lock().unwrap();
-                               if !peer.handshake_complete() { return; }
-                               debug_assert!(peer.their_node_id.is_some());
                                if let Some((node_id, _)) = peer.their_node_id {
                                        log_trace!(self.logger, "Handling disconnection of peer {}", log_pubkey!(node_id));
-                                       self.node_id_to_descriptor.lock().unwrap().remove(&node_id);
+                                       let removed = self.node_id_to_descriptor.lock().unwrap().remove(&node_id);
+                                       debug_assert!(removed.is_some(), "descriptor maps should be consistent");
+                                       if !peer.handshake_complete() { return; }
                                        self.message_handler.chan_handler.peer_disconnected(&node_id);
                                        self.message_handler.onion_message_handler.peer_disconnected(&node_id);
                                }
@@ -2188,12 +2204,13 @@ mod tests {
 
        use crate::prelude::*;
        use crate::sync::{Arc, Mutex};
-       use core::sync::atomic::Ordering;
+       use core::sync::atomic::{AtomicBool, Ordering};
 
        #[derive(Clone)]
        struct FileDescriptor {
                fd: u16,
                outbound_data: Arc<Mutex<Vec<u8>>>,
+               disconnect: Arc<AtomicBool>,
        }
        impl PartialEq for FileDescriptor {
                fn eq(&self, other: &Self) -> bool {
@@ -2213,7 +2230,7 @@ mod tests {
                        data.len()
                }
 
-               fn disconnect_socket(&mut self) {}
+               fn disconnect_socket(&mut self) { self.disconnect.store(true, Ordering::Release); }
        }
 
        struct PeerManagerCfg {
@@ -2254,10 +2271,16 @@ mod tests {
 
        fn establish_connection<'a>(peer_a: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, IgnoringMessageHandler, &'a test_utils::TestNodeSigner>, peer_b: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, IgnoringMessageHandler, &'a test_utils::TestNodeSigner>) -> (FileDescriptor, FileDescriptor) {
                let id_a = peer_a.node_signer.get_node_id(Recipient::Node).unwrap();
-               let mut fd_a = FileDescriptor { fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())) };
+               let mut fd_a = FileDescriptor {
+                       fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
                let addr_a = NetAddress::IPv4{addr: [127, 0, 0, 1], port: 1000};
                let id_b = peer_b.node_signer.get_node_id(Recipient::Node).unwrap();
-               let mut fd_b = FileDescriptor { fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())) };
+               let mut fd_b = FileDescriptor {
+                       fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
                let addr_b = NetAddress::IPv4{addr: [127, 0, 0, 1], port: 1001};
                let initial_data = peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap();
                peer_a.new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap();
@@ -2281,6 +2304,84 @@ mod tests {
                (fd_a.clone(), fd_b.clone())
        }
 
+       #[test]
+       #[cfg(feature = "std")]
+       fn fuzz_threaded_connections() {
+               // Spawn two threads which repeatedly connect two peers together, leading to "got second
+               // connection with peer" disconnections and rapid reconnect. This previously found an issue
+               // with our internal map consistency, and is a generally good smoke test of disconnection.
+               let cfgs = Arc::new(create_peermgr_cfgs(2));
+               // Until we have std::thread::scoped we have to unsafe { turn off the borrow checker }.
+               let peers = Arc::new(create_network(2, unsafe { &*(&*cfgs as *const _) as &'static _ }));
+
+               let start_time = std::time::Instant::now();
+               macro_rules! spawn_thread { ($id: expr) => { {
+                       let peers = Arc::clone(&peers);
+                       let cfgs = Arc::clone(&cfgs);
+                       std::thread::spawn(move || {
+                               let mut ctr = 0;
+                               while start_time.elapsed() < std::time::Duration::from_secs(1) {
+                                       let id_a = peers[0].node_signer.get_node_id(Recipient::Node).unwrap();
+                                       let mut fd_a = FileDescriptor {
+                                               fd: $id  + ctr * 3, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                                               disconnect: Arc::new(AtomicBool::new(false)),
+                                       };
+                                       let addr_a = NetAddress::IPv4{addr: [127, 0, 0, 1], port: 1000};
+                                       let mut fd_b = FileDescriptor {
+                                               fd: $id + ctr * 3, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                                               disconnect: Arc::new(AtomicBool::new(false)),
+                                       };
+                                       let addr_b = NetAddress::IPv4{addr: [127, 0, 0, 1], port: 1001};
+                                       let initial_data = peers[1].new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap();
+                                       peers[0].new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap();
+                                       assert_eq!(peers[0].read_event(&mut fd_a, &initial_data).unwrap(), false);
+
+                                       while start_time.elapsed() < std::time::Duration::from_secs(1) {
+                                               peers[0].process_events();
+                                               if fd_a.disconnect.load(Ordering::Acquire) { break; }
+                                               let a_data = fd_a.outbound_data.lock().unwrap().split_off(0);
+                                               if peers[1].read_event(&mut fd_b, &a_data).is_err() { break; }
+
+                                               peers[1].process_events();
+                                               if fd_b.disconnect.load(Ordering::Acquire) { break; }
+                                               let b_data = fd_b.outbound_data.lock().unwrap().split_off(0);
+                                               if peers[0].read_event(&mut fd_a, &b_data).is_err() { break; }
+
+                                               cfgs[0].chan_handler.pending_events.lock().unwrap()
+                                                       .push(crate::util::events::MessageSendEvent::SendShutdown {
+                                                               node_id: peers[1].node_signer.get_node_id(Recipient::Node).unwrap(),
+                                                               msg: msgs::Shutdown {
+                                                                       channel_id: [0; 32],
+                                                                       scriptpubkey: bitcoin::Script::new(),
+                                                               },
+                                                       });
+                                               cfgs[1].chan_handler.pending_events.lock().unwrap()
+                                                       .push(crate::util::events::MessageSendEvent::SendShutdown {
+                                                               node_id: peers[0].node_signer.get_node_id(Recipient::Node).unwrap(),
+                                                               msg: msgs::Shutdown {
+                                                                       channel_id: [0; 32],
+                                                                       scriptpubkey: bitcoin::Script::new(),
+                                                               },
+                                                       });
+
+                                               peers[0].timer_tick_occurred();
+                                               peers[1].timer_tick_occurred();
+                                       }
+
+                                       peers[0].socket_disconnected(&fd_a);
+                                       peers[1].socket_disconnected(&fd_b);
+                                       ctr += 1;
+                                       std::thread::sleep(std::time::Duration::from_micros(1));
+                               }
+                       })
+               } } }
+               let thrd_a = spawn_thread!(1);
+               let thrd_b = spawn_thread!(2);
+
+               thrd_a.join().unwrap();
+               thrd_b.join().unwrap();
+       }
+
        #[test]
        fn test_disconnect_peer() {
                // Simple test which builds a network of PeerManager, connects and brings them to NoiseState::Finished and
@@ -2337,7 +2438,10 @@ mod tests {
                let cfgs = create_peermgr_cfgs(2);
                let peers = create_network(2, &cfgs);
 
-               let mut fd_dup = FileDescriptor { fd: 3, outbound_data: Arc::new(Mutex::new(Vec::new())) };
+               let mut fd_dup = FileDescriptor {
+                       fd: 3, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
                let addr_dup = NetAddress::IPv4{addr: [127, 0, 0, 1], port: 1003};
                let id_a = cfgs[0].node_signer.get_node_id(Recipient::Node).unwrap();
                peers[0].new_inbound_connection(fd_dup.clone(), Some(addr_dup.clone())).unwrap();
@@ -2441,8 +2545,14 @@ mod tests {
                let peers = create_network(2, &cfgs);
 
                let a_id = peers[0].node_signer.get_node_id(Recipient::Node).unwrap();
-               let mut fd_a = FileDescriptor { fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())) };
-               let mut fd_b = FileDescriptor { fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())) };
+               let mut fd_a = FileDescriptor {
+                       fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
+               let mut fd_b = FileDescriptor {
+                       fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
                let initial_data = peers[1].new_outbound_connection(a_id, fd_b.clone(), None).unwrap();
                peers[0].new_inbound_connection(fd_a.clone(), None).unwrap();
 
index b456e15a2d13cdce9b6090889e1e886ae17a997c..9682503ee45c22d5062a099f5c1af015bbe3ffc0 100644 (file)
@@ -19,7 +19,7 @@ use crate::ln::features::{ChannelFeatures, InvoiceFeatures, NodeFeatures};
 use crate::ln::msgs::{DecodeError, ErrorAction, LightningError, MAX_VALUE_MSAT};
 use crate::routing::gossip::{DirectedChannelInfo, EffectiveCapacity, ReadOnlyNetworkGraph, NetworkGraph, NodeId, RoutingFees};
 use crate::routing::scoring::{ChannelUsage, LockableScore, Score};
-use crate::util::ser::{Writeable, Readable, Writer};
+use crate::util::ser::{Writeable, Readable, ReadableArgs, Writer};
 use crate::util::logger::{Level, Logger};
 use crate::util::chacha20::ChaCha20;
 
@@ -313,18 +313,23 @@ impl Readable for Route {
        fn read<R: io::Read>(reader: &mut R) -> Result<Route, DecodeError> {
                let _ver = read_ver_prefix!(reader, SERIALIZATION_VERSION);
                let path_count: u64 = Readable::read(reader)?;
+               if path_count == 0 { return Err(DecodeError::InvalidValue); }
                let mut paths = Vec::with_capacity(cmp::min(path_count, 128) as usize);
+               let mut min_final_cltv_expiry_delta = u32::max_value();
                for _ in 0..path_count {
                        let hop_count: u8 = Readable::read(reader)?;
-                       let mut hops = Vec::with_capacity(hop_count as usize);
+                       let mut hops: Vec<RouteHop> = Vec::with_capacity(hop_count as usize);
                        for _ in 0..hop_count {
                                hops.push(Readable::read(reader)?);
                        }
+                       if hops.is_empty() { return Err(DecodeError::InvalidValue); }
+                       min_final_cltv_expiry_delta =
+                               cmp::min(min_final_cltv_expiry_delta, hops.last().unwrap().cltv_expiry_delta);
                        paths.push(hops);
                }
                let mut payment_params = None;
                read_tlv_fields!(reader, {
-                       (1, payment_params, option),
+                       (1, payment_params, (option: ReadableArgs, min_final_cltv_expiry_delta)),
                });
                Ok(Route { paths, payment_params })
        }
@@ -343,19 +348,38 @@ pub struct RouteParameters {
 
        /// The amount in msats sent on the failed payment path.
        pub final_value_msat: u64,
+}
 
-       /// The CLTV on the final hop of the failed payment path.
-       ///
-       /// This field is deprecated, [`PaymentParameters::final_cltv_expiry_delta`] should be used
-       /// instead, if available.
-       pub final_cltv_expiry_delta: u32,
+impl Writeable for RouteParameters {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
+               write_tlv_fields!(writer, {
+                       (0, self.payment_params, required),
+                       (2, self.final_value_msat, required),
+                       // LDK versions prior to 0.0.114 had the `final_cltv_expiry_delta` parameter in
+                       // `RouteParameters` directly. For compatibility, we write it here.
+                       (4, self.payment_params.final_cltv_expiry_delta, required),
+               });
+               Ok(())
+       }
 }
 
-impl_writeable_tlv_based!(RouteParameters, {
-       (0, payment_params, required),
-       (2, final_value_msat, required),
-       (4, final_cltv_expiry_delta, required),
-});
+impl Readable for RouteParameters {
+       fn read<R: io::Read>(reader: &mut R) -> Result<Self, DecodeError> {
+               _init_and_read_tlv_fields!(reader, {
+                       (0, payment_params, (required: ReadableArgs, 0)),
+                       (2, final_value_msat, required),
+                       (4, final_cltv_expiry_delta, required),
+               });
+               let mut payment_params: PaymentParameters = payment_params.0.unwrap();
+               if payment_params.final_cltv_expiry_delta == 0 {
+                       payment_params.final_cltv_expiry_delta = final_cltv_expiry_delta.0.unwrap();
+               }
+               Ok(Self {
+                       payment_params,
+                       final_value_msat: final_value_msat.0.unwrap(),
+               })
+       }
+}
 
 /// Maximum total CTLV difference we allow for a full payment path.
 pub const DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA: u32 = 1008;
@@ -429,23 +453,54 @@ pub struct PaymentParameters {
        /// these SCIDs.
        pub previously_failed_channels: Vec<u64>,
 
-       /// The minimum CLTV delta at the end of the route.
-       ///
-       /// This field should always be set to `Some` and may be required in a future release.
-       pub final_cltv_expiry_delta: Option<u32>,
+       /// The minimum CLTV delta at the end of the route. This value must not be zero.
+       pub final_cltv_expiry_delta: u32,
+}
+
+impl Writeable for PaymentParameters {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
+               write_tlv_fields!(writer, {
+                       (0, self.payee_pubkey, required),
+                       (1, self.max_total_cltv_expiry_delta, required),
+                       (2, self.features, option),
+                       (3, self.max_path_count, required),
+                       (4, self.route_hints, vec_type),
+                       (5, self.max_channel_saturation_power_of_half, required),
+                       (6, self.expiry_time, option),
+                       (7, self.previously_failed_channels, vec_type),
+                       (9, self.final_cltv_expiry_delta, required),
+               });
+               Ok(())
+       }
+}
+
+impl ReadableArgs<u32> for PaymentParameters {
+       fn read<R: io::Read>(reader: &mut R, default_final_cltv_expiry_delta: u32) -> Result<Self, DecodeError> {
+               _init_and_read_tlv_fields!(reader, {
+                       (0, payee_pubkey, required),
+                       (1, max_total_cltv_expiry_delta, (default_value, DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA)),
+                       (2, features, option),
+                       (3, max_path_count, (default_value, DEFAULT_MAX_PATH_COUNT)),
+                       (4, route_hints, vec_type),
+                       (5, max_channel_saturation_power_of_half, (default_value, 2)),
+                       (6, expiry_time, option),
+                       (7, previously_failed_channels, vec_type),
+                       (9, final_cltv_expiry_delta, (default_value, default_final_cltv_expiry_delta)),
+               });
+               Ok(Self {
+                       payee_pubkey: _init_tlv_based_struct_field!(payee_pubkey, required),
+                       max_total_cltv_expiry_delta: _init_tlv_based_struct_field!(max_total_cltv_expiry_delta, (default_value, unused)),
+                       features,
+                       max_path_count: _init_tlv_based_struct_field!(max_path_count, (default_value, unused)),
+                       route_hints: route_hints.unwrap_or(Vec::new()),
+                       max_channel_saturation_power_of_half: _init_tlv_based_struct_field!(max_channel_saturation_power_of_half, (default_value, unused)),
+                       expiry_time,
+                       previously_failed_channels: previously_failed_channels.unwrap_or(Vec::new()),
+                       final_cltv_expiry_delta: _init_tlv_based_struct_field!(final_cltv_expiry_delta, (default_value, unused)),
+               })
+       }
 }
 
-impl_writeable_tlv_based!(PaymentParameters, {
-       (0, payee_pubkey, required),
-       (1, max_total_cltv_expiry_delta, (default_value, DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA)),
-       (2, features, option),
-       (3, max_path_count, (default_value, DEFAULT_MAX_PATH_COUNT)),
-       (4, route_hints, vec_type),
-       (5, max_channel_saturation_power_of_half, (default_value, 2)),
-       (6, expiry_time, option),
-       (7, previously_failed_channels, vec_type),
-       (9, final_cltv_expiry_delta, option),
-});
 
 impl PaymentParameters {
        /// Creates a payee with the node id of the given `pubkey`.
@@ -462,7 +517,7 @@ impl PaymentParameters {
                        max_path_count: DEFAULT_MAX_PATH_COUNT,
                        max_channel_saturation_power_of_half: 2,
                        previously_failed_channels: Vec::new(),
-                       final_cltv_expiry_delta: Some(final_cltv_expiry_delta),
+                       final_cltv_expiry_delta,
                }
        }
 
@@ -937,9 +992,7 @@ pub fn find_route<L: Deref, GL: Deref, S: Score>(
 ) -> Result<Route, LightningError>
 where L::Target: Logger, GL::Target: Logger {
        let graph_lock = network_graph.read_only();
-       let final_cltv_expiry_delta =
-               if let Some(delta) = route_params.payment_params.final_cltv_expiry_delta { delta }
-               else { route_params.final_cltv_expiry_delta };
+       let final_cltv_expiry_delta = route_params.payment_params.final_cltv_expiry_delta;
        let mut route = get_route(our_node_pubkey, &route_params.payment_params, &graph_lock, first_hops,
                route_params.final_value_msat, final_cltv_expiry_delta, logger, scorer,
                random_seed_bytes)?;
@@ -978,9 +1031,9 @@ where L::Target: Logger {
        if payment_params.max_total_cltv_expiry_delta <= final_cltv_expiry_delta {
                return Err(LightningError{err: "Can't find a route where the maximum total CLTV expiry delta is below the final CLTV expiry.".to_owned(), action: ErrorAction::IgnoreError});
        }
-       if let Some(delta) = payment_params.final_cltv_expiry_delta {
-               debug_assert_eq!(delta, final_cltv_expiry_delta);
-       }
+
+       // TODO: Remove the explicit final_cltv_expiry_delta parameter
+       debug_assert_eq!(final_cltv_expiry_delta, payment_params.final_cltv_expiry_delta);
 
        // The general routing idea is the following:
        // 1. Fill first/last hops communicated by the caller.
@@ -2019,7 +2072,8 @@ where L::Target: Logger, GL::Target: Logger {
        let graph_lock = network_graph.read_only();
        let mut route = build_route_from_hops_internal(
                our_node_pubkey, hops, &route_params.payment_params, &graph_lock,
-               route_params.final_value_msat, route_params.final_cltv_expiry_delta, logger, random_seed_bytes)?;
+               route_params.final_value_msat, route_params.payment_params.final_cltv_expiry_delta,
+               logger, random_seed_bytes)?;
        add_random_cltv_offset(&mut route, &route_params.payment_params, &graph_lock, random_seed_bytes);
        Ok(route)
 }
index 0e5bd8d64f78a48a9ec5b808aa200a022167edb6..74abd4276432b0d0d048d6791fe0ffb648b4f54d 100644 (file)
@@ -26,7 +26,7 @@ use crate::util::ser::Writeable;
 use crate::prelude::*;
 
 use alloc::sync::{Arc, Weak};
-use crate::sync::Mutex;
+use crate::sync::{Mutex, LockTestExt};
 use core::ops::Deref;
 
 /// An error when accessing the chain via [`UtxoLookup`].
@@ -404,7 +404,10 @@ impl PendingChecks {
                                // lookup if we haven't gotten that far yet).
                                match Weak::upgrade(&e.get()) {
                                        Some(pending_msgs) => {
-                                               let pending_matches = match &pending_msgs.lock().unwrap().channel_announce {
+                                               // This may be called with the mutex held on a different UtxoMessages
+                                               // struct, however in that case we have a global lockorder of new messages
+                                               // -> old messages, which makes this safe.
+                                               let pending_matches = match &pending_msgs.unsafe_well_ordered_double_lock_self().channel_announce {
                                                        Some(ChannelAnnouncement::Full(pending_msg)) => Some(pending_msg) == full_msg,
                                                        Some(ChannelAnnouncement::Unsigned(pending_msg)) => pending_msg == msg,
                                                        None => {
@@ -745,10 +748,11 @@ mod tests {
                        Ok(TxOut { value: 1_000_000, script_pubkey: good_script }));
 
                assert_eq!(chan_update_a.contents.timestamp, chan_update_b.contents.timestamp);
-               assert!(network_graph.read_only().channels()
+               let graph_lock = network_graph.read_only();
+               assert!(graph_lock.channels()
                                .get(&valid_announcement.contents.short_channel_id).as_ref().unwrap()
                                .one_to_two.as_ref().unwrap().last_update !=
-                       network_graph.read_only().channels()
+                       graph_lock.channels()
                                .get(&valid_announcement.contents.short_channel_id).as_ref().unwrap()
                                .two_to_one.as_ref().unwrap().last_update);
        }
index 5631093723733f16f4d7447c0865f58219d1cf40..dd106a9e897e0125b3daf4913d3f65f15c350762 100644 (file)
@@ -75,7 +75,7 @@ struct LockDep {
 }
 
 #[cfg(feature = "backtrace")]
-fn get_construction_location(backtrace: &Backtrace) -> String {
+fn get_construction_location(backtrace: &Backtrace) -> (String, Option<u32>) {
        // Find the first frame that is after `debug_sync` (or that is in our tests) and use
        // that as the mutex construction site. Note that the first few frames may be in
        // the `backtrace` crate, so we have to ignore those.
@@ -86,13 +86,7 @@ fn get_construction_location(backtrace: &Backtrace) -> String {
                        let symbol_name = symbol.name().unwrap().as_str().unwrap();
                        if !sync_mutex_constr_regex.is_match(symbol_name) {
                                if found_debug_sync {
-                                       if let Some(col) = symbol.colno() {
-                                               return format!("{}:{}:{}", symbol.filename().unwrap().display(), symbol.lineno().unwrap(), col);
-                                       } else {
-                                               // Windows debug symbols don't support column numbers, so fall back to
-                                               // line numbers only if no `colno` is available
-                                               return format!("{}:{}", symbol.filename().unwrap().display(), symbol.lineno().unwrap());
-                                       }
+                                       return (format!("{}:{}", symbol.filename().unwrap().display(), symbol.lineno().unwrap()), symbol.colno());
                                }
                        } else { found_debug_sync = true; }
                }
@@ -113,42 +107,51 @@ impl LockMetadata {
 
                #[cfg(feature = "backtrace")]
                {
-                       let lock_constr_location = get_construction_location(&res._lock_construction_bt);
+                       let (lock_constr_location, lock_constr_colno) =
+                               get_construction_location(&res._lock_construction_bt);
                        LOCKS_INIT.call_once(|| { unsafe { LOCKS = Some(StdMutex::new(HashMap::new())); } });
                        let mut locks = unsafe { LOCKS.as_ref() }.unwrap().lock().unwrap();
                        match locks.entry(lock_constr_location) {
-                               hash_map::Entry::Occupied(e) => return Arc::clone(e.get()),
+                               hash_map::Entry::Occupied(e) => {
+                                       assert_eq!(lock_constr_colno,
+                                               get_construction_location(&e.get()._lock_construction_bt).1,
+                                               "Because Windows doesn't support column number results in backtraces, we cannot construct two mutexes on the same line or we risk lockorder detection false positives.");
+                                       return Arc::clone(e.get())
+                               },
                                hash_map::Entry::Vacant(e) => { e.insert(Arc::clone(&res)); },
                        }
                }
                res
        }
 
-       // Returns whether we were a recursive lock (only relevant for read)
-       fn _pre_lock(this: &Arc<LockMetadata>, read: bool) -> bool {
-               let mut inserted = false;
+       fn pre_lock(this: &Arc<LockMetadata>, _double_lock_self_allowed: bool) {
                LOCKS_HELD.with(|held| {
                        // For each lock which is currently locked, check that no lock's locked-before
                        // set includes the lock we're about to lock, which would imply a lockorder
                        // inversion.
-                       for (locked_idx, _locked) in held.borrow().iter() {
-                               if read && *locked_idx == this.lock_idx {
-                                       // Recursive read locks are explicitly allowed
-                                       return;
+                       for (locked_idx, locked) in held.borrow().iter() {
+                               if *locked_idx == this.lock_idx {
+                                       // Note that with `feature = "backtrace"` set, we may be looking at different
+                                       // instances of the same lock. Still, doing so is quite risky, a total order
+                                       // must be maintained, and doing so across a set of otherwise-identical mutexes
+                                       // is fraught with issues.
+                                       #[cfg(feature = "backtrace")]
+                                       debug_assert!(_double_lock_self_allowed,
+                                               "Tried to acquire a lock while it was held!\nLock constructed at {}",
+                                               get_construction_location(&this._lock_construction_bt).0);
+                                       #[cfg(not(feature = "backtrace"))]
+                                       panic!("Tried to acquire a lock while it was held!");
                                }
                        }
                        for (locked_idx, locked) in held.borrow().iter() {
-                               if !read && *locked_idx == this.lock_idx {
-                                       // With `feature = "backtrace"` set, we may be looking at different instances
-                                       // of the same lock.
-                                       debug_assert!(cfg!(feature = "backtrace"), "Tried to acquire a lock while it was held!");
-                               }
                                for (locked_dep_idx, _locked_dep) in locked.locked_before.lock().unwrap().iter() {
                                        if *locked_dep_idx == this.lock_idx && *locked_dep_idx != locked.lock_idx {
                                                #[cfg(feature = "backtrace")]
                                                panic!("Tried to violate existing lockorder.\nMutex that should be locked after the current lock was created at the following backtrace.\nNote that to get a backtrace for the lockorder violation, you should set RUST_BACKTRACE=1\nLock being taken constructed at: {} ({}):\n{:?}\nLock constructed at: {} ({})\n{:?}\n\nLock dep created at:\n{:?}\n\n",
-                                                       get_construction_location(&this._lock_construction_bt), this.lock_idx, this._lock_construction_bt,
-                                                       get_construction_location(&locked._lock_construction_bt), locked.lock_idx, locked._lock_construction_bt,
+                                                       get_construction_location(&this._lock_construction_bt).0,
+                                                       this.lock_idx, this._lock_construction_bt,
+                                                       get_construction_location(&locked._lock_construction_bt).0,
+                                                       locked.lock_idx, locked._lock_construction_bt,
                                                        _locked_dep._lockdep_trace);
                                                #[cfg(not(feature = "backtrace"))]
                                                panic!("Tried to violate existing lockorder. Build with the backtrace feature for more info.");
@@ -162,14 +165,9 @@ impl LockMetadata {
                                }
                        }
                        held.borrow_mut().insert(this.lock_idx, Arc::clone(this));
-                       inserted = true;
                });
-               inserted
        }
 
-       fn pre_lock(this: &Arc<LockMetadata>) { Self::_pre_lock(this, false); }
-       fn pre_read_lock(this: &Arc<LockMetadata>) -> bool { Self::_pre_lock(this, true) }
-
        fn held_by_thread(this: &Arc<LockMetadata>) -> LockHeldState {
                let mut res = LockHeldState::NotHeldByThread;
                LOCKS_HELD.with(|held| {
@@ -203,6 +201,11 @@ pub struct Mutex<T: Sized> {
        inner: StdMutex<T>,
        deps: Arc<LockMetadata>,
 }
+impl<T: Sized> Mutex<T> {
+       pub(crate) fn into_inner(self) -> LockResult<T> {
+               self.inner.into_inner().map_err(|_| ())
+       }
+}
 
 #[must_use = "if unused the Mutex will immediately unlock"]
 pub struct MutexGuard<'a, T: Sized + 'a> {
@@ -249,7 +252,7 @@ impl<T> Mutex<T> {
        }
 
        pub fn lock<'a>(&'a self) -> LockResult<MutexGuard<'a, T>> {
-               LockMetadata::pre_lock(&self.deps);
+               LockMetadata::pre_lock(&self.deps, false);
                self.inner.lock().map(|lock| MutexGuard { mutex: self, lock }).map_err(|_| ())
        }
 
@@ -262,11 +265,17 @@ impl<T> Mutex<T> {
        }
 }
 
-impl <T> LockTestExt for Mutex<T> {
+impl<'a, T: 'a> LockTestExt<'a> for Mutex<T> {
        #[inline]
        fn held_by_thread(&self) -> LockHeldState {
                LockMetadata::held_by_thread(&self.deps)
        }
+       type ExclLock = MutexGuard<'a, T>;
+       #[inline]
+       fn unsafe_well_ordered_double_lock_self(&'a self) -> MutexGuard<T> {
+               LockMetadata::pre_lock(&self.deps, true);
+               self.inner.lock().map(|lock| MutexGuard { mutex: self, lock }).unwrap()
+       }
 }
 
 pub struct RwLock<T: Sized> {
@@ -276,7 +285,6 @@ pub struct RwLock<T: Sized> {
 
 pub struct RwLockReadGuard<'a, T: Sized + 'a> {
        lock: &'a RwLock<T>,
-       first_lock: bool,
        guard: StdRwLockReadGuard<'a, T>,
 }
 
@@ -295,12 +303,6 @@ impl<T: Sized> Deref for RwLockReadGuard<'_, T> {
 
 impl<T: Sized> Drop for RwLockReadGuard<'_, T> {
        fn drop(&mut self) {
-               if !self.first_lock {
-                       // Note that its not strictly true that the first taken read lock will get unlocked
-                       // last, but in practice our locks are always taken as RAII, so it should basically
-                       // always be true.
-                       return;
-               }
                LOCKS_HELD.with(|held| {
                        held.borrow_mut().remove(&self.lock.deps.lock_idx);
                });
@@ -335,12 +337,16 @@ impl<T> RwLock<T> {
        }
 
        pub fn read<'a>(&'a self) -> LockResult<RwLockReadGuard<'a, T>> {
-               let first_lock = LockMetadata::pre_read_lock(&self.deps);
-               self.inner.read().map(|guard| RwLockReadGuard { lock: self, guard, first_lock }).map_err(|_| ())
+               // Note that while we could be taking a recursive read lock here, Rust's `RwLock` may
+               // deadlock trying to take a second read lock if another thread is waiting on the write
+               // lock. This behavior is platform dependent, but our in-tree `FairRwLock` guarantees
+               // such a deadlock.
+               LockMetadata::pre_lock(&self.deps, false);
+               self.inner.read().map(|guard| RwLockReadGuard { lock: self, guard }).map_err(|_| ())
        }
 
        pub fn write<'a>(&'a self) -> LockResult<RwLockWriteGuard<'a, T>> {
-               LockMetadata::pre_lock(&self.deps);
+               LockMetadata::pre_lock(&self.deps, false);
                self.inner.write().map(|guard| RwLockWriteGuard { lock: self, guard }).map_err(|_| ())
        }
 
@@ -353,11 +359,17 @@ impl<T> RwLock<T> {
        }
 }
 
-impl <T> LockTestExt for RwLock<T> {
+impl<'a, T: 'a> LockTestExt<'a> for RwLock<T> {
        #[inline]
        fn held_by_thread(&self) -> LockHeldState {
                LockMetadata::held_by_thread(&self.deps)
        }
+       type ExclLock = RwLockWriteGuard<'a, T>;
+       #[inline]
+       fn unsafe_well_ordered_double_lock_self(&'a self) -> RwLockWriteGuard<'a, T> {
+               LockMetadata::pre_lock(&self.deps, true);
+               self.inner.write().map(|guard| RwLockWriteGuard { lock: self, guard }).unwrap()
+       }
 }
 
 pub type FairRwLock<T> = RwLock<T>;
index a9519ac240cde1e24fb3dd09e37e36eddc552216..de609d5b3d711059568daca1e9d408be80891321 100644 (file)
@@ -50,10 +50,15 @@ impl<T> FairRwLock<T> {
        }
 }
 
-impl<T> LockTestExt for FairRwLock<T> {
+impl<'a, T: 'a> LockTestExt<'a> for FairRwLock<T> {
        #[inline]
        fn held_by_thread(&self) -> LockHeldState {
                // fairrwlock is only built in non-test modes, so we should never support tests.
                LockHeldState::Unsupported
        }
+       type ExclLock = RwLockWriteGuard<'a, T>;
+       #[inline]
+       fn unsafe_well_ordered_double_lock_self(&'a self) -> RwLockWriteGuard<'a, T> {
+               self.write().unwrap()
+       }
 }
index 50ef40e295f50d0a0d9095a0f3dfe4c73916a340..1b2b9a739b8c5d3078204917f55b0fef86e87f13 100644 (file)
@@ -7,8 +7,17 @@ pub(crate) enum LockHeldState {
        Unsupported,
 }
 
-pub(crate) trait LockTestExt {
+pub(crate) trait LockTestExt<'a> {
        fn held_by_thread(&self) -> LockHeldState;
+       type ExclLock;
+       /// If two instances of the same mutex are being taken at the same time, it's very easy to have
+       /// a lockorder inversion and risk deadlock. Thus, we default to disabling such locks.
+       ///
+       /// However, sometimes they cannot be avoided. In such cases, this method exists to take a
+       /// mutex while avoiding a test failure. It is deliberately verbose and includes the term
+       /// "unsafe" to indicate that special care needs to be taken to ensure no deadlocks are
+       /// possible.
+       fn unsafe_well_ordered_double_lock_self(&'a self) -> Self::ExclLock;
 }
 
 #[cfg(all(feature = "std", not(feature = "_bench_unstable"), test))]
@@ -27,13 +36,19 @@ pub use {std::sync::{Arc, Mutex, Condvar, MutexGuard, RwLock, RwLockReadGuard, R
 #[cfg(all(feature = "std", any(feature = "_bench_unstable", not(test))))]
 mod ext_impl {
        use super::*;
-       impl<T> LockTestExt for Mutex<T> {
+       impl<'a, T: 'a> LockTestExt<'a> for Mutex<T> {
                #[inline]
                fn held_by_thread(&self) -> LockHeldState { LockHeldState::Unsupported }
+               type ExclLock = MutexGuard<'a, T>;
+               #[inline]
+               fn unsafe_well_ordered_double_lock_self(&'a self) -> MutexGuard<T> { self.lock().unwrap() }
        }
-       impl<T> LockTestExt for RwLock<T> {
+       impl<'a, T: 'a> LockTestExt<'a> for RwLock<T> {
                #[inline]
                fn held_by_thread(&self) -> LockHeldState { LockHeldState::Unsupported }
+               type ExclLock = RwLockWriteGuard<'a, T>;
+               #[inline]
+               fn unsafe_well_ordered_double_lock_self(&'a self) -> RwLockWriteGuard<T> { self.write().unwrap() }
        }
 }
 
index e17aa6ab15faa5dc33b66dbc1b80ab79fd36cb18..17307997d8176cc2b82a0d9559307971f3e2c252 100644 (file)
@@ -60,14 +60,21 @@ impl<T> Mutex<T> {
        pub fn try_lock<'a>(&'a self) -> LockResult<MutexGuard<'a, T>> {
                Ok(MutexGuard { lock: self.inner.borrow_mut() })
        }
+
+       pub fn into_inner(self) -> LockResult<T> {
+               Ok(self.inner.into_inner())
+       }
 }
 
-impl<T> LockTestExt for Mutex<T> {
+impl<'a, T: 'a> LockTestExt<'a> for Mutex<T> {
        #[inline]
        fn held_by_thread(&self) -> LockHeldState {
                if self.lock().is_err() { return LockHeldState::HeldByThread; }
                else { return LockHeldState::NotHeldByThread; }
        }
+       type ExclLock = MutexGuard<'a, T>;
+       #[inline]
+       fn unsafe_well_ordered_double_lock_self(&'a self) -> MutexGuard<T> { self.lock().unwrap() }
 }
 
 pub struct RwLock<T: ?Sized> {
@@ -125,12 +132,15 @@ impl<T> RwLock<T> {
        }
 }
 
-impl<T> LockTestExt for RwLock<T> {
+impl<'a, T: 'a> LockTestExt<'a> for RwLock<T> {
        #[inline]
        fn held_by_thread(&self) -> LockHeldState {
                if self.write().is_err() { return LockHeldState::HeldByThread; }
                else { return LockHeldState::NotHeldByThread; }
        }
+       type ExclLock = RwLockWriteGuard<'a, T>;
+       #[inline]
+       fn unsafe_well_ordered_double_lock_self(&'a self) -> RwLockWriteGuard<T> { self.write().unwrap() }
 }
 
 pub type FairRwLock<T> = RwLock<T>;
index a3f746b11dc80dd32bd7badf5f9786acf6426a89..6d72410bd596341fd5246760f03046884b93e3b5 100644 (file)
@@ -15,6 +15,8 @@ fn recursive_lock_fail() {
 }
 
 #[test]
+#[should_panic]
+#[cfg(not(feature = "backtrace"))]
 fn recursive_read() {
        let lock = RwLock::new(());
        let _a = lock.read().unwrap();
@@ -66,23 +68,6 @@ fn read_lockorder_fail() {
        }
 }
 
-#[test]
-fn read_recursive_no_lockorder() {
-       // Like the above, but note that no lockorder is implied when we recursively read-lock a
-       // RwLock, causing this to pass just fine.
-       let a = RwLock::new(());
-       let b = RwLock::new(());
-       let _outer = a.read().unwrap();
-       {
-               let _a = a.read().unwrap();
-               let _b = b.read().unwrap();
-       }
-       {
-               let _b = b.read().unwrap();
-               let _a = a.read().unwrap();
-       }
-}
-
 #[test]
 #[should_panic]
 fn read_write_lockorder_fail() {
index 3d45172517db3e706dcbf4884dd75923c872c8a5..2b5bbac0ddc583e188a9b36b2833cbb42bb74eea 100644 (file)
@@ -21,6 +21,8 @@ use core::ops::{Bound, RangeBounds};
 /// actually backed by a [`HashMap`], with some additional tracking to ensure we can iterate over
 /// keys in the order defined by [`Ord`].
 ///
+/// (C-not exported) as bindings provide alternate accessors rather than exposing maps directly.
+///
 /// [`BTreeMap`]: alloc::collections::BTreeMap
 #[derive(Clone, Debug, Eq)]
 pub struct IndexedMap<K: Hash + Ord, V> {
@@ -147,6 +149,8 @@ impl<K: Hash + Ord + PartialEq, V: PartialEq> PartialEq for IndexedMap<K, V> {
 }
 
 /// An iterator over a range of values in an [`IndexedMap`]
+///
+/// (C-not exported) as bindings provide alternate accessors rather than exposing maps directly.
 pub struct Range<'a, K: Hash + Ord, V> {
        inner_range: Iter<'a, K>,
        map: &'a HashMap<K, V>,
@@ -161,6 +165,8 @@ impl<'a, K: Hash + Ord, V: 'a> Iterator for Range<'a, K, V> {
 }
 
 /// An [`Entry`] for a key which currently has no value
+///
+/// (C-not exported) as bindings provide alternate accessors rather than exposing maps directly.
 pub struct VacantEntry<'a, K: Hash + Ord, V> {
        #[cfg(feature = "hashbrown")]
        underlying_entry: hash_map::VacantEntry<'a, K, V, hash_map::DefaultHashBuilder>,
@@ -171,6 +177,8 @@ pub struct VacantEntry<'a, K: Hash + Ord, V> {
 }
 
 /// An [`Entry`] for an existing key-value pair
+///
+/// (C-not exported) as bindings provide alternate accessors rather than exposing maps directly.
 pub struct OccupiedEntry<'a, K: Hash + Ord, V> {
        #[cfg(feature = "hashbrown")]
        underlying_entry: hash_map::OccupiedEntry<'a, K, V, hash_map::DefaultHashBuilder>,
@@ -181,6 +189,8 @@ pub struct OccupiedEntry<'a, K: Hash + Ord, V> {
 
 /// A mutable reference to a position in the map. This can be used to reference, add, or update the
 /// value at a fixed key.
+///
+/// (C-not exported) as bindings provide alternate accessors rather than exposing maps directly.
 pub enum Entry<'a, K: Hash + Ord, V> {
        /// A mutable reference to a position within the map where there is no value.
        Vacant(VacantEntry<'a, K, V>),
index e83e6e2ee48ea27d40327c4e56cb50e2827e94fc..6e98272f3171d2bc47082c4171b6e4013800e911 100644 (file)
@@ -167,17 +167,17 @@ macro_rules! log_given_level {
        ($logger: expr, $lvl:expr, $($arg:tt)+) => (
                match $lvl {
                        #[cfg(not(any(feature = "max_level_off")))]
-                       $crate::util::logger::Level::Error => log_internal!($logger, $lvl, $($arg)*),
+                       $crate::util::logger::Level::Error => $crate::log_internal!($logger, $lvl, $($arg)*),
                        #[cfg(not(any(feature = "max_level_off", feature = "max_level_error")))]
-                       $crate::util::logger::Level::Warn => log_internal!($logger, $lvl, $($arg)*),
+                       $crate::util::logger::Level::Warn => $crate::log_internal!($logger, $lvl, $($arg)*),
                        #[cfg(not(any(feature = "max_level_off", feature = "max_level_error", feature = "max_level_warn")))]
-                       $crate::util::logger::Level::Info => log_internal!($logger, $lvl, $($arg)*),
+                       $crate::util::logger::Level::Info => $crate::log_internal!($logger, $lvl, $($arg)*),
                        #[cfg(not(any(feature = "max_level_off", feature = "max_level_error", feature = "max_level_warn", feature = "max_level_info")))]
-                       $crate::util::logger::Level::Debug => log_internal!($logger, $lvl, $($arg)*),
+                       $crate::util::logger::Level::Debug => $crate::log_internal!($logger, $lvl, $($arg)*),
                        #[cfg(not(any(feature = "max_level_off", feature = "max_level_error", feature = "max_level_warn", feature = "max_level_info", feature = "max_level_debug")))]
-                       $crate::util::logger::Level::Trace => log_internal!($logger, $lvl, $($arg)*),
+                       $crate::util::logger::Level::Trace => $crate::log_internal!($logger, $lvl, $($arg)*),
                        #[cfg(not(any(feature = "max_level_off", feature = "max_level_error", feature = "max_level_warn", feature = "max_level_info", feature = "max_level_debug", feature = "max_level_trace")))]
-                       $crate::util::logger::Level::Gossip => log_internal!($logger, $lvl, $($arg)*),
+                       $crate::util::logger::Level::Gossip => $crate::log_internal!($logger, $lvl, $($arg)*),
 
                        #[cfg(any(feature = "max_level_off", feature = "max_level_error", feature = "max_level_warn", feature = "max_level_info", feature = "max_level_debug", feature = "max_level_trace"))]
                        _ => {
@@ -191,7 +191,7 @@ macro_rules! log_given_level {
 #[macro_export]
 macro_rules! log_error {
        ($logger: expr, $($arg:tt)*) => (
-               log_given_level!($logger, $crate::util::logger::Level::Error, $($arg)*);
+               $crate::log_given_level!($logger, $crate::util::logger::Level::Error, $($arg)*);
        )
 }
 
@@ -199,7 +199,7 @@ macro_rules! log_error {
 #[macro_export]
 macro_rules! log_warn {
        ($logger: expr, $($arg:tt)*) => (
-               log_given_level!($logger, $crate::util::logger::Level::Warn, $($arg)*);
+               $crate::log_given_level!($logger, $crate::util::logger::Level::Warn, $($arg)*);
        )
 }
 
@@ -207,7 +207,7 @@ macro_rules! log_warn {
 #[macro_export]
 macro_rules! log_info {
        ($logger: expr, $($arg:tt)*) => (
-               log_given_level!($logger, $crate::util::logger::Level::Info, $($arg)*);
+               $crate::log_given_level!($logger, $crate::util::logger::Level::Info, $($arg)*);
        )
 }
 
@@ -215,7 +215,7 @@ macro_rules! log_info {
 #[macro_export]
 macro_rules! log_debug {
        ($logger: expr, $($arg:tt)*) => (
-               log_given_level!($logger, $crate::util::logger::Level::Debug, $($arg)*);
+               $crate::log_given_level!($logger, $crate::util::logger::Level::Debug, $($arg)*);
        )
 }
 
@@ -223,7 +223,7 @@ macro_rules! log_debug {
 #[macro_export]
 macro_rules! log_trace {
        ($logger: expr, $($arg:tt)*) => (
-               log_given_level!($logger, $crate::util::logger::Level::Trace, $($arg)*)
+               $crate::log_given_level!($logger, $crate::util::logger::Level::Trace, $($arg)*)
        )
 }
 
@@ -231,6 +231,6 @@ macro_rules! log_trace {
 #[macro_export]
 macro_rules! log_gossip {
        ($logger: expr, $($arg:tt)*) => (
-               log_given_level!($logger, $crate::util::logger::Level::Gossip, $($arg)*);
+               $crate::log_given_level!($logger, $crate::util::logger::Level::Gossip, $($arg)*);
        )
 }
index 5adf5758131bb65e9929042a267162e3665c58e2..bef192585b5e486696068174b92539c0faa895ff 100644 (file)
@@ -89,6 +89,8 @@ impl Writer for VecWriter {
 
 /// Writer that only tracks the amount of data written - useful if you need to calculate the length
 /// of some data when serialized but don't yet need the full data.
+///
+/// (C-not exported) as manual TLV building is not currently supported in bindings
 pub struct LengthCalculatingWriter(pub usize);
 impl Writer for LengthCalculatingWriter {
        #[inline]
@@ -100,6 +102,8 @@ impl Writer for LengthCalculatingWriter {
 
 /// Essentially [`std::io::Take`] but a bit simpler and with a method to walk the underlying stream
 /// forward to ensure we always consume exactly the fixed length specified.
+///
+/// (C-not exported) as manual TLV building is not currently supported in bindings
 pub struct FixedLengthReader<R: Read> {
        read: R,
        bytes_read: u64,
@@ -155,6 +159,8 @@ impl<R: Read> LengthRead for FixedLengthReader<R> {
 
 /// A [`Read`] implementation which tracks whether any bytes have been read at all. This allows us to distinguish
 /// between "EOF reached before we started" and "EOF reached mid-read".
+///
+/// (C-not exported) as manual TLV building is not currently supported in bindings
 pub struct ReadTrackingReader<R: Read> {
        read: R,
        /// Returns whether we have read from this reader or not yet.
@@ -289,22 +295,32 @@ impl<T: Readable> MaybeReadable for T {
 }
 
 /// Wrapper to read a required (non-optional) TLV record.
-pub struct RequiredWrapper<T: Readable>(pub Option<T>);
+///
+/// (C-not exported) as manual TLV building is not currently supported in bindings
+pub struct RequiredWrapper<T>(pub Option<T>);
 impl<T: Readable> Readable for RequiredWrapper<T> {
        #[inline]
        fn read<R: Read>(reader: &mut R) -> Result<Self, DecodeError> {
                Ok(Self(Some(Readable::read(reader)?)))
        }
 }
+impl<A, T: ReadableArgs<A>> ReadableArgs<A> for RequiredWrapper<T> {
+       #[inline]
+       fn read<R: Read>(reader: &mut R, args: A) -> Result<Self, DecodeError> {
+               Ok(Self(Some(ReadableArgs::read(reader, args)?)))
+       }
+}
 /// When handling `default_values`, we want to map the default-value T directly
 /// to a `RequiredWrapper<T>` in a way that works for `field: T = t;` as
 /// well. Thus, we assume `Into<T> for T` does nothing and use that.
-impl<T: Readable> From<T> for RequiredWrapper<T> {
+impl<T> From<T> for RequiredWrapper<T> {
        fn from(t: T) -> RequiredWrapper<T> { RequiredWrapper(Some(t)) }
 }
 
 /// Wrapper to read a required (non-optional) TLV record that may have been upgraded without
 /// backwards compat.
+///
+/// (C-not exported) as manual TLV building is not currently supported in bindings
 pub struct UpgradableRequired<T: MaybeReadable>(pub Option<T>);
 impl<T: MaybeReadable> MaybeReadable for UpgradableRequired<T> {
        #[inline]
@@ -585,6 +601,8 @@ impl Readable for [u16; 8] {
 
 /// A type for variable-length values within TLV record where the length is encoded as part of the record.
 /// Used to prevent encoding the length twice.
+///
+/// (C-not exported) as manual TLV building is not currently supported in bindings
 pub struct WithoutLength<T>(pub T);
 
 impl Writeable for WithoutLength<&String> {
index 5e55ab1d3b8182baf80f87e0e1a929d90380057c..59402f6acc8e0eab43810eb896a678186e2e9efc 100644 (file)
@@ -51,6 +51,10 @@ macro_rules! _encode_tlv {
        ($stream: expr, $type: expr, $field: expr, (option, encoding: $fieldty: ty)) => {
                $crate::_encode_tlv!($stream, $type, $field, option);
        };
+       ($stream: expr, $type: expr, $field: expr, (option: $trait: ident $(, $read_arg: expr)?)) => {
+               // Just a read-mapped type
+               $crate::_encode_tlv!($stream, $type, $field, option);
+       };
 }
 
 /// Panics if the last seen TLV type is not numerically less than the TLV type currently being checked.
@@ -161,6 +165,9 @@ macro_rules! _get_varint_length_prefixed_tlv_length {
                        $len.0 += field_len;
                }
        };
+       ($len: expr, $type: expr, $field: expr, (option: $trait: ident $(, $read_arg: expr)?)) => {
+               $crate::_get_varint_length_prefixed_tlv_length!($len, $type, $field, option);
+       };
        ($len: expr, $type: expr, $field: expr, upgradable_required) => {
                $crate::_get_varint_length_prefixed_tlv_length!($len, $type, $field, required);
        };
@@ -210,6 +217,9 @@ macro_rules! _check_decoded_tlv_order {
                        return Err(DecodeError::InvalidValue);
                }
        }};
+       ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, (required: $trait: ident $(, $read_arg: expr)?)) => {{
+               $crate::_check_decoded_tlv_order!($last_seen_type, $typ, $type, $field, required);
+       }};
        ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, option) => {{
                // no-op
        }};
@@ -252,6 +262,9 @@ macro_rules! _check_missing_tlv {
                        return Err(DecodeError::InvalidValue);
                }
        }};
+       ($last_seen_type: expr, $type: expr, $field: ident, (required: $trait: ident $(, $read_arg: expr)?)) => {{
+               $crate::_check_missing_tlv!($last_seen_type, $type, $field, required);
+       }};
        ($last_seen_type: expr, $type: expr, $field: ident, vec_type) => {{
                // no-op
        }};
@@ -285,6 +298,9 @@ macro_rules! _decode_tlv {
        ($reader: expr, $field: ident, required) => {{
                $field = $crate::util::ser::Readable::read(&mut $reader)?;
        }};
+       ($reader: expr, $field: ident, (required: $trait: ident $(, $read_arg: expr)?)) => {{
+               $field = $trait::read(&mut $reader $(, $read_arg)*)?;
+       }};
        ($reader: expr, $field: ident, vec_type) => {{
                let f: $crate::util::ser::WithoutLength<Vec<_>> = $crate::util::ser::Readable::read(&mut $reader)?;
                $field = Some(f.0);
@@ -644,6 +660,9 @@ macro_rules! _init_tlv_based_struct_field {
        ($field: ident, option) => {
                $field
        };
+       ($field: ident, (option: $trait: ident $(, $read_arg: expr)?)) => {
+               $crate::_init_tlv_based_struct_field!($field, option)
+       };
        ($field: ident, upgradable_required) => {
                $field.0.unwrap()
        };
@@ -673,12 +692,18 @@ macro_rules! _init_tlv_field_var {
        ($field: ident, required) => {
                let mut $field = $crate::util::ser::RequiredWrapper(None);
        };
+       ($field: ident, (required: $trait: ident $(, $read_arg: expr)?)) => {
+               $crate::_init_tlv_field_var!($field, required);
+       };
        ($field: ident, vec_type) => {
                let mut $field = Some(Vec::new());
        };
        ($field: ident, option) => {
                let mut $field = None;
        };
+       ($field: ident, (option: $trait: ident $(, $read_arg: expr)?)) => {
+               $crate::_init_tlv_field_var!($field, option);
+       };
        ($field: ident, upgradable_required) => {
                let mut $field = $crate::util::ser::UpgradableRequired(None);
        };
index fdbc22f116600b7162bdbdcafa4a6f314af944e1..f86fc376cee0202323b9924057157aaf8379f86a 100644 (file)
@@ -105,7 +105,10 @@ impl Notifier {
        pub(crate) fn notify(&self) {
                let mut lock = self.notify_pending.lock().unwrap();
                if let Some(future_state) = &lock.1 {
-                       future_state.lock().unwrap().complete();
+                       if future_state.lock().unwrap().complete() {
+                               lock.1 = None;
+                               return;
+                       }
                }
                lock.0 = true;
                mem::drop(lock);
@@ -161,12 +164,13 @@ pub(crate) struct FutureState {
 }
 
 impl FutureState {
-       fn complete(&mut self) {
+       fn complete(&mut self) -> bool {
                for (counts_as_call, callback) in self.callbacks.drain(..) {
                        callback.call();
                        self.callbacks_made |= counts_as_call;
                }
                self.complete = true;
+               self.callbacks_made
        }
 }
 
@@ -469,4 +473,63 @@ mod tests {
                assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
                assert!(!notifier.wait_timeout(Duration::from_millis(1)));
        }
+
+       #[test]
+       fn test_poll_post_notify_completes() {
+               // Tests that if we have a future state that has completed, and we haven't yet requested a
+               // new future, if we get a notify prior to requesting that second future it is generated
+               // pre-completed.
+               let notifier = Notifier::new();
+
+               notifier.notify();
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+               assert!(!woken.load(Ordering::SeqCst));
+
+               notifier.notify();
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+               assert!(!woken.load(Ordering::SeqCst));
+
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);
+               assert!(!woken.load(Ordering::SeqCst));
+
+               notifier.notify();
+               assert!(woken.load(Ordering::SeqCst));
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+       }
+
+       #[test]
+       fn test_poll_post_notify_completes_initial_notified() {
+               // Identical to the previous test, but the first future completes via a wake rather than an
+               // immediate `Poll::Ready`.
+               let notifier = Notifier::new();
+
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);
+
+               notifier.notify();
+               assert!(woken.load(Ordering::SeqCst));
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+
+               notifier.notify();
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+               assert!(!woken.load(Ordering::SeqCst));
+
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);
+               assert!(!woken.load(Ordering::SeqCst));
+
+               notifier.notify();
+               assert!(woken.load(Ordering::SeqCst));
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+       }
 }