Correctly fail back on outbound channel check for blinded HTLC
[rust-lightning] / lightning / src / ln / channel.rs
index 8a50d47cd20607cc72131d52d5544eb1884f60e0..46209a0c402d3a181c699579eb21daad00b50631 100644 (file)
@@ -230,6 +230,7 @@ struct OutboundHTLCOutput {
        payment_hash: PaymentHash,
        state: OutboundHTLCState,
        source: HTLCSource,
+       blinding_point: Option<PublicKey>,
        skimmed_fee_msat: Option<u64>,
 }
 
@@ -244,6 +245,7 @@ enum HTLCUpdateAwaitingACK {
                onion_routing_packet: msgs::OnionPacket,
                // The extra fee we're skimming off the top of this HTLC.
                skimmed_fee_msat: Option<u64>,
+               blinding_point: Option<PublicKey>,
        },
        ClaimHTLC {
                payment_preimage: PaymentPreimage,
@@ -2145,6 +2147,7 @@ impl<SP: Deref> ChannelContext<SP> where SP::Target: SignerProvider  {
                                        .map(|(sig, _)| sig).ok()?
                        },
                        // TODO (taproot|arik)
+                       #[cfg(taproot)]
                        _ => todo!()
                };
 
@@ -2199,6 +2202,7 @@ impl<SP: Deref> ChannelContext<SP> where SP::Target: SignerProvider  {
                                (counterparty_initial_commitment_tx, funding_signed)
                        },
                        // TODO (taproot|arik)
+                       #[cfg(taproot)]
                        _ => todo!()
                }
        }
@@ -3354,11 +3358,12 @@ impl<SP: Deref> Channel<SP> where
                                match &htlc_update {
                                        &HTLCUpdateAwaitingACK::AddHTLC {
                                                amount_msat, cltv_expiry, ref payment_hash, ref source, ref onion_routing_packet,
-                                               skimmed_fee_msat, ..
+                                               skimmed_fee_msat, blinding_point, ..
                                        } => {
-                                               match self.send_htlc(amount_msat, *payment_hash, cltv_expiry, source.clone(),
-                                                       onion_routing_packet.clone(), false, skimmed_fee_msat, fee_estimator, logger)
-                                               {
+                                               match self.send_htlc(
+                                                       amount_msat, *payment_hash, cltv_expiry, source.clone(), onion_routing_packet.clone(),
+                                                       false, skimmed_fee_msat, blinding_point, fee_estimator, logger
+                                               ) {
                                                        Ok(_) => update_add_count += 1,
                                                        Err(e) => {
                                                                match e {
@@ -3492,6 +3497,7 @@ impl<SP: Deref> Channel<SP> where
                                ).map_err(|_| ChannelError::Close("Failed to validate revocation from peer".to_owned()))?;
                        },
                        // TODO (taproot|arik)
+                       #[cfg(taproot)]
                        _ => todo!()
                };
 
@@ -4073,6 +4079,7 @@ impl<SP: Deref> Channel<SP> where
                                        cltv_expiry: htlc.cltv_expiry,
                                        onion_routing_packet: (**onion_packet).clone(),
                                        skimmed_fee_msat: htlc.skimmed_fee_msat,
+                                       blinding_point: htlc.blinding_point,
                                });
                        }
                }
@@ -4173,6 +4180,7 @@ impl<SP: Deref> Channel<SP> where
                        return Err(ChannelError::Close("Peer sent an invalid channel_reestablish to force close in a non-standard way".to_owned()));
                }
 
