Add intermediate `ConstructedTransaction`
authorDuncan Dean <git@dunxen.dev>
Fri, 12 Apr 2024 13:37:17 +0000 (15:37 +0200)
committerDuncan Dean <git@dunxen.dev>
Tue, 23 Apr 2024 09:33:05 +0000 (11:33 +0200)
lightning/src/ln/interactivetxs.rs

index 7d2fb1983641dc971d3debe141d30a5b35181a5a..5f01bcd163ba3ba410e023334f424fca3762c8e8 100644 (file)
@@ -15,7 +15,8 @@ use bitcoin::blockdata::constants::WITNESS_SCALE_FACTOR;
 use bitcoin::consensus::Encodable;
 use bitcoin::policy::MAX_STANDARD_TX_WEIGHT;
 use bitcoin::{
-       absolute::LockTime as AbsoluteLockTime, OutPoint, ScriptBuf, Sequence, Transaction, TxIn, TxOut,
+       absolute::LockTime as AbsoluteLockTime, OutPoint, ScriptBuf, Sequence, Transaction, TxIn,
+       TxOut, Weight,
 };
 
 use crate::chain::chaininterface::fee_for_weight;
@@ -77,7 +78,7 @@ impl SerialIdExt for SerialId {
 }
 
 #[derive(Debug, Clone, PartialEq)]
