a019cf89f230b4ecf7c9e0c7cf5701d252e660ec
[rust-lightning] / lightning / src / ln / interactivetxs.rs
1 // This file is Copyright its original authors, visible in version control
2 // history.
3 //
4 // This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
5 // or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
7 // You may not use this file except in accordance with one or both of these
8 // licenses.
9
10 use crate::io_extras::sink;
11 use crate::prelude::*;
12 use core::ops::Deref;
13
14 use bitcoin::blockdata::constants::WITNESS_SCALE_FACTOR;
15 use bitcoin::consensus::Encodable;
16 use bitcoin::policy::MAX_STANDARD_TX_WEIGHT;
17 use bitcoin::{
18         absolute::LockTime as AbsoluteLockTime, OutPoint, Sequence, Transaction, TxIn, TxOut,
19 };
20
21 use crate::chain::chaininterface::fee_for_weight;
22 use crate::events::bump_transaction::{BASE_INPUT_WEIGHT, EMPTY_SCRIPT_SIG_WEIGHT};
23 use crate::ln::channel::TOTAL_BITCOIN_SUPPLY_SATOSHIS;
24 use crate::ln::msgs::SerialId;
25 use crate::ln::{msgs, ChannelId};
26 use crate::sign::EntropySource;
27 use crate::util::ser::TransactionU16LenLimited;
28
29 /// The number of received `tx_add_input` messages during a negotiation at which point the
30 /// negotiation MUST be failed.
31 const MAX_RECEIVED_TX_ADD_INPUT_COUNT: u16 = 4096;
32
33 /// The number of received `tx_add_output` messages during a negotiation at which point the
34 /// negotiation MUST be failed.
35 const MAX_RECEIVED_TX_ADD_OUTPUT_COUNT: u16 = 4096;
36
37 /// The number of inputs or outputs that the state machine can have, before it MUST fail the
38 /// negotiation.
39 const MAX_INPUTS_OUTPUTS_COUNT: usize = 252;
40
41 trait SerialIdExt {
42         fn is_for_initiator(&self) -> bool;
43         fn is_for_non_initiator(&self) -> bool;
44 }
45
46 impl SerialIdExt for SerialId {
47         fn is_for_initiator(&self) -> bool {
48                 self % 2 == 0
49         }
50
51         fn is_for_non_initiator(&self) -> bool {
52                 !self.is_for_initiator()
53         }
54 }
55
56 #[derive(Debug, Clone, PartialEq)]
57 pub enum AbortReason {
58         InvalidStateTransition,
59         UnexpectedCounterpartyMessage,
60         ReceivedTooManyTxAddInputs,
61         ReceivedTooManyTxAddOutputs,
62         IncorrectInputSequenceValue,
63         IncorrectSerialIdParity,
64         SerialIdUnknown,
65         DuplicateSerialId,
66         PrevTxOutInvalid,
67         ExceededMaximumSatsAllowed,
68         ExceededNumberOfInputsOrOutputs,
69         TransactionTooLarge,
70         BelowDustLimit,
71         InvalidOutputScript,
72         InsufficientFees,
73         OutputsValueExceedsInputsValue,
74         InvalidTx,
75 }
76
77 #[derive(Debug)]
78 pub struct TxInputWithPrevOutput {
79         input: TxIn,
80         prev_output: TxOut,
81 }
82
83 #[derive(Debug)]
84 struct NegotiationContext {
85         holder_is_initiator: bool,
86         received_tx_add_input_count: u16,
87         received_tx_add_output_count: u16,
88         inputs: HashMap<SerialId, TxInputWithPrevOutput>,
89         prevtx_outpoints: HashSet<OutPoint>,
90         outputs: HashMap<SerialId, TxOut>,
91         tx_locktime: AbsoluteLockTime,
92         feerate_sat_per_kw: u32,
93 }
94
95 impl NegotiationContext {
96         fn is_serial_id_valid_for_counterparty(&self, serial_id: &SerialId) -> bool {
97                 // A received `SerialId`'s parity must match the role of the counterparty.
98                 self.holder_is_initiator == serial_id.is_for_non_initiator()
99         }
100
101         fn total_input_and_output_count(&self) -> usize {
102                 self.inputs.len().saturating_add(self.outputs.len())
103         }
104
105         fn counterparty_inputs_contributed(
106                 &self,
107         ) -> impl Iterator<Item = &TxInputWithPrevOutput> + Clone {
108                 self.inputs
109                         .iter()
110                         .filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
111                         .map(|(_, input_with_prevout)| input_with_prevout)
112         }
113
114         fn counterparty_outputs_contributed(&self) -> impl Iterator<Item = &TxOut> + Clone {
115                 self.outputs
116                         .iter()
117                         .filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
118                         .map(|(_, output)| output)
119         }
120
121         fn received_tx_add_input(&mut self, msg: &msgs::TxAddInput) -> Result<(), AbortReason> {
122                 // The interactive-txs spec calls for us to fail negotiation if the `prevtx` we receive is
123                 // invalid. However, we would not need to account for this explicit negotiation failure
124                 // mode here since `PeerManager` would already disconnect the peer if the `prevtx` is
125                 // invalid; implicitly ending the negotiation.
126
127                 if !self.is_serial_id_valid_for_counterparty(&msg.serial_id) {
128                         // The receiving node:
129                         //  - MUST fail the negotiation if:
130                         //     - the `serial_id` has the wrong parity
131                         return Err(AbortReason::IncorrectSerialIdParity);
132                 }
133
134                 self.received_tx_add_input_count += 1;
135                 if self.received_tx_add_input_count > MAX_RECEIVED_TX_ADD_INPUT_COUNT {
136                         // The receiving node:
137                         //  - MUST fail the negotiation if:
138                         //     - if has received 4096 `tx_add_input` messages during this negotiation
139                         return Err(AbortReason::ReceivedTooManyTxAddInputs);
140                 }
141
142                 if msg.sequence >= 0xFFFFFFFE {
143                         // The receiving node:
144                         //  - MUST fail the negotiation if:
145                         //    - `sequence` is set to `0xFFFFFFFE` or `0xFFFFFFFF`
146                         return Err(AbortReason::IncorrectInputSequenceValue);
147                 }
148
149                 let transaction = msg.prevtx.as_transaction();
150                 let txid = transaction.txid();
151
152                 if let Some(tx_out) = transaction.output.get(msg.prevtx_out as usize) {
153                         if !tx_out.script_pubkey.is_witness_program() {
154                                 // The receiving node:
155                                 //  - MUST fail the negotiation if:
156                                 //     - the `scriptPubKey` is not a witness program
157                                 return Err(AbortReason::PrevTxOutInvalid);
158                         }
159
160                         if !self.prevtx_outpoints.insert(OutPoint { txid, vout: msg.prevtx_out }) {
161                                 // The receiving node:
162                                 //  - MUST fail the negotiation if:
163                                 //     - the `prevtx` and `prevtx_vout` are identical to a previously added
164                                 //       (and not removed) input's
165                                 return Err(AbortReason::PrevTxOutInvalid);
166                         }
167                 } else {
168                         // The receiving node:
169                         //  - MUST fail the negotiation if:
170                         //     - `prevtx_vout` is greater or equal to the number of outputs on `prevtx`
171                         return Err(AbortReason::PrevTxOutInvalid);
172                 }
173
174                 let prev_out = if let Some(prev_out) = transaction.output.get(msg.prevtx_out as usize) {
175                         prev_out.clone()
176                 } else {
177                         return Err(AbortReason::PrevTxOutInvalid);
178                 };
179                 match self.inputs.entry(msg.serial_id) {
180                         hash_map::Entry::Occupied(_) => {
181                                 // The receiving node:
182                                 //  - MUST fail the negotiation if:
183                                 //    - the `serial_id` is already included in the transaction
184                                 Err(AbortReason::DuplicateSerialId)
185                         },
186                         hash_map::Entry::Vacant(entry) => {
187                                 let prev_outpoint = OutPoint { txid, vout: msg.prevtx_out };
188                                 entry.insert(TxInputWithPrevOutput {
189                                         input: TxIn {
190                                                 previous_output: prev_outpoint,
191                                                 sequence: Sequence(msg.sequence),
192                                                 ..Default::default()
193                                         },
194                                         prev_output: prev_out,
195                                 });
196                                 self.prevtx_outpoints.insert(prev_outpoint);
197                                 Ok(())
198                         },
199                 }
200         }
201
202         fn received_tx_remove_input(&mut self, msg: &msgs::TxRemoveInput) -> Result<(), AbortReason> {
203                 if !self.is_serial_id_valid_for_counterparty(&msg.serial_id) {
204                         return Err(AbortReason::IncorrectSerialIdParity);
205                 }
206
207                 self.inputs
208                         .remove(&msg.serial_id)
209                         // The receiving node:
210                         //  - MUST fail the negotiation if:
211                         //    - the input or output identified by the `serial_id` was not added by the sender
212                         //    - the `serial_id` does not correspond to a currently added input
213                         .ok_or(AbortReason::SerialIdUnknown)
214                         .map(|_| ())
215         }
216
217         fn received_tx_add_output(&mut self, msg: &msgs::TxAddOutput) -> Result<(), AbortReason> {
218                 // The receiving node:
219                 //  - MUST fail the negotiation if:
220                 //     - the serial_id has the wrong parity
221                 if !self.is_serial_id_valid_for_counterparty(&msg.serial_id) {
222                         return Err(AbortReason::IncorrectSerialIdParity);
223                 }
224
225                 self.received_tx_add_output_count += 1;
226                 if self.received_tx_add_output_count > MAX_RECEIVED_TX_ADD_OUTPUT_COUNT {
227                         // The receiving node:
228                         //  - MUST fail the negotiation if:
229                         //     - if has received 4096 `tx_add_output` messages during this negotiation
230                         return Err(AbortReason::ReceivedTooManyTxAddOutputs);
231                 }
232
233                 if msg.sats < msg.script.dust_value().to_sat() {
234                         // The receiving node:
235                         // - MUST fail the negotiation if:
236                         //              - the sats amount is less than the dust_limit
237                         return Err(AbortReason::BelowDustLimit);
238                 }
239
240                 // Check that adding this output would not cause the total output value to exceed the total
241                 // bitcoin supply.
242                 let mut outputs_value: u64 = 0;
243                 for output in self.outputs.iter() {
244                         outputs_value = outputs_value.saturating_add(output.1.value);
245                 }
246                 if outputs_value.saturating_add(msg.sats) > TOTAL_BITCOIN_SUPPLY_SATOSHIS {
247                         // The receiving node:
248                         // - MUST fail the negotiation if:
249                         //              - the sats amount is greater than 2,100,000,000,000,000 (TOTAL_BITCOIN_SUPPLY_SATOSHIS)
250                         return Err(AbortReason::ExceededMaximumSatsAllowed);
251                 }
252
253                 // The receiving node:
254                 //   - MUST accept P2WSH, P2WPKH, P2TR scripts
255                 //   - MAY fail the negotiation if script is non-standard
256                 //
257                 // We can actually be a bit looser than the above as only witness version 0 has special
258                 // length-based standardness constraints to match similar consensus rules. All witness scripts
259                 // with witness versions V1 and up are always considered standard. Yes, the scripts can be
260                 // anyone-can-spend-able, but if our counterparty wants to add an output like that then it's none
261                 // of our concern really Â¯\_(ツ)_/¯
262                 //
263                 // TODO: The last check would be simplified when https://github.com/rust-bitcoin/rust-bitcoin/commit/1656e1a09a1959230e20af90d20789a4a8f0a31b
264                 // hits the next release of rust-bitcoin.
265                 if !(msg.script.is_v0_p2wpkh()
266                         || msg.script.is_v0_p2wsh()
267                         || (msg.script.is_witness_program()
268                                 && msg.script.witness_version().map(|v| v.to_num() >= 1).unwrap_or(false)))
269                 {
270                         return Err(AbortReason::InvalidOutputScript);
271                 }
272
273                 match self.outputs.entry(msg.serial_id) {
274                         hash_map::Entry::Occupied(_) => {
275                                 // The receiving node:
276                                 //  - MUST fail the negotiation if:
277                                 //    - the `serial_id` is already included in the transaction
278                                 Err(AbortReason::DuplicateSerialId)
279                         },
280                         hash_map::Entry::Vacant(entry) => {
281                                 entry.insert(TxOut { value: msg.sats, script_pubkey: msg.script.clone() });
282                                 Ok(())
283                         },
284                 }
285         }
286
287         fn received_tx_remove_output(&mut self, msg: &msgs::TxRemoveOutput) -> Result<(), AbortReason> {
288                 if !self.is_serial_id_valid_for_counterparty(&msg.serial_id) {
289                         return Err(AbortReason::IncorrectSerialIdParity);
290                 }
291                 if self.outputs.remove(&msg.serial_id).is_some() {
292                         Ok(())
293                 } else {
294                         // The receiving node:
295                         //  - MUST fail the negotiation if:
296                         //    - the input or output identified by the `serial_id` was not added by the sender
297                         //    - the `serial_id` does not correspond to a currently added input
298                         Err(AbortReason::SerialIdUnknown)
299                 }
300         }
301
302         fn sent_tx_add_input(&mut self, msg: &msgs::TxAddInput) -> Result<(), AbortReason> {
303                 let tx = msg.prevtx.as_transaction();
304                 let input = TxIn {
305                         previous_output: OutPoint { txid: tx.txid(), vout: msg.prevtx_out },
306                         sequence: Sequence(msg.sequence),
307                         ..Default::default()
308                 };
309                 let prev_output =
310                         tx.output.get(msg.prevtx_out as usize).ok_or(AbortReason::PrevTxOutInvalid)?.clone();
311                 if !self.prevtx_outpoints.insert(input.previous_output) {
312                         // We have added an input that already exists
313                         return Err(AbortReason::PrevTxOutInvalid);
314                 }
315                 self.inputs.insert(msg.serial_id, TxInputWithPrevOutput { input, prev_output });
316                 Ok(())
317         }
318
319         fn sent_tx_add_output(&mut self, msg: &msgs::TxAddOutput) -> Result<(), AbortReason> {
320                 self.outputs
321                         .insert(msg.serial_id, TxOut { value: msg.sats, script_pubkey: msg.script.clone() });
322                 Ok(())
323         }
324
325         fn sent_tx_remove_input(&mut self, msg: &msgs::TxRemoveInput) -> Result<(), AbortReason> {
326                 self.inputs.remove(&msg.serial_id);
327                 Ok(())
328         }
329
330         fn sent_tx_remove_output(&mut self, msg: &msgs::TxRemoveOutput) -> Result<(), AbortReason> {
331                 self.outputs.remove(&msg.serial_id);
332                 Ok(())
333         }
334
335         fn build_transaction(self) -> Result<Transaction, AbortReason> {
336                 // The receiving node:
337                 // MUST fail the negotiation if:
338
339                 // - the peer's total input satoshis is less than their outputs
340                 let mut counterparty_inputs_value: u64 = 0;
341                 let mut counterparty_outputs_value: u64 = 0;
342                 for input in self.counterparty_inputs_contributed() {
343                         counterparty_inputs_value =
344                                 counterparty_inputs_value.saturating_add(input.prev_output.value);
345                 }
346                 for output in self.counterparty_outputs_contributed() {
347                         counterparty_outputs_value = counterparty_outputs_value.saturating_add(output.value);
348                 }
349                 if counterparty_inputs_value < counterparty_outputs_value {
350                         return Err(AbortReason::OutputsValueExceedsInputsValue);
351                 }
352
353                 // - there are more than 252 inputs
354                 // - there are more than 252 outputs
355                 if self.inputs.len() > MAX_INPUTS_OUTPUTS_COUNT
356                         || self.outputs.len() > MAX_INPUTS_OUTPUTS_COUNT
357                 {
358                         return Err(AbortReason::ExceededNumberOfInputsOrOutputs);
359                 }
360
361                 // TODO: How do we enforce their fees cover the witness without knowing its expected length?
362                 const INPUT_WEIGHT: u64 = BASE_INPUT_WEIGHT + EMPTY_SCRIPT_SIG_WEIGHT;
363
364                 // - the peer's paid feerate does not meet or exceed the agreed feerate (based on the minimum fee).
365                 let mut counterparty_weight_contributed: u64 = self
366                         .counterparty_outputs_contributed()
367                         .map(|output| {
368                                 (8 /* value */ + output.script_pubkey.consensus_encode(&mut sink()).unwrap() as u64)
369                                         * WITNESS_SCALE_FACTOR as u64
370                         })
371                         .sum();
372                 counterparty_weight_contributed +=
373                         self.counterparty_inputs_contributed().count() as u64 * INPUT_WEIGHT;
374                 let counterparty_fees_contributed =
375                         counterparty_inputs_value.saturating_sub(counterparty_outputs_value);
376                 let mut required_counterparty_contribution_fee =
377                         fee_for_weight(self.feerate_sat_per_kw, counterparty_weight_contributed);
378                 if !self.holder_is_initiator {
379                         // if is the non-initiator:
380                         //      - the initiator's fees do not cover the common fields (version, segwit marker + flag,
381                         //              input count, output count, locktime)
382                         let tx_common_fields_weight =
383                         (4 /* version */ + 4 /* locktime */ + 1 /* input count */ + 1 /* output count */) *
384                             WITNESS_SCALE_FACTOR as u64 + 2 /* segwit marker + flag */;
385                         let tx_common_fields_fee =
386                                 fee_for_weight(self.feerate_sat_per_kw, tx_common_fields_weight);
387                         required_counterparty_contribution_fee += tx_common_fields_fee;
388                 }
389                 if counterparty_fees_contributed < required_counterparty_contribution_fee {
390                         return Err(AbortReason::InsufficientFees);
391                 }
392
393                 // Inputs and outputs must be sorted by serial_id
394                 let mut inputs = self.inputs.into_iter().collect::<Vec<_>>();
395                 let mut outputs = self.outputs.into_iter().collect::<Vec<_>>();
396                 inputs.sort_unstable_by_key(|(serial_id, _)| *serial_id);
397                 outputs.sort_unstable_by_key(|(serial_id, _)| *serial_id);
398
399                 let tx_to_validate = Transaction {
400                         version: 2,
401                         lock_time: self.tx_locktime,
402                         input: inputs.into_iter().map(|(_, input)| input.input).collect(),
403                         output: outputs.into_iter().map(|(_, output)| output).collect(),
404                 };
405                 if tx_to_validate.weight().to_wu() > MAX_STANDARD_TX_WEIGHT as u64 {
406                         return Err(AbortReason::TransactionTooLarge);
407                 }
408
409                 Ok(tx_to_validate)
410         }
411 }
412
413 // The interactive transaction construction protocol allows two peers to collaboratively build a
414 // transaction for broadcast.
415 //
416 // The protocol is turn-based, so we define different states here that we store depending on whose
417 // turn it is to send the next message. The states are defined so that their types ensure we only
418 // perform actions (only send messages) via defined state transitions that do not violate the
419 // protocol.
420 //
421 // An example of a full negotiation and associated states follows:
422 //
423 //     +------------+                         +------------------+---- Holder state after message sent/received ----+
424 //     |            |--(1)- tx_add_input ---->|                  |                  SentChangeMsg                   +
425 //     |            |<-(2)- tx_complete ------|                  |                ReceivedTxComplete                +
426 //     |            |--(3)- tx_add_output --->|                  |                  SentChangeMsg                   +
427 //     |            |<-(4)- tx_complete ------|                  |                ReceivedTxComplete                +
428 //     |            |--(5)- tx_add_input ---->|                  |                  SentChangeMsg                   +
429 //     |   Holder   |<-(6)- tx_add_input -----|   Counterparty   |                ReceivedChangeMsg                 +
430 //     |            |--(7)- tx_remove_output >|                  |                  SentChangeMsg                   +
431 //     |            |<-(8)- tx_add_output ----|                  |                ReceivedChangeMsg                 +
432 //     |            |--(9)- tx_complete ----->|                  |                  SentTxComplete                  +
433 //     |            |<-(10) tx_complete ------|                  |                NegotiationComplete               +
434 //     +------------+                         +------------------+--------------------------------------------------+
435
436 /// Negotiation states that can send & receive `tx_(add|remove)_(input|output)` and `tx_complete`
437 trait State {}
438
439 /// Category of states where we have sent some message to the counterparty, and we are waiting for
440 /// a response.
441 trait SentMsgState: State {
442         fn into_negotiation_context(self) -> NegotiationContext;
443 }
444
445 /// Category of states that our counterparty has put us in after we receive a message from them.
446 trait ReceivedMsgState: State {
447         fn into_negotiation_context(self) -> NegotiationContext;
448 }
449
450 // This macro is a helper for implementing the above state traits for various states subsequently
451 // defined below the macro.
452 macro_rules! define_state {
453         (SENT_MSG_STATE, $state: ident, $doc: expr) => {
454                 define_state!($state, NegotiationContext, $doc);
455                 impl SentMsgState for $state {
456                         fn into_negotiation_context(self) -> NegotiationContext {
457                                 self.0
458                         }
459                 }
460         };
461         (RECEIVED_MSG_STATE, $state: ident, $doc: expr) => {
462                 define_state!($state, NegotiationContext, $doc);
463                 impl ReceivedMsgState for $state {
464                         fn into_negotiation_context(self) -> NegotiationContext {
465                                 self.0
466                         }
467                 }
468         };
469         ($state: ident, $inner: ident, $doc: expr) => {
470                 #[doc = $doc]
471                 #[derive(Debug)]
472                 struct $state($inner);
473                 impl State for $state {}
474         };
475 }
476
477 define_state!(
478         SENT_MSG_STATE,
479         SentChangeMsg,
480         "We have sent a message to the counterparty that has affected our negotiation state."
481 );
482 define_state!(
483         SENT_MSG_STATE,
484         SentTxComplete,
485         "We have sent a `tx_complete` message and are awaiting the counterparty's."
486 );
487 define_state!(
488         RECEIVED_MSG_STATE,
489         ReceivedChangeMsg,
490         "We have received a message from the counterparty that has affected our negotiation state."
491 );
492 define_state!(
493         RECEIVED_MSG_STATE,
494         ReceivedTxComplete,
495         "We have received a `tx_complete` message and the counterparty is awaiting ours."
496 );
497 define_state!(NegotiationComplete, Transaction, "We have exchanged consecutive `tx_complete` messages with the counterparty and the transaction negotiation is complete.");
498 define_state!(
499         NegotiationAborted,
500         AbortReason,
501         "The negotiation has failed and cannot be continued."
502 );
503
504 type StateTransitionResult<S> = Result<S, AbortReason>;
505
506 trait StateTransition<NewState: State, TransitionData> {
507         fn transition(self, data: TransitionData) -> StateTransitionResult<NewState>;
508 }
509
510 // This macro helps define the legal transitions between the states above by implementing
511 // the `StateTransition` trait for each of the states that follow this declaration.
512 macro_rules! define_state_transitions {
513         (SENT_MSG_STATE, [$(DATA $data: ty, TRANSITION $transition: ident),+]) => {
514                 $(
515                         impl<S: SentMsgState> StateTransition<ReceivedChangeMsg, $data> for S {
516                                 fn transition(self, data: $data) -> StateTransitionResult<ReceivedChangeMsg> {
517                                         let mut context = self.into_negotiation_context();
518                                         context.$transition(data)?;
519                                         Ok(ReceivedChangeMsg(context))
520                                 }
521                         }
522                  )*
523         };
524         (RECEIVED_MSG_STATE, [$(DATA $data: ty, TRANSITION $transition: ident),+]) => {
525                 $(
526                         impl<S: ReceivedMsgState> StateTransition<SentChangeMsg, $data> for S {
527                                 fn transition(self, data: $data) -> StateTransitionResult<SentChangeMsg> {
528                                         let mut context = self.into_negotiation_context();
529                                         context.$transition(data)?;
530                                         Ok(SentChangeMsg(context))
531                                 }
532                         }
533                  )*
534         };
535         (TX_COMPLETE, $from_state: ident, $tx_complete_state: ident) => {
536                 impl StateTransition<NegotiationComplete, &msgs::TxComplete> for $tx_complete_state {
537                         fn transition(self, _data: &msgs::TxComplete) -> StateTransitionResult<NegotiationComplete> {
538                                 let context = self.into_negotiation_context();
539                                 let tx = context.build_transaction()?;
540                                 Ok(NegotiationComplete(tx))
541                         }
542                 }
543
544                 impl StateTransition<$tx_complete_state, &msgs::TxComplete> for $from_state {
545                         fn transition(self, _data: &msgs::TxComplete) -> StateTransitionResult<$tx_complete_state> {
546                                 Ok($tx_complete_state(self.into_negotiation_context()))
547                         }
548                 }
549         };
550 }
551
552 // State transitions when we have sent our counterparty some messages and are waiting for them
553 // to respond.
554 define_state_transitions!(SENT_MSG_STATE, [
555         DATA &msgs::TxAddInput, TRANSITION received_tx_add_input,
556         DATA &msgs::TxRemoveInput, TRANSITION received_tx_remove_input,
557         DATA &msgs::TxAddOutput, TRANSITION received_tx_add_output,
558         DATA &msgs::TxRemoveOutput, TRANSITION received_tx_remove_output
559 ]);
560 // State transitions when we have received some messages from our counterparty and we should
561 // respond.
562 define_state_transitions!(RECEIVED_MSG_STATE, [
563         DATA &msgs::TxAddInput, TRANSITION sent_tx_add_input,
564         DATA &msgs::TxRemoveInput, TRANSITION sent_tx_remove_input,
565         DATA &msgs::TxAddOutput, TRANSITION sent_tx_add_output,
566         DATA &msgs::TxRemoveOutput, TRANSITION sent_tx_remove_output
567 ]);
568 define_state_transitions!(TX_COMPLETE, SentChangeMsg, ReceivedTxComplete);
569 define_state_transitions!(TX_COMPLETE, ReceivedChangeMsg, SentTxComplete);
570
571 #[derive(Debug)]
572 enum StateMachine {
573         Indeterminate,
574         SentChangeMsg(SentChangeMsg),
575         ReceivedChangeMsg(ReceivedChangeMsg),
576         SentTxComplete(SentTxComplete),
577         ReceivedTxComplete(ReceivedTxComplete),
578         NegotiationComplete(NegotiationComplete),
579         NegotiationAborted(NegotiationAborted),
580 }
581
582 impl Default for StateMachine {
583         fn default() -> Self {
584                 Self::Indeterminate
585         }
586 }
587
588 // The `StateMachine` internally executes the actual transition between two states and keeps
589 // track of the current state. This macro defines _how_ those state transitions happen to
590 // update the internal state.
591 macro_rules! define_state_machine_transitions {
592         ($transition: ident, $msg: ty, [$(FROM $from_state: ident, TO $to_state: ident),+]) => {
593                 fn $transition(self, msg: $msg) -> StateMachine {
594                         match self {
595                                 $(
596                                         Self::$from_state(s) => match s.transition(msg) {
597                                                 Ok(new_state) => StateMachine::$to_state(new_state),
598                                                 Err(abort_reason) => StateMachine::NegotiationAborted(NegotiationAborted(abort_reason)),
599                                         }
600                                  )*
601                                 _ => StateMachine::NegotiationAborted(NegotiationAborted(AbortReason::UnexpectedCounterpartyMessage)),
602                         }
603                 }
604         };
605 }
606
607 impl StateMachine {
608         fn new(feerate_sat_per_kw: u32, is_initiator: bool, tx_locktime: AbsoluteLockTime) -> Self {
609                 let context = NegotiationContext {
610                         tx_locktime,
611                         holder_is_initiator: is_initiator,
612                         received_tx_add_input_count: 0,
613                         received_tx_add_output_count: 0,
614                         inputs: new_hash_map(),
615                         prevtx_outpoints: new_hash_set(),
616                         outputs: new_hash_map(),
617                         feerate_sat_per_kw,
618                 };
619                 if is_initiator {
620                         Self::ReceivedChangeMsg(ReceivedChangeMsg(context))
621                 } else {
622                         Self::SentChangeMsg(SentChangeMsg(context))
623                 }
624         }
625
626         // TxAddInput
627         define_state_machine_transitions!(sent_tx_add_input, &msgs::TxAddInput, [
628                 FROM ReceivedChangeMsg, TO SentChangeMsg,
629                 FROM ReceivedTxComplete, TO SentChangeMsg
630         ]);
631         define_state_machine_transitions!(received_tx_add_input, &msgs::TxAddInput, [
632                 FROM SentChangeMsg, TO ReceivedChangeMsg,
633                 FROM SentTxComplete, TO ReceivedChangeMsg
634         ]);
635
636         // TxAddOutput
637         define_state_machine_transitions!(sent_tx_add_output, &msgs::TxAddOutput, [
638                 FROM ReceivedChangeMsg, TO SentChangeMsg,
639                 FROM ReceivedTxComplete, TO SentChangeMsg
640         ]);
641         define_state_machine_transitions!(received_tx_add_output, &msgs::TxAddOutput, [
642                 FROM SentChangeMsg, TO ReceivedChangeMsg,
643                 FROM SentTxComplete, TO ReceivedChangeMsg
644         ]);
645
646         // TxRemoveInput
647         define_state_machine_transitions!(sent_tx_remove_input, &msgs::TxRemoveInput, [
648                 FROM ReceivedChangeMsg, TO SentChangeMsg,
649                 FROM ReceivedTxComplete, TO SentChangeMsg
650         ]);
651         define_state_machine_transitions!(received_tx_remove_input, &msgs::TxRemoveInput, [
652                 FROM SentChangeMsg, TO ReceivedChangeMsg,
653                 FROM SentTxComplete, TO ReceivedChangeMsg
654         ]);
655
656         // TxRemoveOutput
657         define_state_machine_transitions!(sent_tx_remove_output, &msgs::TxRemoveOutput, [
658                 FROM ReceivedChangeMsg, TO SentChangeMsg,
659                 FROM ReceivedTxComplete, TO SentChangeMsg
660         ]);
661         define_state_machine_transitions!(received_tx_remove_output, &msgs::TxRemoveOutput, [
662                 FROM SentChangeMsg, TO ReceivedChangeMsg,
663                 FROM SentTxComplete, TO ReceivedChangeMsg
664         ]);
665
666         // TxComplete
667         define_state_machine_transitions!(sent_tx_complete, &msgs::TxComplete, [
668                 FROM ReceivedChangeMsg, TO SentTxComplete,
669                 FROM ReceivedTxComplete, TO NegotiationComplete
670         ]);
671         define_state_machine_transitions!(received_tx_complete, &msgs::TxComplete, [
672                 FROM SentChangeMsg, TO ReceivedTxComplete,
673                 FROM SentTxComplete, TO NegotiationComplete
674         ]);
675 }
676
677 pub struct InteractiveTxConstructor {
678         state_machine: StateMachine,
679         channel_id: ChannelId,
680         inputs_to_contribute: Vec<(SerialId, TxIn, TransactionU16LenLimited)>,
681         outputs_to_contribute: Vec<(SerialId, TxOut)>,
682 }
683
684 pub enum InteractiveTxMessageSend {
685         TxAddInput(msgs::TxAddInput),
686         TxAddOutput(msgs::TxAddOutput),
687         TxComplete(msgs::TxComplete),
688 }
689
690 // This macro executes a state machine transition based on a provided action.
691 macro_rules! do_state_transition {
692         ($self: ident, $transition: ident, $msg: expr) => {{
693                 let state_machine = core::mem::take(&mut $self.state_machine);
694                 $self.state_machine = state_machine.$transition($msg);
695                 match &$self.state_machine {
696                         StateMachine::NegotiationAborted(state) => Err(state.0.clone()),
697                         _ => Ok(()),
698                 }
699         }};
700 }
701
702 fn generate_holder_serial_id<ES: Deref>(entropy_source: &ES, is_initiator: bool) -> SerialId
703 where
704         ES::Target: EntropySource,
705 {
706         let rand_bytes = entropy_source.get_secure_random_bytes();
707         let mut serial_id_bytes = [0u8; 8];
708         serial_id_bytes.copy_from_slice(&rand_bytes[..8]);
709         let mut serial_id = u64::from_be_bytes(serial_id_bytes);
710         if serial_id.is_for_initiator() != is_initiator {
711                 serial_id ^= 1;
712         }
713         serial_id
714 }
715
716 pub enum HandleTxCompleteValue {
717         SendTxMessage(InteractiveTxMessageSend),
718         SendTxComplete(InteractiveTxMessageSend, Transaction),
719         NegotiationComplete(Transaction),
720 }
721
722 impl InteractiveTxConstructor {
723         /// Instantiates a new `InteractiveTxConstructor`.
724         ///
725         /// A tuple is returned containing the newly instantiate `InteractiveTxConstructor` and optionally
726         /// an initial wrapped `Tx_` message which the holder needs to send to the counterparty.
727         pub fn new<ES: Deref>(
728                 entropy_source: &ES, channel_id: ChannelId, feerate_sat_per_kw: u32, is_initiator: bool,
729                 funding_tx_locktime: AbsoluteLockTime,
730                 inputs_to_contribute: Vec<(TxIn, TransactionU16LenLimited)>,
731                 outputs_to_contribute: Vec<TxOut>,
732         ) -> (Self, Option<InteractiveTxMessageSend>)
733         where
734                 ES::Target: EntropySource,
735         {
736                 let state_machine =
737                         StateMachine::new(feerate_sat_per_kw, is_initiator, funding_tx_locktime);
738                 let mut inputs_to_contribute: Vec<(SerialId, TxIn, TransactionU16LenLimited)> =
739                         inputs_to_contribute
740                                 .into_iter()
741                                 .map(|(input, tx)| {
742                                         let serial_id = generate_holder_serial_id(entropy_source, is_initiator);
743                                         (serial_id, input, tx)
744                                 })
745                                 .collect();
746                 // We'll sort by the randomly generated serial IDs, effectively shuffling the order of the inputs
747                 // as the user passed them to us to avoid leaking any potential categorization of transactions
748                 // before we pass any of the inputs to the counterparty.
749                 inputs_to_contribute.sort_unstable_by_key(|(serial_id, _, _)| *serial_id);
750                 let mut outputs_to_contribute: Vec<(SerialId, TxOut)> = outputs_to_contribute
751                         .into_iter()
752                         .map(|output| {
753                                 let serial_id = generate_holder_serial_id(entropy_source, is_initiator);
754                                 (serial_id, output)
755                         })
756                         .collect();
757                 // In the same manner and for the same rationale as the inputs above, we'll shuffle the outputs.
758                 outputs_to_contribute.sort_unstable_by_key(|(serial_id, _)| *serial_id);
759                 let mut constructor =
760                         Self { state_machine, channel_id, inputs_to_contribute, outputs_to_contribute };
761                 let message_send = if is_initiator {
762                         match constructor.maybe_send_message() {
763                                 Ok(msg_send) => Some(msg_send),
764                                 Err(_) => {
765                                         debug_assert!(
766                                                 false,
767                                                 "We should always be able to start our state machine successfully"
768                                         );
769                                         None
770                                 },
771                         }
772                 } else {
773                         None
774                 };
775                 (constructor, message_send)
776         }
777
778         fn maybe_send_message(&mut self) -> Result<InteractiveTxMessageSend, AbortReason> {
779                 // We first attempt to send inputs we want to add, then outputs. Once we are done sending
780                 // them both, then we always send tx_complete.
781                 if let Some((serial_id, input, prevtx)) = self.inputs_to_contribute.pop() {
782                         let msg = msgs::TxAddInput {
783                                 channel_id: self.channel_id,
784                                 serial_id,
785                                 prevtx,
786                                 prevtx_out: input.previous_output.vout,
787                                 sequence: input.sequence.to_consensus_u32(),
788                         };
789                         do_state_transition!(self, sent_tx_add_input, &msg)?;
790                         Ok(InteractiveTxMessageSend::TxAddInput(msg))
791                 } else if let Some((serial_id, output)) = self.outputs_to_contribute.pop() {
792                         let msg = msgs::TxAddOutput {
793                                 channel_id: self.channel_id,
794                                 serial_id,
795                                 sats: output.value,
796                                 script: output.script_pubkey,
797                         };
798                         do_state_transition!(self, sent_tx_add_output, &msg)?;
799                         Ok(InteractiveTxMessageSend::TxAddOutput(msg))
800                 } else {
801                         let msg = msgs::TxComplete { channel_id: self.channel_id };
802                         do_state_transition!(self, sent_tx_complete, &msg)?;
803                         Ok(InteractiveTxMessageSend::TxComplete(msg))
804                 }
805         }
806
807         pub fn handle_tx_add_input(
808                 &mut self, msg: &msgs::TxAddInput,
809         ) -> Result<InteractiveTxMessageSend, AbortReason> {
810                 do_state_transition!(self, received_tx_add_input, msg)?;
811                 self.maybe_send_message()
812         }
813
814         pub fn handle_tx_remove_input(
815                 &mut self, msg: &msgs::TxRemoveInput,
816         ) -> Result<InteractiveTxMessageSend, AbortReason> {
817                 do_state_transition!(self, received_tx_remove_input, msg)?;
818                 self.maybe_send_message()
819         }
820
821         pub fn handle_tx_add_output(
822                 &mut self, msg: &msgs::TxAddOutput,
823         ) -> Result<InteractiveTxMessageSend, AbortReason> {
824                 do_state_transition!(self, received_tx_add_output, msg)?;
825                 self.maybe_send_message()
826         }
827
828         pub fn handle_tx_remove_output(
829                 &mut self, msg: &msgs::TxRemoveOutput,
830         ) -> Result<InteractiveTxMessageSend, AbortReason> {
831                 do_state_transition!(self, received_tx_remove_output, msg)?;
832                 self.maybe_send_message()
833         }
834
835         pub fn handle_tx_complete(
836                 &mut self, msg: &msgs::TxComplete,
837         ) -> Result<HandleTxCompleteValue, AbortReason> {
838                 do_state_transition!(self, received_tx_complete, msg)?;
839                 match &self.state_machine {
840                         StateMachine::ReceivedTxComplete(_) => {
841                                 let msg_send = self.maybe_send_message()?;
842                                 match &self.state_machine {
843                                         StateMachine::NegotiationComplete(s) => {
844                                                 Ok(HandleTxCompleteValue::SendTxComplete(msg_send, s.0.clone()))
845                                         },
846                                         StateMachine::SentChangeMsg(_) => {
847                                                 Ok(HandleTxCompleteValue::SendTxMessage(msg_send))
848                                         }, // We either had an input or output to contribute.
849                                         _ => {
850                                                 debug_assert!(false, "We cannot transition to any other states after receiving `tx_complete` and responding");
851                                                 Err(AbortReason::InvalidStateTransition)
852                                         },
853                                 }
854                         },
855                         StateMachine::NegotiationComplete(s) => {
856                                 Ok(HandleTxCompleteValue::NegotiationComplete(s.0.clone()))
857                         },
858                         _ => {
859                                 debug_assert!(
860                                         false,
861                                         "We cannot transition to any other states after receiving `tx_complete`"
862                                 );
863                                 Err(AbortReason::InvalidStateTransition)
864                         },
865                 }
866         }
867 }
868
869 #[cfg(test)]
870 mod tests {
871         use crate::chain::chaininterface::FEERATE_FLOOR_SATS_PER_KW;
872         use crate::ln::channel::TOTAL_BITCOIN_SUPPLY_SATOSHIS;
873         use crate::ln::interactivetxs::{
874                 generate_holder_serial_id, AbortReason, HandleTxCompleteValue, InteractiveTxConstructor,
875                 InteractiveTxMessageSend, MAX_INPUTS_OUTPUTS_COUNT, MAX_RECEIVED_TX_ADD_INPUT_COUNT,
876                 MAX_RECEIVED_TX_ADD_OUTPUT_COUNT,
877         };
878         use crate::ln::ChannelId;
879         use crate::sign::EntropySource;
880         use crate::util::atomic_counter::AtomicCounter;
881         use crate::util::ser::TransactionU16LenLimited;
882         use bitcoin::blockdata::opcodes;
883         use bitcoin::blockdata::script::Builder;
884         use bitcoin::{
885                 absolute::LockTime as AbsoluteLockTime, OutPoint, Sequence, Transaction, TxIn, TxOut,
886         };
887         use core::ops::Deref;
888
889         // A simple entropy source that works based on an atomic counter.
890         struct TestEntropySource(AtomicCounter);
891         impl EntropySource for TestEntropySource {
892                 fn get_secure_random_bytes(&self) -> [u8; 32] {
893                         let mut res = [0u8; 32];
894                         let increment = self.0.get_increment();
895                         for i in 0..32 {
896                                 // Rotate the increment value by 'i' bits to the right, to avoid clashes
897                                 // when `generate_local_serial_id` does a parity flip on consecutive calls for the
898                                 // same party.
899                                 let rotated_increment = increment.rotate_right(i as u32);
900                                 res[i] = (rotated_increment & 0xff) as u8;
901                         }
902                         res
903                 }
904         }
905
906         // An entropy source that deliberately returns you the same seed every time. We use this
907         // to test if the constructor would catch inputs/outputs that are attempting to be added
908         // with duplicate serial ids.
909         struct DuplicateEntropySource;
910         impl EntropySource for DuplicateEntropySource {
911                 fn get_secure_random_bytes(&self) -> [u8; 32] {
912                         let mut res = [0u8; 32];
913                         let count = 1u64;
914                         res[0..8].copy_from_slice(&count.to_be_bytes());
915                         res
916                 }
917         }
918
919         #[derive(Debug, PartialEq, Eq)]
920         enum ErrorCulprit {
921                 NodeA,
922                 NodeB,
923                 // Some error values are only checked at the end of the negotiation and are not easy to attribute
924                 // to a particular party. Both parties would indicate an `AbortReason` in this case.
925                 // e.g. Exceeded max inputs and outputs after negotiation.
926                 Indeterminate,
927         }
928
929         struct TestSession {
930                 description: &'static str,
931                 inputs_a: Vec<(TxIn, TransactionU16LenLimited)>,
932                 outputs_a: Vec<TxOut>,
933                 inputs_b: Vec<(TxIn, TransactionU16LenLimited)>,
934                 outputs_b: Vec<TxOut>,
935                 expect_error: Option<(AbortReason, ErrorCulprit)>,
936         }
937
938         fn do_test_interactive_tx_constructor(session: TestSession) {
939                 let entropy_source = TestEntropySource(AtomicCounter::new());
940                 do_test_interactive_tx_constructor_internal(session, &&entropy_source);
941         }
942
943         fn do_test_interactive_tx_constructor_with_entropy_source<ES: Deref>(
944                 session: TestSession, entropy_source: ES,
945         ) where
946                 ES::Target: EntropySource,
947         {
948                 do_test_interactive_tx_constructor_internal(session, &entropy_source);
949         }
950
951         fn do_test_interactive_tx_constructor_internal<ES: Deref>(
952                 session: TestSession, entropy_source: &ES,
953         ) where
954                 ES::Target: EntropySource,
955         {
956                 let channel_id = ChannelId(entropy_source.get_secure_random_bytes());
957                 let tx_locktime = AbsoluteLockTime::from_height(1337).unwrap();
958
959                 let (mut constructor_a, first_message_a) = InteractiveTxConstructor::new(
960                         entropy_source,
961                         channel_id,
962                         FEERATE_FLOOR_SATS_PER_KW * 10,
963                         true,
964                         tx_locktime,
965                         session.inputs_a,
966                         session.outputs_a,
967                 );
968                 let (mut constructor_b, first_message_b) = InteractiveTxConstructor::new(
969                         entropy_source,
970                         channel_id,
971                         FEERATE_FLOOR_SATS_PER_KW * 10,
972                         false,
973                         tx_locktime,
974                         session.inputs_b,
975                         session.outputs_b,
976                 );
977
978                 let handle_message_send =
979                         |msg: InteractiveTxMessageSend, for_constructor: &mut InteractiveTxConstructor| {
980                                 match msg {
981                                         InteractiveTxMessageSend::TxAddInput(msg) => for_constructor
982                                                 .handle_tx_add_input(&msg)
983                                                 .map(|msg_send| (Some(msg_send), None)),
984                                         InteractiveTxMessageSend::TxAddOutput(msg) => for_constructor
985                                                 .handle_tx_add_output(&msg)
986                                                 .map(|msg_send| (Some(msg_send), None)),
987                                         InteractiveTxMessageSend::TxComplete(msg) => {
988                                                 for_constructor.handle_tx_complete(&msg).map(|value| match value {
989                                                         HandleTxCompleteValue::SendTxMessage(msg_send) => {
990                                                                 (Some(msg_send), None)
991                                                         },
992                                                         HandleTxCompleteValue::SendTxComplete(msg_send, tx) => {
993                                                                 (Some(msg_send), Some(tx))
994                                                         },
995                                                         HandleTxCompleteValue::NegotiationComplete(tx) => (None, Some(tx)),
996                                                 })
997                                         },
998                                 }
999                         };
1000
1001                 assert!(first_message_b.is_none());
1002                 let mut message_send_a = first_message_a;
1003                 let mut message_send_b = None;
1004                 let mut final_tx_a = None;
1005                 let mut final_tx_b = None;
1006                 while final_tx_a.is_none() || final_tx_b.is_none() {
1007                         if let Some(message_send_a) = message_send_a.take() {
1008                                 match handle_message_send(message_send_a, &mut constructor_b) {
1009                                         Ok((msg_send, final_tx)) => {
1010                                                 message_send_b = msg_send;
1011                                                 final_tx_b = final_tx;
1012                                         },
1013                                         Err(abort_reason) => {
1014                                                 let error_culprit = match abort_reason {
1015                                                         AbortReason::ExceededNumberOfInputsOrOutputs => {
1016                                                                 ErrorCulprit::Indeterminate
1017                                                         },
1018                                                         _ => ErrorCulprit::NodeA,
1019                                                 };
1020                                                 assert_eq!(
1021                                                         Some((abort_reason, error_culprit)),
1022                                                         session.expect_error,
1023                                                         "Test: {}",
1024                                                         session.description
1025                                                 );
1026                                                 assert!(message_send_b.is_none());
1027                                                 return;
1028                                         },
1029                                 }
1030                         }
1031                         if let Some(message_send_b) = message_send_b.take() {
1032                                 match handle_message_send(message_send_b, &mut constructor_a) {
1033                                         Ok((msg_send, final_tx)) => {
1034                                                 message_send_a = msg_send;
1035                                                 final_tx_a = final_tx;
1036                                         },
1037                                         Err(abort_reason) => {
1038                                                 let error_culprit = match abort_reason {
1039                                                         AbortReason::ExceededNumberOfInputsOrOutputs => {
1040                                                                 ErrorCulprit::Indeterminate
1041                                                         },
1042                                                         _ => ErrorCulprit::NodeB,
1043                                                 };
1044                                                 assert_eq!(
1045                                                         Some((abort_reason, error_culprit)),
1046                                                         session.expect_error,
1047                                                         "Test: {}",
1048                                                         session.description
1049                                                 );
1050                                                 assert!(message_send_a.is_none());
1051                                                 return;
1052                                         },
1053                                 }
1054                         }
1055                 }
1056                 assert!(message_send_a.is_none());
1057                 assert!(message_send_b.is_none());
1058                 assert_eq!(final_tx_a, final_tx_b);
1059                 assert!(session.expect_error.is_none());
1060         }
1061
1062         fn generate_tx(values: &[u64]) -> Transaction {
1063                 generate_tx_with_locktime(values, 1337)
1064         }
1065
1066         fn generate_tx_with_locktime(values: &[u64], locktime: u32) -> Transaction {
1067                 Transaction {
1068                         version: 2,
1069                         lock_time: AbsoluteLockTime::from_height(locktime).unwrap(),
1070                         input: vec![TxIn { ..Default::default() }],
1071                         output: values
1072                                 .iter()
1073                                 .map(|value| TxOut {
1074                                         value: *value,
1075                                         script_pubkey: Builder::new()
1076                                                 .push_opcode(opcodes::OP_TRUE)
1077                                                 .into_script()
1078                                                 .to_v0_p2wsh(),
1079                                 })
1080                                 .collect(),
1081                 }
1082         }
1083
1084         fn generate_inputs(values: &[u64]) -> Vec<(TxIn, TransactionU16LenLimited)> {
1085                 let tx = generate_tx(values);
1086                 let txid = tx.txid();
1087                 tx.output
1088                         .iter()
1089                         .enumerate()
1090                         .map(|(idx, _)| {
1091                                 let input = TxIn {
1092                                         previous_output: OutPoint { txid, vout: idx as u32 },
1093                                         script_sig: Default::default(),
1094                                         sequence: Sequence::ENABLE_RBF_NO_LOCKTIME,
1095                                         witness: Default::default(),
1096                                 };
1097                                 (input, TransactionU16LenLimited::new(tx.clone()).unwrap())
1098                         })
1099                         .collect()
1100         }
1101
1102         fn generate_outputs(values: &[u64]) -> Vec<TxOut> {
1103                 values
1104                         .iter()
1105                         .map(|value| TxOut {
1106                                 value: *value,
1107                                 script_pubkey: Builder::new()
1108                                         .push_opcode(opcodes::OP_TRUE)
1109                                         .into_script()
1110                                         .to_v0_p2wsh(),
1111                         })
1112                         .collect()
1113         }
1114
1115         fn generate_fixed_number_of_inputs(count: u16) -> Vec<(TxIn, TransactionU16LenLimited)> {
1116                 // Generate transactions with a total `count` number of outputs such that no transaction has a
1117                 // serialized length greater than u16::MAX.
1118                 let max_outputs_per_prevtx = 1_500;
1119                 let mut remaining = count;
1120                 let mut inputs: Vec<(TxIn, TransactionU16LenLimited)> = Vec::with_capacity(count as usize);
1121
1122                 while remaining > 0 {
1123                         let tx_output_count = remaining.min(max_outputs_per_prevtx);
1124                         remaining -= tx_output_count;
1125
1126                         // Use unique locktime for each tx so outpoints are different across transactions
1127                         let tx = generate_tx_with_locktime(
1128                                 &vec![1_000_000; tx_output_count as usize],
1129                                 (1337 + remaining).into(),
1130                         );
1131                         let txid = tx.txid();
1132
1133                         let mut temp: Vec<(TxIn, TransactionU16LenLimited)> = tx
1134                                 .output
1135                                 .iter()
1136                                 .enumerate()
1137                                 .map(|(idx, _)| {
1138                                         let input = TxIn {
1139                                                 previous_output: OutPoint { txid, vout: idx as u32 },
1140                                                 script_sig: Default::default(),
1141                                                 sequence: Sequence::ENABLE_RBF_NO_LOCKTIME,
1142                                                 witness: Default::default(),
1143                                         };
1144                                         (input, TransactionU16LenLimited::new(tx.clone()).unwrap())
1145                                 })
1146                                 .collect();
1147
1148                         inputs.append(&mut temp);
1149                 }
1150
1151                 inputs
1152         }
1153
1154         fn generate_fixed_number_of_outputs(count: u16) -> Vec<TxOut> {
1155                 // Set a constant value for each TxOut
1156                 generate_outputs(&vec![1_000_000; count as usize])
1157         }
1158
1159         fn generate_non_witness_output(value: u64) -> TxOut {
1160                 TxOut {
1161                         value,
1162                         script_pubkey: Builder::new().push_opcode(opcodes::OP_TRUE).into_script().to_p2sh(),
1163                 }
1164         }
1165
1166         #[test]
1167         fn test_interactive_tx_constructor() {
1168                 do_test_interactive_tx_constructor(TestSession {
1169                         description: "No contributions",
1170                         inputs_a: vec![],
1171                         outputs_a: vec![],
1172                         inputs_b: vec![],
1173                         outputs_b: vec![],
1174                         expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeA)),
1175                 });
1176                 do_test_interactive_tx_constructor(TestSession {
1177                         description: "Single contribution, no initiator inputs",
1178                         inputs_a: vec![],
1179                         outputs_a: generate_outputs(&[1_000_000]),
1180                         inputs_b: vec![],
1181                         outputs_b: vec![],
1182                         expect_error: Some((AbortReason::OutputsValueExceedsInputsValue, ErrorCulprit::NodeA)),
1183                 });
1184                 do_test_interactive_tx_constructor(TestSession {
1185                         description: "Single contribution, no initiator outputs",
1186                         inputs_a: generate_inputs(&[1_000_000]),
1187                         outputs_a: vec![],
1188                         inputs_b: vec![],
1189                         outputs_b: vec![],
1190                         expect_error: None,
1191                 });
1192                 do_test_interactive_tx_constructor(TestSession {
1193                         description: "Single contribution, insufficient fees",
1194                         inputs_a: generate_inputs(&[1_000_000]),
1195                         outputs_a: generate_outputs(&[1_000_000]),
1196                         inputs_b: vec![],
1197                         outputs_b: vec![],
1198                         expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeA)),
1199                 });
1200                 do_test_interactive_tx_constructor(TestSession {
1201                         description: "Initiator contributes sufficient fees, but non-initiator does not",
1202                         inputs_a: generate_inputs(&[1_000_000]),
1203                         outputs_a: vec![],
1204                         inputs_b: generate_inputs(&[100_000]),
1205                         outputs_b: generate_outputs(&[100_000]),
1206                         expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeB)),
1207                 });
1208                 do_test_interactive_tx_constructor(TestSession {
1209                         description: "Multi-input-output contributions from both sides",
1210                         inputs_a: generate_inputs(&[1_000_000, 1_000_000]),
1211                         outputs_a: generate_outputs(&[1_000_000, 200_000]),
1212                         inputs_b: generate_inputs(&[1_000_000, 500_000]),
1213                         outputs_b: generate_outputs(&[1_000_000, 400_000]),
1214                         expect_error: None,
1215                 });
1216
1217                 let non_segwit_output_tx = {
1218                         let mut tx = generate_tx(&[1_000_000]);
1219                         tx.output.push(TxOut {
1220                                 script_pubkey: Builder::new()
1221                                         .push_opcode(opcodes::all::OP_RETURN)
1222                                         .into_script()
1223                                         .to_p2sh(),
1224                                 ..Default::default()
1225                         });
1226
1227                         TransactionU16LenLimited::new(tx).unwrap()
1228                 };
1229                 let non_segwit_input = TxIn {
1230                         previous_output: OutPoint {
1231                                 txid: non_segwit_output_tx.as_transaction().txid(),
1232                                 vout: 1,
1233                         },
1234                         sequence: Sequence::ENABLE_RBF_NO_LOCKTIME,
1235                         ..Default::default()
1236                 };
1237                 do_test_interactive_tx_constructor(TestSession {
1238                         description: "Prevout from initiator is not a witness program",
1239                         inputs_a: vec![(non_segwit_input, non_segwit_output_tx)],
1240                         outputs_a: vec![],
1241                         inputs_b: vec![],
1242                         outputs_b: vec![],
1243                         expect_error: Some((AbortReason::PrevTxOutInvalid, ErrorCulprit::NodeA)),
1244                 });
1245
1246                 let tx = TransactionU16LenLimited::new(generate_tx(&[1_000_000])).unwrap();
1247                 let invalid_sequence_input = TxIn {
1248                         previous_output: OutPoint { txid: tx.as_transaction().txid(), vout: 0 },
1249                         ..Default::default()
1250                 };
1251                 do_test_interactive_tx_constructor(TestSession {
1252                         description: "Invalid input sequence from initiator",
1253                         inputs_a: vec![(invalid_sequence_input, tx.clone())],
1254                         outputs_a: generate_outputs(&[1_000_000]),
1255                         inputs_b: vec![],
1256                         outputs_b: vec![],
1257                         expect_error: Some((AbortReason::IncorrectInputSequenceValue, ErrorCulprit::NodeA)),
1258                 });
1259                 let duplicate_input = TxIn {
1260                         previous_output: OutPoint { txid: tx.as_transaction().txid(), vout: 0 },
1261                         sequence: Sequence::ENABLE_RBF_NO_LOCKTIME,
1262                         ..Default::default()
1263                 };
1264                 do_test_interactive_tx_constructor(TestSession {
1265                         description: "Duplicate prevout from initiator",
1266                         inputs_a: vec![(duplicate_input.clone(), tx.clone()), (duplicate_input, tx.clone())],
1267                         outputs_a: generate_outputs(&[1_000_000]),
1268                         inputs_b: vec![],
1269                         outputs_b: vec![],
1270                         expect_error: Some((AbortReason::PrevTxOutInvalid, ErrorCulprit::NodeB)),
1271                 });
1272                 let duplicate_input = TxIn {
1273                         previous_output: OutPoint { txid: tx.as_transaction().txid(), vout: 0 },
1274                         sequence: Sequence::ENABLE_RBF_NO_LOCKTIME,
1275                         ..Default::default()
1276                 };
1277                 do_test_interactive_tx_constructor(TestSession {
1278                         description: "Non-initiator uses same prevout as initiator",
1279                         inputs_a: vec![(duplicate_input.clone(), tx.clone())],
1280                         outputs_a: generate_outputs(&[1_000_000]),
1281                         inputs_b: vec![(duplicate_input.clone(), tx.clone())],
1282                         outputs_b: vec![],
1283                         expect_error: Some((AbortReason::PrevTxOutInvalid, ErrorCulprit::NodeA)),
1284                 });
1285                 do_test_interactive_tx_constructor(TestSession {
1286                         description: "Initiator sends too many TxAddInputs",
1287                         inputs_a: generate_fixed_number_of_inputs(MAX_RECEIVED_TX_ADD_INPUT_COUNT + 1),
1288                         outputs_a: vec![],
1289                         inputs_b: vec![],
1290                         outputs_b: vec![],
1291                         expect_error: Some((AbortReason::ReceivedTooManyTxAddInputs, ErrorCulprit::NodeA)),
1292                 });
1293                 do_test_interactive_tx_constructor_with_entropy_source(
1294                         TestSession {
1295                                 // We use a deliberately bad entropy source, `DuplicateEntropySource` to simulate this.
1296                                 description: "Attempt to queue up two inputs with duplicate serial ids",
1297                                 inputs_a: generate_fixed_number_of_inputs(2),
1298                                 outputs_a: vec![],
1299                                 inputs_b: vec![],
1300                                 outputs_b: vec![],
1301                                 expect_error: Some((AbortReason::DuplicateSerialId, ErrorCulprit::NodeA)),
1302                         },
1303                         &DuplicateEntropySource,
1304                 );
1305                 do_test_interactive_tx_constructor(TestSession {
1306                         description: "Initiator sends too many TxAddOutputs",
1307                         inputs_a: vec![],
1308                         outputs_a: generate_fixed_number_of_outputs(MAX_RECEIVED_TX_ADD_OUTPUT_COUNT + 1),
1309                         inputs_b: vec![],
1310                         outputs_b: vec![],
1311                         expect_error: Some((AbortReason::ReceivedTooManyTxAddOutputs, ErrorCulprit::NodeA)),
1312                 });
1313                 do_test_interactive_tx_constructor(TestSession {
1314                         description: "Initiator sends an output below dust value",
1315                         inputs_a: vec![],
1316                         outputs_a: generate_outputs(&[1]),
1317                         inputs_b: vec![],
1318                         outputs_b: vec![],
1319                         expect_error: Some((AbortReason::BelowDustLimit, ErrorCulprit::NodeA)),
1320                 });
1321                 do_test_interactive_tx_constructor(TestSession {
1322                         description: "Initiator sends an output above maximum sats allowed",
1323                         inputs_a: vec![],
1324                         outputs_a: generate_outputs(&[TOTAL_BITCOIN_SUPPLY_SATOSHIS + 1]),
1325                         inputs_b: vec![],
1326                         outputs_b: vec![],
1327                         expect_error: Some((AbortReason::ExceededMaximumSatsAllowed, ErrorCulprit::NodeA)),
1328                 });
1329                 do_test_interactive_tx_constructor(TestSession {
1330                         description: "Initiator sends an output without a witness program",
1331                         inputs_a: vec![],
1332                         outputs_a: vec![generate_non_witness_output(1_000_000)],
1333                         inputs_b: vec![],
1334                         outputs_b: vec![],
1335                         expect_error: Some((AbortReason::InvalidOutputScript, ErrorCulprit::NodeA)),
1336                 });
1337                 do_test_interactive_tx_constructor_with_entropy_source(
1338                         TestSession {
1339                                 // We use a deliberately bad entropy source, `DuplicateEntropySource` to simulate this.
1340                                 description: "Attempt to queue up two outputs with duplicate serial ids",
1341                                 inputs_a: vec![],
1342                                 outputs_a: generate_fixed_number_of_outputs(2),
1343                                 inputs_b: vec![],
1344                                 outputs_b: vec![],
1345                                 expect_error: Some((AbortReason::DuplicateSerialId, ErrorCulprit::NodeA)),
1346                         },
1347                         &DuplicateEntropySource,
1348                 );
1349
1350                 do_test_interactive_tx_constructor(TestSession {
1351                         description: "Peer contributed more output value than inputs",
1352                         inputs_a: generate_inputs(&[100_000]),
1353                         outputs_a: generate_outputs(&[1_000_000]),
1354                         inputs_b: vec![],
1355                         outputs_b: vec![],
1356                         expect_error: Some((AbortReason::OutputsValueExceedsInputsValue, ErrorCulprit::NodeA)),
1357                 });
1358
1359                 do_test_interactive_tx_constructor(TestSession {
1360                         description: "Peer contributed more than allowed number of inputs",
1361                         inputs_a: generate_fixed_number_of_inputs(MAX_INPUTS_OUTPUTS_COUNT as u16 + 1),
1362                         outputs_a: vec![],
1363                         inputs_b: vec![],
1364                         outputs_b: vec![],
1365                         expect_error: Some((
1366                                 AbortReason::ExceededNumberOfInputsOrOutputs,
1367                                 ErrorCulprit::Indeterminate,
1368                         )),
1369                 });
1370                 do_test_interactive_tx_constructor(TestSession {
1371                         description: "Peer contributed more than allowed number of outputs",
1372                         inputs_a: generate_inputs(&[TOTAL_BITCOIN_SUPPLY_SATOSHIS]),
1373                         outputs_a: generate_fixed_number_of_outputs(MAX_INPUTS_OUTPUTS_COUNT as u16 + 1),
1374                         inputs_b: vec![],
1375                         outputs_b: vec![],
1376                         expect_error: Some((
1377                                 AbortReason::ExceededNumberOfInputsOrOutputs,
1378                                 ErrorCulprit::Indeterminate,
1379                         )),
1380                 });
1381         }
1382
1383         #[test]
1384         fn test_generate_local_serial_id() {
1385                 let entropy_source = TestEntropySource(AtomicCounter::new());
1386
1387                 // Initiators should have even serial id, non-initiators should have odd serial id.
1388                 assert_eq!(generate_holder_serial_id(&&entropy_source, true) % 2, 0);
1389                 assert_eq!(generate_holder_serial_id(&&entropy_source, false) % 2, 1)
1390         }
1391 }