+               let our_commitment_transaction = INITIAL_COMMITMENT_NUMBER - self.context.cur_holder_commitment_transaction_number - 1;
                if msg.next_remote_commitment_number > 0 {
                        let expected_point = self.context.holder_signer.as_ref().get_per_commitment_point(INITIAL_COMMITMENT_NUMBER - msg.next_remote_commitment_number + 1, &self.context.secp_ctx);
                        let given_secret = SecretKey::from_slice(&msg.your_last_per_commitment_secret)
@@ -4180,7 +4188,7 @@ impl<SP: Deref> Channel<SP> where
                        if expected_point != PublicKey::from_secret_key(&self.context.secp_ctx, &given_secret) {
                                return Err(ChannelError::Close("Peer sent a garbage channel_reestablish with secret key not matching the commitment height provided".to_owned()));
                        }
-                       if msg.next_remote_commitment_number > INITIAL_COMMITMENT_NUMBER - self.context.cur_holder_commitment_transaction_number {
+                       if msg.next_remote_commitment_number > our_commitment_transaction {
                                macro_rules! log_and_panic {
                                        ($err_msg: expr) => {
                                                log_error!(logger, $err_msg, &self.context.channel_id, log_pubkey!(self.context.counterparty_node_id));
@@ -4200,11 +4208,12 @@ impl<SP: Deref> Channel<SP> where
 
                // Before we change the state of the channel, we check if the peer is sending a very old
                // commitment transaction number, if yes we send a warning message.
-               let our_commitment_transaction = INITIAL_COMMITMENT_NUMBER - self.context.cur_holder_commitment_transaction_number - 1;
-               if  msg.next_remote_commitment_number + 1 < our_commitment_transaction {
-                       return Err(
-                               ChannelError::Warn(format!("Peer attempted to reestablish channel with a very old local commitment transaction: {} (received) vs {} (expected)", msg.next_remote_commitment_number, our_commitment_transaction))
-                       );
+               if msg.next_remote_commitment_number + 1 < our_commitment_transaction {
+                       return Err(ChannelError::Warn(format!(
+                               "Peer attempted to reestablish channel with a very old local commitment transaction: {} (received) vs {} (expected)",
+                               msg.next_remote_commitment_number,
+                               our_commitment_transaction
+                       )));
                }
 
                // Go ahead and unmark PeerDisconnected as various calls we may make check for it (and all
@@ -4246,11 +4255,11 @@ impl<SP: Deref> Channel<SP> where
                        });
                }
 
-               let required_revoke = if msg.next_remote_commitment_number + 1 == INITIAL_COMMITMENT_NUMBER - self.context.cur_holder_commitment_transaction_number {
+               let required_revoke = if msg.next_remote_commitment_number == our_commitment_transaction {
                        // Remote isn't waiting on any RevokeAndACK from us!
                        // Note that if we need to repeat our ChannelReady we'll do that in the next if block.
                        None
-               } else if msg.next_remote_commitment_number + 1 == (INITIAL_COMMITMENT_NUMBER - 1) - self.context.cur_holder_commitment_transaction_number {
+               } else if msg.next_remote_commitment_number + 1 == our_commitment_transaction {
                        if self.context.channel_state & (ChannelState::MonitorUpdateInProgress as u32) != 0 {
                                self.context.monitor_pending_revoke_and_ack = true;
                                None
@@ -4258,7 +4267,12 @@ impl<SP: Deref> Channel<SP> where
                                Some(self.get_last_revoke_and_ack())
                        }
                } else {
-                       return Err(ChannelError::Close("Peer attempted to reestablish channel with a very old local commitment transaction".to_owned()));
+                       debug_assert!(false, "All values should have been handled in the four cases above");
+                       return Err(ChannelError::Close(format!(
+                               "Peer attempted to reestablish channel expecting a future local commitment transaction: {} (received) vs {} (expected)",
+                               msg.next_remote_commitment_number,
+                               our_commitment_transaction
+                       )));
                };
 
                // We increment cur_counterparty_commitment_transaction_number only upon receipt of
@@ -4316,8 +4330,18 @@ impl<SP: Deref> Channel<SP> where
                                        order: self.context.resend_order.clone(),
                                })
                        }
+               } else if msg.next_local_commitment_number < next_counterparty_commitment_number {
+                       Err(ChannelError::Close(format!(
+                               "Peer attempted to reestablish channel with a very old remote commitment transaction: {} (received) vs {} (expected)",
+                               msg.next_local_commitment_number,
+                               next_counterparty_commitment_number,
+                       )))
                } else {
-                       Err(ChannelError::Close("Peer attempted to reestablish channel with a very old remote commitment transaction".to_owned()))
+                       Err(ChannelError::Close(format!(
+                               "Peer attempted to reestablish channel with a future remote commitment transaction: {} (received) vs {} (expected)",
+                               msg.next_local_commitment_number,
+                               next_counterparty_commitment_number,
+                       )))
                }
        }
 
@@ -4447,6 +4471,7 @@ impl<SP: Deref> Channel<SP> where
                                }), None, None))
                        },
                        // TODO (taproot|arik)
