From d5fd8d96e107ed2f60cd79a48465f948258ae243 Mon Sep 17 00:00:00 2001 From: bmancini55 Date: Tue, 9 Mar 2021 15:34:52 -0500 Subject: [PATCH] Improve short_channel_id utils Modifies scid_from_parts to use u64 inputs allowing untruncated validation. Adds public constants for limits. --- lightning/src/routing/network_graph.rs | 6 +++--- lightning/src/util/scid_utils.rs | 30 ++++++++++++++++++++------ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/lightning/src/routing/network_graph.rs b/lightning/src/routing/network_graph.rs index 34859cf07..b8655fb0a 100644 --- a/lightning/src/routing/network_graph.rs +++ b/lightning/src/routing/network_graph.rs @@ -30,7 +30,7 @@ use ln::msgs; use util::ser::{Writeable, Readable, Writer}; use util::logger::Logger; use util::events::{MessageSendEvent, MessageSendEventsProvider}; -use util::scid_utils::{block_from_scid, scid_from_parts}; +use util::scid_utils::{block_from_scid, scid_from_parts, MAX_SCID_BLOCK}; use std::{cmp, fmt}; use std::sync::{RwLock, RwLockReadGuard}; @@ -329,11 +329,11 @@ impl RoutingMessageHandler for N let network_graph = self.network_graph.read().unwrap(); - let start_scid = scid_from_parts(msg.first_blocknum, 0, 0); + let start_scid = scid_from_parts(msg.first_blocknum as u64, 0, 0); // We receive valid queries with end_blocknum that would overflow SCID conversion. // Manually cap the ending block to avoid this overflow. - let exclusive_end_scid = scid_from_parts(cmp::min(msg.end_blocknum(), 0xffffff), 0, 0); + let exclusive_end_scid = scid_from_parts(cmp::min(msg.end_blocknum() as u64, MAX_SCID_BLOCK), 0, 0); // Per spec, we must reply to a query. Send an empty message when things are invalid. if msg.chain_hash != network_graph.genesis_hash || start_scid.is_err() || exclusive_end_scid.is_err() { diff --git a/lightning/src/util/scid_utils.rs b/lightning/src/util/scid_utils.rs index 2b9e6fed2..7902a5271 100644 --- a/lightning/src/util/scid_utils.rs +++ b/lightning/src/util/scid_utils.rs @@ -7,32 +7,47 @@ // You may not use this file except in accordance with one or both of these // licenses. +/// Maximum block height that can be used in a `short_channel_id`. This +/// value is based on the 3-bytes available for block height. +pub const MAX_SCID_BLOCK: u64 = 0x00ffffff; + +/// Maximum transaction index that can be used in a `short_channel_id`. +/// This value is based on the 3-bytes available for tx index. +pub const MAX_SCID_TX_INDEX: u64 = 0x00ffffff; + +/// Maximum vout index that can be used in a `short_channel_id`. This +/// value is based on the 2-bytes available for the vout index. +pub const MAX_SCID_VOUT_INDEX: u64 = 0xffff; + /// A `short_channel_id` construction error #[derive(Debug, PartialEq)] pub enum ShortChannelIdError { BlockOverflow, TxIndexOverflow, + VoutIndexOverflow, } /// Extracts the block height (most significant 3-bytes) from the `short_channel_id` -#[allow(dead_code)] pub fn block_from_scid(short_channel_id: &u64) -> u32 { return (short_channel_id >> 40) as u32; } /// Constructs a `short_channel_id` using the components pieces. Results in an error -/// if the block height or tx index overflow the 3-bytes for each component. -#[allow(dead_code)] -pub fn scid_from_parts(block: u32, tx_index: u32, vout_index: u16) -> Result { - if block > 0x00ffffff { +/// if the block height, tx index, or vout index overflow the maximum sizes. +pub fn scid_from_parts(block: u64, tx_index: u64, vout_index: u64) -> Result { + if block > MAX_SCID_BLOCK { return Err(ShortChannelIdError::BlockOverflow); } - if tx_index > 0x00ffffff { + if tx_index > MAX_SCID_TX_INDEX { return Err(ShortChannelIdError::TxIndexOverflow); } - Ok(((block as u64) << 40) | ((tx_index as u64) << 16) | (vout_index as u64)) + if vout_index > MAX_SCID_VOUT_INDEX { + return Err(ShortChannelIdError::VoutIndexOverflow); + } + + Ok((block << 40) | (tx_index << 16) | vout_index) } #[cfg(test)] @@ -56,5 +71,6 @@ mod tests { assert_eq!(scid_from_parts(0x00ffffff, 0x00ffffff, 0xffff).unwrap(), 0xffffff_ffffff_ffff); assert_eq!(scid_from_parts(0x01ffffff, 0x00000000, 0x0000).err().unwrap(), ShortChannelIdError::BlockOverflow); assert_eq!(scid_from_parts(0x00000000, 0x01ffffff, 0x0000).err().unwrap(), ShortChannelIdError::TxIndexOverflow); + assert_eq!(scid_from_parts(0x00000000, 0x00000000, 0x010000).err().unwrap(), ShortChannelIdError::VoutIndexOverflow); } } -- 2.39.5