X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Frouting%2Fnetwork_graph.rs;h=dd6d7d7ed734aadbd7f89f93308aefd0223a9f79;hb=0e3bf19b664ad3f9dd252c9ed50ed365ff827524;hp=35a7e868f04744304e01fa62f5bffc035b242b75;hpb=e553d2c2c0280bc4e82b4685679e79307146b129;p=rust-lightning diff --git a/lightning/src/routing/network_graph.rs b/lightning/src/routing/network_graph.rs index 35a7e868..dd6d7d7e 100644 --- a/lightning/src/routing/network_graph.rs +++ b/lightning/src/routing/network_graph.rs @@ -13,7 +13,7 @@ use chain::chaininterface::{ChainError, ChainWatchInterface}; use ln::features::{ChannelFeatures, NodeFeatures}; use ln::msgs::{DecodeError,ErrorAction,LightningError,RoutingMessageHandler,NetAddress}; use ln::msgs; -use util::ser::{Writeable, Readable, Writer, ReadableArgs}; +use util::ser::{Writeable, Readable, Writer}; use util::logger::Logger; use std::cmp; @@ -23,7 +23,11 @@ use std::collections::BTreeMap; use std::collections::btree_map::Entry as BtreeEntry; use std; -/// Receives network updates from peers to track view of the network. +/// Receives and validates network updates from peers, +/// stores authentic and relevant data as a network graph. +/// This network graph is then used for routing payments. +/// Provides interface to help with initial routing sync by +/// serving historical announcements. pub struct NetGraphMsgHandler { secp_ctx: Secp256k1, /// Representation of the payment channel network @@ -34,7 +38,11 @@ pub struct NetGraphMsgHandler { } impl NetGraphMsgHandler { - /// Creates a new tracker of the actual state of the network of channels and nodes. + /// Creates a new tracker of the actual state of the network of channels and nodes, + /// assuming a fresh network graph. + /// Chain monitor is used to make sure announced channels exist on-chain, + /// channel data is correct, and that the announcement is signed with + /// channel owners' keys. pub fn new(chain_monitor: Arc, logger: Arc) -> Self { NetGraphMsgHandler { secp_ctx: Secp256k1::verification_only(), @@ -48,16 +56,16 @@ impl NetGraphMsgHandler { } } - /// Get network addresses by node id - pub fn get_addresses(&self, pubkey: &PublicKey) -> Option> { - let network = self.network_graph.read().unwrap(); - network.get_nodes().get(pubkey).map(|n| n.addresses.clone()) - } - - /// Dumps the entire network view of this NetGraphMsgHandler to the logger provided in the constructor at - /// level Trace - pub fn trace_state(&self) { - log_trace!(self, "{}", self.network_graph.read().unwrap()); + /// Creates a new tracker of the actual state of the network of channels and nodes, + /// assuming an existing Network Graph. + pub fn from_net_graph(chain_monitor: Arc, logger: Arc, network_graph: RwLock) -> Self { + NetGraphMsgHandler { + secp_ctx: Secp256k1::verification_only(), + network_graph: network_graph, + full_syncs_requested: AtomicUsize::new(0), + chain_monitor, + logger: logger.clone(), + } } } @@ -73,7 +81,7 @@ macro_rules! secp_verify_sig { impl RoutingMessageHandler for NetGraphMsgHandler { fn handle_node_announcement(&self, msg: &msgs::NodeAnnouncement) -> Result { - self.network_graph.write().unwrap().process_node_announcement(msg, Some(&self.secp_ctx)) + self.network_graph.write().unwrap().update_node_from_announcement(msg, Some(&self.secp_ctx)) } fn handle_channel_announcement(&self, msg: &msgs::ChannelAnnouncement) -> Result { @@ -106,7 +114,7 @@ impl RoutingMessageHandler for NetGraphMsgHandler { return Err(LightningError{err: "Channel announced without corresponding UTXO entry", action: ErrorAction::IgnoreError}); }, }; - let result = self.network_graph.write().unwrap().process_channel_announcement(msg, checked_utxo, Some(&self.secp_ctx)); + let result = self.network_graph.write().unwrap().update_channel_from_announcement(msg, checked_utxo, Some(&self.secp_ctx)); log_trace!(self, "Added channel_announcement for {}{}", msg.contents.short_channel_id, if !msg.contents.excess_data.is_empty() { " with excess uninterpreted data!" } else { "" }); result } @@ -114,19 +122,19 @@ impl RoutingMessageHandler for NetGraphMsgHandler { fn handle_htlc_fail_channel_update(&self, update: &msgs::HTLCFailChannelUpdate) { match update { &msgs::HTLCFailChannelUpdate::ChannelUpdateMessage { ref msg } => { - let _ = self.network_graph.write().unwrap().process_channel_update(msg, Some(&self.secp_ctx)); + let _ = self.network_graph.write().unwrap().update_channel(msg, Some(&self.secp_ctx)); }, &msgs::HTLCFailChannelUpdate::ChannelClosed { ref short_channel_id, ref is_permanent } => { - self.network_graph.write().unwrap().process_channel_closing(short_channel_id, &is_permanent); + self.network_graph.write().unwrap().close_channel_from_update(short_channel_id, &is_permanent); }, &msgs::HTLCFailChannelUpdate::NodeFailure { ref node_id, ref is_permanent } => { - self.network_graph.write().unwrap().process_node_failure(node_id, &is_permanent); + self.network_graph.write().unwrap().fail_node(node_id, &is_permanent); }, } } fn handle_channel_update(&self, msg: &msgs::ChannelUpdate) -> Result { - self.network_graph.write().unwrap().process_channel_update(msg, Some(&self.secp_ctx)) + self.network_graph.write().unwrap().update_channel(msg, Some(&self.secp_ctx)) } fn get_next_channel_announcements(&self, starting_point: u64, batch_amount: u8) -> Vec<(msgs::ChannelAnnouncement, Option, Option)> { @@ -136,9 +144,16 @@ impl RoutingMessageHandler for NetGraphMsgHandler { while result.len() < batch_amount as usize { if let Some((_, ref chan)) = iter.next() { if chan.announcement_message.is_some() { - result.push((chan.announcement_message.clone().unwrap(), - chan.one_to_two.last_update_message.clone(), - chan.two_to_one.last_update_message.clone())); + let chan_announcement = chan.announcement_message.clone().unwrap(); + let mut one_to_two_announcement: Option = None; + let mut two_to_one_announcement: Option = None; + if let Some(one_to_two) = chan.one_to_two.as_ref() { + one_to_two_announcement = one_to_two.last_update_message.clone(); + } + if let Some(two_to_one) = chan.two_to_one.as_ref() { + two_to_one_announcement = two_to_one.last_update_message.clone(); + } + result.push((chan_announcement, one_to_two_announcement, two_to_one_announcement)); } else { // TODO: We may end up sending un-announced channel_updates if we are sending // initial sync data while receiving announce/updates for this channel. @@ -162,8 +177,10 @@ impl RoutingMessageHandler for NetGraphMsgHandler { }; while result.len() < batch_amount as usize { if let Some((_, ref node)) = iter.next() { - if node.announcement_message.is_some() { - result.push(node.announcement_message.clone().unwrap()); + if let Some(node_info) = node.announcement_info.as_ref() { + if node_info.announcement_message.is_some() { + result.push(node_info.announcement_message.clone().unwrap()); + } } } else { return result; @@ -184,63 +201,16 @@ impl RoutingMessageHandler for NetGraphMsgHandler { } } - -const SERIALIZATION_VERSION: u8 = 1; -const MIN_SERIALIZATION_VERSION: u8 = 1; - -impl Writeable for NetGraphMsgHandler { - fn write(&self, writer: &mut W) -> Result<(), ::std::io::Error> { - writer.write_all(&[SERIALIZATION_VERSION; 1])?; - writer.write_all(&[MIN_SERIALIZATION_VERSION; 1])?; - - let network = self.network_graph.read().unwrap(); - network.write(writer)?; - Ok(()) - } -} - -/// Arguments for the creation of a NetGraphMsgHandler that are not deserialized. -/// At a high-level, the process for deserializing a NetGraphMsgHandler and resuming normal operation is: -/// 1) Deserialize the NetGraphMsgHandler by filling in this struct and calling ::read(reaser, args). -/// 2) Register the new NetGraphMsgHandler with your ChainWatchInterface -pub struct NetGraphMsgHandlerReadArgs { - /// The ChainWatchInterface for use in the NetGraphMsgHandler in the future. - /// - /// No calls to the ChainWatchInterface will be made during deserialization. - pub chain_monitor: Arc, - /// The Logger for use in the ChannelManager and which may be used to log information during - /// deserialization. - pub logger: Arc, -} - -impl ReadableArgs for NetGraphMsgHandler { - fn read(reader: &mut R, args: NetGraphMsgHandlerReadArgs) -> Result { - let _ver: u8 = Readable::read(reader)?; - let min_ver: u8 = Readable::read(reader)?; - if min_ver > SERIALIZATION_VERSION { - return Err(DecodeError::UnknownVersion); - } - let network_graph = Readable::read(reader)?; - Ok(NetGraphMsgHandler { - secp_ctx: Secp256k1::verification_only(), - network_graph: RwLock::new(network_graph), - chain_monitor: args.chain_monitor, - full_syncs_requested: AtomicUsize::new(0), - logger: args.logger.clone(), - }) - } -} - -#[derive(PartialEq)] -/// Details regarding one direction of a channel +#[derive(PartialEq, Debug)] +/// Details about one direction of a channel. Received +/// within a channel update. pub struct DirectionalChannelInfo { - /// A node from which the channel direction starts - pub src_node_id: PublicKey, - /// When the last update to the channel direction was issued + /// When the last update to the channel direction was issued. + /// Value is opaque, as set in the announcement. pub last_update: u32, - /// Whether the channel can be currently used for payments + /// Whether the channel can be currently used for payments (in this one direction). pub enabled: bool, - /// The difference in CLTV values between the source and the destination node of the channel + /// The difference in CLTV values that you must have when routing through this channel. pub cltv_expiry_delta: u16, /// The minimum value, which must be relayed to the next hop via the channel pub htlc_minimum_msat: u64, @@ -252,13 +222,12 @@ pub struct DirectionalChannelInfo { impl std::fmt::Display for DirectionalChannelInfo { fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { - write!(f, "src_node_id {}, last_update {}, enabled {}, cltv_expiry_delta {}, htlc_minimum_msat {}, fees {:?}", log_pubkey!(self.src_node_id), self.last_update, self.enabled, self.cltv_expiry_delta, self.htlc_minimum_msat, self.fees)?; + write!(f, "last_update {}, enabled {}, cltv_expiry_delta {}, htlc_minimum_msat {}, fees {:?}", self.last_update, self.enabled, self.cltv_expiry_delta, self.htlc_minimum_msat, self.fees)?; Ok(()) } } impl_writeable!(DirectionalChannelInfo, 0, { - src_node_id, last_update, enabled, cltv_expiry_delta, @@ -268,30 +237,39 @@ impl_writeable!(DirectionalChannelInfo, 0, { }); #[derive(PartialEq)] -/// Details regarding a channel (both directions) +/// Details about a channel (both directions). +/// Received within a channel announcement. pub struct ChannelInfo { /// Protocol features of a channel communicated during its announcement pub features: ChannelFeatures, - /// Details regarding one of the directions of a channel - pub one_to_two: DirectionalChannelInfo, - /// Details regarding another direction of a channel - pub two_to_one: DirectionalChannelInfo, + /// Source node of the first direction of a channel + pub node_one: PublicKey, + /// Details about the first direction of a channel + pub one_to_two: Option, + /// Source node of the second direction of a channel + pub node_two: PublicKey, + /// Details about the second direction of a channel + pub two_to_one: Option, /// An initial announcement of the channel - //this is cached here so we can send out it later if required by initial routing sync - //keep an eye on this to see if the extra memory is a problem + /// Mostly redundant with the data we store in fields explicitly. + /// Everything else is useful only for sending out for initial routing sync. + /// Not stored if contains excess data to prevent DoS. pub announcement_message: Option, } impl std::fmt::Display for ChannelInfo { fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { - write!(f, "features: {}, one_to_two: {}, two_to_one: {}", log_bytes!(self.features.encode()), self.one_to_two, self.two_to_one)?; + write!(f, "features: {}, node_one: {}, one_to_two: {:?}, node_two: {}, two_to_one: {:?}", + log_bytes!(self.features.encode()), log_pubkey!(self.node_one), self.one_to_two, log_pubkey!(self.node_two), self.two_to_one)?; Ok(()) } } impl_writeable!(ChannelInfo, 0, { features, + node_one, one_to_two, + node_two, two_to_one, announcement_message }); @@ -300,9 +278,10 @@ impl_writeable!(ChannelInfo, 0, { /// Fees for routing via a given channel or a node #[derive(Eq, PartialEq, Copy, Clone, Debug)] pub struct RoutingFees { - /// Flat routing fee + /// Flat routing fee in satoshis pub base_msat: u32, - /// Liquidity-based routing fee + /// Liquidity-based routing fee in millionths of a routed amount. + /// In other words, 10000 is 1%. pub proportional_millionths: u32, } @@ -325,47 +304,31 @@ impl Writeable for RoutingFees { } } - -#[derive(PartialEq)] -/// Details regarding a node in the network -pub struct NodeInfo { - /// All valid channels a node has announced - pub channels: Vec, - /// Lowest fees enabling routing via any of the known channels to a node - pub lowest_inbound_channel_fees: Option, +#[derive(PartialEq, Debug)] +/// Information received in the latest node_announcement from this node. +pub struct NodeAnnouncementInfo { /// Protocol features the node announced support for - pub features: NodeFeatures, - /// When the last known update to the node state was issued - /// Unlike for channels, we may have a NodeInfo entry before having received a node_update. - /// Thus, we have to be able to capture "no update has been received", which we do with an - /// Option here. - pub last_update: Option, + pub features: NodeFeatures, + /// When the last known update to the node state was issued. + /// Value is opaque, as set in the announcement. + pub last_update: u32, /// Color assigned to the node pub rgb: [u8; 3], - /// Moniker assigned to the node + /// Moniker assigned to the node. + /// May be invalid or malicious (eg control chars), + /// should not be exposed to the user. pub alias: [u8; 32], /// Internet-level addresses via which one can connect to the node pub addresses: Vec, /// An initial announcement of the node - //this is cached here so we can send out it later if required by initial routing sync - //keep an eye on this to see if the extra memory is a problem - pub announcement_message: Option, -} - -impl std::fmt::Display for NodeInfo { - fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { - write!(f, "features: {}, last_update: {:?}, lowest_inbound_channel_fees: {:?}, channels: {:?}", log_bytes!(self.features.encode()), self.last_update, self.lowest_inbound_channel_fees, &self.channels[..])?; - Ok(()) - } + /// Mostly redundant with the data we store in fields explicitly. + /// Everything else is useful only for sending out for initial routing sync. + /// Not stored if contains excess data to prevent DoS. + pub announcement_message: Option } -impl Writeable for NodeInfo { +impl Writeable for NodeAnnouncementInfo { fn write(&self, writer: &mut W) -> Result<(), ::std::io::Error> { - (self.channels.len() as u64).write(writer)?; - for ref chan in self.channels.iter() { - chan.write(writer)?; - } - self.lowest_inbound_channel_fees.write(writer)?; self.features.write(writer)?; self.last_update.write(writer)?; self.rgb.write(writer)?; @@ -379,16 +342,8 @@ impl Writeable for NodeInfo { } } -const MAX_ALLOC_SIZE: u64 = 64*1024; - -impl Readable for NodeInfo { - fn read(reader: &mut R) -> Result { - let channels_count: u64 = Readable::read(reader)?; - let mut channels = Vec::with_capacity(cmp::min(channels_count, MAX_ALLOC_SIZE / 8) as usize); - for _ in 0..channels_count { - channels.push(Readable::read(reader)?); - } - let lowest_inbound_channel_fees = Readable::read(reader)?; +impl Readable for NodeAnnouncementInfo { + fn read(reader: &mut R) -> Result { let features = Readable::read(reader)?; let last_update = Readable::read(reader)?; let rgb = Readable::read(reader)?; @@ -404,9 +359,7 @@ impl Readable for NodeInfo { } } let announcement_message = Readable::read(reader)?; - Ok(NodeInfo { - channels, - lowest_inbound_channel_fees, + Ok(NodeAnnouncementInfo { features, last_update, rgb, @@ -417,6 +370,60 @@ impl Readable for NodeInfo { } } +#[derive(PartialEq)] +/// Details about a node in the network, known from the network announcement. +pub struct NodeInfo { + /// All valid channels a node has announced + pub channels: Vec, + /// Lowest fees enabling routing via any of the known channels to a node. + /// The two fields (flat and proportional fee) are independent, + /// meaning they don't have to refer to the same channel. + pub lowest_inbound_channel_fees: Option, + /// More information about a node from node_announcement. + /// Optional because we store a Node entry after learning about it from + /// a channel announcement, but before receiving a node announcement. + pub announcement_info: Option +} + +impl std::fmt::Display for NodeInfo { + fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { + write!(f, "lowest_inbound_channel_fees: {:?}, channels: {:?}, announcement_info: {:?}", + self.lowest_inbound_channel_fees, &self.channels[..], self.announcement_info)?; + Ok(()) + } +} + +impl Writeable for NodeInfo { + fn write(&self, writer: &mut W) -> Result<(), ::std::io::Error> { + (self.channels.len() as u64).write(writer)?; + for ref chan in self.channels.iter() { + chan.write(writer)?; + } + self.lowest_inbound_channel_fees.write(writer)?; + self.announcement_info.write(writer)?; + Ok(()) + } +} + +const MAX_ALLOC_SIZE: u64 = 64*1024; + +impl Readable for NodeInfo { + fn read(reader: &mut R) -> Result { + let channels_count: u64 = Readable::read(reader)?; + let mut channels = Vec::with_capacity(cmp::min(channels_count, MAX_ALLOC_SIZE / 8) as usize); + for _ in 0..channels_count { + channels.push(Readable::read(reader)?); + } + let lowest_inbound_channel_fees = Readable::read(reader)?; + let announcement_info = Readable::read(reader)?; + Ok(NodeInfo { + channels, + lowest_inbound_channel_fees, + announcement_info, + }) + } +} + /// Represents the network as nodes and channels between them #[derive(PartialEq)] pub struct NetworkGraph { @@ -478,12 +485,26 @@ impl std::fmt::Display for NetworkGraph { } impl NetworkGraph { - /// Returns all known valid channels + /// Returns all known valid channels' short ids along with announced channel info. pub fn get_channels<'a>(&'a self) -> &'a BTreeMap { &self.channels } - /// Returns all known nodes + /// Returns all known nodes' public keys along with announced node info. pub fn get_nodes<'a>(&'a self) -> &'a BTreeMap { &self.nodes } - fn process_node_announcement(&mut self, msg: &msgs::NodeAnnouncement, secp_ctx: Option<&Secp256k1>) -> Result { + /// Get network addresses by node id. + /// Returns None if the requested node is completely unknown, + /// or if node announcement for the node was never received. + pub fn get_addresses<'a>(&'a self, pubkey: &PublicKey) -> Option<&'a Vec> { + if let Some(node) = self.nodes.get(pubkey) { + if let Some(node_info) = node.announcement_info.as_ref() { + return Some(&node_info.addresses) + } + } + None + } + + /// For an already known node (from channel announcements), update its stored properties from a given node announcement + /// Announcement signatures are checked here only if Secp256k1 object is provided. + fn update_node_from_announcement(&mut self, msg: &msgs::NodeAnnouncement, secp_ctx: Option<&Secp256k1>) -> Result { if let Some(sig_verifier) = secp_ctx { let msg_hash = hash_to_message!(&Sha256dHash::hash(&msg.contents.encode()[..])[..]); secp_verify_sig!(sig_verifier, &msg_hash, &msg.signature, &msg.contents.node_id); @@ -492,27 +513,34 @@ impl NetworkGraph { match self.nodes.get_mut(&msg.contents.node_id) { None => Err(LightningError{err: "No existing channels for node_announcement", action: ErrorAction::IgnoreError}), Some(node) => { - match node.last_update { - Some(last_update) => if last_update >= msg.contents.timestamp { + if let Some(node_info) = node.announcement_info.as_ref() { + if node_info.last_update >= msg.contents.timestamp { return Err(LightningError{err: "Update older than last processed update", action: ErrorAction::IgnoreError}); - }, - None => {}, + } } - node.features = msg.contents.features.clone(); - node.last_update = Some(msg.contents.timestamp); - node.rgb = msg.contents.rgb; - node.alias = msg.contents.alias; - node.addresses = msg.contents.addresses.clone(); - let should_relay = msg.contents.excess_data.is_empty() && msg.contents.excess_address_data.is_empty(); - node.announcement_message = if should_relay { Some(msg.clone()) } else { None }; + node.announcement_info = Some(NodeAnnouncementInfo { + features: msg.contents.features.clone(), + last_update: msg.contents.timestamp, + rgb: msg.contents.rgb, + alias: msg.contents.alias, + addresses: msg.contents.addresses.clone(), + announcement_message: if should_relay { Some(msg.clone()) } else { None }, + }); + Ok(should_relay) } } } - fn process_channel_announcement(&mut self, msg: &msgs::ChannelAnnouncement, checked_utxo: bool, secp_ctx: Option<&Secp256k1>) -> Result { + /// For a new or already known (from previous announcement) channel, store or update channel info. + /// Also store nodes (if not stored yet) the channel is between, and make node aware of this channel. + /// Checking utxo on-chain is useful if we receive an update for already known channel id, + /// which is probably result of a reorg. In that case, we update channel info only if the + /// utxo was checked, otherwise stick to the existing update, to prevent DoS risks. + /// Announcement signatures are checked here only if Secp256k1 object is provided. + fn update_channel_from_announcement(&mut self, msg: &msgs::ChannelAnnouncement, checked_utxo: bool, secp_ctx: Option<&Secp256k1>) -> Result { if let Some(sig_verifier) = secp_ctx { let msg_hash = hash_to_message!(&Sha256dHash::hash(&msg.contents.encode()[..])[..]); secp_verify_sig!(sig_verifier, &msg_hash, &msg.node_signature_1, &msg.contents.node_id_1); @@ -525,30 +553,10 @@ impl NetworkGraph { let chan_info = ChannelInfo { features: msg.contents.features.clone(), - one_to_two: DirectionalChannelInfo { - src_node_id: msg.contents.node_id_1.clone(), - last_update: 0, - enabled: false, - cltv_expiry_delta: u16::max_value(), - htlc_minimum_msat: u64::max_value(), - fees: RoutingFees { - base_msat: u32::max_value(), - proportional_millionths: u32::max_value(), - }, - last_update_message: None, - }, - two_to_one: DirectionalChannelInfo { - src_node_id: msg.contents.node_id_2.clone(), - last_update: 0, - enabled: false, - cltv_expiry_delta: u16::max_value(), - htlc_minimum_msat: u64::max_value(), - fees: RoutingFees { - base_msat: u32::max_value(), - proportional_millionths: u32::max_value(), - }, - last_update_message: None, - }, + node_one: msg.contents.node_id_1.clone(), + one_to_two: None, + node_two: msg.contents.node_id_2.clone(), + two_to_one: None, announcement_message: if should_relay { Some(msg.clone()) } else { None }, }; @@ -587,12 +595,7 @@ impl NetworkGraph { node_entry.insert(NodeInfo { channels: vec!(msg.contents.short_channel_id), lowest_inbound_channel_fees: None, - features: NodeFeatures::empty(), - last_update: None, - rgb: [0; 3], - alias: [0; 32], - addresses: Vec::new(), - announcement_message: None, + announcement_info: None, }); } } @@ -605,20 +608,28 @@ impl NetworkGraph { Ok(should_relay) } - fn process_channel_closing(&mut self, short_channel_id: &u64, is_permanent: &bool) { + /// Close a channel if a corresponding HTLC fail was sent. + /// If permanent, removes a channel from the local storage. + /// May cause the removal of nodes too, if this was their last channel. + /// If not permanent, makes channels unavailable for routing. + pub fn close_channel_from_update(&mut self, short_channel_id: &u64, is_permanent: &bool) { if *is_permanent { if let Some(chan) = self.channels.remove(short_channel_id) { Self::remove_channel_in_nodes(&mut self.nodes, &chan, *short_channel_id); } } else { if let Some(chan) = self.channels.get_mut(&short_channel_id) { - chan.one_to_two.enabled = false; - chan.two_to_one.enabled = false; + if let Some(one_to_two) = chan.one_to_two.as_mut() { + one_to_two.enabled = false; + } + if let Some(two_to_one) = chan.two_to_one.as_mut() { + two_to_one.enabled = false; + } } } } - fn process_node_failure(&mut self, _node_id: &PublicKey, is_permanent: &bool) { + fn fail_node(&mut self, _node_id: &PublicKey, is_permanent: &bool) { if *is_permanent { // TODO: Wholly remove the node } else { @@ -626,7 +637,9 @@ impl NetworkGraph { } } - fn process_channel_update(&mut self, msg: &msgs::ChannelUpdate, secp_ctx: Option<&Secp256k1>) -> Result { + /// For an already known (from announcement) channel, update info about one of the directions of a channel. + /// Announcement signatures are checked here only if Secp256k1 object is provided. + fn update_channel(&mut self, msg: &msgs::ChannelUpdate, secp_ctx: Option<&Secp256k1>) -> Result { let dest_node_id; let chan_enabled = msg.contents.flags & (1 << 1) != (1 << 1); let chan_was_enabled; @@ -635,37 +648,50 @@ impl NetworkGraph { None => return Err(LightningError{err: "Couldn't find channel for update", action: ErrorAction::IgnoreError}), Some(channel) => { macro_rules! maybe_update_channel_info { - ( $target: expr) => { - if $target.last_update >= msg.contents.timestamp { - return Err(LightningError{err: "Update older than last processed update", action: ErrorAction::IgnoreError}); + ( $target: expr, $src_node: expr) => { + if let Some(existing_chan_info) = $target.as_ref() { + if existing_chan_info.last_update >= msg.contents.timestamp { + return Err(LightningError{err: "Update older than last processed update", action: ErrorAction::IgnoreError}); + } + chan_was_enabled = existing_chan_info.enabled; + } else { + chan_was_enabled = false; } - chan_was_enabled = $target.enabled; - $target.last_update = msg.contents.timestamp; - $target.enabled = chan_enabled; - $target.cltv_expiry_delta = msg.contents.cltv_expiry_delta; - $target.htlc_minimum_msat = msg.contents.htlc_minimum_msat; - $target.fees.base_msat = msg.contents.fee_base_msat; - $target.fees.proportional_millionths = msg.contents.fee_proportional_millionths; - $target.last_update_message = if msg.contents.excess_data.is_empty() { + + let last_update_message = if msg.contents.excess_data.is_empty() { Some(msg.clone()) } else { None }; + + let updated_channel_dir_info = DirectionalChannelInfo { + enabled: chan_enabled, + last_update: msg.contents.timestamp, + cltv_expiry_delta: msg.contents.cltv_expiry_delta, + htlc_minimum_msat: msg.contents.htlc_minimum_msat, + fees: RoutingFees { + base_msat: msg.contents.fee_base_msat, + proportional_millionths: msg.contents.fee_proportional_millionths, + }, + last_update_message + }; + $target = Some(updated_channel_dir_info); } } + let msg_hash = hash_to_message!(&Sha256dHash::hash(&msg.contents.encode()[..])[..]); if msg.contents.flags & 1 == 1 { - dest_node_id = channel.one_to_two.src_node_id.clone(); + dest_node_id = channel.node_one.clone(); if let Some(sig_verifier) = secp_ctx { - secp_verify_sig!(sig_verifier, &msg_hash, &msg.signature, &channel.two_to_one.src_node_id); + secp_verify_sig!(sig_verifier, &msg_hash, &msg.signature, &channel.node_two); } - maybe_update_channel_info!(channel.two_to_one); + maybe_update_channel_info!(channel.two_to_one, channel.node_two); } else { - dest_node_id = channel.two_to_one.src_node_id.clone(); + dest_node_id = channel.node_two.clone(); if let Some(sig_verifier) = secp_ctx { - secp_verify_sig!(sig_verifier, &msg_hash, &msg.signature, &channel.one_to_two.src_node_id); + secp_verify_sig!(sig_verifier, &msg_hash, &msg.signature, &channel.node_one); } - maybe_update_channel_info!(channel.one_to_two); + maybe_update_channel_info!(channel.one_to_two, channel.node_one); } } } @@ -691,13 +717,15 @@ impl NetworkGraph { for chan_id in node.channels.iter() { let chan = self.channels.get(chan_id).unwrap(); - if chan.one_to_two.src_node_id == dest_node_id { - lowest_inbound_channel_fee_base_msat = cmp::min(lowest_inbound_channel_fee_base_msat, chan.two_to_one.fees.base_msat); - lowest_inbound_channel_fee_proportional_millionths = cmp::min(lowest_inbound_channel_fee_proportional_millionths, chan.two_to_one.fees.proportional_millionths); + // Since direction was enabled, the channel indeed had directional info + let chan_info; + if chan.node_one == dest_node_id { + chan_info = chan.two_to_one.as_ref().unwrap(); } else { - lowest_inbound_channel_fee_base_msat = cmp::min(lowest_inbound_channel_fee_base_msat, chan.one_to_two.fees.base_msat); - lowest_inbound_channel_fee_proportional_millionths = cmp::min(lowest_inbound_channel_fee_proportional_millionths, chan.one_to_two.fees.proportional_millionths); + chan_info = chan.one_to_two.as_ref().unwrap(); } + lowest_inbound_channel_fee_base_msat = cmp::min(lowest_inbound_channel_fee_base_msat, chan_info.fees.base_msat); + lowest_inbound_channel_fee_proportional_millionths = cmp::min(lowest_inbound_channel_fee_proportional_millionths, chan_info.fees.proportional_millionths); } } @@ -729,8 +757,9 @@ impl NetworkGraph { } } } - remove_from_node!(chan.one_to_two.src_node_id); - remove_from_node!(chan.two_to_one.src_node_id); + + remove_from_node!(chan.node_one); + remove_from_node!(chan.node_two); } } @@ -1141,8 +1170,8 @@ mod tests { match network.get_channels().get(&short_channel_id) { None => panic!(), Some(channel_info) => { - assert_eq!(channel_info.one_to_two.cltv_expiry_delta, 144); - assert_eq!(channel_info.two_to_one.cltv_expiry_delta, u16::max_value()); + assert_eq!(channel_info.one_to_two.as_ref().unwrap().cltv_expiry_delta, 144); + assert!(channel_info.two_to_one.is_none()); } } } @@ -1247,6 +1276,38 @@ mod tests { Err(_) => panic!() }; + let unsigned_channel_update = UnsignedChannelUpdate { + chain_hash, + short_channel_id, + timestamp: 100, + flags: 0, + cltv_expiry_delta: 144, + htlc_minimum_msat: 1000000, + fee_base_msat: 10000, + fee_proportional_millionths: 20, + excess_data: Vec::new() + }; + let msghash = hash_to_message!(&Sha256dHash::hash(&unsigned_channel_update.encode()[..])[..]); + let valid_channel_update = ChannelUpdate { + signature: secp_ctx.sign(&msghash, node_1_privkey), + contents: unsigned_channel_update.clone() + }; + + match net_graph_msg_handler.handle_channel_update(&valid_channel_update) { + Ok(res) => assert!(res), + _ => panic!() + }; + } + + // Non-permanent closing just disables a channel + { + let network = net_graph_msg_handler.network_graph.read().unwrap(); + match network.get_channels().get(&short_channel_id) { + None => panic!(), + Some(channel_info) => { + assert!(channel_info.one_to_two.is_some()); + } + } } let channel_close_msg = HTLCFailChannelUpdate::ChannelClosed { @@ -1262,8 +1323,7 @@ mod tests { match network.get_channels().get(&short_channel_id) { None => panic!(), Some(channel_info) => { - assert!(!channel_info.one_to_two.enabled); - assert!(!channel_info.two_to_one.enabled); + assert!(!channel_info.one_to_two.as_ref().unwrap().enabled); } } }