]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Avoid unnecessarily alloc'ing a new buffer when decrypting messages
authorMatt Corallo <git@bluematt.me>
Mon, 6 Nov 2023 16:57:13 +0000 (16:57 +0000)
committerMatt Corallo <git@bluematt.me>
Tue, 7 Nov 2023 18:13:23 +0000 (18:13 +0000)
When decrypting P2P messages, we already have a read buffer that we
read the message into. There's no reason to allocate a new `Vec` to
store the decrypted message when we can just overwrite the read
buffer and call it a day.

fuzz/src/peer_crypt.rs
lightning/src/ln/peer_channel_encryptor.rs
lightning/src/ln/peer_handler.rs

index f6df392fcef5407623b6c0f0cfcf9b8690cdbc75..4f96849871bfd4de842c2c58b177296cbbaf4c43 100644 (file)
@@ -74,6 +74,7 @@ pub fn do_test(data: &[u8]) {
                assert!(crypter.is_ready_for_encryption());
                crypter
        };
+       let mut buf = [0; 65536 + 16];
        loop {
                if get_slice!(1)[0] == 0 {
                        crypter.encrypt_buffer(get_slice!(slice_to_be16(get_slice!(2))));
@@ -82,7 +83,8 @@ pub fn do_test(data: &[u8]) {
                                Ok(len) => len,
                                Err(_) => return,
                        };
-                       match crypter.decrypt_message(get_slice!(len as usize + 16)) {
+                       buf.copy_from_slice(&get_slice!(len as usize + 16));
+                       match crypter.decrypt_message(&mut buf[..len as usize + 16]) {
                                Ok(_) => {},
                                Err(_) => return,
                        }
index a34b31a1bb31d77364e3e67e8adfcb86039495fe..8b276990cb689846680fa7de24070a4a433325de 100644 (file)
@@ -169,6 +169,18 @@ impl PeerChannelEncryptor {
                res.extend_from_slice(&tag);
        }
 
+       fn decrypt_in_place_with_ad(inout: &mut [u8], n: u64, key: &[u8; 32], h: &[u8]) -> Result<(), LightningError> {
+               let mut nonce = [0; 12];
+               nonce[4..].copy_from_slice(&n.to_le_bytes()[..]);
+
+               let mut chacha = ChaCha20Poly1305RFC::new(key, &nonce, h);
+               let (inout, tag) = inout.split_at_mut(inout.len() - 16);
+               if chacha.check_decrypt_in_place(inout, tag).is_err() {
+                       return Err(LightningError{err: "Bad MAC".to_owned(), action: msgs::ErrorAction::DisconnectPeer{ msg: None }});
+               }
+               Ok(())
+       }
+
        #[inline]
        fn decrypt_with_ad(res: &mut[u8], n: u64, key: &[u8; 32], h: &[u8], cyphertext: &[u8]) -> Result<(), LightningError> {
                let mut nonce = [0; 12];
@@ -505,21 +517,20 @@ impl PeerChannelEncryptor {
                }
        }
 
-       /// Decrypts the given message.
+       /// Decrypts the given message up to msg.len() - 16. Bytes after msg.len() - 16 will be left
+       /// undefined (as they contain the Poly1305 tag bytes).
+       ///
        /// panics if msg.len() > 65535 + 16
-       pub fn decrypt_message(&mut self, msg: &[u8]) -> Result<Vec<u8>, LightningError> {
+       pub fn decrypt_message(&mut self, msg: &mut [u8]) -> Result<(), LightningError> {
                if msg.len() > LN_MAX_MSG_LEN + 16 {
                        panic!("Attempted to decrypt message longer than 65535 + 16 bytes!");
                }
 
                match self.noise_state {
                        NoiseState::Finished { sk: _, sn: _, sck: _, ref rk, ref mut rn, rck: _ } => {
-                               let mut res = Vec::with_capacity(msg.len() - 16);
-                               res.resize(msg.len() - 16, 0);
-                               Self::decrypt_with_ad(&mut res[..], *rn, rk, &[0; 0], msg)?;
+                               Self::decrypt_in_place_with_ad(&mut msg[..], *rn, rk, &[0; 0])?;
                                *rn += 1;
-
-                               Ok(res)
+                               Ok(())
                        },
                        _ => panic!("Tried to decrypt a message prior to noise handshake completion"),
                }
@@ -764,12 +775,11 @@ mod tests {
 
                for i in 0..1005 {
                        let msg = [0x68, 0x65, 0x6c, 0x6c, 0x6f];
-                       let res = outbound_peer.encrypt_buffer(&msg);
+                       let mut res = outbound_peer.encrypt_buffer(&msg);
                        assert_eq!(res.len(), 5 + 2*16 + 2);
 
                        let len_header = res[0..2+16].to_vec();
                        assert_eq!(inbound_peer.decrypt_length_header(&len_header[..]).unwrap() as usize, msg.len());
-                       assert_eq!(inbound_peer.decrypt_message(&res[2+16..]).unwrap()[..], msg[..]);
 
                        if i == 0 {
                                assert_eq!(res, hex::decode("cf2b30ddf0cf3f80e7c35a6e6730b59fe802473180f396d88a8fb0db8cbcf25d2f214cf9ea1d95").unwrap());
@@ -784,6 +794,9 @@ mod tests {
                        } else if i == 1001 {
                                assert_eq!(res, hex::decode("2ecd8c8a5629d0d02ab457a0fdd0f7b90a192cd46be5ecb6ca570bfc5e268338b1a16cf4ef2d36").unwrap());
                        }
+
+                       inbound_peer.decrypt_message(&mut res[2+16..]).unwrap();
+                       assert_eq!(res[2 + 16..res.len() - 16], msg[..]);
                }
        }
 
@@ -807,7 +820,7 @@ mod tests {
                let mut inbound_peer = get_inbound_peer_for_test_vectors();
 
                // MSG should not exceed LN_MAX_MSG_LEN + 16
-               let msg = [4u8; LN_MAX_MSG_LEN + 17];
-               inbound_peer.decrypt_message(&msg).unwrap();
+               let mut msg = [4u8; LN_MAX_MSG_LEN + 17];
+               inbound_peer.decrypt_message(&mut msg).unwrap();
        }
 }
index 5f0d88a95273734d6242b3485e42e9575e270a98..a1a4d4b26729c6822227b7d58f576ed0e1725ad1 100644 (file)
@@ -1402,17 +1402,18 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                                        }
                                                                        peer.pending_read_is_header = false;
                                                                } else {
-                                                                       let msg_data = try_potential_handleerror!(peer,
-                                                                               peer.channel_encryptor.decrypt_message(&peer.pending_read_buffer[..]));
-                                                                       assert!(msg_data.len() >= 2);
+                                                                       debug_assert!(peer.pending_read_buffer.len() >= 2 + 16);
+                                                                       try_potential_handleerror!(peer,
+                                                                               peer.channel_encryptor.decrypt_message(&mut peer.pending_read_buffer[..]));
+
+                                                                       let mut reader = io::Cursor::new(&peer.pending_read_buffer[..peer.pending_read_buffer.len() - 16]);
+                                                                       let message_result = wire::read(&mut reader, &*self.message_handler.custom_message_handler);
 
                                                                        // Reset read buffer
                                                                        if peer.pending_read_buffer.capacity() > 8192 { peer.pending_read_buffer = Vec::new(); }
                                                                        peer.pending_read_buffer.resize(18, 0);
                                                                        peer.pending_read_is_header = true;
 
-                                                                       let mut reader = io::Cursor::new(&msg_data[..]);
-                                                                       let message_result = wire::read(&mut reader, &*self.message_handler.custom_message_handler);
                                                                        let message = match message_result {
                                                                                Ok(x) => x,
                                                                                Err(e) => {