Merge pull request #1860 from wpaulino/open-channel-anchors-support
authorMatt Corallo <649246+TheBlueMatt@users.noreply.github.com>
Thu, 19 Jan 2023 01:00:44 +0000 (01:00 +0000)
committerGitHub <noreply@github.com>
Thu, 19 Jan 2023 01:00:44 +0000 (01:00 +0000)
Support opening anchor channels and test end-to-end unilateral close

16 files changed:
fuzz/src/chanmon_consistency.rs
fuzz/src/utils/test_persister.rs
lightning-background-processor/Cargo.toml
lightning-background-processor/src/lib.rs
lightning-invoice/src/payment.rs
lightning-net-tokio/src/lib.rs
lightning/src/chain/chainmonitor.rs
lightning/src/chain/mod.rs
lightning/src/ln/chanmon_update_fail_tests.rs
lightning/src/ln/channelmanager.rs
lightning/src/ln/functional_tests.rs
lightning/src/util/persist.rs
lightning/src/util/ser.rs
lightning/src/util/ser_macros.rs
lightning/src/util/test_utils.rs
no-std-check/Cargo.toml

index 67e7ab1b781bc6d03de28f1ff98f2fb3d0157341..3c174f3093f2e20da4ec78f71558367af5b715f1 100644 (file)
@@ -152,7 +152,7 @@ impl chain::Watch<EnforcingSigner> for TestChainMonitor {
                self.chain_monitor.watch_channel(funding_txo, monitor)
        }
 
