7d2fb1983641dc971d3debe141d30a5b35181a5a
[rust-lightning] / lightning / src / ln / interactivetxs.rs
1 // This file is Copyright its original authors, visible in version control
2 // history.
3 //
4 // This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
5 // or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
7 // You may not use this file except in accordance with one or both of these
8 // licenses.
9
10 use crate::io_extras::sink;
11 use crate::prelude::*;
12 use core::ops::Deref;
13
14 use bitcoin::blockdata::constants::WITNESS_SCALE_FACTOR;
15 use bitcoin::consensus::Encodable;
16 use bitcoin::policy::MAX_STANDARD_TX_WEIGHT;
17 use bitcoin::{
18         absolute::LockTime as AbsoluteLockTime, OutPoint, ScriptBuf, Sequence, Transaction, TxIn, TxOut,
19 };
20
21 use crate::chain::chaininterface::fee_for_weight;
22 use crate::events::bump_transaction::{BASE_INPUT_WEIGHT, EMPTY_SCRIPT_SIG_WEIGHT};
23 use crate::ln::channel::TOTAL_BITCOIN_SUPPLY_SATOSHIS;
24 use crate::ln::msgs::SerialId;
25 use crate::ln::{msgs, ChannelId};
26 use crate::sign::{EntropySource, P2TR_KEY_PATH_WITNESS_WEIGHT, P2WPKH_WITNESS_WEIGHT};
27 use crate::util::ser::TransactionU16LenLimited;
28
29 /// The number of received `tx_add_input` messages during a negotiation at which point the
30 /// negotiation MUST be failed.
31 const MAX_RECEIVED_TX_ADD_INPUT_COUNT: u16 = 4096;
32
33 /// The number of received `tx_add_output` messages during a negotiation at which point the
34 /// negotiation MUST be failed.
35 const MAX_RECEIVED_TX_ADD_OUTPUT_COUNT: u16 = 4096;
36
37 /// The number of inputs or outputs that the state machine can have, before it MUST fail the
38 /// negotiation.
39 const MAX_INPUTS_OUTPUTS_COUNT: usize = 252;
40
41 /// The total weight of the common fields whose fee is paid by the initiator of the interactive
42 /// transaction construction protocol.
43 const TX_COMMON_FIELDS_WEIGHT: u64 = (4 /* version */ + 4 /* locktime */ + 1 /* input count */ +
44         1 /* output count */) * WITNESS_SCALE_FACTOR as u64 + 2 /* segwit marker + flag */;
45
46 // BOLT 3 - Lower bounds for input weights
47
48 /// Lower bound for P2WPKH input weight
49 pub(crate) const P2WPKH_INPUT_WEIGHT_LOWER_BOUND: u64 =
50         BASE_INPUT_WEIGHT + EMPTY_SCRIPT_SIG_WEIGHT + P2WPKH_WITNESS_WEIGHT;
51
52 /// Lower bound for P2WSH input weight is chosen as same as P2WPKH input weight in BOLT 3
53 pub(crate) const P2WSH_INPUT_WEIGHT_LOWER_BOUND: u64 = P2WPKH_INPUT_WEIGHT_LOWER_BOUND;
54
55 /// Lower bound for P2TR input weight is chosen as the key spend path.
56 /// Not specified in BOLT 3, but a reasonable lower bound.
57 pub(crate) const P2TR_INPUT_WEIGHT_LOWER_BOUND: u64 =
58         BASE_INPUT_WEIGHT + EMPTY_SCRIPT_SIG_WEIGHT + P2TR_KEY_PATH_WITNESS_WEIGHT;
59
60 /// Lower bound for unknown segwit version input weight is chosen the same as P2WPKH in BOLT 3
61 pub(crate) const UNKNOWN_SEGWIT_VERSION_INPUT_WEIGHT_LOWER_BOUND: u64 =
62         P2WPKH_INPUT_WEIGHT_LOWER_BOUND;
63
64 trait SerialIdExt {
65         fn is_for_initiator(&self) -> bool;
66         fn is_for_non_initiator(&self) -> bool;
67 }
68
69 impl SerialIdExt for SerialId {
70         fn is_for_initiator(&self) -> bool {
71                 self % 2 == 0
72         }
73
74         fn is_for_non_initiator(&self) -> bool {
75                 !self.is_for_initiator()
76         }
77 }
78
79 #[derive(Debug, Clone, PartialEq)]
80 pub enum AbortReason {
81         InvalidStateTransition,
82         UnexpectedCounterpartyMessage,
83         ReceivedTooManyTxAddInputs,
84         ReceivedTooManyTxAddOutputs,
85         IncorrectInputSequenceValue,
86         IncorrectSerialIdParity,
87         SerialIdUnknown,
88         DuplicateSerialId,
89         PrevTxOutInvalid,
90         ExceededMaximumSatsAllowed,
91         ExceededNumberOfInputsOrOutputs,
92         TransactionTooLarge,
93         BelowDustLimit,
94         InvalidOutputScript,
95         InsufficientFees,
96         OutputsValueExceedsInputsValue,
97         InvalidTx,
98 }
99
100 #[derive(Debug)]
101 pub struct TxInputWithPrevOutput {
102         input: TxIn,
103         prev_output: TxOut,
104 }
105
106 #[derive(Debug)]
107 struct NegotiationContext {
108         holder_is_initiator: bool,
109         received_tx_add_input_count: u16,
110         received_tx_add_output_count: u16,
111         inputs: HashMap<SerialId, TxInputWithPrevOutput>,
112         prevtx_outpoints: HashSet<OutPoint>,
113         outputs: HashMap<SerialId, TxOut>,
114         tx_locktime: AbsoluteLockTime,
115         feerate_sat_per_kw: u32,
116 }
117
118 pub(crate) fn get_output_weight(script_pubkey: &ScriptBuf) -> u64 {
119         (8 /* value */ + script_pubkey.consensus_encode(&mut sink()).unwrap() as u64)
120                 * WITNESS_SCALE_FACTOR as u64
121 }
122
123 impl NegotiationContext {
124         fn is_serial_id_valid_for_counterparty(&self, serial_id: &SerialId) -> bool {
125                 // A received `SerialId`'s parity must match the role of the counterparty.
126                 self.holder_is_initiator == serial_id.is_for_non_initiator()
127         }
128
129         fn total_input_and_output_count(&self) -> usize {
130                 self.inputs.len().saturating_add(self.outputs.len())
131         }
132
133         fn counterparty_inputs_contributed(
134                 &self,
135         ) -> impl Iterator<Item = &TxInputWithPrevOutput> + Clone {
136                 self.inputs
137                         .iter()
138                         .filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
139                         .map(|(_, input_with_prevout)| input_with_prevout)
140         }
141
142         fn counterparty_outputs_contributed(&self) -> impl Iterator<Item = &TxOut> + Clone {
143                 self.outputs
144                         .iter()
145                         .filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
146                         .map(|(_, output)| output)
147         }
148
149         fn received_tx_add_input(&mut self, msg: &msgs::TxAddInput) -> Result<(), AbortReason> {
150                 // The interactive-txs spec calls for us to fail negotiation if the `prevtx` we receive is
151                 // invalid. However, we would not need to account for this explicit negotiation failure
152                 // mode here since `PeerManager` would already disconnect the peer if the `prevtx` is
153                 // invalid; implicitly ending the negotiation.
154
155                 if !self.is_serial_id_valid_for_counterparty(&msg.serial_id) {
156                         // The receiving node:
157                         //  - MUST fail the negotiation if:
158                         //     - the `serial_id` has the wrong parity
159                         return Err(AbortReason::IncorrectSerialIdParity);
160                 }
161
162                 self.received_tx_add_input_count += 1;
163                 if self.received_tx_add_input_count > MAX_RECEIVED_TX_ADD_INPUT_COUNT {
164                         // The receiving node:
165                         //  - MUST fail the negotiation if:
166                         //     - if has received 4096 `tx_add_input` messages during this negotiation
167                         return Err(AbortReason::ReceivedTooManyTxAddInputs);
168                 }
169
170                 if msg.sequence >= 0xFFFFFFFE {
171                         // The receiving node:
172                         //  - MUST fail the negotiation if:
173                         //    - `sequence` is set to `0xFFFFFFFE` or `0xFFFFFFFF`
174                         return Err(AbortReason::IncorrectInputSequenceValue);
175                 }
176
177                 let transaction = msg.prevtx.as_transaction();
178                 let txid = transaction.txid();
179
180                 if let Some(tx_out) = transaction.output.get(msg.prevtx_out as usize) {
181                         if !tx_out.script_pubkey.is_witness_program() {
182                                 // The receiving node:
183                                 //  - MUST fail the negotiation if:
184                                 //     - the `scriptPubKey` is not a witness program
185                                 return Err(AbortReason::PrevTxOutInvalid);
186                         }
187
188                         if !self.prevtx_outpoints.insert(OutPoint { txid, vout: msg.prevtx_out }) {
189                                 // The receiving node:
190                                 //  - MUST fail the negotiation if:
191                                 //     - the `prevtx` and `prevtx_vout` are identical to a previously added
192                                 //       (and not removed) input's
193                                 return Err(AbortReason::PrevTxOutInvalid);
194                         }
195                 } else {
196                         // The receiving node:
197                         //  - MUST fail the negotiation if:
198                         //     - `prevtx_vout` is greater or equal to the number of outputs on `prevtx`
199                         return Err(AbortReason::PrevTxOutInvalid);
200                 }
201
202                 let prev_out = if let Some(prev_out) = transaction.output.get(msg.prevtx_out as usize) {
203                         prev_out.clone()
204                 } else {
205                         return Err(AbortReason::PrevTxOutInvalid);
206                 };
207                 match self.inputs.entry(msg.serial_id) {
208                         hash_map::Entry::Occupied(_) => {
209                                 // The receiving node:
210                                 //  - MUST fail the negotiation if:
211                                 //    - the `serial_id` is already included in the transaction
212                                 Err(AbortReason::DuplicateSerialId)
213                         },
214                         hash_map::Entry::Vacant(entry) => {
215                                 let prev_outpoint = OutPoint { txid, vout: msg.prevtx_out };
216                                 entry.insert(TxInputWithPrevOutput {
217                                         input: TxIn {
218                                                 previous_output: prev_outpoint,
219                                                 sequence: Sequence(msg.sequence),
220                                                 ..Default::default()
221                                         },
222                                         prev_output: prev_out,
223                                 });
224                                 self.prevtx_outpoints.insert(prev_outpoint);
225                                 Ok(())
226                         },
227                 }
228         }
229
230         fn received_tx_remove_input(&mut self, msg: &msgs::TxRemoveInput) -> Result<(), AbortReason> {
231                 if !self.is_serial_id_valid_for_counterparty(&msg.serial_id) {
232                         return Err(AbortReason::IncorrectSerialIdParity);
233                 }
234
235                 self.inputs
236                         .remove(&msg.serial_id)
237                         // The receiving node:
238                         //  - MUST fail the negotiation if:
239                         //    - the input or output identified by the `serial_id` was not added by the sender
240                         //    - the `serial_id` does not correspond to a currently added input
241                         .ok_or(AbortReason::SerialIdUnknown)
242                         .map(|_| ())
243         }
244
245         fn received_tx_add_output(&mut self, msg: &msgs::TxAddOutput) -> Result<(), AbortReason> {
246                 // The receiving node:
247                 //  - MUST fail the negotiation if:
248                 //     - the serial_id has the wrong parity
249                 if !self.is_serial_id_valid_for_counterparty(&msg.serial_id) {
250                         return Err(AbortReason::IncorrectSerialIdParity);
251                 }
252
253                 self.received_tx_add_output_count += 1;
254                 if self.received_tx_add_output_count > MAX_RECEIVED_TX_ADD_OUTPUT_COUNT {
255                         // The receiving node:
256                         //  - MUST fail the negotiation if:
257                         //     - if has received 4096 `tx_add_output` messages during this negotiation
258                         return Err(AbortReason::ReceivedTooManyTxAddOutputs);
259                 }
260
261                 if msg.sats < msg.script.dust_value().to_sat() {
262                         // The receiving node:
263                         // - MUST fail the negotiation if:
264                         //              - the sats amount is less than the dust_limit
265                         return Err(AbortReason::BelowDustLimit);
266                 }
267
268                 // Check that adding this output would not cause the total output value to exceed the total
269                 // bitcoin supply.
270                 let mut outputs_value: u64 = 0;
271                 for output in self.outputs.iter() {
272                         outputs_value = outputs_value.saturating_add(output.1.value);
273                 }
274                 if outputs_value.saturating_add(msg.sats) > TOTAL_BITCOIN_SUPPLY_SATOSHIS {
275                         // The receiving node:
276                         // - MUST fail the negotiation if:
277                         //              - the sats amount is greater than 2,100,000,000,000,000 (TOTAL_BITCOIN_SUPPLY_SATOSHIS)
278                         return Err(AbortReason::ExceededMaximumSatsAllowed);
279                 }
280
281                 // The receiving node:
282                 //   - MUST accept P2WSH, P2WPKH, P2TR scripts
283                 //   - MAY fail the negotiation if script is non-standard
284                 //
285                 // We can actually be a bit looser than the above as only witness version 0 has special
286                 // length-based standardness constraints to match similar consensus rules. All witness scripts
287                 // with witness versions V1 and up are always considered standard. Yes, the scripts can be
288                 // anyone-can-spend-able, but if our counterparty wants to add an output like that then it's none
289                 // of our concern really Â¯\_(ツ)_/¯
290                 //
291                 // TODO: The last check would be simplified when https://github.com/rust-bitcoin/rust-bitcoin/commit/1656e1a09a1959230e20af90d20789a4a8f0a31b
292                 // hits the next release of rust-bitcoin.
293                 if !(msg.script.is_v0_p2wpkh()
294                         || msg.script.is_v0_p2wsh()
295                         || (msg.script.is_witness_program()
296                                 && msg.script.witness_version().map(|v| v.to_num() >= 1).unwrap_or(false)))
297                 {
298                         return Err(AbortReason::InvalidOutputScript);
299                 }
300
301                 match self.outputs.entry(msg.serial_id) {
302                         hash_map::Entry::Occupied(_) => {
303                                 // The receiving node:
304                                 //  - MUST fail the negotiation if:
305                                 //    - the `serial_id` is already included in the transaction
306                                 Err(AbortReason::DuplicateSerialId)
307                         },
308                         hash_map::Entry::Vacant(entry) => {
309                                 entry.insert(TxOut { value: msg.sats, script_pubkey: msg.script.clone() });
310                                 Ok(())
311                         },
312                 }
313         }
314
315         fn received_tx_remove_output(&mut self, msg: &msgs::TxRemoveOutput) -> Result<(), AbortReason> {
316                 if !self.is_serial_id_valid_for_counterparty(&msg.serial_id) {
317                         return Err(AbortReason::IncorrectSerialIdParity);
318                 }
319                 if self.outputs.remove(&msg.serial_id).is_some() {
320                         Ok(())
321                 } else {
322                         // The receiving node:
323                         //  - MUST fail the negotiation if:
324                         //    - the input or output identified by the `serial_id` was not added by the sender
325                         //    - the `serial_id` does not correspond to a currently added input
326                         Err(AbortReason::SerialIdUnknown)
327                 }
328         }
329
330         fn sent_tx_add_input(&mut self, msg: &msgs::TxAddInput) -> Result<(), AbortReason> {
331                 let tx = msg.prevtx.as_transaction();
332                 let input = TxIn {
333                         previous_output: OutPoint { txid: tx.txid(), vout: msg.prevtx_out },
334                         sequence: Sequence(msg.sequence),
335                         ..Default::default()
336                 };
337                 let prev_output =
338                         tx.output.get(msg.prevtx_out as usize).ok_or(AbortReason::PrevTxOutInvalid)?.clone();
339                 if !self.prevtx_outpoints.insert(input.previous_output) {
340                         // We have added an input that already exists
341                         return Err(AbortReason::PrevTxOutInvalid);
342                 }
343                 self.inputs.insert(msg.serial_id, TxInputWithPrevOutput { input, prev_output });
344                 Ok(())
345         }
346
347         fn sent_tx_add_output(&mut self, msg: &msgs::TxAddOutput) -> Result<(), AbortReason> {
348                 self.outputs
349                         .insert(msg.serial_id, TxOut { value: msg.sats, script_pubkey: msg.script.clone() });
350                 Ok(())
351         }
352
353         fn sent_tx_remove_input(&mut self, msg: &msgs::TxRemoveInput) -> Result<(), AbortReason> {
354                 self.inputs.remove(&msg.serial_id);
355                 Ok(())
356         }
357
358         fn sent_tx_remove_output(&mut self, msg: &msgs::TxRemoveOutput) -> Result<(), AbortReason> {
359                 self.outputs.remove(&msg.serial_id);
360                 Ok(())
361         }
362
363         fn check_counterparty_fees(
364                 &self, counterparty_inputs_value: u64, counterparty_outputs_value: u64,
365         ) -> Result<(), AbortReason> {
366                 let mut counterparty_weight_contributed: u64 = self
367                         .counterparty_outputs_contributed()
368                         .map(|output| get_output_weight(&output.script_pubkey))
369                         .sum();
370                 // We don't know the counterparty's witnesses ahead of time obviously, so we use the lower bounds
371                 // specified in BOLT 3.
372                 let mut total_inputs_weight: u64 = 0;
373                 for TxInputWithPrevOutput { prev_output, .. } in self.counterparty_inputs_contributed() {
374                         total_inputs_weight =
375                                 total_inputs_weight.saturating_add(if prev_output.script_pubkey.is_v0_p2wpkh() {
376                                         P2WPKH_INPUT_WEIGHT_LOWER_BOUND
377                                 } else if prev_output.script_pubkey.is_v0_p2wsh() {
378                                         P2WSH_INPUT_WEIGHT_LOWER_BOUND
379                                 } else if prev_output.script_pubkey.is_v1_p2tr() {
380                                         P2TR_INPUT_WEIGHT_LOWER_BOUND
381                                 } else {
382                                         UNKNOWN_SEGWIT_VERSION_INPUT_WEIGHT_LOWER_BOUND
383                                 });
384                 }
385                 counterparty_weight_contributed =
386                         counterparty_weight_contributed.saturating_add(total_inputs_weight);
387                 let counterparty_fees_contributed =
388                         counterparty_inputs_value.saturating_sub(counterparty_outputs_value);
389                 let mut required_counterparty_contribution_fee =
390                         fee_for_weight(self.feerate_sat_per_kw, counterparty_weight_contributed);
391                 if !self.holder_is_initiator {
392                         // if is the non-initiator:
393                         //      - the initiator's fees do not cover the common fields (version, segwit marker + flag,
394                         //              input count, output count, locktime)
395                         let tx_common_fields_fee =
396                                 fee_for_weight(self.feerate_sat_per_kw, TX_COMMON_FIELDS_WEIGHT);
397                         required_counterparty_contribution_fee += tx_common_fields_fee;
398                 }
399                 if counterparty_fees_contributed < required_counterparty_contribution_fee {
400                         return Err(AbortReason::InsufficientFees);
401                 }
402                 Ok(())
403         }
404
405         fn build_transaction(self) -> Result<Transaction, AbortReason> {
406                 // The receiving node:
407                 // MUST fail the negotiation if:
408
409                 // - the peer's total input satoshis is less than their outputs
410                 let mut counterparty_inputs_value: u64 = 0;
411                 let mut counterparty_outputs_value: u64 = 0;
412                 for input in self.counterparty_inputs_contributed() {
413                         counterparty_inputs_value =
414                                 counterparty_inputs_value.saturating_add(input.prev_output.value);
415                 }
416                 for output in self.counterparty_outputs_contributed() {
417                         counterparty_outputs_value = counterparty_outputs_value.saturating_add(output.value);
418                 }
419                 if counterparty_inputs_value < counterparty_outputs_value {
420                         return Err(AbortReason::OutputsValueExceedsInputsValue);
421                 }
422
423                 // - there are more than 252 inputs
424                 // - there are more than 252 outputs
425                 if self.inputs.len() > MAX_INPUTS_OUTPUTS_COUNT
426                         || self.outputs.len() > MAX_INPUTS_OUTPUTS_COUNT
427                 {
428                         return Err(AbortReason::ExceededNumberOfInputsOrOutputs);
429                 }
430
431                 // - the peer's paid feerate does not meet or exceed the agreed feerate (based on the minimum fee).
432                 self.check_counterparty_fees(counterparty_inputs_value, counterparty_outputs_value)?;
433
434                 // Inputs and outputs must be sorted by serial_id
435                 let mut inputs = self.inputs.into_iter().collect::<Vec<_>>();
436                 let mut outputs = self.outputs.into_iter().collect::<Vec<_>>();
437                 inputs.sort_unstable_by_key(|(serial_id, _)| *serial_id);
438                 outputs.sort_unstable_by_key(|(serial_id, _)| *serial_id);
439
440                 let tx_to_validate = Transaction {
441                         version: 2,
442                         lock_time: self.tx_locktime,
443                         input: inputs.into_iter().map(|(_, input)| input.input).collect(),
444                         output: outputs.into_iter().map(|(_, output)| output).collect(),
445                 };
446                 if tx_to_validate.weight().to_wu() > MAX_STANDARD_TX_WEIGHT as u64 {
447                         return Err(AbortReason::TransactionTooLarge);
448                 }
449
450                 Ok(tx_to_validate)
451         }
452 }
453
454 // The interactive transaction construction protocol allows two peers to collaboratively build a
455 // transaction for broadcast.
456 //
457 // The protocol is turn-based, so we define different states here that we store depending on whose
458 // turn it is to send the next message. The states are defined so that their types ensure we only
459 // perform actions (only send messages) via defined state transitions that do not violate the
460 // protocol.
461 //
462 // An example of a full negotiation and associated states follows:
463 //
464 //     +------------+                         +------------------+---- Holder state after message sent/received ----+
465 //     |            |--(1)- tx_add_input ---->|                  |                  SentChangeMsg                   +
466 //     |            |<-(2)- tx_complete ------|                  |                ReceivedTxComplete                +
467 //     |            |--(3)- tx_add_output --->|                  |                  SentChangeMsg                   +
468 //     |            |<-(4)- tx_complete ------|                  |                ReceivedTxComplete                +
469 //     |            |--(5)- tx_add_input ---->|                  |                  SentChangeMsg                   +
470 //     |   Holder   |<-(6)- tx_add_input -----|   Counterparty   |                ReceivedChangeMsg                 +
471 //     |            |--(7)- tx_remove_output >|                  |                  SentChangeMsg                   +
472 //     |            |<-(8)- tx_add_output ----|                  |                ReceivedChangeMsg                 +
473 //     |            |--(9)- tx_complete ----->|                  |                  SentTxComplete                  +
474 //     |            |<-(10) tx_complete ------|                  |                NegotiationComplete               +
475 //     +------------+                         +------------------+--------------------------------------------------+
476
477 /// Negotiation states that can send & receive `tx_(add|remove)_(input|output)` and `tx_complete`
478 trait State {}
479
480 /// Category of states where we have sent some message to the counterparty, and we are waiting for
481 /// a response.
482 trait SentMsgState: State {
483         fn into_negotiation_context(self) -> NegotiationContext;
484 }
485
486 /// Category of states that our counterparty has put us in after we receive a message from them.
487 trait ReceivedMsgState: State {
488         fn into_negotiation_context(self) -> NegotiationContext;
489 }
490
491 // This macro is a helper for implementing the above state traits for various states subsequently
492 // defined below the macro.
493 macro_rules! define_state {
494         (SENT_MSG_STATE, $state: ident, $doc: expr) => {
495                 define_state!($state, NegotiationContext, $doc);
496                 impl SentMsgState for $state {
497                         fn into_negotiation_context(self) -> NegotiationContext {
498                                 self.0
499                         }
500                 }
501         };
502         (RECEIVED_MSG_STATE, $state: ident, $doc: expr) => {
503                 define_state!($state, NegotiationContext, $doc);
504                 impl ReceivedMsgState for $state {
505                         fn into_negotiation_context(self) -> NegotiationContext {
506                                 self.0
507                         }
508                 }
509         };
510         ($state: ident, $inner: ident, $doc: expr) => {
511                 #[doc = $doc]
512                 #[derive(Debug)]
513                 struct $state($inner);
514                 impl State for $state {}
515         };
516 }
517
518 define_state!(
519         SENT_MSG_STATE,
520         SentChangeMsg,
521         "We have sent a message to the counterparty that has affected our negotiation state."
522 );
523 define_state!(
524         SENT_MSG_STATE,
525         SentTxComplete,
526         "We have sent a `tx_complete` message and are awaiting the counterparty's."
527 );
528 define_state!(
529         RECEIVED_MSG_STATE,
530         ReceivedChangeMsg,
531         "We have received a message from the counterparty that has affected our negotiation state."
532 );
533 define_state!(
534         RECEIVED_MSG_STATE,
535         ReceivedTxComplete,
536         "We have received a `tx_complete` message and the counterparty is awaiting ours."
537 );
538 define_state!(NegotiationComplete, Transaction, "We have exchanged consecutive `tx_complete` messages with the counterparty and the transaction negotiation is complete.");
539 define_state!(
540         NegotiationAborted,
541         AbortReason,
542         "The negotiation has failed and cannot be continued."
543 );
544
545 type StateTransitionResult<S> = Result<S, AbortReason>;
546
547 trait StateTransition<NewState: State, TransitionData> {
548         fn transition(self, data: TransitionData) -> StateTransitionResult<NewState>;
549 }
550
551 // This macro helps define the legal transitions between the states above by implementing
552 // the `StateTransition` trait for each of the states that follow this declaration.
553 macro_rules! define_state_transitions {
554         (SENT_MSG_STATE, [$(DATA $data: ty, TRANSITION $transition: ident),+]) => {
555                 $(
556                         impl<S: SentMsgState> StateTransition<ReceivedChangeMsg, $data> for S {
557                                 fn transition(self, data: $data) -> StateTransitionResult<ReceivedChangeMsg> {
558                                         let mut context = self.into_negotiation_context();
559                                         context.$transition(data)?;
560                                         Ok(ReceivedChangeMsg(context))
561                                 }
562                         }
563                  )*
564         };
565         (RECEIVED_MSG_STATE, [$(DATA $data: ty, TRANSITION $transition: ident),+]) => {
566                 $(
567                         impl<S: ReceivedMsgState> StateTransition<SentChangeMsg, $data> for S {
568                                 fn transition(self, data: $data) -> StateTransitionResult<SentChangeMsg> {
569                                         let mut context = self.into_negotiation_context();
570                                         context.$transition(data)?;
571                                         Ok(SentChangeMsg(context))
572                                 }
573                         }
574                  )*
575         };
576         (TX_COMPLETE, $from_state: ident, $tx_complete_state: ident) => {
577                 impl StateTransition<NegotiationComplete, &msgs::TxComplete> for $tx_complete_state {
578                         fn transition(self, _data: &msgs::TxComplete) -> StateTransitionResult<NegotiationComplete> {
579                                 let context = self.into_negotiation_context();
580                                 let tx = context.build_transaction()?;
581                                 Ok(NegotiationComplete(tx))
582                         }
583                 }
584
585                 impl StateTransition<$tx_complete_state, &msgs::TxComplete> for $from_state {
586                         fn transition(self, _data: &msgs::TxComplete) -> StateTransitionResult<$tx_complete_state> {
587                                 Ok($tx_complete_state(self.into_negotiation_context()))
588                         }
589                 }
590         };
591 }
592
593 // State transitions when we have sent our counterparty some messages and are waiting for them
594 // to respond.
595 define_state_transitions!(SENT_MSG_STATE, [
596         DATA &msgs::TxAddInput, TRANSITION received_tx_add_input,
597         DATA &msgs::TxRemoveInput, TRANSITION received_tx_remove_input,
598         DATA &msgs::TxAddOutput, TRANSITION received_tx_add_output,
599         DATA &msgs::TxRemoveOutput, TRANSITION received_tx_remove_output
600 ]);
601 // State transitions when we have received some messages from our counterparty and we should
602 // respond.
603 define_state_transitions!(RECEIVED_MSG_STATE, [
604         DATA &msgs::TxAddInput, TRANSITION sent_tx_add_input,
605         DATA &msgs::TxRemoveInput, TRANSITION sent_tx_remove_input,
606         DATA &msgs::TxAddOutput, TRANSITION sent_tx_add_output,
607         DATA &msgs::TxRemoveOutput, TRANSITION sent_tx_remove_output
608 ]);
609 define_state_transitions!(TX_COMPLETE, SentChangeMsg, ReceivedTxComplete);
610 define_state_transitions!(TX_COMPLETE, ReceivedChangeMsg, SentTxComplete);
611
612 #[derive(Debug)]
613 enum StateMachine {
614         Indeterminate,
615         SentChangeMsg(SentChangeMsg),
616         ReceivedChangeMsg(ReceivedChangeMsg),
617         SentTxComplete(SentTxComplete),
618         ReceivedTxComplete(ReceivedTxComplete),
619         NegotiationComplete(NegotiationComplete),
620         NegotiationAborted(NegotiationAborted),
621 }
622
623 impl Default for StateMachine {
624         fn default() -> Self {
625                 Self::Indeterminate
626         }
627 }
628
629 // The `StateMachine` internally executes the actual transition between two states and keeps
630 // track of the current state. This macro defines _how_ those state transitions happen to
631 // update the internal state.
632 macro_rules! define_state_machine_transitions {
633         ($transition: ident, $msg: ty, [$(FROM $from_state: ident, TO $to_state: ident),+]) => {
634                 fn $transition(self, msg: $msg) -> StateMachine {
635                         match self {
636                                 $(
637                                         Self::$from_state(s) => match s.transition(msg) {
638                                                 Ok(new_state) => StateMachine::$to_state(new_state),
639                                                 Err(abort_reason) => StateMachine::NegotiationAborted(NegotiationAborted(abort_reason)),
640                                         }
641                                  )*
642                                 _ => StateMachine::NegotiationAborted(NegotiationAborted(AbortReason::UnexpectedCounterpartyMessage)),
643                         }
644                 }
645         };
646 }
647
648 impl StateMachine {
649         fn new(feerate_sat_per_kw: u32, is_initiator: bool, tx_locktime: AbsoluteLockTime) -> Self {
650                 let context = NegotiationContext {
651                         tx_locktime,
652                         holder_is_initiator: is_initiator,
653                         received_tx_add_input_count: 0,
654                         received_tx_add_output_count: 0,
655                         inputs: new_hash_map(),
656                         prevtx_outpoints: new_hash_set(),
657                         outputs: new_hash_map(),
658                         feerate_sat_per_kw,
659                 };
660                 if is_initiator {
661                         Self::ReceivedChangeMsg(ReceivedChangeMsg(context))
662                 } else {
663                         Self::SentChangeMsg(SentChangeMsg(context))
664                 }
665         }
666
667         // TxAddInput
668         define_state_machine_transitions!(sent_tx_add_input, &msgs::TxAddInput, [
669                 FROM ReceivedChangeMsg, TO SentChangeMsg,
670                 FROM ReceivedTxComplete, TO SentChangeMsg
671         ]);
672         define_state_machine_transitions!(received_tx_add_input, &msgs::TxAddInput, [
673                 FROM SentChangeMsg, TO ReceivedChangeMsg,
674                 FROM SentTxComplete, TO ReceivedChangeMsg
675         ]);
676
677         // TxAddOutput
678         define_state_machine_transitions!(sent_tx_add_output, &msgs::TxAddOutput, [
679                 FROM ReceivedChangeMsg, TO SentChangeMsg,
680                 FROM ReceivedTxComplete, TO SentChangeMsg
681         ]);
682         define_state_machine_transitions!(received_tx_add_output, &msgs::TxAddOutput, [
683                 FROM SentChangeMsg, TO ReceivedChangeMsg,
684                 FROM SentTxComplete, TO ReceivedChangeMsg
685         ]);
686
687         // TxRemoveInput
688         define_state_machine_transitions!(sent_tx_remove_input, &msgs::TxRemoveInput, [
689                 FROM ReceivedChangeMsg, TO SentChangeMsg,
690                 FROM ReceivedTxComplete, TO SentChangeMsg
691         ]);
692         define_state_machine_transitions!(received_tx_remove_input, &msgs::TxRemoveInput, [
693                 FROM SentChangeMsg, TO ReceivedChangeMsg,
694                 FROM SentTxComplete, TO ReceivedChangeMsg
695         ]);
696
697         // TxRemoveOutput
698         define_state_machine_transitions!(sent_tx_remove_output, &msgs::TxRemoveOutput, [
699                 FROM ReceivedChangeMsg, TO SentChangeMsg,
700                 FROM ReceivedTxComplete, TO SentChangeMsg
701         ]);
702         define_state_machine_transitions!(received_tx_remove_output, &msgs::TxRemoveOutput, [
703                 FROM SentChangeMsg, TO ReceivedChangeMsg,
704                 FROM SentTxComplete, TO ReceivedChangeMsg
705         ]);
706
707         // TxComplete
708         define_state_machine_transitions!(sent_tx_complete, &msgs::TxComplete, [
709                 FROM ReceivedChangeMsg, TO SentTxComplete,
710                 FROM ReceivedTxComplete, TO NegotiationComplete
711         ]);
712         define_state_machine_transitions!(received_tx_complete, &msgs::TxComplete, [
713                 FROM SentChangeMsg, TO ReceivedTxComplete,
714                 FROM SentTxComplete, TO NegotiationComplete
715         ]);
716 }
717
718 pub struct InteractiveTxConstructor {
719         state_machine: StateMachine,
720         channel_id: ChannelId,
721         inputs_to_contribute: Vec<(SerialId, TxIn, TransactionU16LenLimited)>,
722         outputs_to_contribute: Vec<(SerialId, TxOut)>,
723 }
724
725 pub enum InteractiveTxMessageSend {
726         TxAddInput(msgs::TxAddInput),
727         TxAddOutput(msgs::TxAddOutput),
728         TxComplete(msgs::TxComplete),
729 }
730
731 // This macro executes a state machine transition based on a provided action.
732 macro_rules! do_state_transition {
733         ($self: ident, $transition: ident, $msg: expr) => {{
734                 let state_machine = core::mem::take(&mut $self.state_machine);
735                 $self.state_machine = state_machine.$transition($msg);
736                 match &$self.state_machine {
737                         StateMachine::NegotiationAborted(state) => Err(state.0.clone()),
738                         _ => Ok(()),
739                 }
740         }};
741 }
742
743 fn generate_holder_serial_id<ES: Deref>(entropy_source: &ES, is_initiator: bool) -> SerialId
744 where
745         ES::Target: EntropySource,
746 {
747         let rand_bytes = entropy_source.get_secure_random_bytes();
748         let mut serial_id_bytes = [0u8; 8];
749         serial_id_bytes.copy_from_slice(&rand_bytes[..8]);
750         let mut serial_id = u64::from_be_bytes(serial_id_bytes);
751         if serial_id.is_for_initiator() != is_initiator {
752                 serial_id ^= 1;
753         }
754         serial_id
755 }
756
757 pub enum HandleTxCompleteValue {
758         SendTxMessage(InteractiveTxMessageSend),
759         SendTxComplete(InteractiveTxMessageSend, Transaction),
760         NegotiationComplete(Transaction),
761 }
762
763 impl InteractiveTxConstructor {
764         /// Instantiates a new `InteractiveTxConstructor`.
765         ///
766         /// A tuple is returned containing the newly instantiate `InteractiveTxConstructor` and optionally
767         /// an initial wrapped `Tx_` message which the holder needs to send to the counterparty.
768         pub fn new<ES: Deref>(
769                 entropy_source: &ES, channel_id: ChannelId, feerate_sat_per_kw: u32, is_initiator: bool,
770                 funding_tx_locktime: AbsoluteLockTime,
771                 inputs_to_contribute: Vec<(TxIn, TransactionU16LenLimited)>,
772                 outputs_to_contribute: Vec<TxOut>,
773         ) -> (Self, Option<InteractiveTxMessageSend>)
774         where
775                 ES::Target: EntropySource,
776         {
777                 let state_machine =
778                         StateMachine::new(feerate_sat_per_kw, is_initiator, funding_tx_locktime);
779                 let mut inputs_to_contribute: Vec<(SerialId, TxIn, TransactionU16LenLimited)> =
780                         inputs_to_contribute
781                                 .into_iter()
782                                 .map(|(input, tx)| {
783                                         let serial_id = generate_holder_serial_id(entropy_source, is_initiator);
784                                         (serial_id, input, tx)
785                                 })
786                                 .collect();
787                 // We'll sort by the randomly generated serial IDs, effectively shuffling the order of the inputs
788                 // as the user passed them to us to avoid leaking any potential categorization of transactions
789                 // before we pass any of the inputs to the counterparty.
790                 inputs_to_contribute.sort_unstable_by_key(|(serial_id, _, _)| *serial_id);
791                 let mut outputs_to_contribute: Vec<(SerialId, TxOut)> = outputs_to_contribute
792                         .into_iter()
793                         .map(|output| {
794                                 let serial_id = generate_holder_serial_id(entropy_source, is_initiator);
795                                 (serial_id, output)
796                         })
797                         .collect();
798                 // In the same manner and for the same rationale as the inputs above, we'll shuffle the outputs.
799                 outputs_to_contribute.sort_unstable_by_key(|(serial_id, _)| *serial_id);
800                 let mut constructor =
801                         Self { state_machine, channel_id, inputs_to_contribute, outputs_to_contribute };
802                 let message_send = if is_initiator {
803                         match constructor.maybe_send_message() {
804                                 Ok(msg_send) => Some(msg_send),
805                                 Err(_) => {
806                                         debug_assert!(
807                                                 false,
808                                                 "We should always be able to start our state machine successfully"
809                                         );
810                                         None
811                                 },
812                         }
813                 } else {
814                         None
815                 };
816                 (constructor, message_send)
817         }
818
819         fn maybe_send_message(&mut self) -> Result<InteractiveTxMessageSend, AbortReason> {
820                 // We first attempt to send inputs we want to add, then outputs. Once we are done sending
821                 // them both, then we always send tx_complete.
822                 if let Some((serial_id, input, prevtx)) = self.inputs_to_contribute.pop() {
823                         let msg = msgs::TxAddInput {
824                                 channel_id: self.channel_id,
825                                 serial_id,
826                                 prevtx,
827                                 prevtx_out: input.previous_output.vout,
828                                 sequence: input.sequence.to_consensus_u32(),
829                         };
830                         do_state_transition!(self, sent_tx_add_input, &msg)?;
831                         Ok(InteractiveTxMessageSend::TxAddInput(msg))
832                 } else if let Some((serial_id, output)) = self.outputs_to_contribute.pop() {
833                         let msg = msgs::TxAddOutput {
834                                 channel_id: self.channel_id,
835                                 serial_id,
836                                 sats: output.value,
837                                 script: output.script_pubkey,
838                         };
839                         do_state_transition!(self, sent_tx_add_output, &msg)?;
840                         Ok(InteractiveTxMessageSend::TxAddOutput(msg))
841                 } else {
842                         let msg = msgs::TxComplete { channel_id: self.channel_id };
843                         do_state_transition!(self, sent_tx_complete, &msg)?;
844                         Ok(InteractiveTxMessageSend::TxComplete(msg))
845                 }
846         }
847
848         pub fn handle_tx_add_input(
849                 &mut self, msg: &msgs::TxAddInput,
850         ) -> Result<InteractiveTxMessageSend, AbortReason> {
851                 do_state_transition!(self, received_tx_add_input, msg)?;
852                 self.maybe_send_message()
853         }
854
855         pub fn handle_tx_remove_input(
856                 &mut self, msg: &msgs::TxRemoveInput,
857         ) -> Result<InteractiveTxMessageSend, AbortReason> {
858                 do_state_transition!(self, received_tx_remove_input, msg)?;
859                 self.maybe_send_message()
860         }
861
862         pub fn handle_tx_add_output(
863                 &mut self, msg: &msgs::TxAddOutput,
864         ) -> Result<InteractiveTxMessageSend, AbortReason> {
865                 do_state_transition!(self, received_tx_add_output, msg)?;
866                 self.maybe_send_message()
867         }
868
869         pub fn handle_tx_remove_output(
870                 &mut self, msg: &msgs::TxRemoveOutput,
871         ) -> Result<InteractiveTxMessageSend, AbortReason> {
872                 do_state_transition!(self, received_tx_remove_output, msg)?;
873                 self.maybe_send_message()
874         }
875
876         pub fn handle_tx_complete(
877                 &mut self, msg: &msgs::TxComplete,
878         ) -> Result<HandleTxCompleteValue, AbortReason> {
879                 do_state_transition!(self, received_tx_complete, msg)?;
880                 match &self.state_machine {
881                         StateMachine::ReceivedTxComplete(_) => {
882                                 let msg_send = self.maybe_send_message()?;
883                                 match &self.state_machine {
884                                         StateMachine::NegotiationComplete(s) => {
885                                                 Ok(HandleTxCompleteValue::SendTxComplete(msg_send, s.0.clone()))
886                                         },
887                                         StateMachine::SentChangeMsg(_) => {
888                                                 Ok(HandleTxCompleteValue::SendTxMessage(msg_send))
889                                         }, // We either had an input or output to contribute.
890                                         _ => {
891                                                 debug_assert!(false, "We cannot transition to any other states after receiving `tx_complete` and responding");
892                                                 Err(AbortReason::InvalidStateTransition)
893                                         },
894                                 }
895                         },
896                         StateMachine::NegotiationComplete(s) => {
897                                 Ok(HandleTxCompleteValue::NegotiationComplete(s.0.clone()))
898                         },
899                         _ => {
900                                 debug_assert!(
901                                         false,
902                                         "We cannot transition to any other states after receiving `tx_complete`"
903                                 );
904                                 Err(AbortReason::InvalidStateTransition)
905                         },
906                 }
907         }
908 }
909
910 #[cfg(test)]
911 mod tests {
912         use crate::chain::chaininterface::{fee_for_weight, FEERATE_FLOOR_SATS_PER_KW};
913         use crate::ln::channel::TOTAL_BITCOIN_SUPPLY_SATOSHIS;
914         use crate::ln::interactivetxs::{
915                 generate_holder_serial_id, AbortReason, HandleTxCompleteValue, InteractiveTxConstructor,
916                 InteractiveTxMessageSend, MAX_INPUTS_OUTPUTS_COUNT, MAX_RECEIVED_TX_ADD_INPUT_COUNT,
917                 MAX_RECEIVED_TX_ADD_OUTPUT_COUNT,
918         };
919         use crate::ln::ChannelId;
920         use crate::sign::EntropySource;
921         use crate::util::atomic_counter::AtomicCounter;
922         use crate::util::ser::TransactionU16LenLimited;
923         use bitcoin::blockdata::opcodes;
924         use bitcoin::blockdata::script::Builder;
925         use bitcoin::hashes::Hash;
926         use bitcoin::key::UntweakedPublicKey;
927         use bitcoin::secp256k1::{KeyPair, Secp256k1};
928         use bitcoin::{
929                 absolute::LockTime as AbsoluteLockTime, OutPoint, Sequence, Transaction, TxIn, TxOut,
930         };
931         use bitcoin::{PubkeyHash, ScriptBuf, WPubkeyHash, WScriptHash};
932         use core::ops::Deref;
933
934         use super::{
935                 get_output_weight, P2TR_INPUT_WEIGHT_LOWER_BOUND, P2WPKH_INPUT_WEIGHT_LOWER_BOUND,
936                 P2WSH_INPUT_WEIGHT_LOWER_BOUND, TX_COMMON_FIELDS_WEIGHT,
937         };
938
939         const TEST_FEERATE_SATS_PER_KW: u32 = FEERATE_FLOOR_SATS_PER_KW * 10;
940
941         // A simple entropy source that works based on an atomic counter.
942         struct TestEntropySource(AtomicCounter);
943         impl EntropySource for TestEntropySource {
944                 fn get_secure_random_bytes(&self) -> [u8; 32] {
945                         let mut res = [0u8; 32];
946                         let increment = self.0.get_increment();
947                         for i in 0..32 {
948                                 // Rotate the increment value by 'i' bits to the right, to avoid clashes
949                                 // when `generate_local_serial_id` does a parity flip on consecutive calls for the
950                                 // same party.
951                                 let rotated_increment = increment.rotate_right(i as u32);
952                                 res[i] = (rotated_increment & 0xff) as u8;
953                         }
954                         res
955                 }
956         }
957
958         // An entropy source that deliberately returns you the same seed every time. We use this
959         // to test if the constructor would catch inputs/outputs that are attempting to be added
960         // with duplicate serial ids.
961         struct DuplicateEntropySource;
962         impl EntropySource for DuplicateEntropySource {
963                 fn get_secure_random_bytes(&self) -> [u8; 32] {
964                         let mut res = [0u8; 32];
965                         let count = 1u64;
966                         res[0..8].copy_from_slice(&count.to_be_bytes());
967                         res
968                 }
969         }
970
971         #[derive(Debug, PartialEq, Eq)]
972         enum ErrorCulprit {
973                 NodeA,
974                 NodeB,
975                 // Some error values are only checked at the end of the negotiation and are not easy to attribute
976                 // to a particular party. Both parties would indicate an `AbortReason` in this case.
977                 // e.g. Exceeded max inputs and outputs after negotiation.
978                 Indeterminate,
979         }
980
981         struct TestSession {
982                 description: &'static str,
983                 inputs_a: Vec<(TxIn, TransactionU16LenLimited)>,
984                 outputs_a: Vec<TxOut>,
985                 inputs_b: Vec<(TxIn, TransactionU16LenLimited)>,
986                 outputs_b: Vec<TxOut>,
987                 expect_error: Option<(AbortReason, ErrorCulprit)>,
988         }
989
990         fn do_test_interactive_tx_constructor(session: TestSession) {
991                 let entropy_source = TestEntropySource(AtomicCounter::new());
992                 do_test_interactive_tx_constructor_internal(session, &&entropy_source);
993         }
994
995         fn do_test_interactive_tx_constructor_with_entropy_source<ES: Deref>(
996                 session: TestSession, entropy_source: ES,
997         ) where
998                 ES::Target: EntropySource,
999         {
1000                 do_test_interactive_tx_constructor_internal(session, &entropy_source);
1001         }
1002
1003         fn do_test_interactive_tx_constructor_internal<ES: Deref>(
1004                 session: TestSession, entropy_source: &ES,
1005         ) where
1006                 ES::Target: EntropySource,
1007         {
1008                 let channel_id = ChannelId(entropy_source.get_secure_random_bytes());
1009                 let tx_locktime = AbsoluteLockTime::from_height(1337).unwrap();
1010
1011                 let (mut constructor_a, first_message_a) = InteractiveTxConstructor::new(
1012                         entropy_source,
1013                         channel_id,
1014                         TEST_FEERATE_SATS_PER_KW,
1015                         true,
1016                         tx_locktime,
1017                         session.inputs_a,
1018                         session.outputs_a,
1019                 );
1020                 let (mut constructor_b, first_message_b) = InteractiveTxConstructor::new(
1021                         entropy_source,
1022                         channel_id,
1023                         TEST_FEERATE_SATS_PER_KW,
1024                         false,
1025                         tx_locktime,
1026                         session.inputs_b,
1027                         session.outputs_b,
1028                 );
1029
1030                 let handle_message_send =
1031                         |msg: InteractiveTxMessageSend, for_constructor: &mut InteractiveTxConstructor| {
1032                                 match msg {
1033                                         InteractiveTxMessageSend::TxAddInput(msg) => for_constructor
1034                                                 .handle_tx_add_input(&msg)
1035                                                 .map(|msg_send| (Some(msg_send), None)),
1036                                         InteractiveTxMessageSend::TxAddOutput(msg) => for_constructor
1037                                                 .handle_tx_add_output(&msg)
1038                                                 .map(|msg_send| (Some(msg_send), None)),
1039                                         InteractiveTxMessageSend::TxComplete(msg) => {
1040                                                 for_constructor.handle_tx_complete(&msg).map(|value| match value {
1041                                                         HandleTxCompleteValue::SendTxMessage(msg_send) => {
1042                                                                 (Some(msg_send), None)
1043                                                         },
1044                                                         HandleTxCompleteValue::SendTxComplete(msg_send, tx) => {
1045                                                                 (Some(msg_send), Some(tx))
1046                                                         },
1047                                                         HandleTxCompleteValue::NegotiationComplete(tx) => (None, Some(tx)),
1048                                                 })
1049                                         },
1050                                 }
1051                         };
1052
1053                 assert!(first_message_b.is_none());
1054                 let mut message_send_a = first_message_a;
1055                 let mut message_send_b = None;
1056                 let mut final_tx_a = None;
1057                 let mut final_tx_b = None;
1058                 while final_tx_a.is_none() || final_tx_b.is_none() {
1059                         if let Some(message_send_a) = message_send_a.take() {
1060                                 match handle_message_send(message_send_a, &mut constructor_b) {
1061                                         Ok((msg_send, final_tx)) => {
1062                                                 message_send_b = msg_send;
1063                                                 final_tx_b = final_tx;
1064                                         },
1065                                         Err(abort_reason) => {
1066                                                 let error_culprit = match abort_reason {
1067                                                         AbortReason::ExceededNumberOfInputsOrOutputs => {
1068                                                                 ErrorCulprit::Indeterminate
1069                                                         },
1070                                                         _ => ErrorCulprit::NodeA,
1071                                                 };
1072                                                 assert_eq!(
1073                                                         Some((abort_reason, error_culprit)),
1074                                                         session.expect_error,
1075                                                         "Test: {}",
1076                                                         session.description
1077                                                 );
1078                                                 assert!(message_send_b.is_none());
1079                                                 return;
1080                                         },
1081                                 }
1082                         }
1083                         if let Some(message_send_b) = message_send_b.take() {
1084                                 match handle_message_send(message_send_b, &mut constructor_a) {
1085                                         Ok((msg_send, final_tx)) => {
1086                                                 message_send_a = msg_send;
1087                                                 final_tx_a = final_tx;
1088                                         },
1089                                         Err(abort_reason) => {
1090                                                 let error_culprit = match abort_reason {
1091                                                         AbortReason::ExceededNumberOfInputsOrOutputs => {
1092                                                                 ErrorCulprit::Indeterminate
1093                                                         },
1094                                                         _ => ErrorCulprit::NodeB,
1095                                                 };
1096                                                 assert_eq!(
1097                                                         Some((abort_reason, error_culprit)),
1098                                                         session.expect_error,
1099                                                         "Test: {}",
1100                                                         session.description
1101                                                 );
1102                                                 assert!(message_send_a.is_none());
1103                                                 return;
1104                                         },
1105                                 }
1106                         }
1107                 }
1108                 assert!(message_send_a.is_none());
1109                 assert!(message_send_b.is_none());
1110                 assert_eq!(final_tx_a, final_tx_b);
1111                 assert!(session.expect_error.is_none(), "Test: {}", session.description);
1112         }
1113
1114         #[derive(Debug, Clone, Copy)]
1115         enum TestOutput {
1116                 P2WPKH(u64),
1117                 P2WSH(u64),
1118                 P2TR(u64),
1119                 // Non-witness type to test rejection.
1120                 P2PKH(u64),
1121         }
1122
1123         fn generate_tx(outputs: &[TestOutput]) -> Transaction {
1124                 generate_tx_with_locktime(outputs, 1337)
1125         }
1126
1127         fn generate_txout(output: &TestOutput) -> TxOut {
1128                 let secp_ctx = Secp256k1::new();
1129                 let (value, script_pubkey) = match output {
1130                         TestOutput::P2WPKH(value) => {
1131                                 (*value, ScriptBuf::new_v0_p2wpkh(&WPubkeyHash::from_slice(&[1; 20]).unwrap()))
1132                         },
1133                         TestOutput::P2WSH(value) => {
1134                                 (*value, ScriptBuf::new_v0_p2wsh(&WScriptHash::from_slice(&[2; 32]).unwrap()))
1135                         },
1136                         TestOutput::P2TR(value) => (
1137                                 *value,
1138                                 ScriptBuf::new_v1_p2tr(
1139                                         &secp_ctx,
1140                                         UntweakedPublicKey::from_keypair(
1141                                                 &KeyPair::from_seckey_slice(&secp_ctx, &[3; 32]).unwrap(),
1142                                         )
1143                                         .0,
1144                                         None,
1145                                 ),
1146                         ),
1147                         TestOutput::P2PKH(value) => {
1148                                 (*value, ScriptBuf::new_p2pkh(&PubkeyHash::from_slice(&[4; 20]).unwrap()))
1149                         },
1150                 };
1151
1152                 TxOut { value, script_pubkey }
1153         }
1154
1155         fn generate_tx_with_locktime(outputs: &[TestOutput], locktime: u32) -> Transaction {
1156                 Transaction {
1157                         version: 2,
1158                         lock_time: AbsoluteLockTime::from_height(locktime).unwrap(),
1159                         input: vec![TxIn { ..Default::default() }],
1160                         output: outputs.iter().map(generate_txout).collect(),
1161                 }
1162         }
1163
1164         fn generate_inputs(outputs: &[TestOutput]) -> Vec<(TxIn, TransactionU16LenLimited)> {
1165                 let tx = generate_tx(outputs);
1166                 let txid = tx.txid();
1167                 tx.output
1168                         .iter()
1169                         .enumerate()
1170                         .map(|(idx, _)| {
1171                                 let input = TxIn {
1172                                         previous_output: OutPoint { txid, vout: idx as u32 },
1173                                         script_sig: Default::default(),
1174                                         sequence: Sequence::ENABLE_RBF_NO_LOCKTIME,
1175                                         witness: Default::default(),
1176                                 };
1177                                 (input, TransactionU16LenLimited::new(tx.clone()).unwrap())
1178                         })
1179                         .collect()
1180         }
1181
1182         fn generate_p2wsh_script_pubkey() -> ScriptBuf {
1183                 Builder::new().push_opcode(opcodes::OP_TRUE).into_script().to_v0_p2wsh()
1184         }
1185
1186         fn generate_p2wpkh_script_pubkey() -> ScriptBuf {
1187                 ScriptBuf::new_v0_p2wpkh(&WPubkeyHash::from_slice(&[1; 20]).unwrap())
1188         }
1189
1190         fn generate_outputs(outputs: &[TestOutput]) -> Vec<TxOut> {
1191                 outputs.iter().map(generate_txout).collect()
1192         }
1193
1194         fn generate_fixed_number_of_inputs(count: u16) -> Vec<(TxIn, TransactionU16LenLimited)> {
1195                 // Generate transactions with a total `count` number of outputs such that no transaction has a
1196                 // serialized length greater than u16::MAX.
1197                 let max_outputs_per_prevtx = 1_500;
1198                 let mut remaining = count;
1199                 let mut inputs: Vec<(TxIn, TransactionU16LenLimited)> = Vec::with_capacity(count as usize);
1200
1201                 while remaining > 0 {
1202                         let tx_output_count = remaining.min(max_outputs_per_prevtx);
1203                         remaining -= tx_output_count;
1204
1205                         // Use unique locktime for each tx so outpoints are different across transactions
1206                         let tx = generate_tx_with_locktime(
1207                                 &vec![TestOutput::P2WPKH(1_000_000); tx_output_count as usize],
1208                                 (1337 + remaining).into(),
1209                         );
1210                         let txid = tx.txid();
1211
1212                         let mut temp: Vec<(TxIn, TransactionU16LenLimited)> = tx
1213                                 .output
1214                                 .iter()
1215                                 .enumerate()
1216                                 .map(|(idx, _)| {
1217                                         let input = TxIn {
1218                                                 previous_output: OutPoint { txid, vout: idx as u32 },
1219                                                 script_sig: Default::default(),
1220                                                 sequence: Sequence::ENABLE_RBF_NO_LOCKTIME,
1221                                                 witness: Default::default(),
1222                                         };
1223                                         (input, TransactionU16LenLimited::new(tx.clone()).unwrap())
1224                                 })
1225                                 .collect();
1226
1227                         inputs.append(&mut temp);
1228                 }
1229
1230                 inputs
1231         }
1232
1233         fn generate_fixed_number_of_outputs(count: u16) -> Vec<TxOut> {
1234                 // Set a constant value for each TxOut
1235                 generate_outputs(&vec![TestOutput::P2WPKH(1_000_000); count as usize])
1236         }
1237
1238         fn generate_p2sh_script_pubkey() -> ScriptBuf {
1239                 Builder::new().push_opcode(opcodes::OP_TRUE).into_script().to_p2sh()
1240         }
1241
1242         fn generate_non_witness_output(value: u64) -> TxOut {
1243                 TxOut { value, script_pubkey: generate_p2sh_script_pubkey() }
1244         }
1245
1246         #[test]
1247         fn test_interactive_tx_constructor() {
1248                 do_test_interactive_tx_constructor(TestSession {
1249                         description: "No contributions",
1250                         inputs_a: vec![],
1251                         outputs_a: vec![],
1252                         inputs_b: vec![],
1253                         outputs_b: vec![],
1254                         expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeA)),
1255                 });
1256                 do_test_interactive_tx_constructor(TestSession {
1257                         description: "Single contribution, no initiator inputs",
1258                         inputs_a: vec![],
1259                         outputs_a: generate_outputs(&[TestOutput::P2WPKH(1_000_000)]),
1260                         inputs_b: vec![],
1261                         outputs_b: vec![],
1262                         expect_error: Some((AbortReason::OutputsValueExceedsInputsValue, ErrorCulprit::NodeA)),
1263                 });
1264                 do_test_interactive_tx_constructor(TestSession {
1265                         description: "Single contribution, no initiator outputs",
1266                         inputs_a: generate_inputs(&[TestOutput::P2WPKH(1_000_000)]),
1267                         outputs_a: vec![],
1268                         inputs_b: vec![],
1269                         outputs_b: vec![],
1270                         expect_error: None,
1271                 });
1272                 do_test_interactive_tx_constructor(TestSession {
1273                         description: "Single contribution, no fees",
1274                         inputs_a: generate_inputs(&[TestOutput::P2WPKH(1_000_000)]),
1275                         outputs_a: generate_outputs(&[TestOutput::P2WPKH(1_000_000)]),
1276                         inputs_b: vec![],
1277                         outputs_b: vec![],
1278                         expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeA)),
1279                 });
1280                 let p2wpkh_fee = fee_for_weight(TEST_FEERATE_SATS_PER_KW, P2WPKH_INPUT_WEIGHT_LOWER_BOUND);
1281                 let outputs_fee = fee_for_weight(
1282                         TEST_FEERATE_SATS_PER_KW,
1283                         get_output_weight(&generate_p2wpkh_script_pubkey()),
1284                 );
1285                 let tx_common_fields_fee =
1286                         fee_for_weight(TEST_FEERATE_SATS_PER_KW, TX_COMMON_FIELDS_WEIGHT);
1287                 do_test_interactive_tx_constructor(TestSession {
1288                         description: "Single contribution, with P2WPKH input, insufficient fees",
1289                         inputs_a: generate_inputs(&[TestOutput::P2WPKH(1_000_000)]),
1290                         outputs_a: generate_outputs(&[TestOutput::P2WPKH(
1291                                 1_000_000 - p2wpkh_fee - outputs_fee - tx_common_fields_fee + 1, /* makes fees insuffcient for initiator */
1292                         )]),
1293                         inputs_b: vec![],
1294                         outputs_b: vec![],
1295                         expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeA)),
1296                 });
1297                 do_test_interactive_tx_constructor(TestSession {
1298                         description: "Single contribution with P2WPKH input, sufficient fees",
1299                         inputs_a: generate_inputs(&[TestOutput::P2WPKH(1_000_000)]),
1300                         outputs_a: generate_outputs(&[TestOutput::P2WPKH(
1301                                 1_000_000 - p2wpkh_fee - outputs_fee - tx_common_fields_fee,
1302                         )]),
1303                         inputs_b: vec![],
1304                         outputs_b: vec![],
1305                         expect_error: None,
1306                 });
1307                 let p2wsh_fee = fee_for_weight(TEST_FEERATE_SATS_PER_KW, P2WSH_INPUT_WEIGHT_LOWER_BOUND);
1308                 do_test_interactive_tx_constructor(TestSession {
1309                         description: "Single contribution, with P2WSH input, insufficient fees",
1310                         inputs_a: generate_inputs(&[TestOutput::P2WSH(1_000_000)]),
1311                         outputs_a: generate_outputs(&[TestOutput::P2WPKH(
1312                                 1_000_000 - p2wsh_fee - outputs_fee - tx_common_fields_fee + 1, /* makes fees insuffcient for initiator */
1313                         )]),
1314                         inputs_b: vec![],
1315                         outputs_b: vec![],
1316                         expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeA)),
1317                 });
1318                 do_test_interactive_tx_constructor(TestSession {
1319                         description: "Single contribution with P2WSH input, sufficient fees",
1320                         inputs_a: generate_inputs(&[TestOutput::P2WSH(1_000_000)]),
1321                         outputs_a: generate_outputs(&[TestOutput::P2WPKH(
1322                                 1_000_000 - p2wsh_fee - outputs_fee - tx_common_fields_fee,
1323                         )]),
1324                         inputs_b: vec![],
1325                         outputs_b: vec![],
1326                         expect_error: None,
1327                 });
1328                 let p2tr_fee = fee_for_weight(TEST_FEERATE_SATS_PER_KW, P2TR_INPUT_WEIGHT_LOWER_BOUND);
1329                 do_test_interactive_tx_constructor(TestSession {
1330                         description: "Single contribution, with P2TR input, insufficient fees",
1331                         inputs_a: generate_inputs(&[TestOutput::P2TR(1_000_000)]),
1332                         outputs_a: generate_outputs(&[TestOutput::P2WPKH(
1333                                 1_000_000 - p2tr_fee - outputs_fee - tx_common_fields_fee + 1, /* makes fees insuffcient for initiator */
1334                         )]),
1335                         inputs_b: vec![],
1336                         outputs_b: vec![],
1337                         expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeA)),
1338                 });
1339                 do_test_interactive_tx_constructor(TestSession {
1340                         description: "Single contribution with P2TR input, sufficient fees",
1341                         inputs_a: generate_inputs(&[TestOutput::P2TR(1_000_000)]),
1342                         outputs_a: generate_outputs(&[TestOutput::P2WPKH(
1343                                 1_000_000 - p2tr_fee - outputs_fee - tx_common_fields_fee,
1344                         )]),
1345                         inputs_b: vec![],
1346                         outputs_b: vec![],
1347                         expect_error: None,
1348                 });
1349                 do_test_interactive_tx_constructor(TestSession {
1350                         description: "Initiator contributes sufficient fees, but non-initiator does not",
1351                         inputs_a: generate_inputs(&[TestOutput::P2WPKH(1_000_000)]),
1352                         outputs_a: vec![],
1353                         inputs_b: generate_inputs(&[TestOutput::P2WPKH(100_000)]),
1354                         outputs_b: generate_outputs(&[TestOutput::P2WPKH(100_000)]),
1355                         expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeB)),
1356                 });
1357                 do_test_interactive_tx_constructor(TestSession {
1358                         description: "Multi-input-output contributions from both sides",
1359                         inputs_a: generate_inputs(&[TestOutput::P2WPKH(1_000_000); 2]),
1360                         outputs_a: generate_outputs(&[
1361                                 TestOutput::P2WPKH(1_000_000),
1362                                 TestOutput::P2WPKH(200_000),
1363                         ]),
1364                         inputs_b: generate_inputs(&[
1365                                 TestOutput::P2WPKH(1_000_000),
1366                                 TestOutput::P2WPKH(500_000),
1367                         ]),
1368                         outputs_b: generate_outputs(&[
1369                                 TestOutput::P2WPKH(1_000_000),
1370                                 TestOutput::P2WPKH(400_000),
1371                         ]),
1372                         expect_error: None,
1373                 });
1374
1375                 do_test_interactive_tx_constructor(TestSession {
1376                         description: "Prevout from initiator is not a witness program",
1377                         inputs_a: generate_inputs(&[TestOutput::P2PKH(1_000_000)]),
1378                         outputs_a: vec![],
1379                         inputs_b: vec![],
1380                         outputs_b: vec![],
1381                         expect_error: Some((AbortReason::PrevTxOutInvalid, ErrorCulprit::NodeA)),
1382                 });
1383
1384                 let tx =
1385                         TransactionU16LenLimited::new(generate_tx(&[TestOutput::P2WPKH(1_000_000)])).unwrap();
1386                 let invalid_sequence_input = TxIn {
1387                         previous_output: OutPoint { txid: tx.as_transaction().txid(), vout: 0 },
1388                         ..Default::default()
1389                 };
1390                 do_test_interactive_tx_constructor(TestSession {
1391                         description: "Invalid input sequence from initiator",
1392                         inputs_a: vec![(invalid_sequence_input, tx.clone())],
1393                         outputs_a: generate_outputs(&[TestOutput::P2WPKH(1_000_000)]),
1394                         inputs_b: vec![],
1395                         outputs_b: vec![],
1396                         expect_error: Some((AbortReason::IncorrectInputSequenceValue, ErrorCulprit::NodeA)),
1397                 });
1398                 let duplicate_input = TxIn {
1399                         previous_output: OutPoint { txid: tx.as_transaction().txid(), vout: 0 },
1400                         sequence: Sequence::ENABLE_RBF_NO_LOCKTIME,
1401                         ..Default::default()
1402                 };
1403                 do_test_interactive_tx_constructor(TestSession {
1404                         description: "Duplicate prevout from initiator",
1405                         inputs_a: vec![(duplicate_input.clone(), tx.clone()), (duplicate_input, tx.clone())],
1406                         outputs_a: generate_outputs(&[TestOutput::P2WPKH(1_000_000)]),
1407                         inputs_b: vec![],
1408                         outputs_b: vec![],
1409                         expect_error: Some((AbortReason::PrevTxOutInvalid, ErrorCulprit::NodeB)),
1410                 });
1411                 let duplicate_input = TxIn {
1412                         previous_output: OutPoint { txid: tx.as_transaction().txid(), vout: 0 },
1413                         sequence: Sequence::ENABLE_RBF_NO_LOCKTIME,
1414                         ..Default::default()
1415                 };
1416                 do_test_interactive_tx_constructor(TestSession {
1417                         description: "Non-initiator uses same prevout as initiator",
1418                         inputs_a: vec![(duplicate_input.clone(), tx.clone())],
1419                         outputs_a: generate_outputs(&[TestOutput::P2WPKH(1_000_000)]),
1420                         inputs_b: vec![(duplicate_input.clone(), tx.clone())],
1421                         outputs_b: vec![],
1422                         expect_error: Some((AbortReason::PrevTxOutInvalid, ErrorCulprit::NodeA)),
1423                 });
1424                 do_test_interactive_tx_constructor(TestSession {
1425                         description: "Initiator sends too many TxAddInputs",
1426                         inputs_a: generate_fixed_number_of_inputs(MAX_RECEIVED_TX_ADD_INPUT_COUNT + 1),
1427                         outputs_a: vec![],
1428                         inputs_b: vec![],
1429                         outputs_b: vec![],
1430                         expect_error: Some((AbortReason::ReceivedTooManyTxAddInputs, ErrorCulprit::NodeA)),
1431                 });
1432                 do_test_interactive_tx_constructor_with_entropy_source(
1433                         TestSession {
1434                                 // We use a deliberately bad entropy source, `DuplicateEntropySource` to simulate this.
1435                                 description: "Attempt to queue up two inputs with duplicate serial ids",
1436                                 inputs_a: generate_fixed_number_of_inputs(2),
1437                                 outputs_a: vec![],
1438                                 inputs_b: vec![],
1439                                 outputs_b: vec![],
1440                                 expect_error: Some((AbortReason::DuplicateSerialId, ErrorCulprit::NodeA)),
1441                         },
1442                         &DuplicateEntropySource,
1443                 );
1444                 do_test_interactive_tx_constructor(TestSession {
1445                         description: "Initiator sends too many TxAddOutputs",
1446                         inputs_a: vec![],
1447                         outputs_a: generate_fixed_number_of_outputs(MAX_RECEIVED_TX_ADD_OUTPUT_COUNT + 1),
1448                         inputs_b: vec![],
1449                         outputs_b: vec![],
1450                         expect_error: Some((AbortReason::ReceivedTooManyTxAddOutputs, ErrorCulprit::NodeA)),
1451                 });
1452                 do_test_interactive_tx_constructor(TestSession {
1453                         description: "Initiator sends an output below dust value",
1454                         inputs_a: vec![],
1455                         outputs_a: generate_outputs(&[TestOutput::P2WSH(
1456                                 generate_p2wsh_script_pubkey().dust_value().to_sat() - 1,
1457                         )]),
1458                         inputs_b: vec![],
1459                         outputs_b: vec![],
1460                         expect_error: Some((AbortReason::BelowDustLimit, ErrorCulprit::NodeA)),
1461                 });
1462                 do_test_interactive_tx_constructor(TestSession {
1463                         description: "Initiator sends an output above maximum sats allowed",
1464                         inputs_a: vec![],
1465                         outputs_a: generate_outputs(&[TestOutput::P2WPKH(TOTAL_BITCOIN_SUPPLY_SATOSHIS + 1)]),
1466                         inputs_b: vec![],
1467                         outputs_b: vec![],
1468                         expect_error: Some((AbortReason::ExceededMaximumSatsAllowed, ErrorCulprit::NodeA)),
1469                 });
1470                 do_test_interactive_tx_constructor(TestSession {
1471                         description: "Initiator sends an output without a witness program",
1472                         inputs_a: vec![],
1473                         outputs_a: vec![generate_non_witness_output(1_000_000)],
1474                         inputs_b: vec![],
1475                         outputs_b: vec![],
1476                         expect_error: Some((AbortReason::InvalidOutputScript, ErrorCulprit::NodeA)),
1477                 });
1478                 do_test_interactive_tx_constructor_with_entropy_source(
1479                         TestSession {
1480                                 // We use a deliberately bad entropy source, `DuplicateEntropySource` to simulate this.
1481                                 description: "Attempt to queue up two outputs with duplicate serial ids",
1482                                 inputs_a: vec![],
1483                                 outputs_a: generate_fixed_number_of_outputs(2),
1484                                 inputs_b: vec![],
1485                                 outputs_b: vec![],
1486                                 expect_error: Some((AbortReason::DuplicateSerialId, ErrorCulprit::NodeA)),
1487                         },
1488                         &DuplicateEntropySource,
1489                 );
1490
1491                 do_test_interactive_tx_constructor(TestSession {
1492                         description: "Peer contributed more output value than inputs",
1493                         inputs_a: generate_inputs(&[TestOutput::P2WPKH(100_000)]),
1494                         outputs_a: generate_outputs(&[TestOutput::P2WPKH(1_000_000)]),
1495                         inputs_b: vec![],
1496                         outputs_b: vec![],
1497                         expect_error: Some((AbortReason::OutputsValueExceedsInputsValue, ErrorCulprit::NodeA)),
1498                 });
1499
1500                 do_test_interactive_tx_constructor(TestSession {
1501                         description: "Peer contributed more than allowed number of inputs",
1502                         inputs_a: generate_fixed_number_of_inputs(MAX_INPUTS_OUTPUTS_COUNT as u16 + 1),
1503                         outputs_a: vec![],
1504                         inputs_b: vec![],
1505                         outputs_b: vec![],
1506                         expect_error: Some((
1507                                 AbortReason::ExceededNumberOfInputsOrOutputs,
1508                                 ErrorCulprit::Indeterminate,
1509                         )),
1510                 });
1511                 do_test_interactive_tx_constructor(TestSession {
1512                         description: "Peer contributed more than allowed number of outputs",
1513                         inputs_a: generate_inputs(&[TestOutput::P2WPKH(TOTAL_BITCOIN_SUPPLY_SATOSHIS)]),
1514                         outputs_a: generate_fixed_number_of_outputs(MAX_INPUTS_OUTPUTS_COUNT as u16 + 1),
1515                         inputs_b: vec![],
1516                         outputs_b: vec![],
1517                         expect_error: Some((
1518                                 AbortReason::ExceededNumberOfInputsOrOutputs,
1519                                 ErrorCulprit::Indeterminate,
1520                         )),
1521                 });
1522         }
1523
1524         #[test]
1525         fn test_generate_local_serial_id() {
1526                 let entropy_source = TestEntropySource(AtomicCounter::new());
1527
1528                 // Initiators should have even serial id, non-initiators should have odd serial id.
1529                 assert_eq!(generate_holder_serial_id(&&entropy_source, true) % 2, 0);
1530                 assert_eq!(generate_holder_serial_id(&&entropy_source, false) % 2, 1)
1531         }
1532 }