+                       #[cfg(taproot)]
                        _ => todo!()
                }
        }
@@ -4700,6 +4725,7 @@ impl<SP: Deref> Channel<SP> where
                                                }), signed_tx, shutdown_result))
                                        },
                                        // TODO (taproot|arik)
+                                       #[cfg(taproot)]
                                        _ => todo!()
                                }
                        }
@@ -5333,6 +5359,7 @@ impl<SP: Deref> Channel<SP> where
                                })
                        },
                        // TODO (taproot|arik)
+                       #[cfg(taproot)]
                        _ => todo!()
                }
        }
@@ -5362,6 +5389,7 @@ impl<SP: Deref> Channel<SP> where
                                        })
                                },
                                // TODO (taproot|arik)
+                               #[cfg(taproot)]
                                _ => todo!()
                        }
                } else {
@@ -5480,13 +5508,13 @@ impl<SP: Deref> Channel<SP> where
        pub fn queue_add_htlc<F: Deref, L: Deref>(
                &mut self, amount_msat: u64, payment_hash: PaymentHash, cltv_expiry: u32, source: HTLCSource,
                onion_routing_packet: msgs::OnionPacket, skimmed_fee_msat: Option<u64>,
-               fee_estimator: &LowerBoundedFeeEstimator<F>, logger: &L
+               blinding_point: Option<PublicKey>, fee_estimator: &LowerBoundedFeeEstimator<F>, logger: &L
        ) -> Result<(), ChannelError>
        where F::Target: FeeEstimator, L::Target: Logger
        {
                self
                        .send_htlc(amount_msat, payment_hash, cltv_expiry, source, onion_routing_packet, true,
-                               skimmed_fee_msat, fee_estimator, logger)
+                               skimmed_fee_msat, blinding_point, fee_estimator, logger)
                        .map(|msg_opt| assert!(msg_opt.is_none(), "We forced holding cell?"))
                        .map_err(|err| {
                                if let ChannelError::Ignore(_) = err { /* fine */ }
@@ -5514,7 +5542,8 @@ impl<SP: Deref> Channel<SP> where
        fn send_htlc<F: Deref, L: Deref>(
                &mut self, amount_msat: u64, payment_hash: PaymentHash, cltv_expiry: u32, source: HTLCSource,
                onion_routing_packet: msgs::OnionPacket, mut force_holding_cell: bool,
-               skimmed_fee_msat: Option<u64>, fee_estimator: &LowerBoundedFeeEstimator<F>, logger: &L
+               skimmed_fee_msat: Option<u64>, blinding_point: Option<PublicKey>,
+               fee_estimator: &LowerBoundedFeeEstimator<F>, logger: &L
        ) -> Result<Option<msgs::UpdateAddHTLC>, ChannelError>
        where F::Target: FeeEstimator, L::Target: Logger
        {
@@ -5571,6 +5600,7 @@ impl<SP: Deref> Channel<SP> where
                                source,
                                onion_routing_packet,
                                skimmed_fee_msat,
+                               blinding_point,
                        });
                        return Ok(None);
                }
@@ -5582,6 +5612,7 @@ impl<SP: Deref> Channel<SP> where
                        cltv_expiry,
                        state: OutboundHTLCState::LocalAnnounced(Box::new(onion_routing_packet.clone())),
                        source,
+                       blinding_point,
                        skimmed_fee_msat,
                });
 
@@ -5593,6 +5624,7 @@ impl<SP: Deref> Channel<SP> where
                        cltv_expiry,
                        onion_routing_packet,
                        skimmed_fee_msat,
+                       blinding_point,
                };
                self.context.next_holder_htlc_id += 1;
 
@@ -5737,6 +5769,7 @@ impl<SP: Deref> Channel<SP> where
                                }, (counterparty_commitment_txid, commitment_stats.htlcs_included)))
                        },
                        // TODO (taproot|arik)
+                       #[cfg(taproot)]
                        _ => todo!()
                }
        }