-       fn update_channel(&self, funding_txo: OutPoint, update: channelmonitor::ChannelMonitorUpdate) -> chain::ChannelMonitorUpdateStatus {
+       fn update_channel(&self, funding_txo: OutPoint, update: &channelmonitor::ChannelMonitorUpdate) -> chain::ChannelMonitorUpdateStatus {
                let mut map_lock = self.latest_monitors.lock().unwrap();
                let mut map_entry = match map_lock.entry(funding_txo) {
                        hash_map::Entry::Occupied(entry) => entry,
@@ -160,7 +160,7 @@ impl chain::Watch<EnforcingSigner> for TestChainMonitor {
                };
                let deserialized_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::
                        read(&mut Cursor::new(&map_entry.get().1), (&*self.keys, &*self.keys)).unwrap().1;
-               deserialized_monitor.update_monitor(&update, &&TestBroadcaster{}, &FuzzEstimator { ret_val: atomic::AtomicU32::new(253) }, &self.logger).unwrap();
+               deserialized_monitor.update_monitor(update, &&TestBroadcaster{}, &FuzzEstimator { ret_val: atomic::AtomicU32::new(253) }, &self.logger).unwrap();
                let mut ser = VecWriter(Vec::new());
                deserialized_monitor.write(&mut ser).unwrap();
                map_entry.insert((update.update_id, ser.0));
index 44675fa787294b4a63144c9a2cddfacc4b9b33a6..e3635297adb037500adafebd98708025c99fa762 100644 (file)
@@ -14,7 +14,7 @@ impl chainmonitor::Persist<EnforcingSigner> for TestPersister {
                self.update_ret.lock().unwrap().clone()
        }
 
-       fn update_persisted_channel(&self, _funding_txo: OutPoint, _update: &Option<channelmonitor::ChannelMonitorUpdate>, _data: &channelmonitor::ChannelMonitor<EnforcingSigner>, _update_id: MonitorUpdateId) -> chain::ChannelMonitorUpdateStatus {
+       fn update_persisted_channel(&self, _funding_txo: OutPoint, _update: Option<&channelmonitor::ChannelMonitorUpdate>, _data: &channelmonitor::ChannelMonitor<EnforcingSigner>, _update_id: MonitorUpdateId) -> chain::ChannelMonitorUpdateStatus {
                self.update_ret.lock().unwrap().clone()
        }
 }
index 343408ea99cdee000e65ed6b7d37d9ea3fa9abe8..96ac33a917e344a7027fa270babebec1774d8639 100644 (file)
@@ -15,11 +15,14 @@ rustdoc-args = ["--cfg", "docsrs"]
 
 [features]
 futures = [ "futures-util" ]
+std = ["lightning/std", "lightning-rapid-gossip-sync/std"]
+
+default = ["std"]
 
 [dependencies]
-bitcoin = "0.29.0"
-lightning = { version = "0.0.113", path = "../lightning", features = ["std"] }
-lightning-rapid-gossip-sync = { version = "0.0.113", path = "../lightning-rapid-gossip-sync" }
+bitcoin = { version = "0.29.0", default-features = false }
+lightning = { version = "0.0.113", path = "../lightning", default-features = false }
+lightning-rapid-gossip-sync = { version = "0.0.113", path = "../lightning-rapid-gossip-sync", default-features = false }
 futures-util = { version = "0.3", default-features = false, features = ["async-await-macro"], optional = true }
 
 [dev-dependencies]
index fe4496afa2681c06ed440af85e5aae1804154ed4..4759c272dc758e2a4bbdaf097fcdc895a2fa01f2 100644 (file)
 
 #![cfg_attr(docsrs, feature(doc_auto_cfg))]
 
+#![cfg_attr(all(not(feature = "std"), not(test)), no_std)]
+
+#[cfg(any(test, feature = "std"))]
+extern crate core;
+
 #[macro_use] extern crate lightning;
 extern crate lightning_rapid_gossip_sync;
 
@@ -28,15 +33,22 @@ use lightning::util::events::{Event, EventHandler, EventsProvider};
 use lightning::util::logger::Logger;
 use lightning::util::persist::Persister;
 use lightning_rapid_gossip_sync::RapidGossipSync;
+use lightning::io;
+
+use core::ops::Deref;
+use core::time::Duration;
+
+#[cfg(feature = "std")]
 use std::sync::Arc;
-use std::sync::atomic::{AtomicBool, Ordering};
-use std::thread;
-use std::thread::JoinHandle;
-use std::time::{Duration, Instant};
-use std::ops::Deref;
+#[cfg(feature = "std")]
+use core::sync::atomic::{AtomicBool, Ordering};
+#[cfg(feature = "std")]
+use std::thread::{self, JoinHandle};
+#[cfg(feature = "std")]
+use std::time::Instant;
 
 #[cfg(feature = "futures")]
-use futures_util::{select_biased, future::FutureExt};
+use futures_util::{select_biased, future::FutureExt, task};
 
 /// `BackgroundProcessor` takes care of tasks that (1) need to happen periodically to keep
 /// Rust-Lightning running properly, and (2) either can or should be run in the background. Its
@@ -62,6 +74,7 @@ use futures_util::{select_biased, future::FutureExt};
 ///
 /// [`ChannelMonitor`]: lightning::chain::channelmonitor::ChannelMonitor
 /// [`Event`]: lightning::util::events::Event
+#[cfg(feature = "std")]
 #[must_use = "BackgroundProcessor will immediately stop on drop. It should be stored until shutdown."]
 pub struct BackgroundProcessor {
        stop_thread: Arc<AtomicBool>,
@@ -207,15 +220,15 @@ macro_rules! define_run_body {
        ($persister: ident, $chain_monitor: ident, $process_chain_monitor_events: expr,
         $channel_manager: ident, $process_channel_manager_events: expr,
         $gossip_sync: ident, $peer_manager: ident, $logger: ident, $scorer: ident,
-        $loop_exit_check: expr, $await: expr)
+        $loop_exit_check: expr, $await: expr, $get_timer: expr, $timer_elapsed: expr)
        => { {
                log_trace!($logger, "Calling ChannelManager's timer_tick_occurred on startup");
                $channel_manager.timer_tick_occurred();
 
-               let mut last_freshness_call = Instant::now();
-               let mut last_ping_call = Instant::now();
-               let mut last_prune_call = Instant::now();
-               let mut last_scorer_persist_call = Instant::now();
+               let mut last_freshness_call = $get_timer(FRESHNESS_TIMER);
+               let mut last_ping_call = $get_timer(PING_TIMER);
+               let mut last_prune_call = $get_timer(FIRST_NETWORK_PRUNE_TIMER);
+               let mut last_scorer_persist_call = $get_timer(SCORER_PERSIST_TIMER);
                let mut have_pruned = false;
 
                loop {
@@ -237,9 +250,9 @@ macro_rules! define_run_body {
 
                        // We wait up to 100ms, but track how long it takes to detect being put to sleep,
                        // see `await_start`'s use below.
-                       let await_start = Instant::now();
+                       let mut await_start = $get_timer(1);
                        let updates_available = $await;
-                       let await_time = await_start.elapsed();
+                       let await_slow = $timer_elapsed(&mut await_start, 1);
 
                        if updates_available {
                                log_trace!($logger, "Persisting ChannelManager...");
@@ -251,12 +264,12 @@ macro_rules! define_run_body {
                                log_trace!($logger, "Terminating background processor.");
                                break;
                        }
-                       if last_freshness_call.elapsed().as_secs() > FRESHNESS_TIMER {
+                       if $timer_elapsed(&mut last_freshness_call, FRESHNESS_TIMER) {
                                log_trace!($logger, "Calling ChannelManager's timer_tick_occurred");
                                $channel_manager.timer_tick_occurred();
-                               last_freshness_call = Instant::now();
+                               last_freshness_call = $get_timer(FRESHNESS_TIMER);
                        }
-                       if await_time > Duration::from_secs(1) {
+                       if await_slow {
                                // On various platforms, we may be starved of CPU cycles for several reasons.
                                // E.g. on iOS, if we've been in the background, we will be entirely paused.
                                // Similarly, if we're on a desktop platform and the device has been asleep, we
@@ -271,40 +284,46 @@ macro_rules! define_run_body {
                                // peers.
                                log_trace!($logger, "100ms sleep took more than a second, disconnecting peers.");
                                $peer_manager.disconnect_all_peers();
-                               last_ping_call = Instant::now();
-                       } else if last_ping_call.elapsed().as_secs() > PING_TIMER {
+                               last_ping_call = $get_timer(PING_TIMER);
+                       } else if $timer_elapsed(&mut last_ping_call, PING_TIMER) {
                                log_trace!($logger, "Calling PeerManager's timer_tick_occurred");
                                $peer_manager.timer_tick_occurred();
-                               last_ping_call = Instant::now();
+                               last_ping_call = $get_timer(PING_TIMER);
                        }
 
                        // Note that we want to run a graph prune once not long after startup before
                        // falling back to our usual hourly prunes. This avoids short-lived clients never
                        // pruning their network graph. We run once 60 seconds after startup before
                        // continuing our normal cadence.
-                       if last_prune_call.elapsed().as_secs() > if have_pruned { NETWORK_PRUNE_TIMER } else { FIRST_NETWORK_PRUNE_TIMER } {
+                       if $timer_elapsed(&mut last_prune_call, if have_pruned { NETWORK_PRUNE_TIMER } else { FIRST_NETWORK_PRUNE_TIMER }) {
                                // The network graph must not be pruned while rapid sync completion is pending
                                if let Some(network_graph) = $gossip_sync.prunable_network_graph() {
-                                       log_trace!($logger, "Pruning and persisting network graph.");
-                                       network_graph.remove_stale_channels_and_tracking();
+                                       #[cfg(feature = "std")] {
+                                               log_trace!($logger, "Pruning and persisting network graph.");
+                                               network_graph.remove_stale_channels_and_tracking();
+                                       }
+                                       #[cfg(not(feature = "std"))] {
+                                               log_warn!($logger, "Not pruning network graph, consider enabling `std` or doing so manually with remove_stale_channels_and_tracking_with_time.");
+                                               log_trace!($logger, "Persisting network graph.");
+                                       }
 
                                        if let Err(e) = $persister.persist_graph(network_graph) {
                                                log_error!($logger, "Error: Failed to persist network graph, check your disk and permissions {}", e)
                                        }
 
-                                       last_prune_call = Instant::now();
+                                       last_prune_call = $get_timer(NETWORK_PRUNE_TIMER);
                                        have_pruned = true;
                                }
                        }
 
-                       if last_scorer_persist_call.elapsed().as_secs() > SCORER_PERSIST_TIMER {
+                       if $timer_elapsed(&mut last_scorer_persist_call, SCORER_PERSIST_TIMER) {
                                if let Some(ref scorer) = $scorer {
                                        log_trace!($logger, "Persisting scorer");
                                        if let Err(e) = $persister.persist_scorer(&scorer) {
                                                log_error!($logger, "Error: Failed to persist scorer, check your disk and permissions {}", e)
                                        }
                                }
-                               last_scorer_persist_call = Instant::now();
+                               last_scorer_persist_call = $get_timer(SCORER_PERSIST_TIMER);
                        }
                }
 
@@ -334,6 +353,11 @@ macro_rules! define_run_body {
 /// future which outputs true, the loop will exit and this function's future will complete.
 ///
 /// See [`BackgroundProcessor::start`] for information on which actions this handles.
+///
+/// Requires the `futures` feature. Note that while this method is available without the `std`
+/// feature, doing so will skip calling [`NetworkGraph::remove_stale_channels_and_tracking`],
+/// you should call [`NetworkGraph::remove_stale_channels_and_tracking_with_time`] regularly
+/// manually instead.
 #[cfg(feature = "futures")]
 pub async fn process_events_async<
        'a,
@@ -364,13 +388,13 @@ pub async fn process_events_async<
        PM: 'static + Deref<Target = PeerManager<Descriptor, CMH, RMH, OMH, L, UMH>> + Send + Sync,
        S: 'static + Deref<Target = SC> + Send + Sync,
        SC: WriteableScore<'a>,
-       SleepFuture: core::future::Future<Output = bool>,
+       SleepFuture: core::future::Future<Output = bool> + core::marker::Unpin,
        Sleeper: Fn(Duration) -> SleepFuture
 >(
        persister: PS, event_handler: EventHandler, chain_monitor: M, channel_manager: CM,
        gossip_sync: GossipSync<PGS, RGS, G, CA, L>, peer_manager: PM, logger: L, scorer: Option<S>,
        sleeper: Sleeper,
-) -> Result<(), std::io::Error>
+) -> Result<(), io::Error>
 where
        CA::Target: 'static + chain::Access,
        CF::Target: 'static + chain::Filter,
@@ -411,9 +435,15 @@ where
                                        false
                                }
                        }
+               }, |t| sleeper(Duration::from_secs(t)),
+               |fut: &mut SleepFuture, _| {
+                       let mut waker = task::noop_waker();
+                       let mut ctx = task::Context::from_waker(&mut waker);
+                       core::pin::Pin::new(fut).poll(&mut ctx).is_ready()
                })
 }
 
+#[cfg(feature = "std")]
 impl BackgroundProcessor {
        /// Start a background thread that takes care of responsibilities enumerated in the [top-level
        /// documentation].
@@ -522,7 +552,8 @@ impl BackgroundProcessor {
                        define_run_body!(persister, chain_monitor, chain_monitor.process_pending_events(&event_handler),
                                channel_manager, channel_manager.process_pending_events(&event_handler),
                                gossip_sync, peer_manager, logger, scorer, stop_thread.load(Ordering::Acquire),
-                               channel_manager.await_persistable_update_timeout(Duration::from_millis(100)))
+                               channel_manager.await_persistable_update_timeout(Duration::from_millis(100)),
+                               |_| Instant::now(), |time: &Instant, dur| time.elapsed().as_secs() > dur)
                });
                Self { stop_thread: stop_thread_clone, thread_handle: Some(handle) }
        }
@@ -568,13 +599,14 @@ impl BackgroundProcessor {
        }
 }
 
+#[cfg(feature = "std")]
 impl Drop for BackgroundProcessor {
        fn drop(&mut self) {
                self.stop_and_join_thread().unwrap();
        }
 }
 
-#[cfg(test)]
+#[cfg(all(feature = "std", test))]
 mod tests {
        use bitcoin::blockdata::block::BlockHeader;
        use bitcoin::blockdata::constants::genesis_block;
@@ -1133,7 +1165,7 @@ mod tests {
                // Initiate the background processors to watch each node.
                let data_dir = nodes[0].persister.get_data_dir();
                let persister = Arc::new(Persister::new(data_dir));
-               let router = DefaultRouter::new(Arc::clone(&nodes[0].network_graph), Arc::clone(&nodes[0].logger), random_seed_bytes, Arc::clone(&nodes[0].scorer));
+               let router = Arc::new(DefaultRouter::new(Arc::clone(&nodes[0].network_graph), Arc::clone(&nodes[0].logger), random_seed_bytes, Arc::clone(&nodes[0].scorer)));
                let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, Arc::clone(&nodes[0].logger), |_: _| {}, Retry::Attempts(2)));
                let event_handler = Arc::clone(&invoice_payer);
                let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()));
index 0958c79649521b4eec769f090705b34792de9f4e..82f1149db1fbf5771d043b6582a33e510f8ab5f6 100644 (file)
 //! # let router = FakeRouter {};
 //! # let scorer = RefCell::new(FakeScorer {});
 //! # let logger = FakeLogger {};
-//! let invoice_payer = InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(2));
+//! let invoice_payer = InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(2));
 //!
 //! let invoice = "...";
 //! if let Ok(invoice) = invoice.parse::<Invoice>() {
@@ -184,12 +184,13 @@ mod sealed {
 /// (C-not exported) generally all users should use the [`InvoicePayer`] type alias.
 pub struct InvoicePayerUsingTime<
        P: Deref,
-       R: Router,
+       R: Deref,
        L: Deref,
        E: sealed::BaseEventHandler,
        T: Time
 > where
        P::Target: Payer,
+       R::Target: Router,
        L::Target: Logger,
 {
        payer: P,
@@ -316,10 +317,11 @@ pub enum PaymentError {
        Sending(PaymentSendFailure),
 }
 
-impl<P: Deref, R: Router, L: Deref, E: sealed::BaseEventHandler, T: Time>
+impl<P: Deref, R: Deref, L: Deref, E: sealed::BaseEventHandler, T: Time>
        InvoicePayerUsingTime<P, R, L, E, T>
 where
        P::Target: Payer,
+       R::Target: Router,
        L::Target: Logger,
 {
        /// Creates an invoice payer that retries failed payment paths.
@@ -630,10 +632,11 @@ fn has_expired(route_params: &RouteParameters) -> bool {
        } else { false }
 }
 
-impl<P: Deref, R: Router, L: Deref, E: sealed::BaseEventHandler, T: Time>
+impl<P: Deref, R: Deref, L: Deref, E: sealed::BaseEventHandler, T: Time>
        InvoicePayerUsingTime<P, R, L, E, T>
 where
        P::Target: Payer,
+       R::Target: Router,
        L::Target: Logger,
 {
        /// Returns a bool indicating whether the processed event should be forwarded to a user-provided
@@ -697,10 +700,11 @@ where
        }
 }
 
-impl<P: Deref, R: Router, L: Deref, E: EventHandler, T: Time>
+impl<P: Deref, R: Deref, L: Deref, E: EventHandler, T: Time>
        EventHandler for InvoicePayerUsingTime<P, R, L, E, T>
 where
        P::Target: Payer,
+       R::Target: Router,
        L::Target: Logger,
 {
        fn handle_event(&self, event: Event) {
@@ -711,10 +715,11 @@ where
        }
 }
 
-impl<P: Deref, R: Router, L: Deref, T: Time, F: Future, H: Fn(Event) -> F>
+impl<P: Deref, R: Deref, L: Deref, T: Time, F: Future, H: Fn(Event) -> F>
        InvoicePayerUsingTime<P, R, L, H, T>
 where
        P::Target: Payer,
+       R::Target: Router,
        L::Target: Logger,
 {
        /// Intercepts events required by the [`InvoicePayer`] and forwards them to the underlying event
@@ -746,7 +751,6 @@ mod tests {
        use secp256k1::{SecretKey, PublicKey, Secp256k1};
        use std::cell::RefCell;
        use std::collections::VecDeque;
-       use std::ops::DerefMut;
        use std::time::{SystemTime, Duration};
        use crate::time_utils::tests::SinceEpoch;
        use crate::DEFAULT_EXPIRY_TIME;
@@ -831,7 +835,7 @@ mod tests {
                let router = TestRouter::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(0));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(0));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -859,7 +863,7 @@ mod tests {
                let router = TestRouter::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(2));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -903,7 +907,7 @@ mod tests {
                let router = TestRouter::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(2));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(2));
 
                assert!(invoice_payer.pay_invoice(&invoice).is_ok());
        }
@@ -924,7 +928,7 @@ mod tests {
                let router = TestRouter::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(2));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(PaymentId([1; 32]));
                let event = Event::PaymentPathFailed {
@@ -968,7 +972,7 @@ mod tests {
                let router = TestRouter::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(2));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -1027,7 +1031,7 @@ mod tests {
                type InvoicePayerUsingSinceEpoch <P, R, L, E> = InvoicePayerUsingTime::<P, R, L, E, SinceEpoch>;
 
                let invoice_payer =
-                       InvoicePayerUsingSinceEpoch::new(&payer, router, &logger, event_handler, Retry::Timeout(Duration::from_secs(120)));
+                       InvoicePayerUsingSinceEpoch::new(&payer, &router, &logger, event_handler, Retry::Timeout(Duration::from_secs(120)));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -1066,7 +1070,7 @@ mod tests {
                let router = TestRouter::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(2));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -1097,7 +1101,7 @@ mod tests {
                let router = TestRouter::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(2));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(2));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = expired_invoice(payment_preimage);
@@ -1121,7 +1125,7 @@ mod tests {
                let router = TestRouter::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router,  &logger, event_handler, Retry::Attempts(2));
+                       InvoicePayer::new(&payer, &router,  &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -1161,7 +1165,7 @@ mod tests {
                let router = TestRouter::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(2));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -1194,7 +1198,7 @@ mod tests {
                let router = TestRouter::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(2));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -1229,7 +1233,7 @@ mod tests {
                let router = TestRouter::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(0));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(0));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
 
@@ -1267,7 +1271,7 @@ mod tests {
                let router = FailingRouter {};
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, |_: Event| {}, Retry::Attempts(0));
+                       InvoicePayer::new(&payer, &router, &logger, |_: Event| {}, Retry::Attempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -1290,7 +1294,7 @@ mod tests {
                let router = TestRouter::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, |_: Event| {}, Retry::Attempts(0));
+                       InvoicePayer::new(&payer, &router, &logger, |_: Event| {}, Retry::Attempts(0));
 
                match invoice_payer.pay_invoice(&invoice) {
                        Err(PaymentError::Sending(_)) => {},
@@ -1313,7 +1317,7 @@ mod tests {
                let router = TestRouter::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(0));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(0));
 
                let payment_id =
                        Some(invoice_payer.pay_zero_value_invoice(&invoice, final_value_msat).unwrap());
@@ -1335,7 +1339,7 @@ mod tests {
                let router = TestRouter::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router,  &logger, event_handler, Retry::Attempts(0));
+                       InvoicePayer::new(&payer, &router,  &logger, event_handler, Retry::Attempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -1365,7 +1369,7 @@ mod tests {
                let router = TestRouter::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(2));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_pubkey(
                                pubkey, payment_preimage, final_value_msat, final_cltv_expiry_delta
@@ -1420,7 +1424,7 @@ mod tests {
                let router = TestRouter::new(scorer);
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(2));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                let event = Event::PaymentPathFailed {
@@ -1455,7 +1459,7 @@ mod tests {
                let router = TestRouter::new(scorer);
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(2));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(2));
 
                let payment_id = invoice_payer.pay_invoice(&invoice).unwrap();
                let event = Event::PaymentPathSuccessful {
@@ -1498,7 +1502,7 @@ mod tests {
                let router = TestRouter::new(scorer);
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(0));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(0));
 
                // Make first invoice payment.
                invoice_payer.pay_invoice(&payment_invoice).unwrap();
@@ -1552,7 +1556,7 @@ mod tests {
                let router = TestRouter::new(scorer);
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, Retry::Attempts(2));
+                       InvoicePayer::new(&payer, &router, &logger, event_handler, Retry::Attempts(2));
 
                // Fail 1st path, leave 2nd path inflight
                let payment_id = Some(invoice_payer.pay_invoice(&payment_invoice).unwrap());
@@ -2077,7 +2081,7 @@ mod tests {
                router.expect_find_route(Ok(route.clone()));
 
                let event_handler = |_: Event| { panic!(); };
-               let invoice_payer = InvoicePayer::new(nodes[0].node, router, nodes[0].logger, event_handler, Retry::Attempts(1));
+               let invoice_payer = InvoicePayer::new(nodes[0].node, &router, nodes[0].logger, event_handler, Retry::Attempts(1));
 
                assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager_and_duration_since_epoch(
                        &nodes[1].node, nodes[1].keys_manager, nodes[1].logger, Currency::Bitcoin,
@@ -2122,7 +2126,7 @@ mod tests {
                router.expect_find_route(Ok(route.clone()));
 
                let event_handler = |_: Event| { panic!(); };
-               let invoice_payer = InvoicePayer::new(nodes[0].node, router, nodes[0].logger, event_handler, Retry::Attempts(1));
+               let invoice_payer = InvoicePayer::new(nodes[0].node, &router, nodes[0].logger, event_handler, Retry::Attempts(1));
 
                assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager_and_duration_since_epoch(
                        &nodes[1].node, nodes[1].keys_manager, nodes[1].logger, Currency::Bitcoin,
@@ -2203,7 +2207,7 @@ mod tests {
                        let event_checker = expected_events.borrow_mut().pop_front().unwrap();
                        event_checker(event);
                };
-               let invoice_payer = InvoicePayer::new(nodes[0].node, router, nodes[0].logger, event_handler, Retry::Attempts(1));
+               let invoice_payer = InvoicePayer::new(nodes[0].node, &router, nodes[0].logger, event_handler, Retry::Attempts(1));
 
                assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager_and_duration_since_epoch(
                        &nodes[1].node, nodes[1].keys_manager, nodes[1].logger, Currency::Bitcoin,
index bd6f13db85cc56e0ae5332ed810e4f3e30feb03c..45fe9b12f61b17a7b078805de1e6344efa91fbf3 100644 (file)
@@ -123,7 +123,11 @@ struct Connection {
        id: u64,
 }
 impl Connection {
-       async fn poll_event_process<CMH, RMH, OMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, CMH, RMH, OMH, L, UMH>>, mut event_receiver: mpsc::Receiver<()>) where
+       async fn poll_event_process<PM, CMH, RMH, OMH, L, UMH>(
+               peer_manager: PM,
+               mut event_receiver: mpsc::Receiver<()>,
+       ) where
+                       PM: Deref<Target = peer_handler::PeerManager<SocketDescriptor, CMH, RMH, OMH, L, UMH>> + 'static + Send + Sync,
                        CMH: Deref + 'static + Send + Sync,
                        RMH: Deref + 'static + Send + Sync,
                        OMH: Deref + 'static + Send + Sync,
@@ -134,7 +138,7 @@ impl Connection {
                        OMH::Target: OnionMessageHandler + Send + Sync,
                        L::Target: Logger + Send + Sync,
                        UMH::Target: CustomMessageHandler + Send + Sync,
-    {
+       {
                loop {
                        if event_receiver.recv().await.is_none() {
                                return;
@@ -143,7 +147,14 @@ impl Connection {
                }
        }
 
-       async fn schedule_read<CMH, RMH, OMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, CMH, RMH, OMH, L, UMH>>, us: Arc<Mutex<Self>>, mut reader: io::ReadHalf<TcpStream>, mut read_wake_receiver: mpsc::Receiver<()>, mut write_avail_receiver: mpsc::Receiver<()>) where
+       async fn schedule_read<PM, CMH, RMH, OMH, L, UMH>(
+               peer_manager: PM,
+               us: Arc<Mutex<Self>>,
+               mut reader: io::ReadHalf<TcpStream>,
+               mut read_wake_receiver: mpsc::Receiver<()>,
+               mut write_avail_receiver: mpsc::Receiver<()>,
+       ) where
+                       PM: Deref<Target = peer_handler::PeerManager<SocketDescriptor, CMH, RMH, OMH, L, UMH>> + 'static + Send + Sync + Clone,
                        CMH: Deref + 'static + Send + Sync,
                        RMH: Deref + 'static + Send + Sync,
                        OMH: Deref + 'static + Send + Sync,
@@ -154,10 +165,10 @@ impl Connection {
                        OMH::Target: OnionMessageHandler + 'static + Send + Sync,
                        L::Target: Logger + 'static + Send + Sync,
                        UMH::Target: CustomMessageHandler + 'static + Send + Sync,
-        {
+               {
                // Create a waker to wake up poll_event_process, above
                let (event_waker, event_receiver) = mpsc::channel(1);
-               tokio::spawn(Self::poll_event_process(Arc::clone(&peer_manager), event_receiver));
+               tokio::spawn(Self::poll_event_process(peer_manager.clone(), event_receiver));
 
                // 8KB is nice and big but also should never cause any issues with stack overflowing.
                let mut buf = [0; 8192];
@@ -272,7 +283,11 @@ fn get_addr_from_stream(stream: &StdTcpStream) -> Option<NetAddress> {
 /// The returned future will complete when the peer is disconnected and associated handling
 /// futures are freed, though, because all processing futures are spawned with tokio::spawn, you do
 /// not need to poll the provided future in order to make progress.
-pub fn setup_inbound<CMH, RMH, OMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, CMH, RMH, OMH, L, UMH>>, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
+pub fn setup_inbound<PM, CMH, RMH, OMH, L, UMH>(
+       peer_manager: PM,
+       stream: StdTcpStream,
+) -> impl std::future::Future<Output=()> where
+               PM: Deref<Target = peer_handler::PeerManager<SocketDescriptor, CMH, RMH, OMH, L, UMH>> + 'static + Send + Sync + Clone,
                CMH: Deref + 'static + Send + Sync,
                RMH: Deref + 'static + Send + Sync,
                OMH: Deref + 'static + Send + Sync,
@@ -321,7 +336,12 @@ pub fn setup_inbound<CMH, RMH, OMH, L, UMH>(peer_manager: Arc<peer_handler::Peer
 /// The returned future will complete when the peer is disconnected and associated handling
 /// futures are freed, though, because all processing futures are spawned with tokio::spawn, you do
 /// not need to poll the provided future in order to make progress.
-pub fn setup_outbound<CMH, RMH, OMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, CMH, RMH, OMH, L, UMH>>, their_node_id: PublicKey, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
+pub fn setup_outbound<PM, CMH, RMH, OMH, L, UMH>(
+       peer_manager: PM,
+       their_node_id: PublicKey,
+       stream: StdTcpStream,
+) -> impl std::future::Future<Output=()> where
+               PM: Deref<Target = peer_handler::PeerManager<SocketDescriptor, CMH, RMH, OMH, L, UMH>> + 'static + Send + Sync + Clone,
                CMH: Deref + 'static + Send + Sync,
                RMH: Deref + 'static + Send + Sync,
                OMH: Deref + 'static + Send + Sync,
@@ -399,7 +419,12 @@ pub fn setup_outbound<CMH, RMH, OMH, L, UMH>(peer_manager: Arc<peer_handler::Pee
 /// disconnected and associated handling futures are freed, though, because all processing in said
 /// futures are spawned with tokio::spawn, you do not need to poll the second future in order to
 /// make progress.
-pub async fn connect_outbound<CMH, RMH, OMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, CMH, RMH, OMH, L, UMH>>, their_node_id: PublicKey, addr: SocketAddr) -> Option<impl std::future::Future<Output=()>> where
+pub async fn connect_outbound<PM, CMH, RMH, OMH, L, UMH>(
+       peer_manager: PM,
+       their_node_id: PublicKey,
+       addr: SocketAddr,
+) -> Option<impl std::future::Future<Output=()>> where
+               PM: Deref<Target = peer_handler::PeerManager<SocketDescriptor, CMH, RMH, OMH, L, UMH>> + 'static + Send + Sync + Clone,
                CMH: Deref + 'static + Send + Sync,
                RMH: Deref + 'static + Send + Sync,
                OMH: Deref + 'static + Send + Sync,
index 8e97a62630deea7080d38fdc96b97d608225c679..430f6bbac1d2244b7d49c37d10a51153484a5c28 100644 (file)
@@ -144,7 +144,7 @@ pub trait Persist<ChannelSigner: Sign> {
        /// [`ChannelMonitorUpdateStatus`] for requirements when returning errors.
        ///
        /// [`Writeable::write`]: crate::util::ser::Writeable::write
-       fn update_persisted_channel(&self, channel_id: OutPoint, update: &Option<ChannelMonitorUpdate>, data: &ChannelMonitor<ChannelSigner>, update_id: MonitorUpdateId) -> ChannelMonitorUpdateStatus;
+       fn update_persisted_channel(&self, channel_id: OutPoint, update: Option<&ChannelMonitorUpdate>, data: &ChannelMonitor<ChannelSigner>, update_id: MonitorUpdateId) -> ChannelMonitorUpdateStatus;
 }
 
 struct MonitorHolder<ChannelSigner: Sign> {
@@ -294,7 +294,7 @@ where C::Target: chain::Filter,
                                }
 
                                log_trace!(self.logger, "Syncing Channel Monitor for channel {}", log_funding_info!(monitor));
-                               match self.persister.update_persisted_channel(*funding_outpoint, &None, monitor, update_id) {
+                               match self.persister.update_persisted_channel(*funding_outpoint, None, monitor, update_id) {
                                        ChannelMonitorUpdateStatus::Completed =>
                                                log_trace!(self.logger, "Finished syncing Channel Monitor for channel {}", log_funding_info!(monitor)),
                                        ChannelMonitorUpdateStatus::PermanentFailure => {
@@ -646,7 +646,7 @@ where C::Target: chain::Filter,
 
        /// Note that we persist the given `ChannelMonitor` update while holding the
        /// `ChainMonitor` monitors lock.
-       fn update_channel(&self, funding_txo: OutPoint, update: ChannelMonitorUpdate) -> ChannelMonitorUpdateStatus {
+       fn update_channel(&self, funding_txo: OutPoint, update: &ChannelMonitorUpdate) -> ChannelMonitorUpdateStatus {
                // Update the monitor that watches the channel referred to by the given outpoint.
                let monitors = self.monitors.read().unwrap();
                match monitors.get(&funding_txo) {
@@ -664,15 +664,15 @@ where C::Target: chain::Filter,
                        Some(monitor_state) => {
                                let monitor = &monitor_state.monitor;
                                log_trace!(self.logger, "Updating ChannelMonitor for channel {}", log_funding_info!(monitor));
-                               let update_res = monitor.update_monitor(&update, &self.broadcaster, &*self.fee_estimator, &self.logger);
+                               let update_res = monitor.update_monitor(update, &self.broadcaster, &*self.fee_estimator, &self.logger);
                                if update_res.is_err() {
                                        log_error!(self.logger, "Failed to update ChannelMonitor for channel {}.", log_funding_info!(monitor));
                                }
                                // Even if updating the monitor returns an error, the monitor's state will
                                // still be changed. So, persist the updated monitor despite the error.
-                               let update_id = MonitorUpdateId::from_monitor_update(&update);
+                               let update_id = MonitorUpdateId::from_monitor_update(update);
                                let mut pending_monitor_updates = monitor_state.pending_monitor_updates.lock().unwrap();
-                               let persist_res = self.persister.update_persisted_channel(funding_txo, &Some(update), monitor, update_id);
+                               let persist_res = self.persister.update_persisted_channel(funding_txo, Some(update), monitor, update_id);
                                match persist_res {
                                        ChannelMonitorUpdateStatus::InProgress => {
                                                pending_monitor_updates.push(update_id);
index 43add75950d6a50c3c07c5a07e525fd7154bf216..19218ed23274998b9702f74f9cc73e26ddde28a8 100644 (file)
@@ -312,7 +312,7 @@ pub trait Watch<ChannelSigner: Sign> {
        /// [`ChannelMonitorUpdateStatus`] for invariants around returning an error.
        ///
        /// [`update_monitor`]: channelmonitor::ChannelMonitor::update_monitor
-       fn update_channel(&self, funding_txo: OutPoint, update: ChannelMonitorUpdate) -> ChannelMonitorUpdateStatus;
+       fn update_channel(&self, funding_txo: OutPoint, update: &ChannelMonitorUpdate) -> ChannelMonitorUpdateStatus;
 
        /// Returns any monitor events since the last call. Subsequent calls must only return new
        /// events.
index 24f0ffdcd4922d9a6743b67aa9e3c98dfa092e29..385f5b5353ca5f697880c441061de805cfbc2d28 100644 (file)
@@ -147,9 +147,9 @@ fn test_monitor_and_persister_update_fail() {
                        // Check that even though the persister is returning a InProgress,
                        // because the update is bogus, ultimately the error that's returned
                        // should be a PermanentFailure.
-                       if let ChannelMonitorUpdateStatus::PermanentFailure = chain_mon.chain_monitor.update_channel(outpoint, update.clone()) {} else { panic!("Expected monitor error to be permanent"); }
+                       if let ChannelMonitorUpdateStatus::PermanentFailure = chain_mon.chain_monitor.update_channel(outpoint, &update) {} else { panic!("Expected monitor error to be permanent"); }
                        logger.assert_log_regex("lightning::chain::chainmonitor".to_string(), regex::Regex::new("Persistence of ChannelMonitorUpdate for channel [0-9a-f]* in progress").unwrap(), 1);
-                       assert_eq!(nodes[0].chain_monitor.update_channel(outpoint, update), ChannelMonitorUpdateStatus::Completed);
+                       assert_eq!(nodes[0].chain_monitor.update_channel(outpoint, &update), ChannelMonitorUpdateStatus::Completed);
                } else { assert!(false); }
        }
 
index c8798a5cdffe25b7b8c446218a00ad8ca3632e97..af83edc5634d37d76d18c95bf61ede56999cdf23 100644 (file)
@@ -577,13 +577,13 @@ pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> = C
 //  |   |
 //  |   |__`pending_intercepted_htlcs`
 //  |
-//  |__`pending_inbound_payments`
+//  |__`per_peer_state`
 //  |   |
-//  |   |__`claimable_payments`
-//  |   |
-//  |   |__`pending_outbound_payments` // This field's struct contains a map of pending outbounds
+//  |   |__`pending_inbound_payments`
+//  |       |
+//  |       |__`claimable_payments`
 //  |       |
-//  |       |__`per_peer_state`
+//  |       |__`pending_outbound_payments` // This field's struct contains a map of pending outbounds
 //  |           |
 //  |           |__`peer_state`
 //  |               |
@@ -1709,7 +1709,7 @@ where
 
                                        // Update the monitor with the shutdown script if necessary.
                                        if let Some(monitor_update) = monitor_update {
-                                               let update_res = self.chain_monitor.update_channel(chan_entry.get().get_funding_txo().unwrap(), monitor_update);
+                                               let update_res = self.chain_monitor.update_channel(chan_entry.get().get_funding_txo().unwrap(), &monitor_update);
                                                let (result, is_permanent) =
                                                        handle_monitor_update_res!(self, update_res, chan_entry.get_mut(), RAACommitmentOrder::CommitmentFirst, chan_entry.key(), NO_UPDATE);
                                                if is_permanent {
@@ -1807,7 +1807,7 @@ where
                        // force-closing. The monitor update on the required in-memory copy should broadcast
                        // the latest local state, which is the best we can do anyway. Thus, it is safe to
                        // ignore the result here.
-                       let _ = self.chain_monitor.update_channel(funding_txo, monitor_update);
+                       let _ = self.chain_monitor.update_channel(funding_txo, &monitor_update);
                }
        }
 
@@ -2336,7 +2336,7 @@ where
                                                chan)
                                } {
                                        Some((update_add, commitment_signed, monitor_update)) => {
-                                               let update_err = self.chain_monitor.update_channel(chan.get().get_funding_txo().unwrap(), monitor_update);
+                                               let update_err = self.chain_monitor.update_channel(chan.get().get_funding_txo().unwrap(), &monitor_update);
                                                let chan_id = chan.get().channel_id();
                                                match (update_err,
                                                        handle_monitor_update_res!(self, update_err, chan,
@@ -3284,7 +3284,7 @@ where
                                BackgroundEvent::ClosingMonitorUpdate((funding_txo, update)) => {
                                        // The channel has already been closed, so no use bothering to care about the
                                        // monitor updating completing.
-                                       let _ = self.chain_monitor.update_channel(funding_txo, update);
+                                       let _ = self.chain_monitor.update_channel(funding_txo, &update);
                                },
                        }
                }
@@ -3570,9 +3570,12 @@ where
                        // Ensure that no peer state channel storage lock is not held when calling this
                        // function.
                        // This ensures that future code doesn't introduce a lock_order requirement for
-                       // `forward_htlcs` to be locked after the `per_peer_state` locks, which calling this
-                       // function with the `per_peer_state` aquired would.
-                       assert!(self.per_peer_state.try_write().is_ok());
+                       // `forward_htlcs` to be locked after the `per_peer_state` peer locks, which calling
+                       // this function with any `per_peer_state` peer lock aquired would.
+                       let per_peer_state = self.per_peer_state.read().unwrap();
+                       for (_, peer) in per_peer_state.iter() {
+                               assert!(peer.try_lock().is_ok());
+                       }
                }
 
                //TODO: There is a timing attack here where if a node fails an HTLC back to us they can
@@ -3807,7 +3810,7 @@ where
                                match chan.get_mut().get_update_fulfill_htlc_and_commit(prev_hop.htlc_id, payment_preimage, &self.logger) {
                                        Ok(msgs_monitor_option) => {
                                                if let UpdateFulfillCommitFetch::NewClaim { msgs, htlc_value_msat, monitor_update } = msgs_monitor_option {
-                                                       match self.chain_monitor.update_channel(chan.get().get_funding_txo().unwrap(), monitor_update) {
+                                                       match self.chain_monitor.update_channel(chan.get().get_funding_txo().unwrap(), &monitor_update) {
                                                                ChannelMonitorUpdateStatus::Completed => {},
                                                                e => {
                                                                        log_given_level!(self.logger, if e == ChannelMonitorUpdateStatus::PermanentFailure { Level::Error } else { Level::Debug },
@@ -3844,7 +3847,7 @@ where
                                                }
                                        },
                                        Err((e, monitor_update)) => {
-                                               match self.chain_monitor.update_channel(chan.get().get_funding_txo().unwrap(), monitor_update) {
+                                               match self.chain_monitor.update_channel(chan.get().get_funding_txo().unwrap(), &monitor_update) {
                                                        ChannelMonitorUpdateStatus::Completed => {},
                                                        e => {
                                                                // TODO: This needs to be handled somehow - if we receive a monitor update
@@ -3880,7 +3883,7 @@ where
                        };
                        // We update the ChannelMonitor on the backward link, after
                        // receiving an `update_fulfill_htlc` from the forward link.
-                       let update_res = self.chain_monitor.update_channel(prev_hop.outpoint, preimage_update);
+                       let update_res = self.chain_monitor.update_channel(prev_hop.outpoint, &preimage_update);
                        if update_res != ChannelMonitorUpdateStatus::Completed {
                                // TODO: This needs to be handled somehow - if we receive a monitor update
                                // with a preimage we *must* somehow manage to propagate it to the upstream
@@ -4449,7 +4452,7 @@ where
 
                                        // Update the monitor with the shutdown script if necessary.
                                        if let Some(monitor_update) = monitor_update {
-                                               let update_res = self.chain_monitor.update_channel(chan_entry.get().get_funding_txo().unwrap(), monitor_update);
+                                               let update_res = self.chain_monitor.update_channel(chan_entry.get().get_funding_txo().unwrap(), &monitor_update);
                                                let (result, is_permanent) =
                                                        handle_monitor_update_res!(self, update_res, chan_entry.get_mut(), RAACommitmentOrder::CommitmentFirst, chan_entry.key(), NO_UPDATE);
                                                if is_permanent {
@@ -4650,13 +4653,13 @@ where
                                                Err((None, e)) => try_chan_entry!(self, Err(e), chan),
                                                Err((Some(update), e)) => {
                                                        assert!(chan.get().is_awaiting_monitor_update());
-                                                       let _ = self.chain_monitor.update_channel(chan.get().get_funding_txo().unwrap(), update);
+                                                       let _ = self.chain_monitor.update_channel(chan.get().get_funding_txo().unwrap(), &update);
                                                        try_chan_entry!(self, Err(e), chan);
                                                        unreachable!();
                                                },
                                                Ok(res) => res
                                        };
-                               let update_res = self.chain_monitor.update_channel(chan.get().get_funding_txo().unwrap(), monitor_update);
+                               let update_res = self.chain_monitor.update_channel(chan.get().get_funding_txo().unwrap(), &monitor_update);
                                if let Err(e) = handle_monitor_update_res!(self, update_res, chan, RAACommitmentOrder::RevokeAndACKFirst, true, commitment_signed.is_some()) {
                                        return Err(e);
                                }
@@ -4792,7 +4795,7 @@ where
                                        let raa_updates = break_chan_entry!(self,
                                                chan.get_mut().revoke_and_ack(&msg, &self.logger), chan);
                                        htlcs_to_fail = raa_updates.holding_cell_failed_htlcs;
-                                       let update_res = self.chain_monitor.update_channel(chan.get().get_funding_txo().unwrap(), raa_updates.monitor_update);
+                                       let update_res = self.chain_monitor.update_channel(chan.get().get_funding_txo().unwrap(), &raa_updates.monitor_update);
                                        if was_paused_for_mon_update {
                                                assert!(update_res != ChannelMonitorUpdateStatus::Completed);
                                                assert!(raa_updates.commitment_update.is_none());
@@ -5097,7 +5100,7 @@ where
                                                                ));
                                                        }
                                                        if let Some((commitment_update, monitor_update)) = commitment_opt {
-                                                               match self.chain_monitor.update_channel(chan.get_funding_txo().unwrap(), monitor_update) {
+                                                               match self.chain_monitor.update_channel(chan.get_funding_txo().unwrap(), &monitor_update) {
                                                                        ChannelMonitorUpdateStatus::Completed => {
                                                                                pending_msg_events.push(events::MessageSendEvent::UpdateHTLCs {
                                                                                        node_id: chan.get_counterparty_node_id(),
@@ -6745,6 +6748,8 @@ where
                        }
                }
 
+               let per_peer_state = self.per_peer_state.write().unwrap();
+
                let pending_inbound_payments = self.pending_inbound_payments.lock().unwrap();
                let claimable_payments = self.claimable_payments.lock().unwrap();
                let pending_outbound_payments = self.pending_outbound_payments.pending_outbound_payments.lock().unwrap();
@@ -6760,7 +6765,6 @@ where
                        htlc_purposes.push(purpose);
                }
 
-               let per_peer_state = self.per_peer_state.write().unwrap();
                (per_peer_state.len() as u64).write(writer)?;
                for (peer_pubkey, peer_state_mutex) in per_peer_state.iter() {
                        peer_pubkey.write(writer)?;
index ba1e5908ed5b41203c29c839c9e8487637db5432..1260561844b21d77f46cde25b961a1ecf16d8bc4 100644 (file)
@@ -8135,8 +8135,8 @@ fn test_update_err_monitor_lockdown() {
                let mut node_0_peer_state_lock;
                let mut channel = get_channel_ref!(nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1.2);
                if let Ok((_, _, update)) = channel.commitment_signed(&updates.commitment_signed, &node_cfgs[0].logger) {
-                       assert_eq!(watchtower.chain_monitor.update_channel(outpoint, update.clone()), ChannelMonitorUpdateStatus::PermanentFailure);
-                       assert_eq!(nodes[0].chain_monitor.update_channel(outpoint, update), ChannelMonitorUpdateStatus::Completed);
+                       assert_eq!(watchtower.chain_monitor.update_channel(outpoint, &update), ChannelMonitorUpdateStatus::PermanentFailure);
+                       assert_eq!(nodes[0].chain_monitor.update_channel(outpoint, &update), ChannelMonitorUpdateStatus::Completed);
                } else { assert!(false); }
        }
        // Our local monitor is in-sync and hasn't processed yet timeout
@@ -8230,9 +8230,9 @@ fn test_concurrent_monitor_claim() {
                let mut channel = get_channel_ref!(nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1.2);
                if let Ok((_, _, update)) = channel.commitment_signed(&updates.commitment_signed, &node_cfgs[0].logger) {
                        // Watchtower Alice should already have seen the block and reject the update
-                       assert_eq!(watchtower_alice.chain_monitor.update_channel(outpoint, update.clone()), ChannelMonitorUpdateStatus::PermanentFailure);
-                       assert_eq!(watchtower_bob.chain_monitor.update_channel(outpoint, update.clone()), ChannelMonitorUpdateStatus::Completed);
-                       assert_eq!(nodes[0].chain_monitor.update_channel(outpoint, update), ChannelMonitorUpdateStatus::Completed);
+                       assert_eq!(watchtower_alice.chain_monitor.update_channel(outpoint, &update), ChannelMonitorUpdateStatus::PermanentFailure);
+                       assert_eq!(watchtower_bob.chain_monitor.update_channel(outpoint, &update), ChannelMonitorUpdateStatus::Completed);
+                       assert_eq!(nodes[0].chain_monitor.update_channel(outpoint, &update), ChannelMonitorUpdateStatus::Completed);
                } else { assert!(false); }
        }
        // Our local monitor is in-sync and hasn't processed yet timeout
index 6a2f8803464d72f8413d92ed0e710601eed9b061..2e45685daa00709ecea13cd0e8dc521028b05e16 100644 (file)
@@ -94,7 +94,7 @@ impl<ChannelSigner: Sign, K: KVStorePersister> Persist<ChannelSigner> for K {
                }
        }
 
-       fn update_persisted_channel(&self, funding_txo: OutPoint, _update: &Option<ChannelMonitorUpdate>, monitor: &ChannelMonitor<ChannelSigner>, _update_id: MonitorUpdateId) -> chain::ChannelMonitorUpdateStatus {
+       fn update_persisted_channel(&self, funding_txo: OutPoint, _update: Option<&ChannelMonitorUpdate>, monitor: &ChannelMonitor<ChannelSigner>, _update_id: MonitorUpdateId) -> chain::ChannelMonitorUpdateStatus {
                let key = format!("monitors/{}_{}", funding_txo.txid.to_hex(), funding_txo.index);
                match self.persist(&key, monitor) {
                        Ok(()) => chain::ChannelMonitorUpdateStatus::Completed,
index a656604d458c492c17e8fc9a206c20afe4faa5ef..84d1a2e084feb65e1472969875a586c81ccd08e5 100644 (file)
@@ -22,6 +22,8 @@ use core::cmp;
 use core::convert::TryFrom;
 use core::ops::Deref;
 
+use alloc::collections::BTreeMap;
+
 use bitcoin::secp256k1::{PublicKey, SecretKey};
 use bitcoin::secp256k1::constants::{PUBLIC_KEY_SIZE, SECRET_KEY_SIZE, COMPACT_SIGNATURE_SIZE, SCHNORR_SIGNATURE_SIZE};
 use bitcoin::secp256k1::ecdsa;
@@ -381,6 +383,40 @@ impl Readable for BigSize {
        }
 }
 
+/// The lightning protocol uses u16s for lengths in most cases. As our serialization framework
+/// primarily targets that, we must as well. However, because we may serialize objects that have
+/// more than 65K entries, we need to be able to store larger values. Thus, we define a variable
+/// length integer here that is backwards-compatible for values < 0xffff. We treat 0xffff as
+/// "read eight more bytes".
+///
+/// To ensure we only have one valid encoding per value, we add 0xffff to values written as eight
+/// bytes. Thus, 0xfffe is serialized as 0xfffe, whereas 0xffff is serialized as
+/// 0xffff0000000000000000 (i.e. read-eight-bytes then zero).
+struct CollectionLength(pub u64);
+impl Writeable for CollectionLength {
+       #[inline]
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
+               if self.0 < 0xffff {
+                       (self.0 as u16).write(writer)
+               } else {
+                       0xffffu16.write(writer)?;
+                       (self.0 - 0xffff).write(writer)
+               }
+       }
+}
+
+impl Readable for CollectionLength {
+       #[inline]
+       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let mut val: u64 = <u16 as Readable>::read(r)? as u64;
+               if val == 0xffff {
+                       val = <u64 as Readable>::read(r)?
+                               .checked_add(0xffff).ok_or(DecodeError::InvalidValue)?;
+               }
+               Ok(CollectionLength(val))
+       }
+}
+
 /// In TLV we occasionally send fields which only consist of, or potentially end with, a
 /// variable-length integer which is simply truncated by skipping high zero bytes. This type
 /// encapsulates such integers implementing [`Readable`]/[`Writeable`] for them.
@@ -588,50 +624,54 @@ impl<'a, T> From<&'a Vec<T>> for WithoutLength<&'a Vec<T>> {
        fn from(v: &'a Vec<T>) -> Self { Self(v) }
 }
 
-// HashMap
-impl<K, V> Writeable for HashMap<K, V>
-       where K: Writeable + Eq + Hash,
-             V: Writeable
-{
-       #[inline]
-       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-       (self.len() as u16).write(w)?;
-               for (key, value) in self.iter() {
-                       key.write(w)?;
-                       value.write(w)?;
+macro_rules! impl_for_map {
+       ($ty: ident, $keybound: ident, $constr: expr) => {
+               impl<K, V> Writeable for $ty<K, V>
+                       where K: Writeable + Eq + $keybound, V: Writeable
+               {
+                       #[inline]
+                       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+                               CollectionLength(self.len() as u64).write(w)?;
+                               for (key, value) in self.iter() {
+                                       key.write(w)?;
+                                       value.write(w)?;
+                               }
+                               Ok(())
+                       }
                }
-               Ok(())
-       }
-}
 
-impl<K, V> Readable for HashMap<K, V>
-       where K: Readable + Eq + Hash,
-             V: MaybeReadable
-{
-       #[inline]
-       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
-               let len: u16 = Readable::read(r)?;
-               let mut ret = HashMap::with_capacity(len as usize);
-               for _ in 0..len {
-                       let k = K::read(r)?;
-                       let v_opt = V::read(r)?;
-                       if let Some(v) = v_opt {
-                               if ret.insert(k, v).is_some() {
-                                       return Err(DecodeError::InvalidValue);
+               impl<K, V> Readable for $ty<K, V>
+                       where K: Readable + Eq + $keybound, V: MaybeReadable
+               {
+                       #[inline]
+                       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+                               let len: CollectionLength = Readable::read(r)?;
+                               let mut ret = $constr(len.0 as usize);
+                               for _ in 0..len.0 {
+                                       let k = K::read(r)?;
+                                       let v_opt = V::read(r)?;
+                                       if let Some(v) = v_opt {
+                                               if ret.insert(k, v).is_some() {
+                                                       return Err(DecodeError::InvalidValue);
+                                               }
+                                       }
                                }
+                               Ok(ret)
                        }
                }
-               Ok(ret)
        }
 }
 
+impl_for_map!(BTreeMap, Ord, |_| BTreeMap::new());
+impl_for_map!(HashMap, Hash, |len| HashMap::with_capacity(len));
+
 // HashSet
 impl<T> Writeable for HashSet<T>
 where T: Writeable + Eq + Hash
 {
        #[inline]
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-               (self.len() as u16).write(w)?;
+               CollectionLength(self.len() as u64).write(w)?;
                for item in self.iter() {
                        item.write(w)?;
                }
@@ -644,9 +684,9 @@ where T: Readable + Eq + Hash
 {
        #[inline]
        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
-               let len: u16 = Readable::read(r)?;
-               let mut ret = HashSet::with_capacity(len as usize);
-               for _ in 0..len {
+               let len: CollectionLength = Readable::read(r)?;
+               let mut ret = HashSet::with_capacity(cmp::min(len.0 as usize, MAX_BUF_SIZE / core::mem::size_of::<T>()));
+               for _ in 0..len.0 {
                        if !ret.insert(T::read(r)?) {
                                return Err(DecodeError::InvalidValue)
                        }
@@ -656,51 +696,62 @@ where T: Readable + Eq + Hash
 }
 
 // Vectors
-impl Writeable for Vec<u8> {
-       #[inline]
-       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-               (self.len() as u16).write(w)?;
-               w.write_all(&self)
-       }
-}
+macro_rules! impl_for_vec {
+       ($ty: ty $(, $name: ident)*) => {
+               impl<$($name : Writeable),*> Writeable for Vec<$ty> {
+                       #[inline]
+                       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+                               CollectionLength(self.len() as u64).write(w)?;
+                               for elem in self.iter() {
+                                       elem.write(w)?;
+                               }
+                               Ok(())
+                       }
+               }
 
-impl Readable for Vec<u8> {
-       #[inline]
-       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
-               let len: u16 = Readable::read(r)?;
-               let mut ret = Vec::with_capacity(len as usize);
-               ret.resize(len as usize, 0);
-               r.read_exact(&mut ret)?;
-               Ok(ret)
+               impl<$($name : Readable),*> Readable for Vec<$ty> {
+                       #[inline]
+                       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+                               let len: CollectionLength = Readable::read(r)?;
+                               let mut ret = Vec::with_capacity(cmp::min(len.0 as usize, MAX_BUF_SIZE / core::mem::size_of::<$ty>()));
+                               for _ in 0..len.0 {
+                                       if let Some(val) = MaybeReadable::read(r)? {
+                                               ret.push(val);
+                                       }
+                               }
+                               Ok(ret)
+                       }
+               }
        }
 }
-impl Writeable for Vec<ecdsa::Signature> {
+
+impl Writeable for Vec<u8> {
        #[inline]
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-               (self.len() as u16).write(w)?;
-               for e in self.iter() {
-                       e.write(w)?;
-               }
-               Ok(())
+               CollectionLength(self.len() as u64).write(w)?;
+               w.write_all(&self)
        }
 }
 
-impl Readable for Vec<ecdsa::Signature> {
+impl Readable for Vec<u8> {
        #[inline]
        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
-               let len: u16 = Readable::read(r)?;
-               let byte_size = (len as usize)
-                               .checked_mul(COMPACT_SIGNATURE_SIZE)
-                               .ok_or(DecodeError::BadLengthDescriptor)?;
-               if byte_size > MAX_BUF_SIZE {
-                       return Err(DecodeError::BadLengthDescriptor);
+               let mut len: CollectionLength = Readable::read(r)?;
+               let mut ret = Vec::new();
+               while len.0 > 0 {
+                       let readamt = cmp::min(len.0 as usize, MAX_BUF_SIZE);
+                       let readstart = ret.len();
+                       ret.resize(readstart + readamt, 0);
+                       r.read_exact(&mut ret[readstart..])?;
+                       len.0 -= readamt as u64;
                }
-               let mut ret = Vec::with_capacity(len as usize);
-               for _ in 0..len { ret.push(Readable::read(r)?); }
                Ok(ret)
        }
 }
 
+impl_for_vec!(ecdsa::Signature);
+impl_for_vec!((A, B), A, B);
+
 impl Writeable for Script {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
                (self.len() as u16).write(w)?;
@@ -1028,7 +1079,7 @@ impl Readable for () {
 impl Writeable for String {
        #[inline]
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-               (self.len() as u16).write(w)?;
+               CollectionLength(self.len() as u64).write(w)?;
                w.write_all(self.as_bytes())
        }
 }
index 8d74a83a37365fc490ccf089449aad42bcb955ea..afd7fcb2e0fc93b3f536a052c15cec841f8b43b0 100644 (file)
@@ -39,6 +39,9 @@ macro_rules! _encode_tlv {
                        field.write($stream)?;
                }
        };
+       ($stream: expr, $type: expr, $field: expr, ignorable) => {
+               $crate::_encode_tlv!($stream, $type, $field, required);
+       };
        ($stream: expr, $type: expr, $field: expr, (option, encoding: ($fieldty: ty, $encoding: ident))) => {
                $crate::_encode_tlv!($stream, $type, $field.map(|f| $encoding(f)), option);
        };
@@ -155,6 +158,9 @@ macro_rules! _get_varint_length_prefixed_tlv_length {
                        $len.0 += field_len;
                }
        };
+       ($len: expr, $type: expr, $field: expr, ignorable) => {
+               $crate::_get_varint_length_prefixed_tlv_length!($len, $type, $field, required);
+       };
 }
 
 /// See the documentation of [`write_tlv_fields`].
@@ -581,6 +587,9 @@ macro_rules! _init_tlv_based_struct_field {
        ($field: ident, option) => {
                $field
        };
+       ($field: ident, ignorable) => {
+               if $field.is_none() { return Ok(None); } else { $field.unwrap() }
+       };
        ($field: ident, required) => {
                $field.0.unwrap()
        };
@@ -610,6 +619,9 @@ macro_rules! _init_tlv_field_var {
        ($field: ident, option) => {
                let mut $field = None;
        };
+       ($field: ident, ignorable) => {
+               let mut $field = None;
+       };
 }
 
 /// Equivalent to running [`_init_tlv_field_var`] then [`read_tlv_fields`].
index 1838b1078a98d42adbd96f15af3ab93cd01f9426..f528ae2d547c61f78c09fbbe64b2a3e3da6aa173 100644 (file)
@@ -184,12 +184,12 @@ impl<'a> chain::Watch<EnforcingSigner> for TestChainMonitor<'a> {
                self.chain_monitor.watch_channel(funding_txo, new_monitor)
        }
 
-       fn update_channel(&self, funding_txo: OutPoint, update: channelmonitor::ChannelMonitorUpdate) -> chain::ChannelMonitorUpdateStatus {
+       fn update_channel(&self, funding_txo: OutPoint, update: &channelmonitor::ChannelMonitorUpdate) -> chain::ChannelMonitorUpdateStatus {
                // Every monitor update should survive roundtrip
                let mut w = TestVecWriter(Vec::new());
                update.write(&mut w).unwrap();
                assert!(channelmonitor::ChannelMonitorUpdate::read(
-                               &mut io::Cursor::new(&w.0)).unwrap() == update);
+                               &mut io::Cursor::new(&w.0)).unwrap() == *update);
 
                self.monitor_updates.lock().unwrap().entry(funding_txo.to_channel_id()).or_insert(Vec::new()).push(update.clone());
 
@@ -202,7 +202,7 @@ impl<'a> chain::Watch<EnforcingSigner> for TestChainMonitor<'a> {
                }
 
                self.latest_monitor_update_id.lock().unwrap().insert(funding_txo.to_channel_id(),
-                       (funding_txo, update.update_id, MonitorUpdateId::from_monitor_update(&update)));
+                       (funding_txo, update.update_id, MonitorUpdateId::from_monitor_update(update)));
                let update_res = self.chain_monitor.update_channel(funding_txo, update);
                // At every point where we get a monitor update, we should be able to send a useful monitor
                // to a watchtower and disk...
@@ -254,7 +254,7 @@ impl<Signer: keysinterface::Sign> chainmonitor::Persist<Signer> for TestPersiste
                chain::ChannelMonitorUpdateStatus::Completed
        }
 
-       fn update_persisted_channel(&self, funding_txo: OutPoint, update: &Option<channelmonitor::ChannelMonitorUpdate>, _data: &channelmonitor::ChannelMonitor<Signer>, update_id: MonitorUpdateId) -> chain::ChannelMonitorUpdateStatus {
+       fn update_persisted_channel(&self, funding_txo: OutPoint, update: Option<&channelmonitor::ChannelMonitorUpdate>, _data: &channelmonitor::ChannelMonitor<Signer>, update_id: MonitorUpdateId) -> chain::ChannelMonitorUpdateStatus {
                let mut ret = chain::ChannelMonitorUpdateStatus::Completed;
                if let Some(update_ret) = self.update_rets.lock().unwrap().pop_front() {
                        ret = update_ret;
index 947838633c50267e12493dc5fb4b9e5a523185ed..cb22dcea9bfa7b71541404934b32d64c998d59db 100644 (file)
@@ -10,3 +10,4 @@ default = ["lightning/no-std", "lightning-invoice/no-std", "lightning-rapid-goss
 lightning = { path = "../lightning", default-features = false }
 lightning-invoice = { path = "../lightning-invoice", default-features = false }
 lightning-rapid-gossip-sync = { path = "../lightning-rapid-gossip-sync", default-features = false }
+lightning-background-processor = { path = "../lightning-background-processor", features = ["futures"], default-features = false }