Merge pull request #2963 from jkczyz/2024-03-a-channel-manager
[rust-lightning] / lightning / src / ln / interactivetxs.rs
1 // This file is Copyright its original authors, visible in version control
2 // history.
3 //
4 // This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
5 // or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
7 // You may not use this file except in accordance with one or both of these
8 // licenses.
9
10 use crate::io_extras::sink;
11 use crate::prelude::*;
12 use core::ops::Deref;
13
14 use bitcoin::blockdata::constants::WITNESS_SCALE_FACTOR;
15 use bitcoin::consensus::Encodable;
16 use bitcoin::policy::MAX_STANDARD_TX_WEIGHT;
17 use bitcoin::{
18         absolute::LockTime as AbsoluteLockTime, OutPoint, Sequence, Transaction, TxIn, TxOut,
19 };
20
21 use crate::chain::chaininterface::fee_for_weight;
22 use crate::events::bump_transaction::{BASE_INPUT_WEIGHT, EMPTY_SCRIPT_SIG_WEIGHT};
23 use crate::ln::channel::TOTAL_BITCOIN_SUPPLY_SATOSHIS;
24 use crate::ln::msgs::SerialId;
25 use crate::ln::{msgs, ChannelId};
26 use crate::sign::EntropySource;
27 use crate::util::ser::TransactionU16LenLimited;
28
29 /// The number of received `tx_add_input` messages during a negotiation at which point the
30 /// negotiation MUST be failed.
31 const MAX_RECEIVED_TX_ADD_INPUT_COUNT: u16 = 4096;
32
33 /// The number of received `tx_add_output` messages during a negotiation at which point the
34 /// negotiation MUST be failed.
35 const MAX_RECEIVED_TX_ADD_OUTPUT_COUNT: u16 = 4096;
36
37 /// The number of inputs or outputs that the state machine can have, before it MUST fail the
38 /// negotiation.
39 const MAX_INPUTS_OUTPUTS_COUNT: usize = 252;
40
41 trait SerialIdExt {
42         fn is_for_initiator(&self) -> bool;
43         fn is_for_non_initiator(&self) -> bool;
44 }
45
46 impl SerialIdExt for SerialId {
47         fn is_for_initiator(&self) -> bool {
48                 self % 2 == 0
49         }
50
51         fn is_for_non_initiator(&self) -> bool {
52                 !self.is_for_initiator()
53         }
54 }
55
56 #[derive(Debug, Clone, PartialEq)]
57 pub enum AbortReason {
58         InvalidStateTransition,
59         UnexpectedCounterpartyMessage,
60         ReceivedTooManyTxAddInputs,
61         ReceivedTooManyTxAddOutputs,
62         IncorrectInputSequenceValue,
63         IncorrectSerialIdParity,
64         SerialIdUnknown,
65         DuplicateSerialId,
66         PrevTxOutInvalid,
67         ExceededMaximumSatsAllowed,
68         ExceededNumberOfInputsOrOutputs,
69         TransactionTooLarge,
70         BelowDustLimit,
71         InvalidOutputScript,
72         InsufficientFees,
73         OutputsValueExceedsInputsValue,
74         InvalidTx,
75 }
76
77 #[derive(Debug)]
78 pub struct TxInputWithPrevOutput {
79         input: TxIn,
80         prev_output: TxOut,
81 }
82
83 #[derive(Debug)]
84 struct NegotiationContext {
85         holder_is_initiator: bool,
86         received_tx_add_input_count: u16,
87         received_tx_add_output_count: u16,
88         inputs: HashMap<SerialId, TxInputWithPrevOutput>,
89         prevtx_outpoints: HashSet<OutPoint>,
90         outputs: HashMap<SerialId, TxOut>,
91         tx_locktime: AbsoluteLockTime,
92         feerate_sat_per_kw: u32,
93         to_remote_value_satoshis: u64,
94 }
95
96 impl NegotiationContext {
97         fn is_serial_id_valid_for_counterparty(&self, serial_id: &SerialId) -> bool {
98                 // A received `SerialId`'s parity must match the role of the counterparty.
99                 self.holder_is_initiator == serial_id.is_for_non_initiator()
100         }
101
102         fn total_input_and_output_count(&self) -> usize {
103                 self.inputs.len().saturating_add(self.outputs.len())
104         }
105
106         fn counterparty_inputs_contributed(
107                 &self,
108         ) -> impl Iterator<Item = &TxInputWithPrevOutput> + Clone {
109                 self.inputs
110                         .iter()
111                         .filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
112                         .map(|(_, input_with_prevout)| input_with_prevout)
113         }
114
115         fn counterparty_outputs_contributed(&self) -> impl Iterator<Item = &TxOut> + Clone {
116                 self.outputs
117                         .iter()
118                         .filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
119                         .map(|(_, output)| output)
120         }
121
122         fn received_tx_add_input(&mut self, msg: &msgs::TxAddInput) -> Result<(), AbortReason> {
123                 // The interactive-txs spec calls for us to fail negotiation if the `prevtx` we receive is
124                 // invalid. However, we would not need to account for this explicit negotiation failure
125                 // mode here since `PeerManager` would already disconnect the peer if the `prevtx` is
126                 // invalid; implicitly ending the negotiation.
127
128                 if !self.is_serial_id_valid_for_counterparty(&msg.serial_id) {
129                         // The receiving node:
130                         //  - MUST fail the negotiation if:
131                         //     - the `serial_id` has the wrong parity
132                         return Err(AbortReason::IncorrectSerialIdParity);
133                 }
134
135                 self.received_tx_add_input_count += 1;
136                 if self.received_tx_add_input_count > MAX_RECEIVED_TX_ADD_INPUT_COUNT {
137                         // The receiving node:
138                         //  - MUST fail the negotiation if:
139                         //     - if has received 4096 `tx_add_input` messages during this negotiation
140                         return Err(AbortReason::ReceivedTooManyTxAddInputs);
141                 }
142
143                 if msg.sequence >= 0xFFFFFFFE {
144                         // The receiving node:
145                         //  - MUST fail the negotiation if:
146                         //    - `sequence` is set to `0xFFFFFFFE` or `0xFFFFFFFF`
147                         return Err(AbortReason::IncorrectInputSequenceValue);
148                 }
149
150                 let transaction = msg.prevtx.as_transaction();
151                 let txid = transaction.txid();
152
153                 if let Some(tx_out) = transaction.output.get(msg.prevtx_out as usize) {
154                         if !tx_out.script_pubkey.is_witness_program() {
155                                 // The receiving node:
156                                 //  - MUST fail the negotiation if:
157                                 //     - the `scriptPubKey` is not a witness program
158                                 return Err(AbortReason::PrevTxOutInvalid);
159                         }
160
161                         if !self.prevtx_outpoints.insert(OutPoint { txid, vout: msg.prevtx_out }) {
162                                 // The receiving node:
163                                 //  - MUST fail the negotiation if:
164                                 //     - the `prevtx` and `prevtx_vout` are identical to a previously added
165                                 //       (and not removed) input's
166                                 return Err(AbortReason::PrevTxOutInvalid);
167                         }
168                 } else {
169                         // The receiving node:
170                         //  - MUST fail the negotiation if:
171                         //     - `prevtx_vout` is greater or equal to the number of outputs on `prevtx`
172                         return Err(AbortReason::PrevTxOutInvalid);
173                 }
174
175                 let prev_out = if let Some(prev_out) = transaction.output.get(msg.prevtx_out as usize) {
176                         prev_out.clone()
177                 } else {
178                         return Err(AbortReason::PrevTxOutInvalid);
179                 };
180                 if self.inputs.iter().any(|(serial_id, _)| *serial_id == msg.serial_id) {
181                         // The receiving node:
182                         //  - MUST fail the negotiation if:
183                         //    - the `serial_id` is already included in the transaction
184                         return Err(AbortReason::DuplicateSerialId);
185                 }
186                 let prev_outpoint = OutPoint { txid, vout: msg.prevtx_out };
187                 self.inputs.entry(msg.serial_id).or_insert_with(|| TxInputWithPrevOutput {
188                         input: TxIn {
189                                 previous_output: prev_outpoint.clone(),
190                                 sequence: Sequence(msg.sequence),
191                                 ..Default::default()
192                         },
193                         prev_output: prev_out,
194                 });
195                 self.prevtx_outpoints.insert(prev_outpoint);
196                 Ok(())
197         }
198
199         fn received_tx_remove_input(&mut self, msg: &msgs::TxRemoveInput) -> Result<(), AbortReason> {
200                 if !self.is_serial_id_valid_for_counterparty(&msg.serial_id) {
201                         return Err(AbortReason::IncorrectSerialIdParity);
202                 }
203
204                 self.inputs
205                         .remove(&msg.serial_id)
206                         // The receiving node:
207                         //  - MUST fail the negotiation if:
208                         //    - the input or output identified by the `serial_id` was not added by the sender
209                         //    - the `serial_id` does not correspond to a currently added input
210                         .ok_or(AbortReason::SerialIdUnknown)
211                         .map(|_| ())
212         }
213
214         fn received_tx_add_output(&mut self, msg: &msgs::TxAddOutput) -> Result<(), AbortReason> {
215                 // The receiving node:
216                 //  - MUST fail the negotiation if:
217                 //     - the serial_id has the wrong parity
218                 if !self.is_serial_id_valid_for_counterparty(&msg.serial_id) {
219                         return Err(AbortReason::IncorrectSerialIdParity);
220                 }
221
222                 self.received_tx_add_output_count += 1;
223                 if self.received_tx_add_output_count > MAX_RECEIVED_TX_ADD_OUTPUT_COUNT {
224                         // The receiving node:
225                         //  - MUST fail the negotiation if:
226                         //     - if has received 4096 `tx_add_output` messages during this negotiation
227                         return Err(AbortReason::ReceivedTooManyTxAddOutputs);
228                 }
229
230                 if msg.sats < msg.script.dust_value().to_sat() {
231                         // The receiving node:
232                         // - MUST fail the negotiation if:
233                         //              - the sats amount is less than the dust_limit
234                         return Err(AbortReason::BelowDustLimit);
235                 }
236
237                 // Check that adding this output would not cause the total output value to exceed the total
238                 // bitcoin supply.
239                 let mut outputs_value: u64 = 0;
240                 for output in self.outputs.iter() {
241                         outputs_value = outputs_value.saturating_add(output.1.value);
242                 }
243                 if outputs_value.saturating_add(msg.sats) > TOTAL_BITCOIN_SUPPLY_SATOSHIS {
244                         // The receiving node:
245                         // - MUST fail the negotiation if:
246                         //              - the sats amount is greater than 2,100,000,000,000,000 (TOTAL_BITCOIN_SUPPLY_SATOSHIS)
247                         return Err(AbortReason::ExceededMaximumSatsAllowed);
248                 }
249
250                 // The receiving node:
251                 //   - MUST accept P2WSH, P2WPKH, P2TR scripts
252                 //   - MAY fail the negotiation if script is non-standard
253                 //
254                 // We can actually be a bit looser than the above as only witness version 0 has special
255                 // length-based standardness constraints to match similar consensus rules. All witness scripts
256                 // with witness versions V1 and up are always considered standard. Yes, the scripts can be
257                 // anyone-can-spend-able, but if our counterparty wants to add an output like that then it's none
258                 // of our concern really Â¯\_(ツ)_/¯
259                 if !msg.script.is_v0_p2wpkh()
260                         && !msg.script.is_v0_p2wsh()
261                         && msg.script.witness_version().map(|v| v.to_num() < 1).unwrap_or(true)
262                 {
263                         return Err(AbortReason::InvalidOutputScript);
264                 }
265
266                 if self.outputs.iter().any(|(serial_id, _)| *serial_id == msg.serial_id) {
267                         // The receiving node:
268                         //  - MUST fail the negotiation if:
269                         //    - the `serial_id` is already included in the transaction
270                         return Err(AbortReason::DuplicateSerialId);
271                 }
272
273                 let output = TxOut { value: msg.sats, script_pubkey: msg.script.clone() };
274                 self.outputs.entry(msg.serial_id).or_insert(output);
275                 Ok(())
276         }
277
278         fn received_tx_remove_output(&mut self, msg: &msgs::TxRemoveOutput) -> Result<(), AbortReason> {
279                 if !self.is_serial_id_valid_for_counterparty(&msg.serial_id) {
280                         return Err(AbortReason::IncorrectSerialIdParity);
281                 }
282                 if let Some(_) = self.outputs.remove(&msg.serial_id) {
283                         Ok(())
284                 } else {
285                         // The receiving node:
286                         //  - MUST fail the negotiation if:
287                         //    - the input or output identified by the `serial_id` was not added by the sender
288                         //    - the `serial_id` does not correspond to a currently added input
289                         Err(AbortReason::SerialIdUnknown)
290                 }
291         }
292
293         fn sent_tx_add_input(&mut self, msg: &msgs::TxAddInput) {
294                 let tx = msg.prevtx.as_transaction();
295                 let input = TxIn {
296                         previous_output: OutPoint { txid: tx.txid(), vout: msg.prevtx_out },
297                         sequence: Sequence(msg.sequence),
298                         ..Default::default()
299                 };
300                 debug_assert!((msg.prevtx_out as usize) < tx.output.len());
301                 let prev_output = &tx.output[msg.prevtx_out as usize];
302                 self.prevtx_outpoints.insert(input.previous_output.clone());
303                 self.inputs.insert(
304                         msg.serial_id,
305                         TxInputWithPrevOutput { input, prev_output: prev_output.clone() },
306                 );
307         }
308
309         fn sent_tx_add_output(&mut self, msg: &msgs::TxAddOutput) {
310                 self.outputs
311                         .insert(msg.serial_id, TxOut { value: msg.sats, script_pubkey: msg.script.clone() });
312         }
313
314         fn sent_tx_remove_input(&mut self, msg: &msgs::TxRemoveInput) {
315                 self.inputs.remove(&msg.serial_id);
316         }
317
318         fn sent_tx_remove_output(&mut self, msg: &msgs::TxRemoveOutput) {
319                 self.outputs.remove(&msg.serial_id);
320         }
321
322         fn build_transaction(self) -> Result<Transaction, AbortReason> {
323                 // The receiving node:
324                 // MUST fail the negotiation if:
325
326                 // - the peer's total input satoshis is less than their outputs
327                 let mut counterparty_inputs_value: u64 = 0;
328                 let mut counterparty_outputs_value: u64 = 0;
329                 for input in self.counterparty_inputs_contributed() {
330                         counterparty_inputs_value =
331                                 counterparty_inputs_value.saturating_add(input.prev_output.value);
332                 }
333                 for output in self.counterparty_outputs_contributed() {
334                         counterparty_outputs_value = counterparty_outputs_value.saturating_add(output.value);
335                 }
336                 // ...actually the counterparty might be splicing out, so that their balance also contributes
337                 // to the total input value.
338                 if counterparty_inputs_value.saturating_add(self.to_remote_value_satoshis)
339                         < counterparty_outputs_value
340                 {
341                         return Err(AbortReason::OutputsValueExceedsInputsValue);
342                 }
343
344                 // - there are more than 252 inputs
345                 // - there are more than 252 outputs
346                 if self.inputs.len() > MAX_INPUTS_OUTPUTS_COUNT
347                         || self.outputs.len() > MAX_INPUTS_OUTPUTS_COUNT
348                 {
349                         return Err(AbortReason::ExceededNumberOfInputsOrOutputs);
350                 }
351
352                 // TODO: How do we enforce their fees cover the witness without knowing its expected length?
353                 const INPUT_WEIGHT: u64 = BASE_INPUT_WEIGHT + EMPTY_SCRIPT_SIG_WEIGHT;
354
355                 // - the peer's paid feerate does not meet or exceed the agreed feerate (based on the minimum fee).
356                 let counterparty_output_weight_contributed: u64 = self
357                         .counterparty_outputs_contributed()
358                         .map(|output| {
359                                 (8 /* value */ + output.script_pubkey.consensus_encode(&mut sink()).unwrap() as u64)
360                                         * WITNESS_SCALE_FACTOR as u64
361                         })
362                         .sum();
363                 let counterparty_weight_contributed = counterparty_output_weight_contributed
364                         + self.counterparty_inputs_contributed().count() as u64 * INPUT_WEIGHT;
365                 let counterparty_fees_contributed =
366                         counterparty_inputs_value.saturating_sub(counterparty_outputs_value);
367                 let mut required_counterparty_contribution_fee =
368                         fee_for_weight(self.feerate_sat_per_kw, counterparty_weight_contributed);
369                 if !self.holder_is_initiator {
370                         // if is the non-initiator:
371                         //      - the initiator's fees do not cover the common fields (version, segwit marker + flag,
372                         //              input count, output count, locktime)
373                         let tx_common_fields_weight =
374                         (4 /* version */ + 4 /* locktime */ + 1 /* input count */ + 1 /* output count */) *
375                             WITNESS_SCALE_FACTOR as u64 + 2 /* segwit marker + flag */;
376                         let tx_common_fields_fee =
377                                 fee_for_weight(self.feerate_sat_per_kw, tx_common_fields_weight);
378                         required_counterparty_contribution_fee += tx_common_fields_fee;
379                 }
380                 if counterparty_fees_contributed < required_counterparty_contribution_fee {
381                         return Err(AbortReason::InsufficientFees);
382                 }
383
384                 // Inputs and outputs must be sorted by serial_id
385                 let mut inputs = self.inputs.into_iter().collect::<Vec<_>>();
386                 let mut outputs = self.outputs.into_iter().collect::<Vec<_>>();
387                 inputs.sort_unstable_by_key(|(serial_id, _)| *serial_id);
388                 outputs.sort_unstable_by_key(|(serial_id, _)| *serial_id);
389
390                 let tx_to_validate = Transaction {
391                         version: 2,
392                         lock_time: self.tx_locktime,
393                         input: inputs.into_iter().map(|(_, input)| input.input).collect(),
394                         output: outputs.into_iter().map(|(_, output)| output).collect(),
395                 };
396                 if tx_to_validate.weight().to_wu() > MAX_STANDARD_TX_WEIGHT as u64 {
397                         return Err(AbortReason::TransactionTooLarge);
398                 }
399
400                 Ok(tx_to_validate)
401         }
402 }
403
404 // The interactive transaction construction protocol allows two peers to collaboratively build a
405 // transaction for broadcast.
406 //
407 // The protocol is turn-based, so we define different states here that we store depending on whose
408 // turn it is to send the next message. The states are defined so that their types ensure we only
409 // perform actions (only send messages) via defined state transitions that do not violate the
410 // protocol.
411 //
412 // An example of a full negotiation and associated states follows:
413 //
414 //     +------------+                         +------------------+---- Holder state after message sent/received ----+
415 //     |            |--(1)- tx_add_input ---->|                  |                  SentChangeMsg                   +
416 //     |            |<-(2)- tx_complete ------|                  |                ReceivedTxComplete                +
417 //     |            |--(3)- tx_add_output --->|                  |                  SentChangeMsg                   +
418 //     |            |<-(4)- tx_complete ------|                  |                ReceivedTxComplete                +
419 //     |            |--(5)- tx_add_input ---->|                  |                  SentChangeMsg                   +
420 //     |   Holder   |<-(6)- tx_add_input -----|   Counterparty   |                ReceivedChangeMsg                 +
421 //     |            |--(7)- tx_remove_output >|                  |                  SentChangeMsg                   +
422 //     |            |<-(8)- tx_add_output ----|                  |                ReceivedChangeMsg                 +
423 //     |            |--(9)- tx_complete ----->|                  |                  SentTxComplete                  +
424 //     |            |<-(10) tx_complete ------|                  |                NegotiationComplete               +
425 //     +------------+                         +------------------+--------------------------------------------------+
426
427 /// Negotiation states that can send & receive `tx_(add|remove)_(input|output)` and `tx_complete`
428 trait State {}
429
430 /// Category of states where we have sent some message to the counterparty, and we are waiting for
431 /// a response.
432 trait SentMsgState: State {
433         fn into_negotiation_context(self) -> NegotiationContext;
434 }
435
436 /// Category of states that our counterparty has put us in after we receive a message from them.
437 trait ReceivedMsgState: State {
438         fn into_negotiation_context(self) -> NegotiationContext;
439 }
440
441 // This macro is a helper for implementing the above state traits for various states subsequently
442 // defined below the macro.
443 macro_rules! define_state {
444         (SENT_MSG_STATE, $state: ident, $doc: expr) => {
445                 define_state!($state, NegotiationContext, $doc);
446                 impl SentMsgState for $state {
447                         fn into_negotiation_context(self) -> NegotiationContext {
448                                 self.0
449                         }
450                 }
451         };
452         (RECEIVED_MSG_STATE, $state: ident, $doc: expr) => {
453                 define_state!($state, NegotiationContext, $doc);
454                 impl ReceivedMsgState for $state {
455                         fn into_negotiation_context(self) -> NegotiationContext {
456                                 self.0
457                         }
458                 }
459         };
460         ($state: ident, $inner: ident, $doc: expr) => {
461                 #[doc = $doc]
462                 #[derive(Debug)]
463                 struct $state($inner);
464                 impl State for $state {}
465         };
466 }
467
468 define_state!(
469         SENT_MSG_STATE,
470         SentChangeMsg,
471         "We have sent a message to the counterparty that has affected our negotiation state."
472 );
473 define_state!(
474         SENT_MSG_STATE,
475         SentTxComplete,
476         "We have sent a `tx_complete` message and are awaiting the counterparty's."
477 );
478 define_state!(
479         RECEIVED_MSG_STATE,
480         ReceivedChangeMsg,
481         "We have received a message from the counterparty that has affected our negotiation state."
482 );
483 define_state!(
484         RECEIVED_MSG_STATE,
485         ReceivedTxComplete,
486         "We have received a `tx_complete` message and the counterparty is awaiting ours."
487 );
488 define_state!(NegotiationComplete, Transaction, "We have exchanged consecutive `tx_complete` messages with the counterparty and the transaction negotiation is complete.");
489 define_state!(
490         NegotiationAborted,
491         AbortReason,
492         "The negotiation has failed and cannot be continued."
493 );
494
495 type StateTransitionResult<S> = Result<S, AbortReason>;
496
497 trait StateTransition<NewState: State, TransitionData> {
498         fn transition(self, data: TransitionData) -> StateTransitionResult<NewState>;
499 }
500
501 // This macro helps define the legal transitions between the states above by implementing
502 // the `StateTransition` trait for each of the states that follow this declaration.
503 macro_rules! define_state_transitions {
504         (SENT_MSG_STATE, [$(DATA $data: ty, TRANSITION $transition: ident),+]) => {
505                 $(
506                         impl<S: SentMsgState> StateTransition<ReceivedChangeMsg, $data> for S {
507                                 fn transition(self, data: $data) -> StateTransitionResult<ReceivedChangeMsg> {
508                                         let mut context = self.into_negotiation_context();
509                                         context.$transition(data)?;
510                                         Ok(ReceivedChangeMsg(context))
511                                 }
512                         }
513                  )*
514         };
515         (RECEIVED_MSG_STATE, [$(DATA $data: ty, TRANSITION $transition: ident),+]) => {
516                 $(
517                         impl<S: ReceivedMsgState> StateTransition<SentChangeMsg, $data> for S {
518                                 fn transition(self, data: $data) -> StateTransitionResult<SentChangeMsg> {
519                                         let mut context = self.into_negotiation_context();
520                                         context.$transition(data);
521                                         Ok(SentChangeMsg(context))
522                                 }
523                         }
524                  )*
525         };
526         (TX_COMPLETE, $from_state: ident, $tx_complete_state: ident) => {
527                 impl StateTransition<NegotiationComplete, &msgs::TxComplete> for $tx_complete_state {
528                         fn transition(self, _data: &msgs::TxComplete) -> StateTransitionResult<NegotiationComplete> {
529                                 let context = self.into_negotiation_context();
530                                 let tx = context.build_transaction()?;
531                                 Ok(NegotiationComplete(tx))
532                         }
533                 }
534
535                 impl StateTransition<$tx_complete_state, &msgs::TxComplete> for $from_state {
536                         fn transition(self, _data: &msgs::TxComplete) -> StateTransitionResult<$tx_complete_state> {
537                                 Ok($tx_complete_state(self.into_negotiation_context()))
538                         }
539                 }
540         };
541 }
542
543 // State transitions when we have sent our counterparty some messages and are waiting for them
544 // to respond.
545 define_state_transitions!(SENT_MSG_STATE, [
546         DATA &msgs::TxAddInput, TRANSITION received_tx_add_input,
547         DATA &msgs::TxRemoveInput, TRANSITION received_tx_remove_input,
548         DATA &msgs::TxAddOutput, TRANSITION received_tx_add_output,
549         DATA &msgs::TxRemoveOutput, TRANSITION received_tx_remove_output
550 ]);
551 // State transitions when we have received some messages from our counterparty and we should
552 // respond.
553 define_state_transitions!(RECEIVED_MSG_STATE, [
554         DATA &msgs::TxAddInput, TRANSITION sent_tx_add_input,
555         DATA &msgs::TxRemoveInput, TRANSITION sent_tx_remove_input,
556         DATA &msgs::TxAddOutput, TRANSITION sent_tx_add_output,
557         DATA &msgs::TxRemoveOutput, TRANSITION sent_tx_remove_output
558 ]);
559 define_state_transitions!(TX_COMPLETE, SentChangeMsg, ReceivedTxComplete);
560 define_state_transitions!(TX_COMPLETE, ReceivedChangeMsg, SentTxComplete);
561
562 #[derive(Debug)]
563 enum StateMachine {
564         Indeterminate,
565         SentChangeMsg(SentChangeMsg),
566         ReceivedChangeMsg(ReceivedChangeMsg),
567         SentTxComplete(SentTxComplete),
568         ReceivedTxComplete(ReceivedTxComplete),
569         NegotiationComplete(NegotiationComplete),
570         NegotiationAborted(NegotiationAborted),
571 }
572
573 impl Default for StateMachine {
574         fn default() -> Self {
575                 Self::Indeterminate
576         }
577 }
578
579 // The `StateMachine` internally executes the actual transition between two states and keeps
580 // track of the current state. This macro defines _how_ those state transitions happen to
581 // update the internal state.
582 macro_rules! define_state_machine_transitions {
583         ($transition: ident, $msg: ty, [$(FROM $from_state: ident, TO $to_state: ident),+]) => {
584                 fn $transition(self, msg: $msg) -> StateMachine {
585                         match self {
586                                 $(
587                                         Self::$from_state(s) => match s.transition(msg) {
588                                                 Ok(new_state) => StateMachine::$to_state(new_state),
589                                                 Err(abort_reason) => StateMachine::NegotiationAborted(NegotiationAborted(abort_reason)),
590                                         }
591                                  )*
592                                 _ => StateMachine::NegotiationAborted(NegotiationAborted(AbortReason::UnexpectedCounterpartyMessage)),
593                         }
594                 }
595         };
596 }
597
598 impl StateMachine {
599         fn new(
600                 feerate_sat_per_kw: u32, is_initiator: bool, tx_locktime: AbsoluteLockTime,
601                 to_remote_value_satoshis: u64,
602         ) -> Self {
603                 let context = NegotiationContext {
604                         tx_locktime,
605                         holder_is_initiator: is_initiator,
606                         received_tx_add_input_count: 0,
607                         received_tx_add_output_count: 0,
608                         inputs: new_hash_map(),
609                         prevtx_outpoints: new_hash_set(),
610                         outputs: new_hash_map(),
611                         feerate_sat_per_kw,
612                         to_remote_value_satoshis,
613                 };
614                 if is_initiator {
615                         Self::ReceivedChangeMsg(ReceivedChangeMsg(context))
616                 } else {
617                         Self::SentChangeMsg(SentChangeMsg(context))
618                 }
619         }
620
621         // TxAddInput
622         define_state_machine_transitions!(sent_tx_add_input, &msgs::TxAddInput, [
623                 FROM ReceivedChangeMsg, TO SentChangeMsg,
624                 FROM ReceivedTxComplete, TO SentChangeMsg
625         ]);
626         define_state_machine_transitions!(received_tx_add_input, &msgs::TxAddInput, [
627                 FROM SentChangeMsg, TO ReceivedChangeMsg,
628                 FROM SentTxComplete, TO ReceivedChangeMsg
629         ]);
630
631         // TxAddOutput
632         define_state_machine_transitions!(sent_tx_add_output, &msgs::TxAddOutput, [
633                 FROM ReceivedChangeMsg, TO SentChangeMsg,
634                 FROM ReceivedTxComplete, TO SentChangeMsg
635         ]);
636         define_state_machine_transitions!(received_tx_add_output, &msgs::TxAddOutput, [
637                 FROM SentChangeMsg, TO ReceivedChangeMsg,
638                 FROM SentTxComplete, TO ReceivedChangeMsg
639         ]);
640
641         // TxRemoveInput
642         define_state_machine_transitions!(sent_tx_remove_input, &msgs::TxRemoveInput, [
643                 FROM ReceivedChangeMsg, TO SentChangeMsg,
644                 FROM ReceivedTxComplete, TO SentChangeMsg
645         ]);
646         define_state_machine_transitions!(received_tx_remove_input, &msgs::TxRemoveInput, [
647                 FROM SentChangeMsg, TO ReceivedChangeMsg,
648                 FROM SentTxComplete, TO ReceivedChangeMsg
649         ]);
650
651         // TxRemoveOutput
652         define_state_machine_transitions!(sent_tx_remove_output, &msgs::TxRemoveOutput, [
653                 FROM ReceivedChangeMsg, TO SentChangeMsg,
654                 FROM ReceivedTxComplete, TO SentChangeMsg
655         ]);
656         define_state_machine_transitions!(received_tx_remove_output, &msgs::TxRemoveOutput, [
657                 FROM SentChangeMsg, TO ReceivedChangeMsg,
658                 FROM SentTxComplete, TO ReceivedChangeMsg
659         ]);
660
661         // TxComplete
662         define_state_machine_transitions!(sent_tx_complete, &msgs::TxComplete, [
663                 FROM ReceivedChangeMsg, TO SentTxComplete,
664                 FROM ReceivedTxComplete, TO NegotiationComplete
665         ]);
666         define_state_machine_transitions!(received_tx_complete, &msgs::TxComplete, [
667                 FROM SentChangeMsg, TO ReceivedTxComplete,
668                 FROM SentTxComplete, TO NegotiationComplete
669         ]);
670 }
671
672 pub struct InteractiveTxConstructor {
673         state_machine: StateMachine,
674         channel_id: ChannelId,
675         inputs_to_contribute: Vec<(SerialId, TxIn, TransactionU16LenLimited)>,
676         outputs_to_contribute: Vec<(SerialId, TxOut)>,
677 }
678
679 pub enum InteractiveTxMessageSend {
680         TxAddInput(msgs::TxAddInput),
681         TxAddOutput(msgs::TxAddOutput),
682         TxComplete(msgs::TxComplete),
683 }
684
685 // This macro executes a state machine transition based on a provided action.
686 macro_rules! do_state_transition {
687         ($self: ident, $transition: ident, $msg: expr) => {{
688                 let state_machine = core::mem::take(&mut $self.state_machine);
689                 $self.state_machine = state_machine.$transition($msg);
690                 match &$self.state_machine {
691                         StateMachine::NegotiationAborted(state) => Err(state.0.clone()),
692                         _ => Ok(()),
693                 }
694         }};
695 }
696
697 fn generate_holder_serial_id<ES: Deref>(entropy_source: &ES, is_initiator: bool) -> SerialId
698 where
699         ES::Target: EntropySource,
700 {
701         let rand_bytes = entropy_source.get_secure_random_bytes();
702         let mut serial_id_bytes = [0u8; 8];
703         serial_id_bytes.copy_from_slice(&rand_bytes[..8]);
704         let mut serial_id = u64::from_be_bytes(serial_id_bytes);
705         if serial_id.is_for_initiator() != is_initiator {
706                 serial_id ^= 1;
707         }
708         serial_id
709 }
710
711 pub enum HandleTxCompleteValue {
712         SendTxMessage(InteractiveTxMessageSend),
713         SendTxComplete(InteractiveTxMessageSend, Transaction),
714         NegotiationComplete(Transaction),
715 }
716
717 impl InteractiveTxConstructor {
718         /// Instantiates a new `InteractiveTxConstructor`.
719         ///
720         /// If this is for a dual_funded channel then the `to_remote_value_satoshis` parameter should be set
721         /// to zero.
722         ///
723         /// A tuple is returned containing the newly instantiate `InteractiveTxConstructor` and optionally
724         /// an initial wrapped `Tx_` message which the holder needs to send to the counterparty.
725         pub fn new<ES: Deref>(
726                 entropy_source: &ES, channel_id: ChannelId, feerate_sat_per_kw: u32, is_initiator: bool,
727                 funding_tx_locktime: AbsoluteLockTime,
728                 inputs_to_contribute: Vec<(TxIn, TransactionU16LenLimited)>,
729                 outputs_to_contribute: Vec<TxOut>, to_remote_value_satoshis: u64,
730         ) -> (Self, Option<InteractiveTxMessageSend>)
731         where
732                 ES::Target: EntropySource,
733         {
734                 let state_machine = StateMachine::new(
735                         feerate_sat_per_kw,
736                         is_initiator,
737                         funding_tx_locktime,
738                         to_remote_value_satoshis,
739                 );
740                 let mut inputs_to_contribute: Vec<(SerialId, TxIn, TransactionU16LenLimited)> =
741                         inputs_to_contribute
742                                 .into_iter()
743                                 .map(|(input, tx)| {
744                                         let serial_id = generate_holder_serial_id(entropy_source, is_initiator);
745                                         (serial_id, input, tx)
746                                 })
747                                 .collect();
748                 // We'll sort by the randomly generated serial IDs, effectively shuffling the order of the inputs
749                 // as the user passed them to us to avoid leaking any potential categorization of transactions
750                 // before we pass any of the inputs to the counterparty.
751                 inputs_to_contribute.sort_unstable_by_key(|(serial_id, _, _)| *serial_id);
752                 let mut outputs_to_contribute: Vec<(SerialId, TxOut)> = outputs_to_contribute
753                         .into_iter()
754                         .map(|output| {
755                                 let serial_id = generate_holder_serial_id(entropy_source, is_initiator);
756                                 (serial_id, output)
757                         })
758                         .collect();
759                 // In the same manner and for the same rationale as the inputs above, we'll shuffle the outputs.
760                 outputs_to_contribute.sort_unstable_by_key(|(serial_id, _)| *serial_id);
761                 let mut constructor =
762                         Self { state_machine, channel_id, inputs_to_contribute, outputs_to_contribute };
763                 let message_send = if is_initiator {
764                         match constructor.maybe_send_message() {
765                                 Ok(msg_send) => Some(msg_send),
766                                 Err(_) => {
767                                         debug_assert!(
768                                                 false,
769                                                 "We should always be able to start our state machine successfully"
770                                         );
771                                         None
772                                 },
773                         }
774                 } else {
775                         None
776                 };
777                 (constructor, message_send)
778         }
779
780         fn maybe_send_message(&mut self) -> Result<InteractiveTxMessageSend, AbortReason> {
781                 // We first attempt to send inputs we want to add, then outputs. Once we are done sending
782                 // them both, then we always send tx_complete.
783                 if let Some((serial_id, input, prevtx)) = self.inputs_to_contribute.pop() {
784                         let msg = msgs::TxAddInput {
785                                 channel_id: self.channel_id,
786                                 serial_id,
787                                 prevtx,
788                                 prevtx_out: input.previous_output.vout,
789                                 sequence: input.sequence.to_consensus_u32(),
790                         };
791                         do_state_transition!(self, sent_tx_add_input, &msg)?;
792                         Ok(InteractiveTxMessageSend::TxAddInput(msg))
793                 } else if let Some((serial_id, output)) = self.outputs_to_contribute.pop() {
794                         let msg = msgs::TxAddOutput {
795                                 channel_id: self.channel_id,
796                                 serial_id,
797                                 sats: output.value,
798                                 script: output.script_pubkey,
799                         };
800                         do_state_transition!(self, sent_tx_add_output, &msg)?;
801                         Ok(InteractiveTxMessageSend::TxAddOutput(msg))
802                 } else {
803                         let msg = msgs::TxComplete { channel_id: self.channel_id };
804                         do_state_transition!(self, sent_tx_complete, &msg)?;
805                         Ok(InteractiveTxMessageSend::TxComplete(msg))
806                 }
807         }
808
809         pub fn handle_tx_add_input(
810                 &mut self, msg: &msgs::TxAddInput,
811         ) -> Result<InteractiveTxMessageSend, AbortReason> {
812                 do_state_transition!(self, received_tx_add_input, msg)?;
813                 self.maybe_send_message()
814         }
815
816         pub fn handle_tx_remove_input(
817                 &mut self, msg: &msgs::TxRemoveInput,
818         ) -> Result<InteractiveTxMessageSend, AbortReason> {
819                 do_state_transition!(self, received_tx_remove_input, msg)?;
820                 self.maybe_send_message()
821         }
822
823         pub fn handle_tx_add_output(
824                 &mut self, msg: &msgs::TxAddOutput,
825         ) -> Result<InteractiveTxMessageSend, AbortReason> {
826                 do_state_transition!(self, received_tx_add_output, msg)?;
827                 self.maybe_send_message()
828         }
829
830         pub fn handle_tx_remove_output(
831                 &mut self, msg: &msgs::TxRemoveOutput,
832         ) -> Result<InteractiveTxMessageSend, AbortReason> {
833                 do_state_transition!(self, received_tx_remove_output, msg)?;
834                 self.maybe_send_message()
835         }
836
837         pub fn handle_tx_complete(
838                 &mut self, msg: &msgs::TxComplete,
839         ) -> Result<HandleTxCompleteValue, AbortReason> {
840                 do_state_transition!(self, received_tx_complete, msg)?;
841                 match &self.state_machine {
842                         StateMachine::ReceivedTxComplete(_) => {
843                                 let msg_send = self.maybe_send_message()?;
844                                 return match &self.state_machine {
845                                         StateMachine::NegotiationComplete(s) => {
846                                                 Ok(HandleTxCompleteValue::SendTxComplete(msg_send, s.0.clone()))
847                                         },
848                                         StateMachine::SentChangeMsg(_) => {
849                                                 Ok(HandleTxCompleteValue::SendTxMessage(msg_send))
850                                         }, // We either had an input or output to contribute.
851                                         _ => {
852                                                 debug_assert!(false, "We cannot transition to any other states after receiving `tx_complete` and responding");
853                                                 return Err(AbortReason::InvalidStateTransition);
854                                         },
855                                 };
856                         },
857                         StateMachine::NegotiationComplete(s) => {
858                                 Ok(HandleTxCompleteValue::NegotiationComplete(s.0.clone()))
859                         },
860                         _ => {
861                                 debug_assert!(
862                                         false,
863                                         "We cannot transition to any other states after receiving `tx_complete`"
864                                 );
865                                 Err(AbortReason::InvalidStateTransition)
866                         },
867                 }
868         }
869 }
870
871 #[cfg(test)]
872 mod tests {
873         use crate::chain::chaininterface::FEERATE_FLOOR_SATS_PER_KW;
874         use crate::ln::channel::TOTAL_BITCOIN_SUPPLY_SATOSHIS;
875         use crate::ln::interactivetxs::{
876                 generate_holder_serial_id, AbortReason, HandleTxCompleteValue, InteractiveTxConstructor,
877                 InteractiveTxMessageSend, MAX_INPUTS_OUTPUTS_COUNT, MAX_RECEIVED_TX_ADD_INPUT_COUNT,
878                 MAX_RECEIVED_TX_ADD_OUTPUT_COUNT,
879         };
880         use crate::ln::ChannelId;
881         use crate::sign::EntropySource;
882         use crate::util::atomic_counter::AtomicCounter;
883         use crate::util::ser::TransactionU16LenLimited;
884         use bitcoin::blockdata::opcodes;
885         use bitcoin::blockdata::script::Builder;
886         use bitcoin::{
887                 absolute::LockTime as AbsoluteLockTime, OutPoint, Sequence, Transaction, TxIn, TxOut,
888         };
889         use core::ops::Deref;
890
891         // A simple entropy source that works based on an atomic counter.
892         struct TestEntropySource(AtomicCounter);
893         impl EntropySource for TestEntropySource {
894                 fn get_secure_random_bytes(&self) -> [u8; 32] {
895                         let mut res = [0u8; 32];
896                         let increment = self.0.get_increment();
897                         for i in 0..32 {
898                                 // Rotate the increment value by 'i' bits to the right, to avoid clashes
899                                 // when `generate_local_serial_id` does a parity flip on consecutive calls for the
900                                 // same party.
901                                 let rotated_increment = increment.rotate_right(i as u32);
902                                 res[i] = (rotated_increment & 0xff) as u8;
903                         }
904                         res
905                 }
906         }
907
908         // An entropy source that deliberately returns you the same seed every time. We use this
909         // to test if the constructor would catch inputs/outputs that are attempting to be added
910         // with duplicate serial ids.
911         struct DuplicateEntropySource;
912         impl EntropySource for DuplicateEntropySource {
913                 fn get_secure_random_bytes(&self) -> [u8; 32] {
914                         let mut res = [0u8; 32];
915                         let count = 1u64;
916                         res[0..8].copy_from_slice(&count.to_be_bytes());
917                         res
918                 }
919         }
920
921         #[derive(Debug, PartialEq, Eq)]
922         enum ErrorCulprit {
923                 NodeA,
924                 NodeB,
925                 // Some error values are only checked at the end of the negotiation and are not easy to attribute
926                 // to a particular party. Both parties would indicate an `AbortReason` in this case.
927                 // e.g. Exceeded max inputs and outputs after negotiation.
928                 Indeterminate,
929         }
930
931         struct TestSession {
932                 inputs_a: Vec<(TxIn, TransactionU16LenLimited)>,
933                 outputs_a: Vec<TxOut>,
934                 inputs_b: Vec<(TxIn, TransactionU16LenLimited)>,
935                 outputs_b: Vec<TxOut>,
936                 expect_error: Option<(AbortReason, ErrorCulprit)>,
937         }
938
939         fn do_test_interactive_tx_constructor(session: TestSession) {
940                 let entropy_source = TestEntropySource(AtomicCounter::new());
941                 do_test_interactive_tx_constructor_internal(session, &&entropy_source);
942         }
943
944         fn do_test_interactive_tx_constructor_with_entropy_source<ES: Deref>(
945                 session: TestSession, entropy_source: ES,
946         ) where
947                 ES::Target: EntropySource,
948         {
949                 do_test_interactive_tx_constructor_internal(session, &entropy_source);
950         }
951
952         fn do_test_interactive_tx_constructor_internal<ES: Deref>(
953                 session: TestSession, entropy_source: &ES,
954         ) where
955                 ES::Target: EntropySource,
956         {
957                 let channel_id = ChannelId(entropy_source.get_secure_random_bytes());
958                 let tx_locktime = AbsoluteLockTime::from_height(1337).unwrap();
959
960                 let (mut constructor_a, first_message_a) = InteractiveTxConstructor::new(
961                         entropy_source,
962                         channel_id,
963                         FEERATE_FLOOR_SATS_PER_KW * 10,
964                         true,
965                         tx_locktime,
966                         session.inputs_a,
967                         session.outputs_a,
968                         0,
969                 );
970                 let (mut constructor_b, first_message_b) = InteractiveTxConstructor::new(
971                         entropy_source,
972                         channel_id,
973                         FEERATE_FLOOR_SATS_PER_KW * 10,
974                         false,
975                         tx_locktime,
976                         session.inputs_b,
977                         session.outputs_b,
978                         0,
979                 );
980
981                 let handle_message_send =
982                         |msg: InteractiveTxMessageSend, for_constructor: &mut InteractiveTxConstructor| {
983                                 match msg {
984                                         InteractiveTxMessageSend::TxAddInput(msg) => for_constructor
985                                                 .handle_tx_add_input(&msg)
986                                                 .map(|msg_send| (Some(msg_send), None)),
987                                         InteractiveTxMessageSend::TxAddOutput(msg) => for_constructor
988                                                 .handle_tx_add_output(&msg)
989                                                 .map(|msg_send| (Some(msg_send), None)),
990                                         InteractiveTxMessageSend::TxComplete(msg) => {
991                                                 for_constructor.handle_tx_complete(&msg).map(|value| match value {
992                                                         HandleTxCompleteValue::SendTxMessage(msg_send) => {
993                                                                 (Some(msg_send), None)
994                                                         },
995                                                         HandleTxCompleteValue::SendTxComplete(msg_send, tx) => {
996                                                                 (Some(msg_send), Some(tx))
997                                                         },
998                                                         HandleTxCompleteValue::NegotiationComplete(tx) => (None, Some(tx)),
999                                                 })
1000                                         },
1001                                 }
1002                         };
1003
1004                 assert!(first_message_b.is_none());
1005                 let mut message_send_a = first_message_a;
1006                 let mut message_send_b = None;
1007                 let mut final_tx_a = None;
1008                 let mut final_tx_b = None;
1009                 while final_tx_a.is_none() || final_tx_b.is_none() {
1010                         if let Some(message_send_a) = message_send_a.take() {
1011                                 match handle_message_send(message_send_a, &mut constructor_b) {
1012                                         Ok((msg_send, final_tx)) => {
1013                                                 message_send_b = msg_send;
1014                                                 final_tx_b = final_tx;
1015                                         },
1016                                         Err(abort_reason) => {
1017                                                 let error_culprit = match abort_reason {
1018                                                         AbortReason::ExceededNumberOfInputsOrOutputs => {
1019                                                                 ErrorCulprit::Indeterminate
1020                                                         },
1021                                                         _ => ErrorCulprit::NodeA,
1022                                                 };
1023                                                 assert_eq!(Some((abort_reason, error_culprit)), session.expect_error);
1024                                                 assert!(message_send_b.is_none());
1025                                                 return;
1026                                         },
1027                                 }
1028                         }
1029                         if let Some(message_send_b) = message_send_b.take() {
1030                                 match handle_message_send(message_send_b, &mut constructor_a) {
1031                                         Ok((msg_send, final_tx)) => {
1032                                                 message_send_a = msg_send;
1033                                                 final_tx_a = final_tx;
1034                                         },
1035                                         Err(abort_reason) => {
1036                                                 let error_culprit = match abort_reason {
1037                                                         AbortReason::ExceededNumberOfInputsOrOutputs => {
1038                                                                 ErrorCulprit::Indeterminate
1039                                                         },
1040                                                         _ => ErrorCulprit::NodeB,
1041                                                 };
1042                                                 assert_eq!(Some((abort_reason, error_culprit)), session.expect_error);
1043                                                 assert!(message_send_a.is_none());
1044                                                 return;
1045                                         },
1046                                 }
1047                         }
1048                 }
1049                 assert!(message_send_a.is_none());
1050                 assert!(message_send_b.is_none());
1051                 assert_eq!(final_tx_a, final_tx_b);
1052                 assert!(session.expect_error.is_none());
1053         }
1054
1055         fn generate_tx(values: &[u64]) -> Transaction {
1056                 generate_tx_with_locktime(values, 1337)
1057         }
1058
1059         fn generate_tx_with_locktime(values: &[u64], locktime: u32) -> Transaction {
1060                 Transaction {
1061                         version: 2,
1062                         lock_time: AbsoluteLockTime::from_height(locktime).unwrap(),
1063                         input: vec![TxIn { ..Default::default() }],
1064                         output: values
1065                                 .iter()
1066                                 .map(|value| TxOut {
1067                                         value: *value,
1068                                         script_pubkey: Builder::new()
1069                                                 .push_opcode(opcodes::OP_TRUE)
1070                                                 .into_script()
1071                                                 .to_v0_p2wsh(),
1072                                 })
1073                                 .collect(),
1074                 }
1075         }
1076
1077         fn generate_inputs(values: &[u64]) -> Vec<(TxIn, TransactionU16LenLimited)> {
1078                 let tx = generate_tx(values);
1079                 let txid = tx.txid();
1080                 tx.output
1081                         .iter()
1082                         .enumerate()
1083                         .map(|(idx, _)| {
1084                                 let input = TxIn {
1085                                         previous_output: OutPoint { txid, vout: idx as u32 },
1086                                         script_sig: Default::default(),
1087                                         sequence: Sequence::ENABLE_RBF_NO_LOCKTIME,
1088                                         witness: Default::default(),
1089                                 };
1090                                 (input, TransactionU16LenLimited::new(tx.clone()).unwrap())
1091                         })
1092                         .collect()
1093         }
1094
1095         fn generate_outputs(values: &[u64]) -> Vec<TxOut> {
1096                 values
1097                         .iter()
1098                         .map(|value| TxOut {
1099                                 value: *value,
1100                                 script_pubkey: Builder::new()
1101                                         .push_opcode(opcodes::OP_TRUE)
1102                                         .into_script()
1103                                         .to_v0_p2wsh(),
1104                         })
1105                         .collect()
1106         }
1107
1108         fn generate_fixed_number_of_inputs(count: u16) -> Vec<(TxIn, TransactionU16LenLimited)> {
1109                 // Generate transactions with a total `count` number of outputs such that no transaction has a
1110                 // serialized length greater than u16::MAX.
1111                 let max_outputs_per_prevtx = 1_500;
1112                 let mut remaining = count;
1113                 let mut inputs: Vec<(TxIn, TransactionU16LenLimited)> = Vec::with_capacity(count as usize);
1114
1115                 while remaining > 0 {
1116                         let tx_output_count = remaining.min(max_outputs_per_prevtx);
1117                         remaining -= tx_output_count;
1118
1119                         // Use unique locktime for each tx so outpoints are different across transactions
1120                         let tx = generate_tx_with_locktime(
1121                                 &vec![1_000_000; tx_output_count as usize],
1122                                 (1337 + remaining).into(),
1123                         );
1124                         let txid = tx.txid();
1125
1126                         let mut temp: Vec<(TxIn, TransactionU16LenLimited)> = tx
1127                                 .output
1128                                 .iter()
1129                                 .enumerate()
1130                                 .map(|(idx, _)| {
1131                                         let input = TxIn {
1132                                                 previous_output: OutPoint { txid, vout: idx as u32 },
1133                                                 script_sig: Default::default(),
1134                                                 sequence: Sequence::ENABLE_RBF_NO_LOCKTIME,
1135                                                 witness: Default::default(),
1136                                         };
1137                                         (input, TransactionU16LenLimited::new(tx.clone()).unwrap())
1138                                 })
1139                                 .collect();
1140
1141                         inputs.append(&mut temp);
1142                 }
1143
1144                 inputs
1145         }
1146
1147         fn generate_fixed_number_of_outputs(count: u16) -> Vec<TxOut> {
1148                 // Set a constant value for each TxOut
1149                 generate_outputs(&vec![1_000_000; count as usize])
1150         }
1151
1152         fn generate_non_witness_output(value: u64) -> TxOut {
1153                 TxOut {
1154                         value,
1155                         script_pubkey: Builder::new().push_opcode(opcodes::OP_TRUE).into_script().to_p2sh(),
1156                 }
1157         }
1158
1159         #[test]
1160         fn test_interactive_tx_constructor() {
1161                 // No contributions.
1162                 do_test_interactive_tx_constructor(TestSession {
1163                         inputs_a: vec![],
1164                         outputs_a: vec![],
1165                         inputs_b: vec![],
1166                         outputs_b: vec![],
1167                         expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeA)),
1168                 });
1169                 // Single contribution, no initiator inputs.
1170                 do_test_interactive_tx_constructor(TestSession {
1171                         inputs_a: vec![],
1172                         outputs_a: generate_outputs(&[1_000_000]),
1173                         inputs_b: vec![],
1174                         outputs_b: vec![],
1175                         expect_error: Some((AbortReason::OutputsValueExceedsInputsValue, ErrorCulprit::NodeA)),
1176                 });
1177                 // Single contribution, no initiator outputs.
1178                 do_test_interactive_tx_constructor(TestSession {
1179                         inputs_a: generate_inputs(&[1_000_000]),
1180                         outputs_a: vec![],
1181                         inputs_b: vec![],
1182                         outputs_b: vec![],
1183                         expect_error: None,
1184                 });
1185                 // Single contribution, insufficient fees.
1186                 do_test_interactive_tx_constructor(TestSession {
1187                         inputs_a: generate_inputs(&[1_000_000]),
1188                         outputs_a: generate_outputs(&[1_000_000]),
1189                         inputs_b: vec![],
1190                         outputs_b: vec![],
1191                         expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeA)),
1192                 });
1193                 // Initiator contributes sufficient fees, but non-initiator does not.
1194                 do_test_interactive_tx_constructor(TestSession {
1195                         inputs_a: generate_inputs(&[1_000_000]),
1196                         outputs_a: vec![],
1197                         inputs_b: generate_inputs(&[100_000]),
1198                         outputs_b: generate_outputs(&[100_000]),
1199                         expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeB)),
1200                 });
1201                 // Multi-input-output contributions from both sides.
1202                 do_test_interactive_tx_constructor(TestSession {
1203                         inputs_a: generate_inputs(&[1_000_000, 1_000_000]),
1204                         outputs_a: generate_outputs(&[1_000_000, 200_000]),
1205                         inputs_b: generate_inputs(&[1_000_000, 500_000]),
1206                         outputs_b: generate_outputs(&[1_000_000, 400_000]),
1207                         expect_error: None,
1208                 });
1209
1210                 // Prevout from initiator is not a witness program
1211                 let non_segwit_output_tx = {
1212                         let mut tx = generate_tx(&[1_000_000]);
1213                         tx.output.push(TxOut {
1214                                 script_pubkey: Builder::new()
1215                                         .push_opcode(opcodes::all::OP_RETURN)
1216                                         .into_script()
1217                                         .to_p2sh(),
1218                                 ..Default::default()
1219                         });
1220
1221                         TransactionU16LenLimited::new(tx).unwrap()
1222                 };
1223                 let non_segwit_input = TxIn {
1224                         previous_output: OutPoint {
1225                                 txid: non_segwit_output_tx.as_transaction().txid(),
1226                                 vout: 1,
1227                         },
1228                         sequence: Sequence::ENABLE_RBF_NO_LOCKTIME,
1229                         ..Default::default()
1230                 };
1231                 do_test_interactive_tx_constructor(TestSession {
1232                         inputs_a: vec![(non_segwit_input, non_segwit_output_tx)],
1233                         outputs_a: vec![],
1234                         inputs_b: vec![],
1235                         outputs_b: vec![],
1236                         expect_error: Some((AbortReason::PrevTxOutInvalid, ErrorCulprit::NodeA)),
1237                 });
1238
1239                 // Invalid input sequence from initiator.
1240                 let tx = TransactionU16LenLimited::new(generate_tx(&[1_000_000])).unwrap();
1241                 let invalid_sequence_input = TxIn {
1242                         previous_output: OutPoint { txid: tx.as_transaction().txid(), vout: 0 },
1243                         ..Default::default()
1244                 };
1245                 do_test_interactive_tx_constructor(TestSession {
1246                         inputs_a: vec![(invalid_sequence_input, tx.clone())],
1247                         outputs_a: generate_outputs(&[1_000_000]),
1248                         inputs_b: vec![],
1249                         outputs_b: vec![],
1250                         expect_error: Some((AbortReason::IncorrectInputSequenceValue, ErrorCulprit::NodeA)),
1251                 });
1252                 // Duplicate prevout from initiator.
1253                 let duplicate_input = TxIn {
1254                         previous_output: OutPoint { txid: tx.as_transaction().txid(), vout: 0 },
1255                         sequence: Sequence::ENABLE_RBF_NO_LOCKTIME,
1256                         ..Default::default()
1257                 };
1258                 do_test_interactive_tx_constructor(TestSession {
1259                         inputs_a: vec![(duplicate_input.clone(), tx.clone()), (duplicate_input, tx.clone())],
1260                         outputs_a: generate_outputs(&[1_000_000]),
1261                         inputs_b: vec![],
1262                         outputs_b: vec![],
1263                         expect_error: Some((AbortReason::PrevTxOutInvalid, ErrorCulprit::NodeA)),
1264                 });
1265                 // Non-initiator uses same prevout as initiator.
1266                 let duplicate_input = TxIn {
1267                         previous_output: OutPoint { txid: tx.as_transaction().txid(), vout: 0 },
1268                         sequence: Sequence::ENABLE_RBF_NO_LOCKTIME,
1269                         ..Default::default()
1270                 };
1271                 do_test_interactive_tx_constructor(TestSession {
1272                         inputs_a: vec![(duplicate_input.clone(), tx.clone())],
1273                         outputs_a: generate_outputs(&[1_000_000]),
1274                         inputs_b: vec![(duplicate_input.clone(), tx.clone())],
1275                         outputs_b: vec![],
1276                         expect_error: Some((AbortReason::PrevTxOutInvalid, ErrorCulprit::NodeB)),
1277                 });
1278                 // Initiator sends too many TxAddInputs
1279                 do_test_interactive_tx_constructor(TestSession {
1280                         inputs_a: generate_fixed_number_of_inputs(MAX_RECEIVED_TX_ADD_INPUT_COUNT + 1),
1281                         outputs_a: vec![],
1282                         inputs_b: vec![],
1283                         outputs_b: vec![],
1284                         expect_error: Some((AbortReason::ReceivedTooManyTxAddInputs, ErrorCulprit::NodeA)),
1285                 });
1286                 // Attempt to queue up two inputs with duplicate serial ids. We use a deliberately bad
1287                 // entropy source, `DuplicateEntropySource` to simulate this.
1288                 do_test_interactive_tx_constructor_with_entropy_source(
1289                         TestSession {
1290                                 inputs_a: generate_fixed_number_of_inputs(2),
1291                                 outputs_a: vec![],
1292                                 inputs_b: vec![],
1293                                 outputs_b: vec![],
1294                                 expect_error: Some((AbortReason::DuplicateSerialId, ErrorCulprit::NodeA)),
1295                         },
1296                         &DuplicateEntropySource,
1297                 );
1298                 // Initiator sends too many TxAddOutputs.
1299                 do_test_interactive_tx_constructor(TestSession {
1300                         inputs_a: vec![],
1301                         outputs_a: generate_fixed_number_of_outputs(MAX_RECEIVED_TX_ADD_OUTPUT_COUNT + 1),
1302                         inputs_b: vec![],
1303                         outputs_b: vec![],
1304                         expect_error: Some((AbortReason::ReceivedTooManyTxAddOutputs, ErrorCulprit::NodeA)),
1305                 });
1306                 // Initiator sends an output below dust value.
1307                 do_test_interactive_tx_constructor(TestSession {
1308                         inputs_a: vec![],
1309                         outputs_a: generate_outputs(&[1]),
1310                         inputs_b: vec![],
1311                         outputs_b: vec![],
1312                         expect_error: Some((AbortReason::BelowDustLimit, ErrorCulprit::NodeA)),
1313                 });
1314                 // Initiator sends an output above maximum sats allowed.
1315                 do_test_interactive_tx_constructor(TestSession {
1316                         inputs_a: vec![],
1317                         outputs_a: generate_outputs(&[TOTAL_BITCOIN_SUPPLY_SATOSHIS + 1]),
1318                         inputs_b: vec![],
1319                         outputs_b: vec![],
1320                         expect_error: Some((AbortReason::ExceededMaximumSatsAllowed, ErrorCulprit::NodeA)),
1321                 });
1322                 // Initiator sends an output without a witness program.
1323                 do_test_interactive_tx_constructor(TestSession {
1324                         inputs_a: vec![],
1325                         outputs_a: vec![generate_non_witness_output(1_000_000)],
1326                         inputs_b: vec![],
1327                         outputs_b: vec![],
1328                         expect_error: Some((AbortReason::InvalidOutputScript, ErrorCulprit::NodeA)),
1329                 });
1330                 // Attempt to queue up two outputs with duplicate serial ids. We use a deliberately bad
1331                 // entropy source, `DuplicateEntropySource` to simulate this.
1332                 do_test_interactive_tx_constructor_with_entropy_source(
1333                         TestSession {
1334                                 inputs_a: vec![],
1335                                 outputs_a: generate_fixed_number_of_outputs(2),
1336                                 inputs_b: vec![],
1337                                 outputs_b: vec![],
1338                                 expect_error: Some((AbortReason::DuplicateSerialId, ErrorCulprit::NodeA)),
1339                         },
1340                         &DuplicateEntropySource,
1341                 );
1342
1343                 // Peer contributed more output value than inputs
1344                 do_test_interactive_tx_constructor(TestSession {
1345                         inputs_a: generate_inputs(&[100_000]),
1346                         outputs_a: generate_outputs(&[1_000_000]),
1347                         inputs_b: vec![],
1348                         outputs_b: vec![],
1349                         expect_error: Some((AbortReason::OutputsValueExceedsInputsValue, ErrorCulprit::NodeA)),
1350                 });
1351
1352                 // Peer contributed more than allowed number of inputs.
1353                 do_test_interactive_tx_constructor(TestSession {
1354                         inputs_a: generate_fixed_number_of_inputs(MAX_INPUTS_OUTPUTS_COUNT as u16 + 1),
1355                         outputs_a: vec![],
1356                         inputs_b: vec![],
1357                         outputs_b: vec![],
1358                         expect_error: Some((
1359                                 AbortReason::ExceededNumberOfInputsOrOutputs,
1360                                 ErrorCulprit::Indeterminate,
1361                         )),
1362                 });
1363                 // Peer contributed more than allowed number of outputs.
1364                 do_test_interactive_tx_constructor(TestSession {
1365                         inputs_a: generate_inputs(&[TOTAL_BITCOIN_SUPPLY_SATOSHIS]),
1366                         outputs_a: generate_fixed_number_of_outputs(MAX_INPUTS_OUTPUTS_COUNT as u16 + 1),
1367                         inputs_b: vec![],
1368                         outputs_b: vec![],
1369                         expect_error: Some((
1370                                 AbortReason::ExceededNumberOfInputsOrOutputs,
1371                                 ErrorCulprit::Indeterminate,
1372                         )),
1373                 });
1374         }
1375
1376         #[test]
1377         fn test_generate_local_serial_id() {
1378                 let entropy_source = TestEntropySource(AtomicCounter::new());
1379
1380                 // Initiators should have even serial id, non-initiators should have odd serial id.
1381                 assert_eq!(generate_holder_serial_id(&&entropy_source, true) % 2, 0);
1382                 assert_eq!(generate_holder_serial_id(&&entropy_source, false) % 2, 1)
1383         }
1384 }