@@ -5754,7 +5787,7 @@ impl<SP: Deref> Channel<SP> where
        where F::Target: FeeEstimator, L::Target: Logger
        {
                let send_res = self.send_htlc(amount_msat, payment_hash, cltv_expiry, source,
-                       onion_routing_packet, false, skimmed_fee_msat, fee_estimator, logger);
+                       onion_routing_packet, false, skimmed_fee_msat, None, fee_estimator, logger);
                if let Err(e) = &send_res { if let ChannelError::Ignore(_) = e {} else { debug_assert!(false, "Sending cannot trigger channel failure"); } }
                match send_res? {
                        Some(_) => {
@@ -7064,6 +7097,7 @@ impl<SP: Deref> Writeable for Channel<SP> where SP::Target: SignerProvider {
 
                let mut preimages: Vec<&Option<PaymentPreimage>> = vec![];
                let mut pending_outbound_skimmed_fees: Vec<Option<u64>> = Vec::new();
+               let mut pending_outbound_blinding_points: Vec<Option<PublicKey>> = Vec::new();
 
                (self.context.pending_outbound_htlcs.len() as u64).write(writer)?;
                for (idx, htlc) in self.context.pending_outbound_htlcs.iter().enumerate() {
@@ -7110,15 +7144,17 @@ impl<SP: Deref> Writeable for Channel<SP> where SP::Target: SignerProvider {
                        } else if !pending_outbound_skimmed_fees.is_empty() {
                                pending_outbound_skimmed_fees.push(None);
                        }
+                       pending_outbound_blinding_points.push(htlc.blinding_point);
                }
 
                let mut holding_cell_skimmed_fees: Vec<Option<u64>> = Vec::new();
+               let mut holding_cell_blinding_points: Vec<Option<PublicKey>> = Vec::new();
                (self.context.holding_cell_htlc_updates.len() as u64).write(writer)?;
                for (idx, update) in self.context.holding_cell_htlc_updates.iter().enumerate() {
                        match update {
                                &HTLCUpdateAwaitingACK::AddHTLC {
                                        ref amount_msat, ref cltv_expiry, ref payment_hash, ref source, ref onion_routing_packet,
-                                       skimmed_fee_msat,
+                                       blinding_point, skimmed_fee_msat,
                                } => {
                                        0u8.write(writer)?;
                                        amount_msat.write(writer)?;
@@ -7133,6 +7169,8 @@ impl<SP: Deref> Writeable for Channel<SP> where SP::Target: SignerProvider {
                                                }
                                                holding_cell_skimmed_fees.push(Some(skimmed_fee));
                                        } else if !holding_cell_skimmed_fees.is_empty() { holding_cell_skimmed_fees.push(None); }
+
+                                       holding_cell_blinding_points.push(blinding_point);
                                },
                                &HTLCUpdateAwaitingACK::ClaimHTLC { ref payment_preimage, ref htlc_id } => {
                                        1u8.write(writer)?;
@@ -7302,6 +7340,8 @@ impl<SP: Deref> Writeable for Channel<SP> where SP::Target: SignerProvider {
                        (35, pending_outbound_skimmed_fees, optional_vec),
                        (37, holding_cell_skimmed_fees, optional_vec),
                        (38, self.context.is_batch_funding, option),
+                       (39, pending_outbound_blinding_points, optional_vec),
+                       (41, holding_cell_blinding_points, optional_vec),
                });
 
                Ok(())
@@ -7413,6 +7453,7 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
                                        _ => return Err(DecodeError::InvalidValue),
                                },
                                skimmed_fee_msat: None,
+                               blinding_point: None,
                        });
                }
 
@@ -7427,6 +7468,7 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
                                        source: Readable::read(reader)?,
                                        onion_routing_packet: Readable::read(reader)?,
                                        skimmed_fee_msat: None,