-pub enum AbortReason {
+pub(crate) enum AbortReason {
        InvalidStateTransition,
        UnexpectedCounterpartyMessage,
        ReceivedTooManyTxAddInputs,
@@ -97,53 +98,183 @@ pub enum AbortReason {
        InvalidTx,
 }
 
-#[derive(Debug)]
-pub struct TxInputWithPrevOutput {
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub(crate) struct InteractiveTxInput {
+       serial_id: SerialId,
        input: TxIn,
        prev_output: TxOut,
 }
 
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub(crate) struct InteractiveTxOutput {
+       serial_id: SerialId,
+       tx_out: TxOut,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub(crate) struct ConstructedTransaction {
+       holder_is_initiator: bool,
+
+       inputs: Vec<InteractiveTxInput>,
+       outputs: Vec<InteractiveTxOutput>,
+
+       local_inputs_value_satoshis: u64,
+       local_outputs_value_satoshis: u64,
+
+       remote_inputs_value_satoshis: u64,
+       remote_outputs_value_satoshis: u64,
+
+       lock_time: AbsoluteLockTime,
+}
+
+impl ConstructedTransaction {
+       fn new(context: NegotiationContext) -> Self {
+               let local_inputs_value_satoshis = context
+                       .inputs
+                       .iter()
+                       .filter(|(serial_id, _)| {
+                               !is_serial_id_valid_for_counterparty(context.holder_is_initiator, serial_id)
+                       })
+                       .fold(0u64, |value, (_, input)| value.saturating_add(input.prev_output.value));
+
+               let local_outputs_value_satoshis = context
+                       .outputs
+                       .iter()
+                       .filter(|(serial_id, _)| {
+                               !is_serial_id_valid_for_counterparty(context.holder_is_initiator, serial_id)
+                       })
+                       .fold(0u64, |value, (_, output)| value.saturating_add(output.tx_out.value));
+
+               Self {
+                       holder_is_initiator: context.holder_is_initiator,
+
+                       local_inputs_value_satoshis,
+                       local_outputs_value_satoshis,
+
+                       remote_inputs_value_satoshis: context.remote_inputs_value(),
+                       remote_outputs_value_satoshis: context.remote_outputs_value(),
+
+                       inputs: context.inputs.into_values().collect(),
+                       outputs: context.outputs.into_values().collect(),
+
+                       lock_time: context.tx_locktime,
+               }
+       }
+
+       pub fn weight(&self) -> Weight {
+               let inputs_weight = self.inputs.iter().fold(
+                       Weight::from_wu(0),
+                       |weight, InteractiveTxInput { prev_output, .. }| {
+                               weight.checked_add(estimate_input_weight(prev_output)).unwrap_or(Weight::MAX)
+                       },
+               );
+               let outputs_weight = self.outputs.iter().fold(
+                       Weight::from_wu(0),
+                       |weight, InteractiveTxOutput { tx_out, .. }| {
+                               weight.checked_add(get_output_weight(&tx_out.script_pubkey)).unwrap_or(Weight::MAX)
+                       },
+               );
+               Weight::from_wu(TX_COMMON_FIELDS_WEIGHT)
+                       .checked_add(inputs_weight)
+                       .and_then(|weight| weight.checked_add(outputs_weight))
+                       .unwrap_or(Weight::MAX)
+       }
+
+       pub fn into_unsigned_tx(self) -> Transaction {
+               // Inputs and outputs must be sorted by serial_id
+               let ConstructedTransaction { mut inputs, mut outputs, .. } = self;
+
+               inputs.sort_unstable_by_key(|InteractiveTxInput { serial_id, .. }| *serial_id);
+               outputs.sort_unstable_by_key(|InteractiveTxOutput { serial_id, .. }| *serial_id);
+
+               let input: Vec<TxIn> =
+                       inputs.into_iter().map(|InteractiveTxInput { input, .. }| input).collect();
+               let output: Vec<TxOut> =
+                       outputs.into_iter().map(|InteractiveTxOutput { tx_out, .. }| tx_out).collect();
+
+               Transaction { version: 2, lock_time: self.lock_time, input, output }
+       }
+}
+
 #[derive(Debug)]
 struct NegotiationContext {
        holder_is_initiator: bool,
        received_tx_add_input_count: u16,
        received_tx_add_output_count: u16,
-       inputs: HashMap<SerialId, TxInputWithPrevOutput>,
+       inputs: HashMap<SerialId, InteractiveTxInput>,
        prevtx_outpoints: HashSet<OutPoint>,
-       outputs: HashMap<SerialId, TxOut>,
+       outputs: HashMap<SerialId, InteractiveTxOutput>,
        tx_locktime: AbsoluteLockTime,
        feerate_sat_per_kw: u32,
 }
 
-pub(crate) fn get_output_weight(script_pubkey: &ScriptBuf) -> u64 {
-       (8 /* value */ + script_pubkey.consensus_encode(&mut sink()).unwrap() as u64)
-               * WITNESS_SCALE_FACTOR as u64
+pub(crate) fn estimate_input_weight(prev_output: &TxOut) -> Weight {
+       Weight::from_wu(if prev_output.script_pubkey.is_v0_p2wpkh() {
+               P2WPKH_INPUT_WEIGHT_LOWER_BOUND
+       } else if prev_output.script_pubkey.is_v0_p2wsh() {
+               P2WSH_INPUT_WEIGHT_LOWER_BOUND
+       } else if prev_output.script_pubkey.is_v1_p2tr() {
+               P2TR_INPUT_WEIGHT_LOWER_BOUND
+       } else {
+               UNKNOWN_SEGWIT_VERSION_INPUT_WEIGHT_LOWER_BOUND
+       })
+}
+
+pub(crate) fn get_output_weight(script_pubkey: &ScriptBuf) -> Weight {
+       Weight::from_wu(
+               (8 /* value */ + script_pubkey.consensus_encode(&mut sink()).unwrap() as u64)
+                       * WITNESS_SCALE_FACTOR as u64,
+       )
+}
+
+fn is_serial_id_valid_for_counterparty(holder_is_initiator: bool, serial_id: &SerialId) -> bool {
+       // A received `SerialId`'s parity must match the role of the counterparty.
+       holder_is_initiator == serial_id.is_for_non_initiator()
 }
 
 impl NegotiationContext {
        fn is_serial_id_valid_for_counterparty(&self, serial_id: &SerialId) -> bool {
-               // A received `SerialId`'s parity must match the role of the counterparty.
-               self.holder_is_initiator == serial_id.is_for_non_initiator()
-       }
-
-       fn total_input_and_output_count(&self) -> usize {
-               self.inputs.len().saturating_add(self.outputs.len())
+               is_serial_id_valid_for_counterparty(self.holder_is_initiator, serial_id)
        }
 
-       fn counterparty_inputs_contributed(
-               &self,
-       ) -> impl Iterator<Item = &TxInputWithPrevOutput> + Clone {
+       fn remote_inputs_value(&self) -> u64 {
                self.inputs
                        .iter()
-                       .filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
-                       .map(|(_, input_with_prevout)| input_with_prevout)
+                       .filter(|(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
+                       .fold(0u64, |acc, (_, InteractiveTxInput { prev_output, .. })| {
+                               acc.saturating_add(prev_output.value)
+                       })
        }
 
-       fn counterparty_outputs_contributed(&self) -> impl Iterator<Item = &TxOut> + Clone {
+       fn remote_outputs_value(&self) -> u64 {
                self.outputs
                        .iter()
-                       .filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
-                       .map(|(_, output)| output)
+                       .filter(|(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
+                       .fold(0u64, |acc, (_, InteractiveTxOutput { tx_out, .. })| {
+                               acc.saturating_add(tx_out.value)
+                       })
+       }
+
+       fn remote_inputs_weight(&self) -> Weight {
+               Weight::from_wu(
+                       self.inputs
+                               .iter()
+                               .filter(|(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
+                               .fold(0u64, |weight, (_, InteractiveTxInput { prev_output, .. })| {
+                                       weight.saturating_add(estimate_input_weight(prev_output).to_wu())
+                               }),
+               )
+       }
+
+       fn remote_outputs_weight(&self) -> Weight {
+               Weight::from_wu(
+                       self.outputs
+                               .iter()
+                               .filter(|(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
+                               .fold(0u64, |weight, (_, InteractiveTxOutput { tx_out, .. })| {
+                                       weight.saturating_add(get_output_weight(&tx_out.script_pubkey).to_wu())
+                               }),
+               )
        }
 
        fn received_tx_add_input(&mut self, msg: &msgs::TxAddInput) -> Result<(), AbortReason> {
@@ -213,7 +344,8 @@ impl NegotiationContext {
                        },
                        hash_map::Entry::Vacant(entry) => {
                                let prev_outpoint = OutPoint { txid, vout: msg.prevtx_out };
-                               entry.insert(TxInputWithPrevOutput {
+                               entry.insert(InteractiveTxInput {
+                                       serial_id: msg.serial_id,
                                        input: TxIn {
                                                previous_output: prev_outpoint,
                                                sequence: Sequence(msg.sequence),
@@ -269,7 +401,7 @@ impl NegotiationContext {
                // bitcoin supply.
                let mut outputs_value: u64 = 0;
                for output in self.outputs.iter() {
-                       outputs_value = outputs_value.saturating_add(output.1.value);
+                       outputs_value = outputs_value.saturating_add(output.1.tx_out.value);
                }
                if outputs_value.saturating_add(msg.sats) > TOTAL_BITCOIN_SUPPLY_SATOSHIS {
                        // The receiving node:
@@ -306,7 +438,10 @@ impl NegotiationContext {
                                Err(AbortReason::DuplicateSerialId)
                        },
                        hash_map::Entry::Vacant(entry) => {
-                               entry.insert(TxOut { value: msg.sats, script_pubkey: msg.script.clone() });
+                               entry.insert(InteractiveTxOutput {
+                                       serial_id: msg.serial_id,
+                                       tx_out: TxOut { value: msg.sats, script_pubkey: msg.script.clone() },
+                               });
                                Ok(())
                        },
                }
@@ -340,13 +475,21 @@ impl NegotiationContext {
                        // We have added an input that already exists
                        return Err(AbortReason::PrevTxOutInvalid);
                }
-               self.inputs.insert(msg.serial_id, TxInputWithPrevOutput { input, prev_output });
+               self.inputs.insert(
+                       msg.serial_id,
+                       InteractiveTxInput { serial_id: msg.serial_id, input, prev_output },
+               );
                Ok(())
        }
 
        fn sent_tx_add_output(&mut self, msg: &msgs::TxAddOutput) -> Result<(), AbortReason> {
-               self.outputs
-                       .insert(msg.serial_id, TxOut { value: msg.sats, script_pubkey: msg.script.clone() });
+               self.outputs.insert(
+                       msg.serial_id,
+                       InteractiveTxOutput {
+                               serial_id: msg.serial_id,
+                               tx_out: TxOut { value: msg.sats, script_pubkey: msg.script.clone() },
+                       },
+               );
                Ok(())
        }
 
@@ -361,31 +504,12 @@ impl NegotiationContext {
        }
 
        fn check_counterparty_fees(
-               &self, counterparty_inputs_value: u64, counterparty_outputs_value: u64,
+               &self, counterparty_fees_contributed: u64,
        ) -> Result<(), AbortReason> {
-               let mut counterparty_weight_contributed: u64 = self
-                       .counterparty_outputs_contributed()
-                       .map(|output| get_output_weight(&output.script_pubkey))
-                       .sum();
-               // We don't know the counterparty's witnesses ahead of time obviously, so we use the lower bounds
-               // specified in BOLT 3.
-               let mut total_inputs_weight: u64 = 0;
-               for TxInputWithPrevOutput { prev_output, .. } in self.counterparty_inputs_contributed() {
-                       total_inputs_weight =
-                               total_inputs_weight.saturating_add(if prev_output.script_pubkey.is_v0_p2wpkh() {
-                                       P2WPKH_INPUT_WEIGHT_LOWER_BOUND
-                               } else if prev_output.script_pubkey.is_v0_p2wsh() {
-                                       P2WSH_INPUT_WEIGHT_LOWER_BOUND
-                               } else if prev_output.script_pubkey.is_v1_p2tr() {
-                                       P2TR_INPUT_WEIGHT_LOWER_BOUND
-                               } else {
-                                       UNKNOWN_SEGWIT_VERSION_INPUT_WEIGHT_LOWER_BOUND
-                               });
-               }
-               counterparty_weight_contributed =
-                       counterparty_weight_contributed.saturating_add(total_inputs_weight);
-               let counterparty_fees_contributed =
-                       counterparty_inputs_value.saturating_sub(counterparty_outputs_value);
+               let counterparty_weight_contributed = self
+                       .remote_inputs_weight()
+                       .to_wu()
+                       .saturating_add(self.remote_outputs_weight().to_wu());
                let mut required_counterparty_contribution_fee =
                        fee_for_weight(self.feerate_sat_per_kw, counterparty_weight_contributed);
                if !self.holder_is_initiator {
@@ -402,21 +526,14 @@ impl NegotiationContext {
                Ok(())
        }
 
-       fn build_transaction(self) -> Result<Transaction, AbortReason> {
+       fn validate_tx(self) -> Result<ConstructedTransaction, AbortReason> {
                // The receiving node:
                // MUST fail the negotiation if:
 
                // - the peer's total input satoshis is less than their outputs
-               let mut counterparty_inputs_value: u64 = 0;
-               let mut counterparty_outputs_value: u64 = 0;
-               for input in self.counterparty_inputs_contributed() {
-                       counterparty_inputs_value =
-                               counterparty_inputs_value.saturating_add(input.prev_output.value);
-               }
-               for output in self.counterparty_outputs_contributed() {
-                       counterparty_outputs_value = counterparty_outputs_value.saturating_add(output.value);
-               }
-               if counterparty_inputs_value < counterparty_outputs_value {
+               let remote_inputs_value = self.remote_inputs_value();
+               let remote_outputs_value = self.remote_outputs_value();
+               if remote_inputs_value < remote_outputs_value {
                        return Err(AbortReason::OutputsValueExceedsInputsValue);
                }
 
@@ -429,25 +546,15 @@ impl NegotiationContext {
                }
 
                // - the peer's paid feerate does not meet or exceed the agreed feerate (based on the minimum fee).
-               self.check_counterparty_fees(counterparty_inputs_value, counterparty_outputs_value)?;
+               self.check_counterparty_fees(remote_inputs_value.saturating_sub(remote_outputs_value))?;
 
-               // Inputs and outputs must be sorted by serial_id
-               let mut inputs = self.inputs.into_iter().collect::<Vec<_>>();
-               let mut outputs = self.outputs.into_iter().collect::<Vec<_>>();
-               inputs.sort_unstable_by_key(|(serial_id, _)| *serial_id);
-               outputs.sort_unstable_by_key(|(serial_id, _)| *serial_id);
+               let constructed_tx = ConstructedTransaction::new(self);
 
-               let tx_to_validate = Transaction {
-                       version: 2,
-                       lock_time: self.tx_locktime,
-                       input: inputs.into_iter().map(|(_, input)| input.input).collect(),
-                       output: outputs.into_iter().map(|(_, output)| output).collect(),
-               };
-               if tx_to_validate.weight().to_wu() > MAX_STANDARD_TX_WEIGHT as u64 {
+               if constructed_tx.weight().to_wu() > MAX_STANDARD_TX_WEIGHT as u64 {
                        return Err(AbortReason::TransactionTooLarge);
                }
 
-               Ok(tx_to_validate)
+               Ok(constructed_tx)
        }
 }
 
@@ -535,7 +642,7 @@ define_state!(
        ReceivedTxComplete,
        "We have received a `tx_complete` message and the counterparty is awaiting ours."
 );
-define_state!(NegotiationComplete, Transaction, "We have exchanged consecutive `tx_complete` messages with the counterparty and the transaction negotiation is complete.");
+define_state!(NegotiationComplete, ConstructedTransaction, "We have exchanged consecutive `tx_complete` messages with the counterparty and the transaction negotiation is complete.");
 define_state!(
        NegotiationAborted,
        AbortReason,
@@ -577,7 +684,7 @@ macro_rules! define_state_transitions {
                impl StateTransition<NegotiationComplete, &msgs::TxComplete> for $tx_complete_state {
                        fn transition(self, _data: &msgs::TxComplete) -> StateTransitionResult<NegotiationComplete> {
                                let context = self.into_negotiation_context();
-                               let tx = context.build_transaction()?;
+                               let tx = context.validate_tx()?;
                                Ok(NegotiationComplete(tx))
                        }
                }
@@ -715,14 +822,14 @@ impl StateMachine {
        ]);
 }
 
-pub struct InteractiveTxConstructor {
+pub(crate) struct InteractiveTxConstructor {
        state_machine: StateMachine,
        channel_id: ChannelId,
        inputs_to_contribute: Vec<(SerialId, TxIn, TransactionU16LenLimited)>,
        outputs_to_contribute: Vec<(SerialId, TxOut)>,
 }
 
-pub enum InteractiveTxMessageSend {
+pub(crate) enum InteractiveTxMessageSend {
        TxAddInput(msgs::TxAddInput),
        TxAddOutput(msgs::TxAddOutput),
        TxComplete(msgs::TxComplete),
@@ -754,10 +861,10 @@ where
        serial_id
 }
 
-pub enum HandleTxCompleteValue {
+pub(crate) enum HandleTxCompleteValue {
        SendTxMessage(InteractiveTxMessageSend),
-       SendTxComplete(InteractiveTxMessageSend, Transaction),
-       NegotiationComplete(Transaction),
+       SendTxComplete(InteractiveTxMessageSend, ConstructedTransaction),
+       NegotiationComplete(ConstructedTransaction),
 }
 
 impl InteractiveTxConstructor {
@@ -1107,7 +1214,7 @@ mod tests {
                }
                assert!(message_send_a.is_none());
                assert!(message_send_b.is_none());
-               assert_eq!(final_tx_a, final_tx_b);
+               assert_eq!(final_tx_a.unwrap().into_unsigned_tx(), final_tx_b.unwrap().into_unsigned_tx());
                assert!(session.expect_error.is_none(), "Test: {}", session.description);
        }
 
@@ -1280,7 +1387,7 @@ mod tests {
                let p2wpkh_fee = fee_for_weight(TEST_FEERATE_SATS_PER_KW, P2WPKH_INPUT_WEIGHT_LOWER_BOUND);
                let outputs_fee = fee_for_weight(
                        TEST_FEERATE_SATS_PER_KW,
-                       get_output_weight(&generate_p2wpkh_script_pubkey()),
+                       get_output_weight(&generate_p2wpkh_script_pubkey()).to_wu(),
                );
                let tx_common_fields_fee =
                        fee_for_weight(TEST_FEERATE_SATS_PER_KW, TX_COMMON_FIELDS_WEIGHT);