Accept and track counterparty inbound forwarding fees in Channel
[rust-lightning] / lightning / src / ln / channelmanager.rs
index e499576a3f8c6aa8c89abc1117cce541628b3598..7be1fe43fd206bb2c211334a113284fa56726d82 100644 (file)
@@ -898,10 +898,14 @@ pub(crate) const IDEMPOTENCY_TIMEOUT_TICKS: u8 = 7;
 /// Information needed for constructing an invoice route hint for this channel.
 #[derive(Clone, Debug, PartialEq)]
 pub struct CounterpartyForwardingInfo {
-       /// Base routing fee in millisatoshis.
+       /// Base outbound routing fee in millisatoshis.
        pub fee_base_msat: u32,
-       /// Amount in millionths of a satoshi the channel will charge per transferred satoshi.
+       /// Amount in millionths of a satoshi the channel will charge per outbound transferred satoshi.
        pub fee_proportional_millionths: u32,
+       /// Base inbound routing fee in millisatoshis.
+       pub inbound_fee_base_msat: i32,
+       /// Amount in millionths of a satoshi the channel will charge per inbound transferred satoshi.
+       pub inbound_fee_proportional_millionths: i32,
        /// The minimum difference in cltv_expiry between an ingoing HTLC and its outgoing counterpart,
        /// such that the outgoing HTLC is forwardable to this counterparty. See `msgs::ChannelUpdate`'s
        /// `cltv_expiry_delta` for more details.
@@ -4755,6 +4759,22 @@ where
                Ok(NotifyOption::DoPersist)
        }
 
+       fn internal_inbound_fees_update(&self, counterparty_node_id: &PublicKey, msg: &msgs::InboundFeesUpdate) -> Result<(), MsgHandleErrInternal> {
+               let mut channel_state_lock = self.channel_state.lock().unwrap();
+               let channel_state = &mut *channel_state_lock;
+               match channel_state.by_id.entry(msg.channel_id) {
+                       hash_map::Entry::Occupied(mut chan) => {
+                               if chan.get().get_counterparty_node_id() != *counterparty_node_id {
+                                       return Err(MsgHandleErrInternal::send_err_msg_no_close("Got a inbound_fees_update for a channel from the wrong node".to_owned(), msg.channel_id));
+                               }
+                               try_chan_entry!(self, chan.get_mut().inbound_fees_update(&msg), chan);
+                               Ok(())
+                       },
+                       hash_map::Entry::Vacant(_) =>
+                               Err(MsgHandleErrInternal::send_err_msg_no_close("Got a inbound_fees_update with no known channel".to_owned(), msg.channel_id))
+               }
+       }
+
        fn internal_channel_reestablish(&self, counterparty_node_id: &PublicKey, msg: &msgs::ChannelReestablish) -> Result<(), MsgHandleErrInternal> {
                let htlc_forwards;
                let need_lnd_workaround = {
@@ -5790,7 +5810,8 @@ where
        }
 
        fn handle_inbound_fees_update(&self, counterparty_node_id: &PublicKey, msg: &msgs::InboundFeesUpdate) {
-               // TODO
+               let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
+               let _ = handle_error!(self, self.internal_inbound_fees_update(counterparty_node_id, msg), *counterparty_node_id);
        }
 
        fn handle_channel_reestablish(&self, counterparty_node_id: &PublicKey, msg: &msgs::ChannelReestablish) {
@@ -6004,7 +6025,9 @@ const SERIALIZATION_VERSION: u8 = 1;
 const MIN_SERIALIZATION_VERSION: u8 = 1;
 
 impl_writeable_tlv_based!(CounterpartyForwardingInfo, {
+       (1, inbound_fee_base_msat, (default_value, 0)),
        (2, fee_base_msat, required),
+       (3, inbound_fee_proportional_millionths, (default_value, 0)),
        (4, fee_proportional_millionths, required),
        (6, cltv_expiry_delta, required),
 });