Persist update_add sender skimmed fee in Channel
authorValentine Wallace <vwallace@protonmail.com>
Fri, 9 Jun 2023 11:42:07 +0000 (13:42 +0200)
committerValentine Wallace <vwallace@protonmail.com>
Tue, 20 Jun 2023 21:57:38 +0000 (17:57 -0400)
lightning/src/ln/channel.rs

index e24fed12aa7985ce4a855f2ce65cc3a7e47c314b..3a81b8d17f855f353ffafbdf0ffb70710a48e72f 100644 (file)
@@ -224,6 +224,7 @@ struct OutboundHTLCOutput {
        payment_hash: PaymentHash,
        state: OutboundHTLCState,
        source: HTLCSource,
+       skimmed_fee_msat: Option<u64>,
 }
 
 /// See AwaitingRemoteRevoke ChannelState for more info
@@ -235,6 +236,8 @@ enum HTLCUpdateAwaitingACK {
                payment_hash: PaymentHash,
                source: HTLCSource,
                onion_routing_packet: msgs::OnionPacket,
+               // The extra fee we're skimming off the top of this HTLC.
+               skimmed_fee_msat: Option<u64>,
        },
        ClaimHTLC {
                payment_preimage: PaymentPreimage,
@@ -5126,6 +5129,7 @@ impl<Signer: WriteableEcdsaChannelSigner> Channel<Signer> {
                                cltv_expiry,
                                source,
                                onion_routing_packet,
+                               skimmed_fee_msat: None,
                        });
                        return Ok(None);
                }
@@ -5137,6 +5141,7 @@ impl<Signer: WriteableEcdsaChannelSigner> Channel<Signer> {
                        cltv_expiry,
                        state: OutboundHTLCState::LocalAnnounced(Box::new(onion_routing_packet.clone())),
                        source,
+                       skimmed_fee_msat: None,
                });
 
                let res = msgs::UpdateAddHTLC {
@@ -6611,9 +6616,10 @@ impl<Signer: WriteableEcdsaChannelSigner> Writeable for Channel<Signer> {
                }
 
                let mut preimages: Vec<&Option<PaymentPreimage>> = vec![];
+               let mut pending_outbound_skimmed_fees: Vec<Option<u64>> = Vec::new();
 
                (self.context.pending_outbound_htlcs.len() as u64).write(writer)?;
-               for htlc in self.context.pending_outbound_htlcs.iter() {
+               for (idx, htlc) in self.context.pending_outbound_htlcs.iter().enumerate() {
                        htlc.htlc_id.write(writer)?;
                        htlc.amount_msat.write(writer)?;
                        htlc.cltv_expiry.write(writer)?;
@@ -6649,18 +6655,37 @@ impl<Signer: WriteableEcdsaChannelSigner> Writeable for Channel<Signer> {
                                        reason.write(writer)?;
                                }
                        }
+                       if let Some(skimmed_fee) = htlc.skimmed_fee_msat {
+                               if pending_outbound_skimmed_fees.is_empty() {
+                                       for _ in 0..idx { pending_outbound_skimmed_fees.push(None); }
+                               }
+                               pending_outbound_skimmed_fees.push(Some(skimmed_fee));
+                       } else if !pending_outbound_skimmed_fees.is_empty() {
+                               pending_outbound_skimmed_fees.push(None);
+                       }
                }
 
+               let mut holding_cell_skimmed_fees: Vec<Option<u64>> = Vec::new();
                (self.context.holding_cell_htlc_updates.len() as u64).write(writer)?;
-               for update in self.context.holding_cell_htlc_updates.iter() {
+               for (idx, update) in self.context.holding_cell_htlc_updates.iter().enumerate() {
                        match update {
-                               &HTLCUpdateAwaitingACK::AddHTLC { ref amount_msat, ref cltv_expiry, ref payment_hash, ref source, ref onion_routing_packet } => {
+                               &HTLCUpdateAwaitingACK::AddHTLC {
+                                       ref amount_msat, ref cltv_expiry, ref payment_hash, ref source, ref onion_routing_packet,
+                                       skimmed_fee_msat,
+                               } => {
                                        0u8.write(writer)?;
                                        amount_msat.write(writer)?;
                                        cltv_expiry.write(writer)?;
                                        payment_hash.write(writer)?;
                                        source.write(writer)?;
                                        onion_routing_packet.write(writer)?;
+
+                                       if let Some(skimmed_fee) = skimmed_fee_msat {
+                                               if holding_cell_skimmed_fees.is_empty() {
+                                                       for _ in 0..idx { holding_cell_skimmed_fees.push(None); }
+                                               }
+                                               holding_cell_skimmed_fees.push(Some(skimmed_fee));
+                                       } else if !holding_cell_skimmed_fees.is_empty() { holding_cell_skimmed_fees.push(None); }
                                },
                                &HTLCUpdateAwaitingACK::ClaimHTLC { ref payment_preimage, ref htlc_id } => {
                                        1u8.write(writer)?;
@@ -6827,6 +6852,8 @@ impl<Signer: WriteableEcdsaChannelSigner> Writeable for Channel<Signer> {
                        (29, self.context.temporary_channel_id, option),
                        (31, channel_pending_event_emitted, option),
                        (33, self.context.pending_monitor_updates, vec_type),
+                       (35, pending_outbound_skimmed_fees, optional_vec),
+                       (37, holding_cell_skimmed_fees, optional_vec),
                });
 
                Ok(())
@@ -6937,6 +6964,7 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
                                        },
                                        _ => return Err(DecodeError::InvalidValue),
                                },
+                               skimmed_fee_msat: None,
                        });
                }
 