+                                       blinding_point: None,
                                },
                                1 => HTLCUpdateAwaitingACK::ClaimHTLC {
                                        payment_preimage: Readable::read(reader)?,
@@ -7587,6 +7629,9 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
 
                let mut is_batch_funding: Option<()> = None;
 
+               let mut pending_outbound_blinding_points_opt: Option<Vec<Option<PublicKey>>> = None;
+               let mut holding_cell_blinding_points_opt: Option<Vec<Option<PublicKey>>> = None;
+
                read_tlv_fields!(reader, {
                        (0, announcement_sigs, option),
                        (1, minimum_depth, option),
@@ -7613,6 +7658,8 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
                        (35, pending_outbound_skimmed_fees_opt, optional_vec),
                        (37, holding_cell_skimmed_fees_opt, optional_vec),
                        (38, is_batch_funding, option),
+                       (39, pending_outbound_blinding_points_opt, optional_vec),
+                       (41, holding_cell_blinding_points_opt, optional_vec),
                });
 
                let (channel_keys_id, holder_signer) = if let Some(channel_keys_id) = channel_keys_id {
@@ -7689,6 +7736,24 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
                        // We expect all skimmed fees to be consumed above
                        if iter.next().is_some() { return Err(DecodeError::InvalidValue) }
                }
+               if let Some(blinding_pts) = pending_outbound_blinding_points_opt {
+                       let mut iter = blinding_pts.into_iter();
+                       for htlc in pending_outbound_htlcs.iter_mut() {
+                               htlc.blinding_point = iter.next().ok_or(DecodeError::InvalidValue)?;
+                       }
+                       // We expect all blinding points to be consumed above
+                       if iter.next().is_some() { return Err(DecodeError::InvalidValue) }
+               }
+               if let Some(blinding_pts) = holding_cell_blinding_points_opt {
+                       let mut iter = blinding_pts.into_iter();
+                       for htlc in holding_cell_htlc_updates.iter_mut() {
+                               if let HTLCUpdateAwaitingACK::AddHTLC { ref mut blinding_point, .. } = htlc {
+                                       *blinding_point = iter.next().ok_or(DecodeError::InvalidValue)?;
+                               }
+                       }
+                       // We expect all blinding points to be consumed above
+                       if iter.next().is_some() { return Err(DecodeError::InvalidValue) }
+               }
 
                Ok(Channel {
                        context: ChannelContext {
@@ -8028,6 +8093,7 @@ use crate::ln::channelmanager::{self, HTLCSource, PaymentId};
                                payment_id: PaymentId([42; 32]),
                        },
                        skimmed_fee_msat: None,
+                       blinding_point: None,
                });
 
                // Make sure when Node A calculates their local commitment transaction, none of the HTLCs pass
@@ -8602,6 +8668,7 @@ use crate::ln::channelmanager::{self, HTLCSource, PaymentId};
                                state: OutboundHTLCState::Committed,
                                source: HTLCSource::dummy(),
                                skimmed_fee_msat: None,
+                               blinding_point: None,
                        };
                        out.payment_hash.0 = Sha256::hash(&<Vec<u8>>::from_hex("0202020202020202020202020202020202020202020202020202020202020202").unwrap()).to_byte_array();
                        out
@@ -8615,6 +8682,7 @@ use crate::ln::channelmanager::{self, HTLCSource, PaymentId};
                                state: OutboundHTLCState::Committed,
                                source: HTLCSource::dummy(),
                                skimmed_fee_msat: None,
+                               blinding_point: None,
                        };
                        out.payment_hash.0 = Sha256::hash(&<Vec<u8>>::from_hex("0303030303030303030303030303030303030303030303030303030303030303").unwrap()).to_byte_array();
                        out
@@ -9026,6 +9094,7 @@ use crate::ln::channelmanager::{self, HTLCSource, PaymentId};
                                state: OutboundHTLCState::Committed,
                                source: HTLCSource::dummy(),
                                skimmed_fee_msat: None,
+                               blinding_point: None,
                        };
                        out.payment_hash.0 = Sha256::hash(&<Vec<u8>>::from_hex("0505050505050505050505050505050505050505050505050505050505050505").unwrap()).to_byte_array();
                        out
@@ -9039,6 +9108,7 @@ use crate::ln::channelmanager::{self, HTLCSource, PaymentId};
                                state: OutboundHTLCState::Committed,
                                source: HTLCSource::dummy(),
                                skimmed_fee_msat: None,
+                               blinding_point: None,
                        };
                        out.payment_hash.0 = Sha256::hash(&<Vec<u8>>::from_hex("0505050505050505050505050505050505050505050505050505050505050505").unwrap()).to_byte_array();
                        out
@@ -9105,7 +9175,7 @@ use crate::ln::channelmanager::{self, HTLCSource, PaymentId};
                assert_eq!(chan_utils::build_commitment_secret(&seed, 1),
                           <Vec<u8>>::from_hex("915c75942a26bb3a433a8ce2cb0427c29ec6c1775cfc78328b57f6ba7bfeaa9c").unwrap()[..]);
        }
-       
+
        #[test]
        fn test_key_derivation() {
                // Test vectors from BOLT 3 Appendix E: