Move the two-AtomicUsize counter in peer_handler to a util struct
authorMatt Corallo <git@bluematt.me>
Fri, 8 Oct 2021 22:54:32 +0000 (22:54 +0000)
committerMatt Corallo <git@bluematt.me>
Mon, 18 Oct 2021 22:04:56 +0000 (22:04 +0000)
We also take this opportunity to drop byte_utils::le64_to_array, as
our MSRV now supports the native to_le_bytes() call.

lightning/src/ln/peer_handler.rs
lightning/src/util/atomic_counter.rs [new file with mode: 0644]
lightning/src/util/byte_utils.rs
lightning/src/util/chacha20poly1305rfc.rs
lightning/src/util/mod.rs

index 1815d4a350cc0289100d207b9fa2a98029397b1d..ebb0322f810a901c9e46a3e78dadce6b54e99611 100644 (file)
@@ -24,7 +24,7 @@ use ln::channelmanager::{SimpleArcChannelManager, SimpleRefChannelManager};
 use util::ser::{VecWriter, Writeable, Writer};
 use ln::peer_channel_encryptor::{PeerChannelEncryptor,NextNoiseStep};
 use ln::wire;
-use util::byte_utils;
+use util::atomic_counter::AtomicCounter;
 use util::events::{MessageSendEvent, MessageSendEventsProvider};
 use util::logger::Logger;
 use routing::network_graph::NetGraphMsgHandler;
@@ -33,7 +33,6 @@ use prelude::*;
 use io;
 use alloc::collections::LinkedList;
 use sync::{Arc, Mutex};
-use core::sync::atomic::{AtomicUsize, Ordering};
 use core::{cmp, hash, fmt, mem};
 use core::ops::Deref;
 use core::convert::Infallible;
@@ -343,12 +342,6 @@ struct PeerHolder<Descriptor: SocketDescriptor> {
        node_id_to_descriptor: HashMap<PublicKey, Descriptor>,
 }
 
-#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))]
-fn _check_usize_is_32_or_64() {
-       // See below, less than 32 bit pointers may be unsafe here!
-       unsafe { mem::transmute::<*const usize, [u8; 4]>(panic!()); }
-}
-
 /// SimpleArcPeerManager is useful when you need a PeerManager with a static lifetime, e.g.
 /// when you're using lightning-net-tokio (since tokio::spawn requires parameters with static
 /// lifetimes). Other times you can afford a reference, which is more efficient, in which case
@@ -394,10 +387,7 @@ pub struct PeerManager<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: De
        ephemeral_key_midstate: Sha256Engine,
        custom_message_handler: CMH,
 
-       // Usize needs to be at least 32 bits to avoid overflowing both low and high. If usize is 64
-       // bits we will never realistically count into high:
-       peer_counter_low: AtomicUsize,
-       peer_counter_high: AtomicUsize,
+       peer_counter: AtomicCounter,
 
        logger: L,
 }
@@ -485,8 +475,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                        }),
                        our_node_secret,
                        ephemeral_key_midstate,
-                       peer_counter_low: AtomicUsize::new(0),
-                       peer_counter_high: AtomicUsize::new(0),
+                       peer_counter: AtomicCounter::new(),
                        logger,
                        custom_message_handler,
                }
@@ -509,14 +498,8 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
 
        fn get_ephemeral_key(&self) -> SecretKey {
                let mut ephemeral_hash = self.ephemeral_key_midstate.clone();
-               let low = self.peer_counter_low.fetch_add(1, Ordering::AcqRel);
-               let high = if low == 0 {
-                       self.peer_counter_high.fetch_add(1, Ordering::AcqRel)
-               } else {
-                       self.peer_counter_high.load(Ordering::Acquire)
-               };
-               ephemeral_hash.input(&byte_utils::le64_to_array(low as u64));
-               ephemeral_hash.input(&byte_utils::le64_to_array(high as u64));
+               let counter = self.peer_counter.get_increment();
+               ephemeral_hash.input(&counter.to_le_bytes());
                SecretKey::from_slice(&Sha256::from_engine(ephemeral_hash).into_inner()).expect("You broke SHA-256!")
        }
 