@@ -6950,6 +6978,7 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
                                        payment_hash: Readable::read(reader)?,
                                        source: Readable::read(reader)?,
                                        onion_routing_packet: Readable::read(reader)?,
+                                       skimmed_fee_msat: None,
                                },
                                1 => HTLCUpdateAwaitingACK::ClaimHTLC {
                                        payment_preimage: Readable::read(reader)?,
@@ -7105,6 +7134,9 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
 
                let mut pending_monitor_updates = Some(Vec::new());
 
+               let mut pending_outbound_skimmed_fees_opt: Option<Vec<Option<u64>>> = None;
+               let mut holding_cell_skimmed_fees_opt: Option<Vec<Option<u64>>> = None;
+
                read_tlv_fields!(reader, {
                        (0, announcement_sigs, option),
                        (1, minimum_depth, option),
@@ -7128,6 +7160,8 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
                        (29, temporary_channel_id, option),
                        (31, channel_pending_event_emitted, option),
                        (33, pending_monitor_updates, vec_type),
+                       (35, pending_outbound_skimmed_fees_opt, optional_vec),
+                       (37, holding_cell_skimmed_fees_opt, optional_vec),
                });
 
                let (channel_keys_id, holder_signer) = if let Some(channel_keys_id) = channel_keys_id {
@@ -7182,6 +7216,25 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
 
                let holder_max_accepted_htlcs = holder_max_accepted_htlcs.unwrap_or(DEFAULT_MAX_HTLCS);
 
+               if let Some(skimmed_fees) = pending_outbound_skimmed_fees_opt {
+                       let mut iter = skimmed_fees.into_iter();
+                       for htlc in pending_outbound_htlcs.iter_mut() {
+                               htlc.skimmed_fee_msat = iter.next().ok_or(DecodeError::InvalidValue)?;
+                       }
+                       // We expect all skimmed fees to be consumed above
+                       if iter.next().is_some() { return Err(DecodeError::InvalidValue) }
+               }
+               if let Some(skimmed_fees) = holding_cell_skimmed_fees_opt {
+                       let mut iter = skimmed_fees.into_iter();
+                       for htlc in holding_cell_htlc_updates.iter_mut() {
+                               if let HTLCUpdateAwaitingACK::AddHTLC { ref mut skimmed_fee_msat, .. } = htlc {
+                                       *skimmed_fee_msat = iter.next().ok_or(DecodeError::InvalidValue)?;
+                               }
+                       }
+                       // We expect all skimmed fees to be consumed above
+                       if iter.next().is_some() { return Err(DecodeError::InvalidValue) }
+               }
+
                Ok(Channel {
                        context: ChannelContext {
                                user_id,
@@ -7524,7 +7577,8 @@ mod tests {
                                session_priv: SecretKey::from_slice(&hex::decode("0fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").unwrap()[..]).unwrap(),
                                first_hop_htlc_msat: 548,
                                payment_id: PaymentId([42; 32]),
-                       }
+                       },
+                       skimmed_fee_msat: None,
                });
 
                // Make sure when Node A calculates their local commitment transaction, none of the HTLCs pass
@@ -8081,6 +8135,7 @@ mod tests {
                                payment_hash: PaymentHash([0; 32]),
                                state: OutboundHTLCState::Committed,
                                source: HTLCSource::dummy(),
+                               skimmed_fee_msat: None,
                        };
                        out.payment_hash.0 = Sha256::hash(&hex::decode("0202020202020202020202020202020202020202020202020202020202020202").unwrap()).into_inner();
                        out
@@ -8093,6 +8148,7 @@ mod tests {
                                payment_hash: PaymentHash([0; 32]),
                                state: OutboundHTLCState::Committed,
                                source: HTLCSource::dummy(),
+                               skimmed_fee_msat: None,
                        };
                        out.payment_hash.0 = Sha256::hash(&hex::decode("0303030303030303030303030303030303030303030303030303030303030303").unwrap()).into_inner();
                        out
@@ -8494,6 +8550,7 @@ mod tests {
                                payment_hash: PaymentHash([0; 32]),
                                state: OutboundHTLCState::Committed,
                                source: HTLCSource::dummy(),
+                               skimmed_fee_msat: None,
                        };
                        out.payment_hash.0 = Sha256::hash(&hex::decode("0505050505050505050505050505050505050505050505050505050505050505").unwrap()).into_inner();
                        out
@@ -8506,6 +8563,7 @@ mod tests {
                                payment_hash: PaymentHash([0; 32]),
                                state: OutboundHTLCState::Committed,
                                source: HTLCSource::dummy(),
+                               skimmed_fee_msat: None,
                        };
                        out.payment_hash.0 = Sha256::hash(&hex::decode("0505050505050505050505050505050505050505050505050505050505050505").unwrap()).into_inner();
                        out