Merge pull request #2072 from jkczyz/2023-01-fix-scoring-div-by-zero
authorMatt Corallo <649246+TheBlueMatt@users.noreply.github.com>
Sat, 4 Mar 2023 00:06:29 +0000 (00:06 +0000)
committerGitHub <noreply@github.com>
Sat, 4 Mar 2023 00:06:29 +0000 (00:06 +0000)
Fix division by zero in `ProbabilisticScorer`

28 files changed:
.github/workflows/build.yml
fuzz/src/full_stack.rs
fuzz/src/process_network_graph.rs
lightning-background-processor/src/lib.rs
lightning-block-sync/src/lib.rs
lightning-block-sync/src/rpc.rs
lightning-rapid-gossip-sync/src/lib.rs
lightning-rapid-gossip-sync/src/processing.rs
lightning-transaction-sync/Cargo.toml
lightning-transaction-sync/src/error.rs
lightning/src/chain/channelmonitor.rs
lightning/src/chain/keysinterface.rs
lightning/src/ln/channel.rs
lightning/src/ln/channelmanager.rs
lightning/src/ln/onion_route_tests.rs
lightning/src/ln/outbound_payment.rs
lightning/src/ln/payment_tests.rs
lightning/src/ln/peer_handler.rs
lightning/src/ln/priv_short_conf_tests.rs
lightning/src/sync/debug_sync.rs
lightning/src/sync/fairrwlock.rs
lightning/src/sync/nostd_sync.rs
lightning/src/sync/test_lockorder_checks.rs
lightning/src/util/indexed_map.rs
lightning/src/util/macro_logger.rs
lightning/src/util/ser.rs
lightning/src/util/wakers.rs
no-std-check/Cargo.toml

index 2bb21c2cdbb39b87c5d9f08f8d4fe73d9d4707aa..1c1472a1007cd214ff8248e186cbf5cfab1d8d89 100644 (file)
@@ -168,7 +168,7 @@ jobs:
           done
           # check no-std compatibility across dependencies
           cd no-std-check
-          cargo check --verbose --color always
+          cargo check --verbose --color always --features lightning-transaction-sync
       - name: Build no-std-check on Rust ${{ matrix.toolchain }} for ARM Embedded
         if: "matrix.build-no-std && matrix.platform == 'ubuntu-latest'"
         run: |
index b96de0b9782443f9c4fd72da7b861a2c6727a420..05ae32e4ca0e080348955b51cc7d22dc3e71e66e 100644 (file)
@@ -42,7 +42,7 @@ use lightning::ln::msgs::{self, DecodeError};
 use lightning::ln::script::ShutdownScript;
 use lightning::routing::gossip::{P2PGossipSync, NetworkGraph};
 use lightning::routing::utxo::UtxoLookup;
-use lightning::routing::router::{find_route, InFlightHtlcs, PaymentParameters, Route, RouteHop, RouteParameters, Router};
+use lightning::routing::router::{find_route, InFlightHtlcs, PaymentParameters, Route, RouteParameters, Router};
 use lightning::routing::scoring::FixedPenaltyScorer;
 use lightning::util::config::UserConfig;
 use lightning::util::errors::APIError;
index c900a7d38d5ac0529b0912e05717f9d0f0e4c693..b4c6a29e8a99f47744fdb9fd0caf2468c7604e57 100644 (file)
@@ -7,7 +7,7 @@ use crate::utils::test_logger;
 fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
        let logger = test_logger::TestLogger::new("".to_owned(), out);
        let network_graph = lightning::routing::gossip::NetworkGraph::new(bitcoin::Network::Bitcoin, &logger);
-       let rapid_sync = RapidGossipSync::new(&network_graph);
+       let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
        let _ = rapid_sync.update_network_graph(data);
 }
 
index 8711a4aeb5898e393f173886ca8e20de4b9b0d6f..a6de0a62fb07583ae29741b87d209c0e66ee7c5d 100644 (file)
@@ -33,7 +33,9 @@ use lightning::routing::gossip::{NetworkGraph, P2PGossipSync};
 use lightning::routing::utxo::UtxoLookup;
 use lightning::routing::router::Router;
 use lightning::routing::scoring::{Score, WriteableScore};
-use lightning::util::events::{Event, EventHandler, EventsProvider, PathFailure};
+use lightning::util::events::{Event, PathFailure};
+#[cfg(feature = "std")]
+use lightning::util::events::{EventHandler, EventsProvider};
 use lightning::util::logger::Logger;
 use lightning::util::persist::Persister;
 use lightning_rapid_gossip_sync::RapidGossipSync;
@@ -953,7 +955,7 @@ mod tests {
                        let params = ChainParameters { network, best_block };
                        let manager = Arc::new(ChannelManager::new(fee_estimator.clone(), chain_monitor.clone(), tx_broadcaster.clone(), router.clone(), logger.clone(), keys_manager.clone(), keys_manager.clone(), keys_manager.clone(), UserConfig::default(), params));
                        let p2p_gossip_sync = Arc::new(P2PGossipSync::new(network_graph.clone(), Some(chain_source.clone()), logger.clone()));
-                       let rapid_gossip_sync = Arc::new(RapidGossipSync::new(network_graph.clone()));
+                       let rapid_gossip_sync = Arc::new(RapidGossipSync::new(network_graph.clone(), logger.clone()));
                        let msg_handler = MessageHandler { chan_handler: Arc::new(test_utils::TestChannelMessageHandler::new()), route_handler: Arc::new(test_utils::TestRoutingMessageHandler::new()), onion_message_handler: IgnoringMessageHandler{}};
                        let peer_manager = Arc::new(PeerManager::new(msg_handler, 0, &seed, logger.clone(), IgnoringMessageHandler{}, keys_manager.clone()));
                        let node = Node { node: manager, p2p_gossip_sync, rapid_gossip_sync, peer_manager, chain_monitor, persister, tx_broadcaster, network_graph, logger, best_block, scorer };
index 189a68be0654dab1453ef459d3046d5d0001c17d..0a7c655147ff0c85faf1297a2830d95b1f5746ce 100644 (file)
@@ -132,6 +132,9 @@ impl BlockSourceError {
        }
 
        /// Converts the error into the underlying error.
+       ///
+       /// May contain an [`std::io::Error`] from the [`BlockSource`]. See implementations for further
+       /// details, if any.
        pub fn into_inner(self) -> Box<dyn std::error::Error + Send + Sync> {
                self.error
        }
index f04769560246f8537e1e022efa22c8b7a815eab4..e1dc43c8f28d65511c45ce32591ef9eb5b933e8b 100644 (file)
@@ -13,9 +13,31 @@ use serde_json;
 
 use std::convert::TryFrom;
 use std::convert::TryInto;
+use std::error::Error;
+use std::fmt;
 use std::sync::atomic::{AtomicUsize, Ordering};
 
+/// An error returned by the RPC server.
+#[derive(Debug)]
+pub struct RpcError {
+       /// The error code.
+       pub code: i64,
+       /// The error message.
+       pub message: String,
+}
+
+impl fmt::Display for RpcError {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "RPC error {}: {}", self.code, self.message)
+    }
+}
+
+impl Error for RpcError {}
+
 /// A simple RPC client for calling methods using HTTP `POST`.