diff --git a/lightning/src/util/atomic_counter.rs b/lightning/src/util/atomic_counter.rs
new file mode 100644 (file)
index 0000000..81cc1f4
--- /dev/null
@@ -0,0 +1,31 @@
+//! A simple atomic counter that uses AtomicUsize to give a u64 counter.
+
+#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))]
+compile_error!("We need at least 32-bit pointers for atomic counter (and to have enough memory to run LDK)");
+
+use core::sync::atomic::{AtomicUsize, Ordering};
+
+pub(crate) struct AtomicCounter {
+       // Usize needs to be at least 32 bits to avoid overflowing both low and high. If usize is 64
+       // bits we will never realistically count into high:
+       counter_low: AtomicUsize,
+       counter_high: AtomicUsize,
+}
+
+impl AtomicCounter {
+       pub(crate) fn new() -> Self {
+               Self {
+                       counter_low: AtomicUsize::new(0),
+                       counter_high: AtomicUsize::new(0),
+               }
+       }
+       pub(crate) fn get_increment(&self) -> u64 {
+               let low = self.counter_low.fetch_add(1, Ordering::AcqRel) as u64;
+               let high = if low == 0 {
+                       self.counter_high.fetch_add(1, Ordering::AcqRel) as u64
+               } else {
+                       self.counter_high.load(Ordering::Acquire) as u64
+               };
+               (high << 32) | low
+       }
+}
index 0c6530f29e0eb1cb8b937f9842b9c83eb3d504be..1ab6384e3b8ebe997b3e4c511ddcaf721946e935 100644 (file)
@@ -70,20 +70,6 @@ pub fn be64_to_array(u: u64) -> [u8; 8] {
        v
 }
 
-#[inline]
-pub fn le64_to_array(u: u64) -> [u8; 8] {
-       let mut v = [0; 8];
-       v[0] = ((u >> 8*0) & 0xff) as u8;
-       v[1] = ((u >> 8*1) & 0xff) as u8;
-       v[2] = ((u >> 8*2) & 0xff) as u8;
-       v[3] = ((u >> 8*3) & 0xff) as u8;
-       v[4] = ((u >> 8*4) & 0xff) as u8;
-       v[5] = ((u >> 8*5) & 0xff) as u8;
-       v[6] = ((u >> 8*6) & 0xff) as u8;
-       v[7] = ((u >> 8*7) & 0xff) as u8;
-       v
-}
-
 #[cfg(test)]
 mod tests {
        use super::*;
@@ -96,6 +82,5 @@ mod tests {
                assert_eq!(be32_to_array(0xdeadbeef), [0xde, 0xad, 0xbe, 0xef]);
                assert_eq!(be48_to_array(0xdeadbeef1bad), [0xde, 0xad, 0xbe, 0xef, 0x1b, 0xad]);
                assert_eq!(be64_to_array(0xdeadbeef1bad1dea), [0xde, 0xad, 0xbe, 0xef, 0x1b, 0xad, 0x1d, 0xea]);
-               assert_eq!(le64_to_array(0xdeadbeef1bad1dea), [0xea, 0x1d, 0xad, 0x1b, 0xef, 0xbe, 0xad, 0xde]);
        }
 }
index 3908116cccc85f03624c73cdb73d2ead507f0712..fdd51e757b5bd12f1b973db7737efb802c3d2284 100644 (file)
@@ -16,8 +16,6 @@ mod real_chachapoly {
        use util::poly1305::Poly1305;
        use bitcoin::hashes::cmp::fixed_time_eq;
 
-       use util::byte_utils;
-
        #[derive(Clone, Copy)]
        pub struct ChaCha20Poly1305RFC {
                cipher: ChaCha20,
@@ -67,8 +65,8 @@ mod real_chachapoly {
                        self.mac.input(output);
                        ChaCha20Poly1305RFC::pad_mac_16(&mut self.mac, self.data_len);
                        self.finished = true;
-                       self.mac.input(&byte_utils::le64_to_array(self.aad_len));
-                       self.mac.input(&byte_utils::le64_to_array(self.data_len as u64));
+                       self.mac.input(&self.aad_len.to_le_bytes());
+                       self.mac.input(&(self.data_len as u64).to_le_bytes());
                        self.mac.raw_result(out_tag);
                }
 
@@ -82,8 +80,8 @@ mod real_chachapoly {
 
                        self.data_len += input.len();
                        ChaCha20Poly1305RFC::pad_mac_16(&mut self.mac, self.data_len);
-                       self.mac.input(&byte_utils::le64_to_array(self.aad_len));
-                       self.mac.input(&byte_utils::le64_to_array(self.data_len as u64));
+                       self.mac.input(&self.aad_len.to_le_bytes());
+                       self.mac.input(&(self.data_len as u64).to_le_bytes());
 
                        let mut calc_tag =  [0u8; 16];
                        self.mac.raw_result(&mut calc_tag);
index cc0c3192a859fc79de02fe4bef36e8cfd2c2a2e0..34e66190121da07dd050f5d4451bae58115888c8 100644 (file)
@@ -20,6 +20,7 @@ pub mod errors;
 pub mod ser;
 pub mod message_signing;
 
+pub(crate) mod atomic_counter;
 pub(crate) mod byte_utils;
 pub(crate) mod chacha20;
 #[cfg(feature = "fuzztarget")]