Add WithContext and Tests
authorhenghonglee <henghong.lee@gmail.com>
Mon, 4 Sep 2023 18:37:39 +0000 (02:37 +0800)
committerJeffrey Czyz <jkczyz@gmail.com>
Fri, 1 Dec 2023 17:30:19 +0000 (11:30 -0600)
lightning/src/ln/channel.rs
lightning/src/util/logger.rs
lightning/src/util/test_utils.rs

index 001c4d8c9632fcdcace6b5592b2b29b25ede25de..f3f46fe3b0a427f2cccabd37755fd26e3709569d 100644 (file)
@@ -42,7 +42,7 @@ use crate::sign::{EntropySource, ChannelSigner, SignerProvider, NodeSigner, Reci
 use crate::events::ClosureReason;
 use crate::routing::gossip::NodeId;
 use crate::util::ser::{Readable, ReadableArgs, Writeable, Writer};
-use crate::util::logger::Logger;
+use crate::util::logger::{Logger, WithContext};
 use crate::util::errors::APIError;
 use crate::util::config::{UserConfig, ChannelConfig, LegacyChannelConfig, ChannelHandshakeConfig, ChannelHandshakeLimits, MaxDustHTLCExposure};
 use crate::util::scid_utils::scid_from_parts;
@@ -6463,6 +6463,7 @@ impl<SP: Deref> InboundV1Channel<SP> where SP::Target: SignerProvider {
                          F::Target: FeeEstimator,
                          L::Target: Logger,
        {
+               let logger = WithContext::from(logger, Some(counterparty_node_id), Some(msg.temporary_channel_id));
                let announced_channel = if (msg.channel_flags & 1) == 1 { true } else { false };
 
                // First check the channel type is known, failing before we do anything else if we don't
@@ -6529,7 +6530,7 @@ impl<SP: Deref> InboundV1Channel<SP> where SP::Target: SignerProvider {
                if msg.htlc_minimum_msat >= full_channel_value_msat {
                        return Err(ChannelError::Close(format!("Minimum htlc value ({}) was larger than full channel value ({})", msg.htlc_minimum_msat, full_channel_value_msat)));
                }
-               Channel::<SP>::check_remote_fee(&channel_type, fee_estimator, msg.feerate_per_kw, None, logger)?;
+               Channel::<SP>::check_remote_fee(&channel_type, fee_estimator, msg.feerate_per_kw, None, &&logger)?;
 
                let max_counterparty_selected_contest_delay = u16::min(config.channel_handshake_limits.their_to_self_delay, MAX_LOCAL_BREAKDOWN_TIMEOUT);
                if msg.to_self_delay > max_counterparty_selected_contest_delay {
index 4018b8bf7af4180be712fb7c433ccfe69b2e2bd0..f1534933792949a41318c3b80e0cc1847cbf63f4 100644 (file)
@@ -18,6 +18,7 @@ use bitcoin::secp256k1::PublicKey;
 
 use core::cmp;
 use core::fmt;
+use core::ops::Deref;
 
 use crate::ln::ChannelId;
 #[cfg(c_bindings)]
@@ -152,6 +153,39 @@ pub trait Logger {
        fn log(&self, record: Record);
 }
 
+/// Adds relevant context to a [`Record`] before passing it to the wrapped [`Logger`].
+pub struct WithContext<'a, L: Deref> where L::Target: Logger {
+       /// The logger to delegate to after adding context to the record.
+       logger: &'a L,
+       /// The node id of the peer pertaining to the logged record.
+       peer_id: Option<PublicKey>,
+       /// The channel id of the channel pertaining to the logged record.
+       channel_id: Option<ChannelId>,
+}
+
+impl<'a, L: Deref> Logger for WithContext<'a, L> where L::Target: Logger {
+       fn log(&self, mut record: Record) {
+               if self.peer_id.is_some() {
+                       record.peer_id = self.peer_id
+               };
+               if self.channel_id.is_some() {
+                       record.channel_id = self.channel_id;
+               }
+               self.logger.log(record)
+       }
+}
+
+impl<'a, L: Deref> WithContext<'a, L> where L::Target: Logger {
+       /// Wraps the given logger, providing additional context to any logged records.
+       pub fn from(logger: &'a L, peer_id: Option<PublicKey>, channel_id: Option<ChannelId>) -> Self {
+               WithContext {
+                       logger,
+                       peer_id,
+                       channel_id,
+               }
+       }
+}
+
 /// Wrapper for logging a [`PublicKey`] in hex format.
 ///
 /// This is not exported to bindings users as fmt can't be used in C
@@ -202,7 +236,9 @@ impl<T: fmt::Display, I: core::iter::Iterator<Item = T> + Clone> fmt::Display fo
 
 #[cfg(test)]
 mod tests {
-       use crate::util::logger::{Logger, Level};
+       use bitcoin::secp256k1::{PublicKey, SecretKey, Secp256k1};
+       use crate::ln::ChannelId;
+       use crate::util::logger::{Logger, Level, WithContext};
        use crate::util::test_utils::TestLogger;
        use crate::sync::Arc;
 
@@ -243,6 +279,41 @@ mod tests {
                wrapper.call_macros();
        }
 
+       #[test]
+       fn test_logging_with_context() {
+               let logger = &TestLogger::new();
+               let secp_ctx = Secp256k1::new();
+               let pk = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&[42; 32]).unwrap());
+               let context_logger = WithContext::from(&logger, Some(pk), Some(ChannelId([0; 32])));
+               log_error!(context_logger, "This is an error");
+               log_warn!(context_logger, "This is an error");
+               log_debug!(context_logger, "This is an error");
+               log_trace!(context_logger, "This is an error");
+               log_gossip!(context_logger, "This is an error");
+               log_info!(context_logger, "This is an error");
+               logger.assert_log_context_contains(
+                       "lightning::util::logger::tests", Some(pk), Some(ChannelId([0;32])), 6
+               );
+       }
+
+       #[test]
+       fn test_logging_with_multiple_wrapped_context() {
+               let logger = &TestLogger::new();
+               let secp_ctx = Secp256k1::new();
+               let pk = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&[42; 32]).unwrap());
+               let context_logger = &WithContext::from(&logger, None, Some(ChannelId([0; 32])));
+               let full_context_logger = WithContext::from(&context_logger, Some(pk), None);
+               log_error!(full_context_logger, "This is an error");
+               log_warn!(full_context_logger, "This is an error");
+               log_debug!(full_context_logger, "This is an error");
+               log_trace!(full_context_logger, "This is an error");
+               log_gossip!(full_context_logger, "This is an error");
+               log_info!(full_context_logger, "This is an error");
+               logger.assert_log_context_contains(
+                       "lightning::util::logger::tests", Some(pk), Some(ChannelId([0;32])), 6
+               );
+       }
+
        #[test]
        fn test_log_ordering() {
                assert!(Level::Error > Level::Warn);
index 2c34bc92c7e2eb5b38901cb7df06de73e0435d46..0606e36e50e01d88496351e18dc7c019f94aae1e 100644 (file)
@@ -931,6 +931,7 @@ pub struct TestLogger {
        level: Level,
        pub(crate) id: String,
        pub lines: Mutex<HashMap<(String, String), usize>>,
+       pub context: Mutex<HashMap<(String, Option<PublicKey>, Option<ChannelId>), usize>>,
 }
 
 impl TestLogger {
@@ -941,7 +942,8 @@ impl TestLogger {
                TestLogger {
                        level: Level::Trace,
                        id,
-                       lines: Mutex::new(HashMap::new())
+                       lines: Mutex::new(HashMap::new()),
+                       context: Mutex::new(HashMap::new()),
                }
        }
        pub fn enable(&mut self, level: Level) {
@@ -976,11 +978,23 @@ impl TestLogger {
                }).map(|(_, c) | { c }).sum();
                assert_eq!(l, count)
        }
+
+       pub fn assert_log_context_contains(
+               &self, module: &str, peer_id: Option<PublicKey>, channel_id: Option<ChannelId>, count: usize
+       ) {
+               let context_entries = self.context.lock().unwrap();
+               let l: usize = context_entries.iter()
+                       .filter(|&(&(ref m, ref p, ref c), _)| m == module && *p == peer_id && *c == channel_id)
+                       .map(|(_, c) | c)
+                       .sum();
+               assert_eq!(l, count)
+       }
 }
 
 impl Logger for TestLogger {
        fn log(&self, record: Record) {
                *self.lines.lock().unwrap().entry((record.module_path.to_string(), format!("{}", record.args))).or_insert(0) += 1;
+               *self.context.lock().unwrap().entry((record.module_path.to_string(), record.peer_id, record.channel_id)).or_insert(0) += 1;
                if record.level >= self.level {
                        #[cfg(all(not(ldk_bench), feature = "std"))] {
                                let pfx = format!("{} {} [{}:{}]", self.id, record.level.to_string(), record.module_path, record.line);