+///
+/// Implements [`BlockSource`] and may return an `Err` containing [`RpcError`]. See
+/// [`RpcClient::call_method`] for details.
 pub struct RpcClient {
        basic_auth: String,
        endpoint: HttpEndpoint,
@@ -38,6 +60,9 @@ impl RpcClient {
        }
 
        /// Calls a method with the response encoded in JSON format and interpreted as type `T`.
+       ///
+       /// When an `Err` is returned, [`std::io::Error::into_inner`] may contain an [`RpcError`] if
+       /// [`std::io::Error::kind`] is [`std::io::ErrorKind::Other`].
        pub async fn call_method<T>(&self, method: &str, params: &[serde_json::Value]) -> std::io::Result<T>
        where JsonResponse: TryFrom<Vec<u8>, Error = std::io::Error> + TryInto<T, Error = std::io::Error> {
                let host = format!("{}:{}", self.endpoint.host(), self.endpoint.port());
@@ -69,8 +94,11 @@ impl RpcClient {
                let error = &response["error"];
                if !error.is_null() {
                        // TODO: Examine error code for a more precise std::io::ErrorKind.
-                       let message = error["message"].as_str().unwrap_or("unknown error");
-                       return Err(std::io::Error::new(std::io::ErrorKind::Other, message));
+                       let rpc_error = RpcError { 
+                               code: error["code"].as_i64().unwrap_or(-1), 
+                               message: error["message"].as_str().unwrap_or("unknown error").to_string() 
+                       };
+                       return Err(std::io::Error::new(std::io::ErrorKind::Other, rpc_error));
                }
 
                let result = &mut response["result"];
@@ -163,7 +191,9 @@ mod tests {
                match client.call_method::<u64>("getblock", &[invalid_block_hash]).await {
                        Err(e) => {
                                assert_eq!(e.kind(), std::io::ErrorKind::Other);
-                               assert_eq!(e.get_ref().unwrap().to_string(), "invalid parameter");
+                               let rpc_error: Box<RpcError> = e.into_inner().unwrap().downcast().unwrap();
+                               assert_eq!(rpc_error.code, -8);
+                               assert_eq!(rpc_error.message, "invalid parameter");
                        },
                        Ok(_) => panic!("Expected error"),
                }
index 3bceb2e28e9239c0a7e8b34790275d4f51f064e2..af235b1c4224d990f2a1d71881ab23e659d61e28 100644 (file)
@@ -54,7 +54,7 @@
 //! # let logger = FakeLogger {};
 //!
 //! let network_graph = NetworkGraph::new(Network::Bitcoin, &logger);
-//! let rapid_sync = RapidGossipSync::new(&network_graph);
+//! let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
 //! let snapshot_contents: &[u8] = &[0; 0];
 //! let new_last_sync_timestamp_result = rapid_sync.update_network_graph(snapshot_contents);
 //! ```
@@ -94,14 +94,16 @@ mod processing;
 pub struct RapidGossipSync<NG: Deref<Target=NetworkGraph<L>>, L: Deref>
 where L::Target: Logger {
        network_graph: NG,
+       logger: L,
        is_initial_sync_complete: AtomicBool
 }
 
 impl<NG: Deref<Target=NetworkGraph<L>>, L: Deref> RapidGossipSync<NG, L> where L::Target: Logger {
        /// Instantiate a new [`RapidGossipSync`] instance.
-       pub fn new(network_graph: NG) -> Self {
+       pub fn new(network_graph: NG, logger: L) -> Self {
                Self {
                        network_graph,
+                       logger,
                        is_initial_sync_complete: AtomicBool::new(false)
                }
        }
@@ -228,7 +230,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let sync_result = rapid_sync.sync_network_graph_with_file_path(&graph_sync_test_file);
 
                if sync_result.is_err() {
@@ -260,7 +262,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let start = std::time::Instant::now();
                let sync_result = rapid_sync
                        .sync_network_graph_with_file_path("./res/full_graph.lngossip");
@@ -299,7 +301,7 @@ pub mod bench {
                let logger = TestLogger::new();
                b.iter(|| {
                        let network_graph = NetworkGraph::new(Network::Bitcoin, &logger);
-                       let rapid_sync = RapidGossipSync::new(&network_graph);
+                       let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                        let sync_result = rapid_sync.sync_network_graph_with_file_path("./res/full_graph.lngossip");
                        if let Err(crate::error::GraphSyncError::DecodeError(DecodeError::Io(io_error))) = &sync_result {
                                let error_string = format!("Input file lightning-rapid-gossip-sync/res/full_graph.lngossip is missing! Download it from https://bitcoin.ninja/ldk-compressed_graph-bc08df7542-2022-05-05.bin\n\n{:?}", io_error);
index 4b6de04c6556a5ef302d4dced78f1b833f5e5380..8d36dfe38844f75d189e4c27896b3f3a4d3e4891 100644 (file)
@@ -10,6 +10,7 @@ use lightning::ln::msgs::{
 };
 use lightning::routing::gossip::NetworkGraph;
 use lightning::util::logger::Logger;
+use lightning::{log_warn, log_trace, log_given_level};
 use lightning::util::ser::{BigSize, Readable};
 use lightning::io;
 
@@ -41,7 +42,7 @@ impl<NG: Deref<Target=NetworkGraph<L>>, L: Deref> RapidGossipSync<NG, L> where L
                &self,
                read_cursor: &mut R,
        ) -> Result<u32, GraphSyncError> {
-               #[allow(unused_mut)]
+               #[allow(unused_mut, unused_assignments)]
                let mut current_time_unix = None;
                #[cfg(all(feature = "std", not(test)))]
                {
@@ -120,6 +121,7 @@ impl<NG: Deref<Target=NetworkGraph<L>>, L: Deref> RapidGossipSync<NG, L> where L
                                if let ErrorAction::IgnoreDuplicateGossip = lightning_error.action {
                                        // everything is fine, just a duplicate channel announcement
                                } else {
+                                       log_warn!(self.logger, "Failed to process channel announcement: {:?}", lightning_error);
                                        return Err(lightning_error.into());
                                }
                        }
@@ -169,24 +171,19 @@ impl<NG: Deref<Target=NetworkGraph<L>>, L: Deref> RapidGossipSync<NG, L> where L
                        if (channel_flags & 0b_1000_0000) != 0 {
                                // incremental update, field flags will indicate mutated values
                                let read_only_network_graph = network_graph.read_only();
-                               if let Some(channel) = read_only_network_graph
-                                       .channels()
-                                       .get(&short_channel_id) {
-
-                                       let directional_info = channel
-                                               .get_directional_info(channel_flags)
-                                               .ok_or(LightningError {
-                                                       err: "Couldn't find previous directional data for update".to_owned(),
-                                                       action: ErrorAction::IgnoreError,
-                                               })?;
-
+                               if let Some(directional_info) =
+                                       read_only_network_graph.channels().get(&short_channel_id)
+                                       .and_then(|channel| channel.get_directional_info(channel_flags))
+                               {
                                        synthetic_update.cltv_expiry_delta = directional_info.cltv_expiry_delta;
                                        synthetic_update.htlc_minimum_msat = directional_info.htlc_minimum_msat;
                                        synthetic_update.htlc_maximum_msat = directional_info.htlc_maximum_msat;
                                        synthetic_update.fee_base_msat = directional_info.fees.base_msat;
                                        synthetic_update.fee_proportional_millionths = directional_info.fees.proportional_millionths;
-
                                } else {
+                                       log_trace!(self.logger,
+                                               "Skipping application of channel update for chan {} with flags {} as original data is missing.",
+                                               short_channel_id, channel_flags);
                                        skip_update_for_unknown_channel = true;
                                }
                        };
@@ -223,7 +220,9 @@ impl<NG: Deref<Target=NetworkGraph<L>>, L: Deref> RapidGossipSync<NG, L> where L
                        match network_graph.update_channel_unsigned(&synthetic_update) {
                                Ok(_) => {},
                                Err(LightningError { action: ErrorAction::IgnoreDuplicateGossip, .. }) => {},
-                               Err(LightningError { action: ErrorAction::IgnoreAndLog(_), .. }) => {},
+                               Err(LightningError { action: ErrorAction::IgnoreAndLog(level), err }) => {
+                                       log_given_level!(self.logger, level, "Failed to apply channel update: {:?}", err);
+                               },
                                Err(LightningError { action: ErrorAction::IgnoreError, .. }) => {},
                                Err(e) => return Err(e.into()),
                        }
@@ -287,7 +286,7 @@ mod tests {
                        0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 58, 85, 116, 216, 255, 2, 68, 226, 0, 6, 11, 0, 1, 24, 0,
                        0, 3, 232, 0, 0, 0,
                ];
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let update_result = rapid_sync.update_network_graph(&example_input[..]);
                assert!(update_result.is_err());
                if let Err(GraphSyncError::DecodeError(DecodeError::ShortRead)) = update_result {
@@ -312,7 +311,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let update_result = rapid_sync.update_network_graph(&incremental_update_input[..]);
                assert!(update_result.is_ok());
        }
@@ -340,17 +339,8 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
-               let update_result = rapid_sync.update_network_graph(&announced_update_input[..]);
-               assert!(update_result.is_err());
-               if let Err(GraphSyncError::LightningError(lightning_error)) = update_result {
-                       assert_eq!(
-                               lightning_error.err,
-                               "Couldn't find previous directional data for update"
-                       );
-               } else {
-                       panic!("Unexpected update result: {:?}", update_result)
-               }
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
+               rapid_sync.update_network_graph(&announced_update_input[..]).unwrap();
        }
 
        #[test]
@@ -376,7 +366,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let initialization_result = rapid_sync.update_network_graph(&initialization_input[..]);
                if initialization_result.is_err() {
                        panic!(
@@ -405,16 +395,7 @@ mod tests {
                        0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 136, 0, 0, 0, 221, 255, 2,
                        68, 226, 0, 6, 11, 0, 1, 128,
                ];
-               let update_result = rapid_sync.update_network_graph(&opposite_direction_incremental_update_input[..]);
-               assert!(update_result.is_err());
-               if let Err(GraphSyncError::LightningError(lightning_error)) = update_result {
-                       assert_eq!(
-                               lightning_error.err,
-                               "Couldn't find previous directional data for update"
-                       );
-               } else {
-                       panic!("Unexpected update result: {:?}", update_result)
-               }
+               rapid_sync.update_network_graph(&opposite_direction_incremental_update_input[..]).unwrap();
        }
 
        #[test]
@@ -442,7 +423,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let initialization_result = rapid_sync.update_network_graph(&initialization_input[..]);
                assert!(initialization_result.is_ok());
 
@@ -501,7 +482,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let initialization_result = rapid_sync.update_network_graph(&initialization_input[..]);
                assert!(initialization_result.is_ok());
 
@@ -526,7 +507,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let update_result = rapid_sync.update_network_graph(&VALID_RGS_BINARY);
                if update_result.is_err() {
                        panic!("Unexpected update result: {:?}", update_result)
@@ -557,7 +538,7 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                // this is mostly for checking uint underflow issues before the fuzzer does
                let update_result = rapid_sync.update_network_graph_no_std(&VALID_RGS_BINARY, Some(0));
                assert!(update_result.is_ok());
@@ -576,7 +557,7 @@ mod tests {
                        let network_graph = NetworkGraph::new(Network::Bitcoin, &logger);
                        assert_eq!(network_graph.read_only().channels().len(), 0);
 
-                       let rapid_sync = RapidGossipSync::new(&network_graph);
+                       let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                        let update_result = rapid_sync.update_network_graph_no_std(&VALID_RGS_BINARY, Some(latest_succeeding_time));
                        assert!(update_result.is_ok());
                        assert_eq!(network_graph.read_only().channels().len(), 2);
@@ -586,7 +567,7 @@ mod tests {
                        let network_graph = NetworkGraph::new(Network::Bitcoin, &logger);
                        assert_eq!(network_graph.read_only().channels().len(), 0);
 
-                       let rapid_sync = RapidGossipSync::new(&network_graph);
+                       let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                        let update_result = rapid_sync.update_network_graph_no_std(&VALID_RGS_BINARY, Some(earliest_failing_time));
                        assert!(update_result.is_err());
                        if let Err(GraphSyncError::LightningError(lightning_error)) = update_result {
@@ -622,7 +603,7 @@ mod tests {
 
                let logger = TestLogger::new();
                let network_graph = NetworkGraph::new(Network::Bitcoin, &logger);
-               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let rapid_sync = RapidGossipSync::new(&network_graph, &logger);
                let update_result = rapid_sync.update_network_graph(&unknown_version_input[..]);
 
                assert!(update_result.is_err());
index ae29753ecc4381663dad9599353dc888afabf449..9dfce6c2ad18a8c50bf9eea4c7fb0748fa1c2da6 100644 (file)
@@ -20,13 +20,14 @@ esplora-blocking = ["esplora-client/blocking"]
 async-interface = []
 
 [dependencies]
-lightning = { version = "0.0.113", path = "../lightning" }
-bitcoin = "0.29.0"
+lightning = { version = "0.0.113", path = "../lightning", default-features = false }
+bitcoin = { version = "0.29.0", default-features = false }
 bdk-macros = "0.6"
 futures = { version = "0.3", optional = true }
 esplora-client = { version = "0.3.0", default-features = false, optional = true }
 
 [dev-dependencies]
+lightning = { version = "0.0.113", path = "../lightning", features = ["std"] }
 electrsd = { version = "0.22.0", features = ["legacy", "esplora_a33e97e1", "bitcoind_23_0"] }
 electrum-client = "0.12.0"
 once_cell = "1.16.0"
index 0a529d063ec81187336df0ba6b58055093ecca7d..73d9de70169d6ea559121c2aaf9fc5656b7d1f5c 100644 (file)
@@ -22,7 +22,7 @@ impl fmt::Display for TxSyncError {
 pub(crate) enum InternalError {
        /// A transaction sync failed and needs to be retried eventually.
        Failed,
-       /// An inconsisteny was encounterd during transaction sync.
+       /// An inconsistency was encountered during transaction sync.
        Inconsistency,
 }
 
@@ -32,7 +32,7 @@ impl fmt::Display for InternalError {
                match *self {
                        Self::Failed => write!(f, "Failed to conduct transaction sync."),
                        Self::Inconsistency => {
-                               write!(f, "Encountered an inconsisteny during transaction sync.")
+                               write!(f, "Encountered an inconsistency during transaction sync.")
                        }
                }
        }
index a664c7c794efe019dd82e909c132c1fd839b9ccf..cd18ffad12d385747da26ee6e7222f9dc457b5f3 100644 (file)
@@ -37,7 +37,7 @@ use crate::ln::{PaymentHash, PaymentPreimage};
 use crate::ln::msgs::DecodeError;
 use crate::ln::chan_utils;
 use crate::ln::chan_utils::{CounterpartyCommitmentSecrets, HTLCOutputInCommitment, HTLCClaim, ChannelTransactionParameters, HolderCommitmentTransaction};
-use crate::ln::channelmanager::HTLCSource;
+use crate::ln::channelmanager::{HTLCSource, SentHTLCId};
 use crate::chain;
 use crate::chain::{BestBlock, WatchedOutput};
 use crate::chain::chaininterface::{BroadcasterInterface, FeeEstimator, LowerBoundedFeeEstimator};
@@ -494,6 +494,7 @@ pub(crate) enum ChannelMonitorUpdateStep {
        LatestHolderCommitmentTXInfo {
                commitment_tx: HolderCommitmentTransaction,
                htlc_outputs: Vec<(HTLCOutputInCommitment, Option<Signature>, Option<HTLCSource>)>,
+               claimed_htlcs: Vec<(SentHTLCId, PaymentPreimage)>,
        },
        LatestCounterpartyCommitmentTXInfo {
                commitment_txid: Txid,
@@ -536,6 +537,7 @@ impl ChannelMonitorUpdateStep {
 impl_writeable_tlv_based_enum_upgradable!(ChannelMonitorUpdateStep,
        (0, LatestHolderCommitmentTXInfo) => {
                (0, commitment_tx, required),
+               (1, claimed_htlcs, vec_type),
                (2, htlc_outputs, vec_type),
        },
        (1, LatestCounterpartyCommitmentTXInfo) => {
@@ -750,6 +752,8 @@ pub(crate) struct ChannelMonitorImpl<Signer: WriteableEcdsaChannelSigner> {
        /// Serialized to disk but should generally not be sent to Watchtowers.
        counterparty_hash_commitment_number: HashMap<PaymentHash, u64>,
 
+       counterparty_fulfilled_htlcs: HashMap<SentHTLCId, PaymentPreimage>,
+
        // We store two holder commitment transactions to avoid any race conditions where we may update
        // some monitors (potentially on watchtowers) but then fail to update others, resulting in the
        // various monitors for one channel being out of sync, and us broadcasting a holder
@@ -1033,6 +1037,7 @@ impl<Signer: WriteableEcdsaChannelSigner> Writeable for ChannelMonitorImpl<Signe
                        (9, self.counterparty_node_id, option),
                        (11, self.confirmed_commitment_tx_counterparty_output, option),
                        (13, self.spendable_txids_confirmed, vec_type),
+                       (15, self.counterparty_fulfilled_htlcs, required),
                });
 
                Ok(())
@@ -1120,6 +1125,7 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitor<Signer> {
                        counterparty_claimable_outpoints: HashMap::new(),
                        counterparty_commitment_txn_on_chain: HashMap::new(),
                        counterparty_hash_commitment_number: HashMap::new(),
+                       counterparty_fulfilled_htlcs: HashMap::new(),
 
                        prev_holder_signed_commitment_tx: None,
                        current_holder_commitment_tx: holder_commitment_tx,
@@ -1174,7 +1180,7 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitor<Signer> {
                &self, holder_commitment_tx: HolderCommitmentTransaction,
                htlc_outputs: Vec<(HTLCOutputInCommitment, Option<Signature>, Option<HTLCSource>)>,
        ) -> Result<(), ()> {
-               self.inner.lock().unwrap().provide_latest_holder_commitment_tx(holder_commitment_tx, htlc_outputs).map_err(|_| ())
+               self.inner.lock().unwrap().provide_latest_holder_commitment_tx(holder_commitment_tx, htlc_outputs, &Vec::new()).map_err(|_| ())
        }
 
        /// This is used to provide payment preimage(s) out-of-band during startup without updating the
@@ -1810,9 +1816,10 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitor<Signer> {
        /// `ChannelMonitor`. This is used to determine if an HTLC was removed from the channel prior
        /// to the `ChannelManager` having been persisted.
        ///
-       /// This is similar to [`Self::get_pending_outbound_htlcs`] except it includes HTLCs which were
-       /// resolved by this `ChannelMonitor`.
-       pub(crate) fn get_all_current_outbound_htlcs(&self) -> HashMap<HTLCSource, HTLCOutputInCommitment> {
+       /// This is similar to [`Self::get_pending_or_resolved_outbound_htlcs`] except it includes
+       /// HTLCs which were resolved on-chain (i.e. where the final HTLC resolution was done by an
+       /// event from this `ChannelMonitor`).
+       pub(crate) fn get_all_current_outbound_htlcs(&self) -> HashMap<HTLCSource, (HTLCOutputInCommitment, Option<PaymentPreimage>)> {
                let mut res = HashMap::new();
                // Just examine the available counterparty commitment transactions. See docs on
                // `fail_unbroadcast_htlcs`, below, for justification.
@@ -1822,7 +1829,8 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitor<Signer> {
                                if let Some(ref latest_outpoints) = us.counterparty_claimable_outpoints.get($txid) {
                                        for &(ref htlc, ref source_option) in latest_outpoints.iter() {
                                                if let &Some(ref source) = source_option {
-                                                       res.insert((**source).clone(), htlc.clone());
+                                                       res.insert((**source).clone(), (htlc.clone(),
+                                                               us.counterparty_fulfilled_htlcs.get(&SentHTLCId::from_source(source)).cloned()));
                                                }
                                        }
                                }
@@ -1837,9 +1845,14 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitor<Signer> {
                res
        }
 
-       /// Gets the set of outbound HTLCs which are pending resolution in this channel.
+       /// Gets the set of outbound HTLCs which are pending resolution in this channel or which were
+       /// resolved with a preimage from our counterparty.
+       ///
        /// This is used to reconstruct pending outbound payments on restart in the ChannelManager.
-       pub(crate) fn get_pending_outbound_htlcs(&self) -> HashMap<HTLCSource, HTLCOutputInCommitment> {
+       ///
+       /// Currently, the preimage is unused, however if it is present in the relevant internal state
+       /// an HTLC is always included even if it has been resolved.
+       pub(crate) fn get_pending_or_resolved_outbound_htlcs(&self) -> HashMap<HTLCSource, (HTLCOutputInCommitment, Option<PaymentPreimage>)> {
                let us = self.inner.lock().unwrap();
                // We're only concerned with the confirmation count of HTLC transactions, and don't
                // actually care how many confirmations a commitment transaction may or may not have. Thus,
@@ -1887,8 +1900,10 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitor<Signer> {
                                                                Some(commitment_tx_output_idx) == htlc.transaction_output_index
                                                        } else { false }
                                                });
-                                               if !htlc_update_confd {
-                                                       res.insert(source.clone(), htlc.clone());
+                                               let counterparty_resolved_preimage_opt =
+                                                       us.counterparty_fulfilled_htlcs.get(&SentHTLCId::from_source(source)).cloned();
+                                               if !htlc_update_confd || counterparty_resolved_preimage_opt.is_some() {
+                                                       res.insert(source.clone(), (htlc.clone(), counterparty_resolved_preimage_opt));
                                                }
                                        }
                                }
@@ -1970,6 +1985,9 @@ macro_rules! fail_unbroadcast_htlcs {
                                                                }
                                                        }
                                                        if matched_htlc { continue; }
+                                                       if $self.counterparty_fulfilled_htlcs.get(&SentHTLCId::from_source(source)).is_some() {
+                                                               continue;
+                                                       }
                                                        $self.onchain_events_awaiting_threshold_conf.retain(|ref entry| {
                                                                if entry.height != $commitment_tx_conf_height { return true; }
                                                                match entry.event {
@@ -2041,8 +2059,23 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitorImpl<Signer> {
                // Prune HTLCs from the previous counterparty commitment tx so we don't generate failure/fulfill
                // events for now-revoked/fulfilled HTLCs.
                if let Some(txid) = self.prev_counterparty_commitment_txid.take() {
-                       for &mut (_, ref mut source) in self.counterparty_claimable_outpoints.get_mut(&txid).unwrap() {
-                               *source = None;
+                       if self.current_counterparty_commitment_txid.unwrap() != txid {
+                               let cur_claimables = self.counterparty_claimable_outpoints.get(
+                                       &self.current_counterparty_commitment_txid.unwrap()).unwrap();
+                               for (_, ref source_opt) in self.counterparty_claimable_outpoints.get(&txid).unwrap() {
+                                       if let Some(source) = source_opt {
+                                               if !cur_claimables.iter()
+                                                       .any(|(_, cur_source_opt)| cur_source_opt == source_opt)
+                                               {
+                                                       self.counterparty_fulfilled_htlcs.remove(&SentHTLCId::from_source(source));
+                                               }
+                                       }
+                               }
+                               for &mut (_, ref mut source_opt) in self.counterparty_claimable_outpoints.get_mut(&txid).unwrap() {
+                                       *source_opt = None;
+                               }
+                       } else {
+                               assert!(cfg!(fuzzing), "Commitment txids are unique outside of fuzzing, where hashes can collide");
                        }
                }
 
@@ -2127,28 +2160,37 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitorImpl<Signer> {
        /// is important that any clones of this channel monitor (including remote clones) by kept
        /// up-to-date as our holder commitment transaction is updated.
        /// Panics if set_on_holder_tx_csv has never been called.
-       fn provide_latest_holder_commitment_tx(&mut self, holder_commitment_tx: HolderCommitmentTransaction, htlc_outputs: Vec<(HTLCOutputInCommitment, Option<Signature>, Option<HTLCSource>)>) -> Result<(), &'static str> {
-               // block for Rust 1.34 compat
-               let mut new_holder_commitment_tx = {
-                       let trusted_tx = holder_commitment_tx.trust();
-                       let txid = trusted_tx.txid();
-                       let tx_keys = trusted_tx.keys();
-                       self.current_holder_commitment_number = trusted_tx.commitment_number();
-                       HolderSignedTx {
-                               txid,
-                               revocation_key: tx_keys.revocation_key,
-                               a_htlc_key: tx_keys.broadcaster_htlc_key,
-                               b_htlc_key: tx_keys.countersignatory_htlc_key,
-                               delayed_payment_key: tx_keys.broadcaster_delayed_payment_key,
-                               per_commitment_point: tx_keys.per_commitment_point,
-                               htlc_outputs,
-                               to_self_value_sat: holder_commitment_tx.to_broadcaster_value_sat(),
-                               feerate_per_kw: trusted_tx.feerate_per_kw(),
-                       }
+       fn provide_latest_holder_commitment_tx(&mut self, holder_commitment_tx: HolderCommitmentTransaction, htlc_outputs: Vec<(HTLCOutputInCommitment, Option<Signature>, Option<HTLCSource>)>, claimed_htlcs: &[(SentHTLCId, PaymentPreimage)]) -> Result<(), &'static str> {
+               let trusted_tx = holder_commitment_tx.trust();
+               let txid = trusted_tx.txid();
+               let tx_keys = trusted_tx.keys();
+               self.current_holder_commitment_number = trusted_tx.commitment_number();
+               let mut new_holder_commitment_tx = HolderSignedTx {
+                       txid,
+                       revocation_key: tx_keys.revocation_key,
+                       a_htlc_key: tx_keys.broadcaster_htlc_key,
+                       b_htlc_key: tx_keys.countersignatory_htlc_key,
+                       delayed_payment_key: tx_keys.broadcaster_delayed_payment_key,
+                       per_commitment_point: tx_keys.per_commitment_point,
+                       htlc_outputs,
+                       to_self_value_sat: holder_commitment_tx.to_broadcaster_value_sat(),
+                       feerate_per_kw: trusted_tx.feerate_per_kw(),
                };
                self.onchain_tx_handler.provide_latest_holder_tx(holder_commitment_tx);
                mem::swap(&mut new_holder_commitment_tx, &mut self.current_holder_commitment_tx);
                self.prev_holder_signed_commitment_tx = Some(new_holder_commitment_tx);
+               for (claimed_htlc_id, claimed_preimage) in claimed_htlcs {
+                       #[cfg(debug_assertions)] {
+                               let cur_counterparty_htlcs = self.counterparty_claimable_outpoints.get(
+                                               &self.current_counterparty_commitment_txid.unwrap()).unwrap();
+                               assert!(cur_counterparty_htlcs.iter().any(|(_, source_opt)| {
+                                       if let Some(source) = source_opt {
+                                               SentHTLCId::from_source(source) == *claimed_htlc_id
+                                       } else { false }
+                               }));
+                       }
+                       self.counterparty_fulfilled_htlcs.insert(*claimed_htlc_id, *claimed_preimage);
+               }
                if self.holder_tx_signed {
                        return Err("Latest holder commitment signed has already been signed, update is rejected");
                }
@@ -2243,10 +2285,10 @@ impl<Signer: WriteableEcdsaChannelSigner> ChannelMonitorImpl<Signer> {
                let bounded_fee_estimator = LowerBoundedFeeEstimator::new(&*fee_estimator);
                for update in updates.updates.iter() {
                        match update {
-                               ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo { commitment_tx, htlc_outputs } => {
+                               ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo { commitment_tx, htlc_outputs, claimed_htlcs } => {
                                        log_trace!(logger, "Updating ChannelMonitor with latest holder commitment transaction info");
                                        if self.lockdown_from_offchain { panic!(); }
-                                       if let Err(e) = self.provide_latest_holder_commitment_tx(commitment_tx.clone(), htlc_outputs.clone()) {
+                                       if let Err(e) = self.provide_latest_holder_commitment_tx(commitment_tx.clone(), htlc_outputs.clone(), &claimed_htlcs) {
                                                log_error!(logger, "Providing latest holder commitment transaction failed/was refused:");
                                                log_error!(logger, "    {}", e);
                                                ret = Err(());
@@ -3868,6 +3910,7 @@ impl<'a, 'b, ES: EntropySource, SP: SignerProvider> ReadableArgs<(&'a ES, &'b SP
                let mut counterparty_node_id = None;
                let mut confirmed_commitment_tx_counterparty_output = None;
                let mut spendable_txids_confirmed = Some(Vec::new());
+               let mut counterparty_fulfilled_htlcs = Some(HashMap::new());
                read_tlv_fields!(reader, {
                        (1, funding_spend_confirmed, option),
                        (3, htlcs_resolved_on_chain, vec_type),
@@ -3876,6 +3919,7 @@ impl<'a, 'b, ES: EntropySource, SP: SignerProvider> ReadableArgs<(&'a ES, &'b SP
                        (9, counterparty_node_id, option),
                        (11, confirmed_commitment_tx_counterparty_output, option),
                        (13, spendable_txids_confirmed, vec_type),
+                       (15, counterparty_fulfilled_htlcs, option),
                });
 
                Ok((best_block.block_hash(), ChannelMonitor::from_impl(ChannelMonitorImpl {
@@ -3904,6 +3948,7 @@ impl<'a, 'b, ES: EntropySource, SP: SignerProvider> ReadableArgs<(&'a ES, &'b SP
                        counterparty_claimable_outpoints,
                        counterparty_commitment_txn_on_chain,
                        counterparty_hash_commitment_number,
+                       counterparty_fulfilled_htlcs: counterparty_fulfilled_htlcs.unwrap(),
 
                        prev_holder_signed_commitment_tx,
                        current_holder_commitment_tx,
@@ -4077,7 +4122,6 @@ mod tests {
                let fee_estimator = TestFeeEstimator { sat_per_kw: Mutex::new(253) };
 
                let dummy_key = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&[42; 32]).unwrap());
-               let dummy_tx = Transaction { version: 0, lock_time: PackedLockTime::ZERO, input: Vec::new(), output: Vec::new() };
 
                let mut preimages = Vec::new();
                {
@@ -4167,11 +4211,10 @@ mod tests {
                                                  HolderCommitmentTransaction::dummy(), best_block, dummy_key);
 
                monitor.provide_latest_holder_commitment_tx(HolderCommitmentTransaction::dummy(), preimages_to_holder_htlcs!(preimages[0..10])).unwrap();
-               let dummy_txid = dummy_tx.txid();
-               monitor.provide_latest_counterparty_commitment_tx(dummy_txid, preimages_slice_to_htlc_outputs!(preimages[5..15]), 281474976710655, dummy_key, &logger);
-               monitor.provide_latest_counterparty_commitment_tx(dummy_txid, preimages_slice_to_htlc_outputs!(preimages[15..20]), 281474976710654, dummy_key, &logger);
-               monitor.provide_latest_counterparty_commitment_tx(dummy_txid, preimages_slice_to_htlc_outputs!(preimages[17..20]), 281474976710653, dummy_key, &logger);
-               monitor.provide_latest_counterparty_commitment_tx(dummy_txid, preimages_slice_to_htlc_outputs!(preimages[18..20]), 281474976710652, dummy_key, &logger);
+               monitor.provide_latest_counterparty_commitment_tx(Txid::from_inner(Sha256::hash(b"1").into_inner()),
+                       preimages_slice_to_htlc_outputs!(preimages[5..15]), 281474976710655, dummy_key, &logger);
+               monitor.provide_latest_counterparty_commitment_tx(Txid::from_inner(Sha256::hash(b"2").into_inner()),
+                       preimages_slice_to_htlc_outputs!(preimages[15..20]), 281474976710654, dummy_key, &logger);
                for &(ref preimage, ref hash) in preimages.iter() {
                        let bounded_fee_estimator = LowerBoundedFeeEstimator::new(&fee_estimator);
                        monitor.provide_payment_preimage(hash, preimage, &broadcaster, &bounded_fee_estimator, &logger);
@@ -4185,6 +4228,9 @@ mod tests {
                test_preimages_exist!(&preimages[0..10], monitor);
                test_preimages_exist!(&preimages[15..20], monitor);
 
+               monitor.provide_latest_counterparty_commitment_tx(Txid::from_inner(Sha256::hash(b"3").into_inner()),
+                       preimages_slice_to_htlc_outputs!(preimages[17..20]), 281474976710653, dummy_key, &logger);
+
                // Now provide a further secret, pruning preimages 15-17
                secret[0..32].clone_from_slice(&hex::decode("c7518c8ae4660ed02894df8976fa1a3659c1a8b4b5bec0c4b872abeba4cb8964").unwrap());
                monitor.provide_secret(281474976710654, secret.clone()).unwrap();
@@ -4192,6 +4238,9 @@ mod tests {
                test_preimages_exist!(&preimages[0..10], monitor);
                test_preimages_exist!(&preimages[17..20], monitor);
 
+               monitor.provide_latest_counterparty_commitment_tx(Txid::from_inner(Sha256::hash(b"4").into_inner()),
+                       preimages_slice_to_htlc_outputs!(preimages[18..20]), 281474976710652, dummy_key, &logger);
+
                // Now update holder commitment tx info, pruning only element 18 as we still care about the
                // previous commitment tx's preimages too
                monitor.provide_latest_holder_commitment_tx(HolderCommitmentTransaction::dummy(), preimages_to_holder_htlcs!(preimages[0..5])).unwrap();
index a2611e7df87f79fefad954aa3b9133addaeb650e..21331fff435e601a255a6944c04a8fff820db0bc 100644 (file)
@@ -1064,6 +1064,12 @@ impl KeysManager {
                        Err(_) => panic!("Your rng is busted"),
                }
        }
+
+       /// Gets the "node_id" secret key used to sign gossip announcements, decode onion data, etc.
+       pub fn get_node_secret_key(&self) -> SecretKey {
+               self.node_secret
+       }
+
        /// Derive an old [`WriteableEcdsaChannelSigner`] containing per-channel secrets based on a key derivation parameters.
        pub fn derive_channel_keys(&self, channel_value_satoshis: u64, params: &[u8; 32]) -> InMemorySigner {
                let chan_id = u64::from_be_bytes(params[0..8].try_into().unwrap());
@@ -1458,6 +1464,17 @@ impl PhantomKeysManager {
        pub fn derive_channel_keys(&self, channel_value_satoshis: u64, params: &[u8; 32]) -> InMemorySigner {
                self.inner.derive_channel_keys(channel_value_satoshis, params)
        }
+
+       /// Gets the "node_id" secret key used to sign gossip announcements, decode onion data, etc.
+       pub fn get_node_secret_key(&self) -> SecretKey {
+               self.inner.get_node_secret_key()
+       }
+
+       /// Gets the "node_id" secret key of the phantom node used to sign invoices, decode the
+       /// last-hop onion data, etc.
+       pub fn get_phantom_node_secret_key(&self) -> SecretKey {
+               self.phantom_secret
+       }
 }
 
 // Ensure that EcdsaChannelSigner can have a vtable
index 2b9920aa2bae50879764b632f6f071588a14f89d..9af0f6379e5a1672e851a61ef0c54d717235a1a5 100644 (file)
@@ -27,7 +27,7 @@ use crate::ln::features::{ChannelTypeFeatures, InitFeatures};
 use crate::ln::msgs;
 use crate::ln::msgs::{DecodeError, OptionalField, DataLossProtect};
 use crate::ln::script::{self, ShutdownScript};
-use crate::ln::channelmanager::{self, CounterpartyForwardingInfo, PendingHTLCStatus, HTLCSource, HTLCFailureMsg, PendingHTLCInfo, RAACommitmentOrder, BREAKDOWN_TIMEOUT, MIN_CLTV_EXPIRY_DELTA, MAX_LOCAL_BREAKDOWN_TIMEOUT};
+use crate::ln::channelmanager::{self, CounterpartyForwardingInfo, PendingHTLCStatus, HTLCSource, SentHTLCId, HTLCFailureMsg, PendingHTLCInfo, RAACommitmentOrder, BREAKDOWN_TIMEOUT, MIN_CLTV_EXPIRY_DELTA, MAX_LOCAL_BREAKDOWN_TIMEOUT};
 use crate::ln::chan_utils::{CounterpartyCommitmentSecrets, TxCreationKeys, HTLCOutputInCommitment, htlc_success_tx_weight, htlc_timeout_tx_weight, make_funding_redeemscript, ChannelPublicKeys, CommitmentTransaction, HolderCommitmentTransaction, ChannelTransactionParameters, CounterpartyChannelTransactionParameters, MAX_HTLCS, get_commitment_transaction_number_obscure_factor, ClosingTransaction};
 use crate::ln::chan_utils;
 use crate::ln::onion_utils::HTLCFailReason;
@@ -192,6 +192,7 @@ enum OutboundHTLCState {
 
 #[derive(Clone)]
 enum OutboundHTLCOutcome {
+       /// LDK version 0.0.105+ will always fill in the preimage here.
        Success(Option<PaymentPreimage>),
        Failure(HTLCFailReason),
 }
@@ -2483,6 +2484,11 @@ impl<Signer: WriteableEcdsaChannelSigner> Channel<Signer> {
                                        // If they haven't ever sent an updated point, the point they send should match
                                        // the current one.
                                        self.counterparty_cur_commitment_point
+                               } else if self.cur_counterparty_commitment_transaction_number == INITIAL_COMMITMENT_NUMBER - 2 {
+                                       // If we've advanced the commitment number once, the second commitment point is
+                                       // at `counterparty_prev_commitment_point`, which is not yet revoked.
+                                       debug_assert!(self.counterparty_prev_commitment_point.is_some());
+                                       self.counterparty_prev_commitment_point
                                } else {
                                        // If they have sent updated points, channel_ready is always supposed to match
                                        // their "first" point, which we re-derive here.
@@ -3159,15 +3165,6 @@ impl<Signer: WriteableEcdsaChannelSigner> Channel<Signer> {
                        }
                }
 
-               self.latest_monitor_update_id += 1;
-               let mut monitor_update = ChannelMonitorUpdate {
-                       update_id: self.latest_monitor_update_id,
-                       updates: vec![ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo {
-                               commitment_tx: holder_commitment_tx,
-                               htlc_outputs: htlcs_and_sigs
-                       }]
-               };
-
                for htlc in self.pending_inbound_htlcs.iter_mut() {
                        let new_forward = if let &InboundHTLCState::RemoteAnnounced(ref forward_info) = &htlc.state {
                                Some(forward_info.clone())
@@ -3179,6 +3176,7 @@ impl<Signer: WriteableEcdsaChannelSigner> Channel<Signer> {
                                need_commitment = true;
                        }
                }
+               let mut claimed_htlcs = Vec::new();
                for htlc in self.pending_outbound_htlcs.iter_mut() {
                        if let &mut OutboundHTLCState::RemoteRemoved(ref mut outcome) = &mut htlc.state {
                                log_trace!(logger, "Updating HTLC {} to AwaitingRemoteRevokeToRemove due to commitment_signed in channel {}.",
@@ -3186,11 +3184,30 @@ impl<Signer: WriteableEcdsaChannelSigner> Channel<Signer> {
                                // Grab the preimage, if it exists, instead of cloning
                                let mut reason = OutboundHTLCOutcome::Success(None);
                                mem::swap(outcome, &mut reason);
+                               if let OutboundHTLCOutcome::Success(Some(preimage)) = reason {
+                                       // If a user (a) receives an HTLC claim using LDK 0.0.104 or before, then (b)
+                                       // upgrades to LDK 0.0.114 or later before the HTLC is fully resolved, we could
+                                       // have a `Success(None)` reason. In this case we could forget some HTLC
+                                       // claims, but such an upgrade is unlikely and including claimed HTLCs here
+                                       // fixes a bug which the user was exposed to on 0.0.104 when they started the
+                                       // claim anyway.
+                                       claimed_htlcs.push((SentHTLCId::from_source(&htlc.source), preimage));
+                               }
                                htlc.state = OutboundHTLCState::AwaitingRemoteRevokeToRemove(reason);
                                need_commitment = true;
                        }
                }
 
+               self.latest_monitor_update_id += 1;
+               let mut monitor_update = ChannelMonitorUpdate {
+                       update_id: self.latest_monitor_update_id,
+                       updates: vec![ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo {
+                               commitment_tx: holder_commitment_tx,
+                               htlc_outputs: htlcs_and_sigs,
+                               claimed_htlcs,
+                       }]
+               };
+
                self.cur_holder_commitment_transaction_number -= 1;
                // Note that if we need_commitment & !AwaitingRemoteRevoke we'll call
                // build_commitment_no_status_check() next which will reset this to RAAFirst.
index 0757e117ce2e7660648856053070c12773b2b6b1..62629fc548679dfa56b9a58ae711704942657e06 100644 (file)
 //! responsible for tracking which channels are open, HTLCs are in flight and reestablishing those
 //! upon reconnect to the relevant peer(s).
 //!
-//! It does not manage routing logic (see [`find_route`] for that) nor does it manage constructing
+//! It does not manage routing logic (see [`Router`] for that) nor does it manage constructing
 //! on-chain transactions (it only monitors the chain to watch for any force-closes that might
 //! imply it needs to fail HTLCs/payments/channels it manages).
-//!
-//! [`find_route`]: crate::routing::router::find_route
 
 use bitcoin::blockdata::block::BlockHeader;
 use bitcoin::blockdata::transaction::Transaction;
@@ -234,6 +232,36 @@ impl Readable for InterceptId {
                Ok(InterceptId(buf))
        }
 }
+
+#[derive(Clone, Copy, PartialEq, Eq, Hash)]
+/// Uniquely describes an HTLC by its source. Just the guaranteed-unique subset of [`HTLCSource`].
+pub(crate) enum SentHTLCId {
+       PreviousHopData { short_channel_id: u64, htlc_id: u64 },
+       OutboundRoute { session_priv: SecretKey },
+}
+impl SentHTLCId {
+       pub(crate) fn from_source(source: &HTLCSource) -> Self {
+               match source {
+                       HTLCSource::PreviousHopData(hop_data) => Self::PreviousHopData {
+                               short_channel_id: hop_data.short_channel_id,
+                               htlc_id: hop_data.htlc_id,
+                       },
+                       HTLCSource::OutboundRoute { session_priv, .. } =>
+                               Self::OutboundRoute { session_priv: *session_priv },
+               }
+       }
+}
+impl_writeable_tlv_based_enum!(SentHTLCId,
+       (0, PreviousHopData) => {
+               (0, short_channel_id, required),
+               (2, htlc_id, required),
+       },
+       (2, OutboundRoute) => {
+               (0, session_priv, required),
+       };
+);
+
+
 /// Tracks the inbound corresponding to an outbound HTLC
 #[allow(clippy::derive_hash_xor_eq)] // Our Hash is faithful to the data, we just don't have SecretKey::hash
 #[derive(Clone, PartialEq, Eq)]
@@ -1745,14 +1773,12 @@ where
                self.list_channels_with_filter(|_| true)
        }
 
-       /// Gets the list of usable channels, in random order. Useful as an argument to [`find_route`]
-       /// to ensure non-announced channels are used.
+       /// Gets the list of usable channels, in random order. Useful as an argument to
+       /// [`Router::find_route`] to ensure non-announced channels are used.
        ///
        /// These are guaranteed to have their [`ChannelDetails::is_usable`] value set to true, see the
        /// documentation for [`ChannelDetails::is_usable`] for more info on exactly what the criteria
        /// are.
-       ///
-       /// [`find_route`]: crate::routing::router::find_route
        pub fn list_usable_channels(&self) -> Vec<ChannelDetails> {
                // Note we use is_live here instead of usable which leads to somewhat confused
                // internal/external nomenclature, but that's ok cause that's probably what the user
@@ -3653,14 +3679,14 @@ where
        /// [`events::Event::PaymentClaimed`] events even for payments you intend to fail, especially on
        /// startup during which time claims that were in-progress at shutdown may be replayed.
        pub fn fail_htlc_backwards(&self, payment_hash: &PaymentHash) {
-               self.fail_htlc_backwards_with_reason(payment_hash, &FailureCode::IncorrectOrUnknownPaymentDetails);
+               self.fail_htlc_backwards_with_reason(payment_hash, FailureCode::IncorrectOrUnknownPaymentDetails);
        }
 
        /// This is a variant of [`ChannelManager::fail_htlc_backwards`] that allows you to specify the
        /// reason for the failure.
        ///
        /// See [`FailureCode`] for valid failure codes.
-       pub fn fail_htlc_backwards_with_reason(&self, payment_hash: &PaymentHash, failure_code: &FailureCode) {
+       pub fn fail_htlc_backwards_with_reason(&self, payment_hash: &PaymentHash, failure_code: FailureCode) {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
                let removed_source = self.claimable_payments.lock().unwrap().claimable_htlcs.remove(payment_hash);
@@ -3675,14 +3701,14 @@ where
        }
 
        /// Gets error data to form an [`HTLCFailReason`] given a [`FailureCode`] and [`ClaimableHTLC`].
-       fn get_htlc_fail_reason_from_failure_code(&self, failure_code: &FailureCode, htlc: &ClaimableHTLC) -> HTLCFailReason {
+       fn get_htlc_fail_reason_from_failure_code(&self, failure_code: FailureCode, htlc: &ClaimableHTLC) -> HTLCFailReason {
                match failure_code {
-                       FailureCode::TemporaryNodeFailure => HTLCFailReason::from_failure_code(*failure_code as u16),
-                       FailureCode::RequiredNodeFeatureMissing => HTLCFailReason::from_failure_code(*failure_code as u16),
+                       FailureCode::TemporaryNodeFailure => HTLCFailReason::from_failure_code(failure_code as u16),
+                       FailureCode::RequiredNodeFeatureMissing => HTLCFailReason::from_failure_code(failure_code as u16),
                        FailureCode::IncorrectOrUnknownPaymentDetails => {
                                let mut htlc_msat_height_data = htlc.value.to_be_bytes().to_vec();
                                htlc_msat_height_data.extend_from_slice(&self.best_block.read().unwrap().height().to_be_bytes());
-                               HTLCFailReason::reason(*failure_code as u16, htlc_msat_height_data)
+                               HTLCFailReason::reason(failure_code as u16, htlc_msat_height_data)
                        }
                }
        }
@@ -3986,7 +4012,7 @@ where
                        None => None
                };
 
-               let mut peer_state_opt = counterparty_node_id_opt.as_ref().map(
+               let peer_state_opt = counterparty_node_id_opt.as_ref().map(
                        |counterparty_node_id| per_peer_state.get(counterparty_node_id).map(
                                |peer_mutex| peer_mutex.lock().unwrap()
                        )
@@ -7445,6 +7471,10 @@ where
                        probing_cookie_secret = Some(args.entropy_source.get_secure_random_bytes());
                }
 
+               if !channel_closures.is_empty() {
+                       pending_events_read.append(&mut channel_closures);
+               }
+
                if pending_outbound_payments.is_none() && pending_outbound_payments_no_retry.is_none() {
                        pending_outbound_payments = Some(pending_outbound_payments_compat);
                } else if pending_outbound_payments.is_none() {
@@ -7453,7 +7483,13 @@ where
                                outbounds.insert(id, PendingOutboundPayment::Legacy { session_privs });
                        }
                        pending_outbound_payments = Some(outbounds);
-               } else {
+               }
+               let pending_outbounds = OutboundPayments {
+                       pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()),
+                       retry_lock: Mutex::new(())
+               };
+
+               {
                        // If we're tracking pending payments, ensure we haven't lost any by looking at the
                        // ChannelMonitor data for any channels for which we do not have authorative state
                        // (i.e. those for which we just force-closed above or we otherwise don't have a
@@ -7464,16 +7500,17 @@ where
                        // 0.0.102+
                        for (_, monitor) in args.channel_monitors.iter() {
                                if id_to_peer.get(&monitor.get_funding_txo().0.to_channel_id()).is_none() {
-                                       for (htlc_source, htlc) in monitor.get_pending_outbound_htlcs() {
+                                       for (htlc_source, (htlc, _)) in monitor.get_pending_or_resolved_outbound_htlcs() {
                                                if let HTLCSource::OutboundRoute { payment_id, session_priv, path, payment_secret, .. } = htlc_source {
                                                        if path.is_empty() {
                                                                log_error!(args.logger, "Got an empty path for a pending payment");
                                                                return Err(DecodeError::InvalidValue);
                                                        }
+
                                                        let path_amt = path.last().unwrap().fee_msat;
                                                        let mut session_priv_bytes = [0; 32];
                                                        session_priv_bytes[..].copy_from_slice(&session_priv[..]);
-                                                       match pending_outbound_payments.as_mut().unwrap().entry(payment_id) {
+                                                       match pending_outbounds.pending_outbound_payments.lock().unwrap().entry(payment_id) {
                                                                hash_map::Entry::Occupied(mut entry) => {
                                                                        let newly_added = entry.get_mut().insert(session_priv_bytes, &path);
                                                                        log_info!(args.logger, "{} a pending payment path for {} msat for session priv {} on an existing pending payment with payment hash {}",
@@ -7500,51 +7537,64 @@ where
                                                        }
                                                }
                                        }
-                                       for (htlc_source, htlc) in monitor.get_all_current_outbound_htlcs() {
-                                               if let HTLCSource::PreviousHopData(prev_hop_data) = htlc_source {
-                                                       let pending_forward_matches_htlc = |info: &PendingAddHTLCInfo| {
-                                                               info.prev_funding_outpoint == prev_hop_data.outpoint &&
-                                                                       info.prev_htlc_id == prev_hop_data.htlc_id
-                                                       };
-                                                       // The ChannelMonitor is now responsible for this HTLC's
-                                                       // failure/success and will let us know what its outcome is. If we
-                                                       // still have an entry for this HTLC in `forward_htlcs` or
-                                                       // `pending_intercepted_htlcs`, we were apparently not persisted after
-                                                       // the monitor was when forwarding the payment.
-                                                       forward_htlcs.retain(|_, forwards| {
-                                                               forwards.retain(|forward| {
-                                                                       if let HTLCForwardInfo::AddHTLC(htlc_info) = forward {
-                                                                               if pending_forward_matches_htlc(&htlc_info) {
-                                                                                       log_info!(args.logger, "Removing pending to-forward HTLC with hash {} as it was forwarded to the closed channel {}",
-                                                                                               log_bytes!(htlc.payment_hash.0), log_bytes!(monitor.get_funding_txo().0.to_channel_id()));
-                                                                                       false
+                                       for (htlc_source, (htlc, preimage_opt)) in monitor.get_all_current_outbound_htlcs() {
+                                               match htlc_source {
+                                                       HTLCSource::PreviousHopData(prev_hop_data) => {
+                                                               let pending_forward_matches_htlc = |info: &PendingAddHTLCInfo| {
+                                                                       info.prev_funding_outpoint == prev_hop_data.outpoint &&
+                                                                               info.prev_htlc_id == prev_hop_data.htlc_id
+                                                               };
+                                                               // The ChannelMonitor is now responsible for this HTLC's
+                                                               // failure/success and will let us know what its outcome is. If we
+                                                               // still have an entry for this HTLC in `forward_htlcs` or
+                                                               // `pending_intercepted_htlcs`, we were apparently not persisted after
+                                                               // the monitor was when forwarding the payment.
+                                                               forward_htlcs.retain(|_, forwards| {
+                                                                       forwards.retain(|forward| {
+                                                                               if let HTLCForwardInfo::AddHTLC(htlc_info) = forward {
+                                                                                       if pending_forward_matches_htlc(&htlc_info) {
+                                                                                               log_info!(args.logger, "Removing pending to-forward HTLC with hash {} as it was forwarded to the closed channel {}",
+                                                                                                       log_bytes!(htlc.payment_hash.0), log_bytes!(monitor.get_funding_txo().0.to_channel_id()));
+                                                                                               false
+                                                                                       } else { true }
                                                                                } else { true }
+                                                                       });
+                                                                       !forwards.is_empty()
+                                                               });
+                                                               pending_intercepted_htlcs.as_mut().unwrap().retain(|intercepted_id, htlc_info| {
+                                                                       if pending_forward_matches_htlc(&htlc_info) {
+                                                                               log_info!(args.logger, "Removing pending intercepted HTLC with hash {} as it was forwarded to the closed channel {}",
+                                                                                       log_bytes!(htlc.payment_hash.0), log_bytes!(monitor.get_funding_txo().0.to_channel_id()));
+                                                                               pending_events_read.retain(|event| {
+                                                                                       if let Event::HTLCIntercepted { intercept_id: ev_id, .. } = event {
+                                                                                               intercepted_id != ev_id
+                                                                                       } else { true }
+                                                                               });
+                                                                               false
                                                                        } else { true }
                                                                });
-                                                               !forwards.is_empty()
-                                                       });
-                                                       pending_intercepted_htlcs.as_mut().unwrap().retain(|intercepted_id, htlc_info| {
-                                                               if pending_forward_matches_htlc(&htlc_info) {
-                                                                       log_info!(args.logger, "Removing pending intercepted HTLC with hash {} as it was forwarded to the closed channel {}",
-                                                                               log_bytes!(htlc.payment_hash.0), log_bytes!(monitor.get_funding_txo().0.to_channel_id()));
-                                                                       pending_events_read.retain(|event| {
-                                                                               if let Event::HTLCIntercepted { intercept_id: ev_id, .. } = event {
-                                                                                       intercepted_id != ev_id
-                                                                               } else { true }
-                                                                       });
-                                                                       false
-                                                               } else { true }
-                                                       });
+                                                       },
+                                                       HTLCSource::OutboundRoute { payment_id, session_priv, path, .. } => {
+                                                               if let Some(preimage) = preimage_opt {
+                                                                       let pending_events = Mutex::new(pending_events_read);
+                                                                       // Note that we set `from_onchain` to "false" here,
+                                                                       // deliberately keeping the pending payment around forever.
+                                                                       // Given it should only occur when we have a channel we're
+                                                                       // force-closing for being stale that's okay.
+                                                                       // The alternative would be to wipe the state when claiming,
+                                                                       // generating a `PaymentPathSuccessful` event but regenerating
+                                                                       // it and the `PaymentSent` on every restart until the
+                                                                       // `ChannelMonitor` is removed.
+                                                                       pending_outbounds.claim_htlc(payment_id, preimage, session_priv, path, false, &pending_events, &args.logger);
+                                                                       pending_events_read = pending_events.into_inner().unwrap();
+                                                               }
+                                                       },
                                                }
                                        }
                                }
                        }
                }
 
-               let pending_outbounds = OutboundPayments {
-                       pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()),
-                       retry_lock: Mutex::new(())
-               };
                if !forward_htlcs.is_empty() || pending_outbounds.needs_abandon() {
                        // If we have pending HTLCs to forward, assume we either dropped a
                        // `PendingHTLCsForwardable` or the user received it but never processed it as they
@@ -7602,10 +7652,6 @@ where
                let mut secp_ctx = Secp256k1::new();
                secp_ctx.seeded_randomize(&args.entropy_source.get_secure_random_bytes());
 
-               if !channel_closures.is_empty() {
-                       pending_events_read.append(&mut channel_closures);
-               }
-
                let our_network_pubkey = match args.node_signer.get_node_id(Recipient::Node) {
                        Ok(key) => key,
                        Err(()) => return Err(DecodeError::InvalidValue)
index 6ea8ab4ce1b28e837f69678ea037bbba86aaf7e4..36e9ddb067a87ad574fac2ddd9a604352e1da190 100644 (file)
@@ -848,7 +848,7 @@ fn do_test_fail_htlc_backwards_with_reason(failure_code: FailureCode) {
 
        expect_pending_htlcs_forwardable!(nodes[1]);
        expect_payment_claimable!(nodes[1], payment_hash, payment_secret, payment_amount);
-       nodes[1].node.fail_htlc_backwards_with_reason(&payment_hash, &failure_code);
+       nodes[1].node.fail_htlc_backwards_with_reason(&payment_hash, failure_code);
 
        expect_pending_htlcs_forwardable_and_htlc_handling_failed!(nodes[1], vec![HTLCDestination::FailedPayment { payment_hash: payment_hash }]);
        check_added_monitors!(nodes[1], 1);
index 33ccecb4ca858b17c45321b3a17c2303e6931587..86b4de768f830172d1caf411b5670667be60923e 100644 (file)
@@ -276,7 +276,11 @@ pub(crate) struct PaymentAttemptsUsingTime<T: Time> {
        /// it means the result of the first attempt is not known yet.
        pub(crate) count: usize,
        /// This field is only used when retry is `Retry::Timeout` which is only build with feature std
-       first_attempted_at: T
+       #[cfg(not(feature = "no-std"))]
+       first_attempted_at: T,
+       #[cfg(feature = "no-std")]
+       phantom: core::marker::PhantomData<T>,
+
 }
 
 #[cfg(not(any(feature = "no-std", test)))]
@@ -290,7 +294,10 @@ impl<T: Time> PaymentAttemptsUsingTime<T> {
        pub(crate) fn new() -> Self {
                PaymentAttemptsUsingTime {
                        count: 0,
-                       first_attempted_at: T::now()
+                       #[cfg(not(feature = "no-std"))]
+                       first_attempted_at: T::now(),
+                       #[cfg(feature = "no-std")]
+                       phantom: core::marker::PhantomData,
                }
        }
 }
index c4cd0fc1b09d1ca691654c7fbc32e256bb39ed41..15361b98ad71fd9ee1193891aa71f4d45ef5d9a3 100644 (file)
@@ -2748,3 +2748,83 @@ fn test_threaded_payment_retries() {
                }
        }
 }
+
+fn do_no_missing_sent_on_midpoint_reload(persist_manager_with_payment: bool) {
+       // Test that if we reload in the middle of an HTLC claim commitment signed dance we'll still
+       // receive the PaymentSent event even if the ChannelManager had no idea about the payment when
+       // it was last persisted.
+       let chanmon_cfgs = create_chanmon_cfgs(2);
+       let node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
+       let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
+       let (persister_a, persister_b, persister_c);
+       let (chain_monitor_a, chain_monitor_b, chain_monitor_c);
+       let (nodes_0_deserialized, nodes_0_deserialized_b, nodes_0_deserialized_c);
+       let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs);
+
+       let chan_id = create_announced_chan_between_nodes(&nodes, 0, 1).2;
+
+       let mut nodes_0_serialized = Vec::new();
+       if !persist_manager_with_payment {
+               nodes_0_serialized = nodes[0].node.encode();
+       }
+
+       let (our_payment_preimage, our_payment_hash, _) = route_payment(&nodes[0], &[&nodes[1]], 1_000_000);
+
+       if persist_manager_with_payment {
+               nodes_0_serialized = nodes[0].node.encode();
+       }
+
+       nodes[1].node.claim_funds(our_payment_preimage);
+       check_added_monitors!(nodes[1], 1);
+       expect_payment_claimed!(nodes[1], our_payment_hash, 1_000_000);
+
+       let updates = get_htlc_update_msgs!(nodes[1], nodes[0].node.get_our_node_id());
+       nodes[0].node.handle_update_fulfill_htlc(&nodes[1].node.get_our_node_id(), &updates.update_fulfill_htlcs[0]);
+       nodes[0].node.handle_commitment_signed(&nodes[1].node.get_our_node_id(), &updates.commitment_signed);
+       check_added_monitors!(nodes[0], 1);
+
+       // The ChannelMonitor should always be the latest version, as we're required to persist it
+       // during the commitment signed handling.
+       let chan_0_monitor_serialized = get_monitor!(nodes[0], chan_id).encode();
+       reload_node!(nodes[0], test_default_channel_config(), &nodes_0_serialized, &[&chan_0_monitor_serialized], persister_a, chain_monitor_a, nodes_0_deserialized);
+
+       let events = nodes[0].node.get_and_clear_pending_events();
+       assert_eq!(events.len(), 2);
+       if let Event::ChannelClosed { reason: ClosureReason::OutdatedChannelManager, .. } = events[0] {} else { panic!(); }
+       if let Event::PaymentSent { payment_preimage, .. } = events[1] { assert_eq!(payment_preimage, our_payment_preimage); } else { panic!(); }
+       // Note that we don't get a PaymentPathSuccessful here as we leave the HTLC pending to avoid
+       // the double-claim that would otherwise appear at the end of this test.
+       let as_broadcasted_txn = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap().split_off(0);
+       assert_eq!(as_broadcasted_txn.len(), 1);
+
+       // Ensure that, even after some time, if we restart we still include *something* in the current
+       // `ChannelManager` which prevents a `PaymentFailed` when we restart even if pending resolved
+       // payments have since been timed out thanks to `IDEMPOTENCY_TIMEOUT_TICKS`.
+       // A naive implementation of the fix here would wipe the pending payments set, causing a
+       // failure event when we restart.
+       for _ in 0..(IDEMPOTENCY_TIMEOUT_TICKS * 2) { nodes[0].node.timer_tick_occurred(); }
+
+       let chan_0_monitor_serialized = get_monitor!(nodes[0], chan_id).encode();
+       reload_node!(nodes[0], test_default_channel_config(), &nodes[0].node.encode(), &[&chan_0_monitor_serialized], persister_b, chain_monitor_b, nodes_0_deserialized_b);
+       let events = nodes[0].node.get_and_clear_pending_events();
+       assert!(events.is_empty());
+
+       // Ensure that we don't generate any further events even after the channel-closing commitment
+       // transaction is confirmed on-chain.
+       confirm_transaction(&nodes[0], &as_broadcasted_txn[0]);
+       for _ in 0..(IDEMPOTENCY_TIMEOUT_TICKS * 2) { nodes[0].node.timer_tick_occurred(); }
+
+       let events = nodes[0].node.get_and_clear_pending_events();
+       assert!(events.is_empty());
+
+       let chan_0_monitor_serialized = get_monitor!(nodes[0], chan_id).encode();
+       reload_node!(nodes[0], test_default_channel_config(), &nodes[0].node.encode(), &[&chan_0_monitor_serialized], persister_c, chain_monitor_c, nodes_0_deserialized_c);
+       let events = nodes[0].node.get_and_clear_pending_events();
+       assert!(events.is_empty());
+}
+
+#[test]
+fn no_missing_sent_on_midpoint_reload() {
+       do_no_missing_sent_on_midpoint_reload(false);
+       do_no_missing_sent_on_midpoint_reload(true);
+}
index e9eaf33e8840b021afe7fbcfc60afec3e0fa5b3d..41bbf5b4908bba966f83d274dce8b2f3a2bb8ebf 100644 (file)
@@ -815,34 +815,40 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                let pending_read_buffer = [0; 50].to_vec(); // Noise act two is 50 bytes
 
                let mut peers = self.peers.write().unwrap();
-               if peers.insert(descriptor, Mutex::new(Peer {
-                       channel_encryptor: peer_encryptor,
-                       their_node_id: None,
-                       their_features: None,
-                       their_net_address: remote_network_address,
-
-                       pending_outbound_buffer: LinkedList::new(),
-                       pending_outbound_buffer_first_msg_offset: 0,
-                       gossip_broadcast_buffer: LinkedList::new(),
-                       awaiting_write_event: false,
-
-                       pending_read_buffer,
-                       pending_read_buffer_pos: 0,
-                       pending_read_is_header: false,
-
-                       sync_status: InitSyncTracker::NoSyncRequested,
-
-                       msgs_sent_since_pong: 0,
-                       awaiting_pong_timer_tick_intervals: 0,
-                       received_message_since_timer_tick: false,
-                       sent_gossip_timestamp_filter: false,
-
-                       received_channel_announce_since_backlogged: false,
-                       inbound_connection: false,
-               })).is_some() {
-                       panic!("PeerManager driver duplicated descriptors!");
-               };
-               Ok(res)
+               match peers.entry(descriptor) {
+                       hash_map::Entry::Occupied(_) => {
+                               debug_assert!(false, "PeerManager driver duplicated descriptors!");
+                               Err(PeerHandleError {})
+                       },
+                       hash_map::Entry::Vacant(e) => {
+                               e.insert(Mutex::new(Peer {
+                                       channel_encryptor: peer_encryptor,
+                                       their_node_id: None,
+                                       their_features: None,
+                                       their_net_address: remote_network_address,
+
+                                       pending_outbound_buffer: LinkedList::new(),
+                                       pending_outbound_buffer_first_msg_offset: 0,
+                                       gossip_broadcast_buffer: LinkedList::new(),
+                                       awaiting_write_event: false,
+
+                                       pending_read_buffer,
+                                       pending_read_buffer_pos: 0,
+                                       pending_read_is_header: false,
+
+                                       sync_status: InitSyncTracker::NoSyncRequested,
+
+                                       msgs_sent_since_pong: 0,
+                                       awaiting_pong_timer_tick_intervals: 0,
+                                       received_message_since_timer_tick: false,
+                                       sent_gossip_timestamp_filter: false,
+
+                                       received_channel_announce_since_backlogged: false,
+                                       inbound_connection: false,
+                               }));
+                               Ok(res)
+                       }
+               }
        }
 
        /// Indicates a new inbound connection has been established to a node with an optional remote
@@ -865,34 +871,40 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                let pending_read_buffer = [0; 50].to_vec(); // Noise act one is 50 bytes
 
                let mut peers = self.peers.write().unwrap();
-               if peers.insert(descriptor, Mutex::new(Peer {
-                       channel_encryptor: peer_encryptor,
-                       their_node_id: None,
-                       their_features: None,
-                       their_net_address: remote_network_address,
-
-                       pending_outbound_buffer: LinkedList::new(),
-                       pending_outbound_buffer_first_msg_offset: 0,
-                       gossip_broadcast_buffer: LinkedList::new(),
-                       awaiting_write_event: false,
-
-                       pending_read_buffer,
-                       pending_read_buffer_pos: 0,
-                       pending_read_is_header: false,
-
-                       sync_status: InitSyncTracker::NoSyncRequested,
-
-                       msgs_sent_since_pong: 0,
-                       awaiting_pong_timer_tick_intervals: 0,
-                       received_message_since_timer_tick: false,
-                       sent_gossip_timestamp_filter: false,
-
-                       received_channel_announce_since_backlogged: false,
-                       inbound_connection: true,
-               })).is_some() {
-                       panic!("PeerManager driver duplicated descriptors!");
-               };
-               Ok(())
+               match peers.entry(descriptor) {
+                       hash_map::Entry::Occupied(_) => {
+                               debug_assert!(false, "PeerManager driver duplicated descriptors!");
+                               Err(PeerHandleError {})
+                       },
+                       hash_map::Entry::Vacant(e) => {
+                               e.insert(Mutex::new(Peer {
+                                       channel_encryptor: peer_encryptor,
+                                       their_node_id: None,
+                                       their_features: None,
+                                       their_net_address: remote_network_address,
+
+                                       pending_outbound_buffer: LinkedList::new(),
+                                       pending_outbound_buffer_first_msg_offset: 0,
+                                       gossip_broadcast_buffer: LinkedList::new(),
+                                       awaiting_write_event: false,
+
+                                       pending_read_buffer,
+                                       pending_read_buffer_pos: 0,
+                                       pending_read_is_header: false,
+
+                                       sync_status: InitSyncTracker::NoSyncRequested,
+
+                                       msgs_sent_since_pong: 0,
+                                       awaiting_pong_timer_tick_intervals: 0,
+                                       received_message_since_timer_tick: false,
+                                       sent_gossip_timestamp_filter: false,
+
+                                       received_channel_announce_since_backlogged: false,
+                                       inbound_connection: true,
+                               }));
+                               Ok(())
+                       }
+               }
        }
 
        fn peer_should_read(&self, peer: &mut Peer) -> bool {
@@ -1141,9 +1153,13 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                macro_rules! insert_node_id {
                                                        () => {
                                                                match self.node_id_to_descriptor.lock().unwrap().entry(peer.their_node_id.unwrap().0) {
-                                                                       hash_map::Entry::Occupied(_) => {
+                                                                       hash_map::Entry::Occupied(e) => {
                                                                                log_trace!(self.logger, "Got second connection with {}, closing", log_pubkey!(peer.their_node_id.unwrap().0));
                                                                                peer.their_node_id = None; // Unset so that we don't generate a peer_disconnected event
+                                                                               // Check that the peers map is consistent with the
+                                                                               // node_id_to_descriptor map, as this has been broken
+                                                                               // before.
+                                                                               debug_assert!(peers.get(e.get()).is_some());
                                                                                return Err(PeerHandleError { })
                                                                        },
                                                                        hash_map::Entry::Vacant(entry) => {
@@ -1913,7 +1929,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                        self.do_attempt_write_data(&mut descriptor, &mut *peer, false);
                                                }
                                                self.do_disconnect(descriptor, &*peer, "DisconnectPeer HandleError");
-                                       }
+                                       } else { debug_assert!(false, "Missing connection for peer"); }
                                }
                        }
                }
@@ -1951,11 +1967,11 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                        },
                        Some(peer_lock) => {
                                let peer = peer_lock.lock().unwrap();
-                               if !peer.handshake_complete() { return; }
-                               debug_assert!(peer.their_node_id.is_some());
                                if let Some((node_id, _)) = peer.their_node_id {
                                        log_trace!(self.logger, "Handling disconnection of peer {}", log_pubkey!(node_id));
-                                       self.node_id_to_descriptor.lock().unwrap().remove(&node_id);
+                                       let removed = self.node_id_to_descriptor.lock().unwrap().remove(&node_id);
+                                       debug_assert!(removed.is_some(), "descriptor maps should be consistent");
+                                       if !peer.handshake_complete() { return; }
                                        self.message_handler.chan_handler.peer_disconnected(&node_id);
                                        self.message_handler.onion_message_handler.peer_disconnected(&node_id);
                                }
@@ -2188,12 +2204,13 @@ mod tests {
 
        use crate::prelude::*;
        use crate::sync::{Arc, Mutex};
-       use core::sync::atomic::Ordering;
+       use core::sync::atomic::{AtomicBool, Ordering};
 
        #[derive(Clone)]
        struct FileDescriptor {
                fd: u16,
                outbound_data: Arc<Mutex<Vec<u8>>>,
+               disconnect: Arc<AtomicBool>,
        }
        impl PartialEq for FileDescriptor {
                fn eq(&self, other: &Self) -> bool {
@@ -2213,7 +2230,7 @@ mod tests {
                        data.len()
                }
 
-               fn disconnect_socket(&mut self) {}
+               fn disconnect_socket(&mut self) { self.disconnect.store(true, Ordering::Release); }
        }
 
        struct PeerManagerCfg {
@@ -2254,10 +2271,16 @@ mod tests {
 
        fn establish_connection<'a>(peer_a: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, IgnoringMessageHandler, &'a test_utils::TestNodeSigner>, peer_b: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, IgnoringMessageHandler, &'a test_utils::TestNodeSigner>) -> (FileDescriptor, FileDescriptor) {
                let id_a = peer_a.node_signer.get_node_id(Recipient::Node).unwrap();
-               let mut fd_a = FileDescriptor { fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())) };
+               let mut fd_a = FileDescriptor {
+                       fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
                let addr_a = NetAddress::IPv4{addr: [127, 0, 0, 1], port: 1000};
                let id_b = peer_b.node_signer.get_node_id(Recipient::Node).unwrap();
-               let mut fd_b = FileDescriptor { fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())) };
+               let mut fd_b = FileDescriptor {
+                       fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
                let addr_b = NetAddress::IPv4{addr: [127, 0, 0, 1], port: 1001};
                let initial_data = peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap();
                peer_a.new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap();
@@ -2281,6 +2304,84 @@ mod tests {
                (fd_a.clone(), fd_b.clone())
        }
 
+       #[test]
+       #[cfg(feature = "std")]
+       fn fuzz_threaded_connections() {
+               // Spawn two threads which repeatedly connect two peers together, leading to "got second
+               // connection with peer" disconnections and rapid reconnect. This previously found an issue
+               // with our internal map consistency, and is a generally good smoke test of disconnection.
+               let cfgs = Arc::new(create_peermgr_cfgs(2));
+               // Until we have std::thread::scoped we have to unsafe { turn off the borrow checker }.
+               let peers = Arc::new(create_network(2, unsafe { &*(&*cfgs as *const _) as &'static _ }));
+
+               let start_time = std::time::Instant::now();
+               macro_rules! spawn_thread { ($id: expr) => { {
+                       let peers = Arc::clone(&peers);
+                       let cfgs = Arc::clone(&cfgs);
+                       std::thread::spawn(move || {
+                               let mut ctr = 0;
+                               while start_time.elapsed() < std::time::Duration::from_secs(1) {
+                                       let id_a = peers[0].node_signer.get_node_id(Recipient::Node).unwrap();
+                                       let mut fd_a = FileDescriptor {
+                                               fd: $id  + ctr * 3, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                                               disconnect: Arc::new(AtomicBool::new(false)),
+                                       };
+                                       let addr_a = NetAddress::IPv4{addr: [127, 0, 0, 1], port: 1000};
+                                       let mut fd_b = FileDescriptor {
+                                               fd: $id + ctr * 3, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                                               disconnect: Arc::new(AtomicBool::new(false)),
+                                       };
+                                       let addr_b = NetAddress::IPv4{addr: [127, 0, 0, 1], port: 1001};
+                                       let initial_data = peers[1].new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap();
+                                       peers[0].new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap();
+                                       assert_eq!(peers[0].read_event(&mut fd_a, &initial_data).unwrap(), false);
+
+                                       while start_time.elapsed() < std::time::Duration::from_secs(1) {
+                                               peers[0].process_events();
+                                               if fd_a.disconnect.load(Ordering::Acquire) { break; }
+                                               let a_data = fd_a.outbound_data.lock().unwrap().split_off(0);
+                                               if peers[1].read_event(&mut fd_b, &a_data).is_err() { break; }
+
+                                               peers[1].process_events();
+                                               if fd_b.disconnect.load(Ordering::Acquire) { break; }
+                                               let b_data = fd_b.outbound_data.lock().unwrap().split_off(0);
+                                               if peers[0].read_event(&mut fd_a, &b_data).is_err() { break; }
+
+                                               cfgs[0].chan_handler.pending_events.lock().unwrap()
+                                                       .push(crate::util::events::MessageSendEvent::SendShutdown {
+                                                               node_id: peers[1].node_signer.get_node_id(Recipient::Node).unwrap(),
+                                                               msg: msgs::Shutdown {
+                                                                       channel_id: [0; 32],
+                                                                       scriptpubkey: bitcoin::Script::new(),
+                                                               },
+                                                       });
+                                               cfgs[1].chan_handler.pending_events.lock().unwrap()
+                                                       .push(crate::util::events::MessageSendEvent::SendShutdown {
+                                                               node_id: peers[0].node_signer.get_node_id(Recipient::Node).unwrap(),
+                                                               msg: msgs::Shutdown {
+                                                                       channel_id: [0; 32],
+                                                                       scriptpubkey: bitcoin::Script::new(),
+                                                               },
+                                                       });
+
+                                               peers[0].timer_tick_occurred();
+                                               peers[1].timer_tick_occurred();
+                                       }
+
+                                       peers[0].socket_disconnected(&fd_a);
+                                       peers[1].socket_disconnected(&fd_b);
+                                       ctr += 1;
+                                       std::thread::sleep(std::time::Duration::from_micros(1));
+                               }
+                       })
+               } } }
+               let thrd_a = spawn_thread!(1);
+               let thrd_b = spawn_thread!(2);
+
+               thrd_a.join().unwrap();
+               thrd_b.join().unwrap();
+       }
+
        #[test]
        fn test_disconnect_peer() {
                // Simple test which builds a network of PeerManager, connects and brings them to NoiseState::Finished and
@@ -2337,7 +2438,10 @@ mod tests {
                let cfgs = create_peermgr_cfgs(2);
                let peers = create_network(2, &cfgs);
 
-               let mut fd_dup = FileDescriptor { fd: 3, outbound_data: Arc::new(Mutex::new(Vec::new())) };
+               let mut fd_dup = FileDescriptor {
+                       fd: 3, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
                let addr_dup = NetAddress::IPv4{addr: [127, 0, 0, 1], port: 1003};
                let id_a = cfgs[0].node_signer.get_node_id(Recipient::Node).unwrap();
                peers[0].new_inbound_connection(fd_dup.clone(), Some(addr_dup.clone())).unwrap();
@@ -2441,8 +2545,14 @@ mod tests {
                let peers = create_network(2, &cfgs);
 
                let a_id = peers[0].node_signer.get_node_id(Recipient::Node).unwrap();
-               let mut fd_a = FileDescriptor { fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())) };
-               let mut fd_b = FileDescriptor { fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())) };
+               let mut fd_a = FileDescriptor {
+                       fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
+               let mut fd_b = FileDescriptor {
+                       fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
                let initial_data = peers[1].new_outbound_connection(a_id, fd_b.clone(), None).unwrap();
                peers[0].new_inbound_connection(fd_a.clone(), None).unwrap();
 
index 7636f5c63641edc5e18cb5d7f69d3375a9fa235b..f563a63cfd1a41b2363ac5421c53d39b3effa7a2 100644 (file)
@@ -242,6 +242,13 @@ fn test_routed_scid_alias() {
        check_added_monitors!(nodes[0], 1);
 
        pass_along_route(&nodes[0], &[&[&nodes[1], &nodes[2]]], 100_000, payment_hash, payment_secret);
+
+       as_channel_ready.short_channel_id_alias = Some(0xeadbeef);
+       nodes[2].node.handle_channel_ready(&nodes[1].node.get_our_node_id(), &as_channel_ready);
+       // Note that we always respond to a channel_ready with a channel_update. Not a lot of reason
+       // to bother updating that code, so just drop the message here.
+       get_event_msg!(nodes[2], MessageSendEvent::SendChannelUpdate, nodes[1].node.get_our_node_id());
+
        claim_payment(&nodes[0], &[&nodes[1], &nodes[2]], payment_preimage);
 
        // Now test that if a peer sends us a second channel_ready after the channel is operational we
index 5b6acbcadd5bf38686f8ec43dc761eae6a9e3ba1..11824d5bc73182ebf3d77fc53dd7a53b2be59045 100644 (file)
@@ -129,7 +129,7 @@ impl LockMetadata {
                        // For each lock which is currently locked, check that no lock's locked-before
                        // set includes the lock we're about to lock, which would imply a lockorder
                        // inversion.
-                       for (locked_idx, locked) in held.borrow().iter() {
+                       for (locked_idx, _locked) in held.borrow().iter() {
                                if *locked_idx == this.lock_idx {
                                        // Note that with `feature = "backtrace"` set, we may be looking at different
                                        // instances of the same lock. Still, doing so is quite risky, a total order
@@ -143,7 +143,7 @@ impl LockMetadata {
                                        panic!("Tried to acquire a lock while it was held!");
                                }
                        }
-                       for (locked_idx, locked) in held.borrow().iter() {
+                       for (_locked_idx, locked) in held.borrow().iter() {
                                for (locked_dep_idx, _locked_dep) in locked.locked_before.lock().unwrap().iter() {
                                        if *locked_dep_idx == this.lock_idx && *locked_dep_idx != locked.lock_idx {
                                                #[cfg(feature = "backtrace")]
@@ -201,6 +201,11 @@ pub struct Mutex<T: Sized> {
        inner: StdMutex<T>,
        deps: Arc<LockMetadata>,
 }
+impl<T: Sized> Mutex<T> {
+       pub(crate) fn into_inner(self) -> LockResult<T> {
+               self.inner.into_inner().map_err(|_| ())
+       }
+}
 
 #[must_use = "if unused the Mutex will immediately unlock"]
 pub struct MutexGuard<'a, T: Sized + 'a> {
index de609d5b3d711059568daca1e9d408be80891321..23b8c23db282b20ad5c90d3fc937e367e6b6dac6 100644 (file)
@@ -45,6 +45,7 @@ impl<T> FairRwLock<T> {
                self.lock.read()
        }
 
+       #[allow(dead_code)]
        pub fn try_write(&self) -> TryLockResult<RwLockWriteGuard<'_, T>> {
                self.lock.try_write()
        }
index 858f60db5b5b46a5bb703f2d454d8d2b2c1f2934..17307997d8176cc2b82a0d9559307971f3e2c252 100644 (file)
@@ -60,6 +60,10 @@ impl<T> Mutex<T> {
        pub fn try_lock<'a>(&'a self) -> LockResult<MutexGuard<'a, T>> {
                Ok(MutexGuard { lock: self.inner.borrow_mut() })
        }
+
+       pub fn into_inner(self) -> LockResult<T> {
+               Ok(self.inner.into_inner())
+       }
 }
 
 impl<'a, T: 'a> LockTestExt<'a> for Mutex<T> {
index 6d72410bd596341fd5246760f03046884b93e3b5..96e497d44392985cd33b31f873751b0cc030f1d6 100644 (file)
@@ -3,7 +3,6 @@ use crate::sync::debug_sync::{RwLock, Mutex};
 use super::{LockHeldState, LockTestExt};
 
 use std::sync::Arc;
-use std::thread;
 
 #[test]
 #[should_panic]
index 3d45172517db3e706dcbf4884dd75923c872c8a5..2b5bbac0ddc583e188a9b36b2833cbb42bb74eea 100644 (file)
@@ -21,6 +21,8 @@ use core::ops::{Bound, RangeBounds};
 /// actually backed by a [`HashMap`], with some additional tracking to ensure we can iterate over
 /// keys in the order defined by [`Ord`].
 ///
+/// (C-not exported) as bindings provide alternate accessors rather than exposing maps directly.
+///
 /// [`BTreeMap`]: alloc::collections::BTreeMap
 #[derive(Clone, Debug, Eq)]
 pub struct IndexedMap<K: Hash + Ord, V> {
@@ -147,6 +149,8 @@ impl<K: Hash + Ord + PartialEq, V: PartialEq> PartialEq for IndexedMap<K, V> {
 }
 
 /// An iterator over a range of values in an [`IndexedMap`]
+///
+/// (C-not exported) as bindings provide alternate accessors rather than exposing maps directly.
 pub struct Range<'a, K: Hash + Ord, V> {
        inner_range: Iter<'a, K>,
        map: &'a HashMap<K, V>,
@@ -161,6 +165,8 @@ impl<'a, K: Hash + Ord, V: 'a> Iterator for Range<'a, K, V> {
 }
 
 /// An [`Entry`] for a key which currently has no value
+///
+/// (C-not exported) as bindings provide alternate accessors rather than exposing maps directly.
 pub struct VacantEntry<'a, K: Hash + Ord, V> {
        #[cfg(feature = "hashbrown")]
        underlying_entry: hash_map::VacantEntry<'a, K, V, hash_map::DefaultHashBuilder>,
@@ -171,6 +177,8 @@ pub struct VacantEntry<'a, K: Hash + Ord, V> {
 }
 
 /// An [`Entry`] for an existing key-value pair
+///
+/// (C-not exported) as bindings provide alternate accessors rather than exposing maps directly.
 pub struct OccupiedEntry<'a, K: Hash + Ord, V> {
        #[cfg(feature = "hashbrown")]
        underlying_entry: hash_map::OccupiedEntry<'a, K, V, hash_map::DefaultHashBuilder>,
@@ -181,6 +189,8 @@ pub struct OccupiedEntry<'a, K: Hash + Ord, V> {
 
 /// A mutable reference to a position in the map. This can be used to reference, add, or update the
 /// value at a fixed key.
+///
+/// (C-not exported) as bindings provide alternate accessors rather than exposing maps directly.
 pub enum Entry<'a, K: Hash + Ord, V> {
        /// A mutable reference to a position within the map where there is no value.
        Vacant(VacantEntry<'a, K, V>),
index e83e6e2ee48ea27d40327c4e56cb50e2827e94fc..6e98272f3171d2bc47082c4171b6e4013800e911 100644 (file)
@@ -167,17 +167,17 @@ macro_rules! log_given_level {
        ($logger: expr, $lvl:expr, $($arg:tt)+) => (
                match $lvl {
                        #[cfg(not(any(feature = "max_level_off")))]
-                       $crate::util::logger::Level::Error => log_internal!($logger, $lvl, $($arg)*),
+                       $crate::util::logger::Level::Error => $crate::log_internal!($logger, $lvl, $($arg)*),
                        #[cfg(not(any(feature = "max_level_off", feature = "max_level_error")))]
-                       $crate::util::logger::Level::Warn => log_internal!($logger, $lvl, $($arg)*),
+                       $crate::util::logger::Level::Warn => $crate::log_internal!($logger, $lvl, $($arg)*),
                        #[cfg(not(any(feature = "max_level_off", feature = "max_level_error", feature = "max_level_warn")))]
-                       $crate::util::logger::Level::Info => log_internal!($logger, $lvl, $($arg)*),
+                       $crate::util::logger::Level::Info => $crate::log_internal!($logger, $lvl, $($arg)*),
                        #[cfg(not(any(feature = "max_level_off", feature = "max_level_error", feature = "max_level_warn", feature = "max_level_info")))]
-                       $crate::util::logger::Level::Debug => log_internal!($logger, $lvl, $($arg)*),
+                       $crate::util::logger::Level::Debug => $crate::log_internal!($logger, $lvl, $($arg)*),
                        #[cfg(not(any(feature = "max_level_off", feature = "max_level_error", feature = "max_level_warn", feature = "max_level_info", feature = "max_level_debug")))]
-                       $crate::util::logger::Level::Trace => log_internal!($logger, $lvl, $($arg)*),
+                       $crate::util::logger::Level::Trace => $crate::log_internal!($logger, $lvl, $($arg)*),
                        #[cfg(not(any(feature = "max_level_off", feature = "max_level_error", feature = "max_level_warn", feature = "max_level_info", feature = "max_level_debug", feature = "max_level_trace")))]
-                       $crate::util::logger::Level::Gossip => log_internal!($logger, $lvl, $($arg)*),
+                       $crate::util::logger::Level::Gossip => $crate::log_internal!($logger, $lvl, $($arg)*),
 
                        #[cfg(any(feature = "max_level_off", feature = "max_level_error", feature = "max_level_warn", feature = "max_level_info", feature = "max_level_debug", feature = "max_level_trace"))]
                        _ => {
@@ -191,7 +191,7 @@ macro_rules! log_given_level {
 #[macro_export]
 macro_rules! log_error {
        ($logger: expr, $($arg:tt)*) => (
-               log_given_level!($logger, $crate::util::logger::Level::Error, $($arg)*);
+               $crate::log_given_level!($logger, $crate::util::logger::Level::Error, $($arg)*);
        )
 }
 
@@ -199,7 +199,7 @@ macro_rules! log_error {
 #[macro_export]
 macro_rules! log_warn {
        ($logger: expr, $($arg:tt)*) => (
-               log_given_level!($logger, $crate::util::logger::Level::Warn, $($arg)*);
+               $crate::log_given_level!($logger, $crate::util::logger::Level::Warn, $($arg)*);
        )
 }
 
@@ -207,7 +207,7 @@ macro_rules! log_warn {
 #[macro_export]
 macro_rules! log_info {
        ($logger: expr, $($arg:tt)*) => (
-               log_given_level!($logger, $crate::util::logger::Level::Info, $($arg)*);
+               $crate::log_given_level!($logger, $crate::util::logger::Level::Info, $($arg)*);
        )
 }
 
@@ -215,7 +215,7 @@ macro_rules! log_info {
 #[macro_export]
 macro_rules! log_debug {
        ($logger: expr, $($arg:tt)*) => (
-               log_given_level!($logger, $crate::util::logger::Level::Debug, $($arg)*);
+               $crate::log_given_level!($logger, $crate::util::logger::Level::Debug, $($arg)*);
        )
 }
 
@@ -223,7 +223,7 @@ macro_rules! log_debug {
 #[macro_export]
 macro_rules! log_trace {
        ($logger: expr, $($arg:tt)*) => (
-               log_given_level!($logger, $crate::util::logger::Level::Trace, $($arg)*)
+               $crate::log_given_level!($logger, $crate::util::logger::Level::Trace, $($arg)*)
        )
 }
 
@@ -231,6 +231,6 @@ macro_rules! log_trace {
 #[macro_export]
 macro_rules! log_gossip {
        ($logger: expr, $($arg:tt)*) => (
-               log_given_level!($logger, $crate::util::logger::Level::Gossip, $($arg)*);
+               $crate::log_given_level!($logger, $crate::util::logger::Level::Gossip, $($arg)*);
        )
 }
index 14c25775174b19dc50f4dabe5ecb434ce0ea259b..bef192585b5e486696068174b92539c0faa895ff 100644 (file)
@@ -89,6 +89,8 @@ impl Writer for VecWriter {
 
 /// Writer that only tracks the amount of data written - useful if you need to calculate the length
 /// of some data when serialized but don't yet need the full data.
+///
+/// (C-not exported) as manual TLV building is not currently supported in bindings
 pub struct LengthCalculatingWriter(pub usize);
 impl Writer for LengthCalculatingWriter {
        #[inline]
@@ -100,6 +102,8 @@ impl Writer for LengthCalculatingWriter {
 
 /// Essentially [`std::io::Take`] but a bit simpler and with a method to walk the underlying stream
 /// forward to ensure we always consume exactly the fixed length specified.
+///
+/// (C-not exported) as manual TLV building is not currently supported in bindings
 pub struct FixedLengthReader<R: Read> {
        read: R,
        bytes_read: u64,
@@ -155,6 +159,8 @@ impl<R: Read> LengthRead for FixedLengthReader<R> {
 
 /// A [`Read`] implementation which tracks whether any bytes have been read at all. This allows us to distinguish
 /// between "EOF reached before we started" and "EOF reached mid-read".
+///
+/// (C-not exported) as manual TLV building is not currently supported in bindings
 pub struct ReadTrackingReader<R: Read> {
        read: R,
        /// Returns whether we have read from this reader or not yet.
@@ -289,6 +295,8 @@ impl<T: Readable> MaybeReadable for T {
 }
 
 /// Wrapper to read a required (non-optional) TLV record.
+///
+/// (C-not exported) as manual TLV building is not currently supported in bindings
 pub struct RequiredWrapper<T>(pub Option<T>);
 impl<T: Readable> Readable for RequiredWrapper<T> {
        #[inline]
@@ -311,6 +319,8 @@ impl<T> From<T> for RequiredWrapper<T> {
 
 /// Wrapper to read a required (non-optional) TLV record that may have been upgraded without
 /// backwards compat.
+///
+/// (C-not exported) as manual TLV building is not currently supported in bindings
 pub struct UpgradableRequired<T: MaybeReadable>(pub Option<T>);
 impl<T: MaybeReadable> MaybeReadable for UpgradableRequired<T> {
        #[inline]
@@ -591,6 +601,8 @@ impl Readable for [u16; 8] {
 
 /// A type for variable-length values within TLV record where the length is encoded as part of the record.
 /// Used to prevent encoding the length twice.
+///
+/// (C-not exported) as manual TLV building is not currently supported in bindings
 pub struct WithoutLength<T>(pub T);
 
 impl Writeable for WithoutLength<&String> {
index fdbc22f116600b7162bdbdcafa4a6f314af944e1..f86fc376cee0202323b9924057157aaf8379f86a 100644 (file)
@@ -105,7 +105,10 @@ impl Notifier {
        pub(crate) fn notify(&self) {
                let mut lock = self.notify_pending.lock().unwrap();
                if let Some(future_state) = &lock.1 {
-                       future_state.lock().unwrap().complete();
+                       if future_state.lock().unwrap().complete() {
+                               lock.1 = None;
+                               return;
+                       }
                }
                lock.0 = true;
                mem::drop(lock);
@@ -161,12 +164,13 @@ pub(crate) struct FutureState {
 }
 
 impl FutureState {
-       fn complete(&mut self) {
+       fn complete(&mut self) -> bool {
                for (counts_as_call, callback) in self.callbacks.drain(..) {
                        callback.call();
                        self.callbacks_made |= counts_as_call;
                }
                self.complete = true;
+               self.callbacks_made
        }
 }
 
@@ -469,4 +473,63 @@ mod tests {
                assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
                assert!(!notifier.wait_timeout(Duration::from_millis(1)));
        }
+
+       #[test]
+       fn test_poll_post_notify_completes() {
+               // Tests that if we have a future state that has completed, and we haven't yet requested a
+               // new future, if we get a notify prior to requesting that second future it is generated
+               // pre-completed.
+               let notifier = Notifier::new();
+
+               notifier.notify();
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+               assert!(!woken.load(Ordering::SeqCst));
+
+               notifier.notify();
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+               assert!(!woken.load(Ordering::SeqCst));
+
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);
+               assert!(!woken.load(Ordering::SeqCst));
+
+               notifier.notify();
+               assert!(woken.load(Ordering::SeqCst));
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+       }
+
+       #[test]
+       fn test_poll_post_notify_completes_initial_notified() {
+               // Identical to the previous test, but the first future completes via a wake rather than an
+               // immediate `Poll::Ready`.
+               let notifier = Notifier::new();
+
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);
+
+               notifier.notify();
+               assert!(woken.load(Ordering::SeqCst));
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+
+               notifier.notify();
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+               assert!(!woken.load(Ordering::SeqCst));
+
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);
+               assert!(!woken.load(Ordering::SeqCst));
+
+               notifier.notify();
+               assert!(woken.load(Ordering::SeqCst));
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+       }
 }
index cb22dcea9bfa7b71541404934b32d64c998d59db..16d2fc110e2a42252dbc83afb623f6a76bff570d 100644 (file)
@@ -11,3 +11,7 @@ lightning = { path = "../lightning", default-features = false }
 lightning-invoice = { path = "../lightning-invoice", default-features = false }
 lightning-rapid-gossip-sync = { path = "../lightning-rapid-gossip-sync", default-features = false }
 lightning-background-processor = { path = "../lightning-background-processor", features = ["futures"], default-features = false }
+
+# Obviously lightning-transaction-sync doesn't support no-std, but it should build
+# even if lightning is built with no-std.
+lightning-transaction-sync = { path = "../lightning-transaction-sync", optional = true }