Replace expect_value_msat with expect_send
authorJeffrey Czyz <jkczyz@gmail.com>
Wed, 10 Nov 2021 00:07:23 +0000 (18:07 -0600)
committerJeffrey Czyz <jkczyz@gmail.com>
Tue, 16 Nov 2021 19:57:58 +0000 (13:57 -0600)
Modify all InvoicePayer unit tests to use expect_send instead of
expect_value_msat, since the former can discern whether the send was for
an invoice, spontaneous payment, or a retry. Updates tests to set payer
expectations if they weren't already and assert these before returning a
failure.

lightning-invoice/src/payment.rs

index b3b84e351b72f7a19d497810d2fb1a5e85b36a85..55d877d415e569b11166a11b1701fa84556a5662 100644 (file)
@@ -566,8 +566,9 @@ mod tests {
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
                let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
+               let final_value_msat = invoice.amount_milli_satoshis().unwrap();
 
-               let payer = TestPayer::new();
+               let payer = TestPayer::new().expect_send(Amount::ForInvoice(final_value_msat));
                let router = TestRouter {};
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
@@ -595,8 +596,8 @@ mod tests {
                let final_value_msat = invoice.amount_milli_satoshis().unwrap();
 
                let payer = TestPayer::new()
-                       .expect_value_msat(final_value_msat)
-                       .expect_value_msat(final_value_msat / 2);
+                       .expect_send(Amount::ForInvoice(final_value_msat))
+                       .expect_send(Amount::OnRetry(final_value_msat / 2));
                let router = TestRouter {};
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
@@ -637,7 +638,9 @@ mod tests {
                let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
                let final_value_msat = invoice.amount_milli_satoshis().unwrap();
 
-               let payer = TestPayer::new();
+               let payer = TestPayer::new()
+                       .expect_send(Amount::OnRetry(final_value_msat / 2))
+                       .expect_send(Amount::OnRetry(final_value_msat / 2));
                let router = TestRouter {};
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
@@ -680,9 +683,9 @@ mod tests {
                let final_value_msat = invoice.amount_milli_satoshis().unwrap();
 
                let payer = TestPayer::new()
-                       .expect_value_msat(final_value_msat)
-                       .expect_value_msat(final_value_msat / 2)
-                       .expect_value_msat(final_value_msat / 2);
+                       .expect_send(Amount::ForInvoice(final_value_msat))
+                       .expect_send(Amount::OnRetry(final_value_msat / 2))
+                       .expect_send(Amount::OnRetry(final_value_msat / 2));
                let router = TestRouter {};
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
@@ -732,15 +735,17 @@ mod tests {
                let event_handled = core::cell::RefCell::new(false);
                let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
 
-               let payer = TestPayer::new();
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let final_value_msat = invoice.amount_milli_satoshis().unwrap();
+
+               let payer = TestPayer::new().expect_send(Amount::ForInvoice(final_value_msat));
                let router = TestRouter {};
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
-               let payment_preimage = PaymentPreimage([1; 32]);
-               let invoice = invoice(payment_preimage);
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
 
@@ -783,15 +788,17 @@ mod tests {
                let event_handled = core::cell::RefCell::new(false);
                let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
 
-               let payer = TestPayer::new();
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let final_value_msat = invoice.amount_milli_satoshis().unwrap();
+
+               let payer = TestPayer::new().expect_send(Amount::ForInvoice(final_value_msat));
                let router = TestRouter {};
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
-               let payment_preimage = PaymentPreimage([1; 32]);
-               let invoice = invoice(payment_preimage);
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
 
@@ -825,7 +832,8 @@ mod tests {
 
                let payer = TestPayer::new()
                        .fails_on_attempt(2)
-                       .expect_value_msat(final_value_msat);
+                       .expect_send(Amount::ForInvoice(final_value_msat))
+                       .expect_send(Amount::OnRetry(final_value_msat / 2));
                let router = TestRouter {};
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
@@ -855,15 +863,17 @@ mod tests {
                let event_handled = core::cell::RefCell::new(false);
                let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
 
-               let payer = TestPayer::new();
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let final_value_msat = invoice.amount_milli_satoshis().unwrap();
+
+               let payer = TestPayer::new().expect_send(Amount::ForInvoice(final_value_msat));
                let router = TestRouter {};
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
-               let payment_preimage = PaymentPreimage([1; 32]);
-               let invoice = invoice(payment_preimage);
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
 
@@ -887,15 +897,19 @@ mod tests {
                let event_handled = core::cell::RefCell::new(false);
                let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
 
-               let payer = TestPayer::new();
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let final_value_msat = invoice.amount_milli_satoshis().unwrap();
+
+               let payer = TestPayer::new()
+                       .expect_send(Amount::ForInvoice(final_value_msat))
+                       .expect_send(Amount::ForInvoice(final_value_msat));
                let router = TestRouter {};
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
 
-               let payment_preimage = PaymentPreimage([1; 32]);
-               let invoice = invoice(payment_preimage);
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
 
                // Cannot repay an invoice pending payment.
@@ -946,15 +960,19 @@ mod tests {
 
        #[test]
        fn fails_paying_invoice_with_sending_errors() {
-               let payer = TestPayer::new().fails_on_attempt(1);
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let final_value_msat = invoice.amount_milli_satoshis().unwrap();
+
+               let payer = TestPayer::new()
+                       .fails_on_attempt(1)
+                       .expect_send(Amount::ForInvoice(final_value_msat));
                let router = TestRouter {};
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0));
 
-               let payment_preimage = PaymentPreimage([1; 32]);
-               let invoice = invoice(payment_preimage);
                match invoice_payer.pay_invoice(&invoice) {
                        Err(PaymentError::Sending(_)) => {},
                        Err(_) => panic!("unexpected error"),
@@ -972,7 +990,7 @@ mod tests {
                let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
                let final_value_msat = 100;
 
-               let payer = TestPayer::new().expect_value_msat(final_value_msat);
+               let payer = TestPayer::new().expect_send(Amount::ForInvoice(final_value_msat));
                let router = TestRouter {};
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
@@ -1026,7 +1044,7 @@ mod tests {
 
                let payer = TestPayer::new()
                        .expect_send(Amount::Spontaneous(final_value_msat))
-                       .expect_send(Amount::ForInvoiceOrRetry(final_value_msat));
+                       .expect_send(Amount::OnRetry(final_value_msat));
                let router = TestRouter {};
                let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
@@ -1077,7 +1095,9 @@ mod tests {
                let short_channel_id = Some(path[0].short_channel_id);
 
                // Expect that scorer is given short_channel_id upon handling the event.
-               let payer = TestPayer::new();
+               let payer = TestPayer::new()
+                       .expect_send(Amount::ForInvoice(final_value_msat))
+                       .expect_send(Amount::OnRetry(final_value_msat / 2));
                let router = TestRouter {};
                let scorer = RefCell::new(TestScorer::new().expect_channel_failure(short_channel_id.unwrap()));
                let logger = TestLogger::new();
@@ -1218,8 +1238,9 @@ mod tests {
 
        #[derive(Clone, Debug, PartialEq, Eq)]
        enum Amount {
-               ForInvoiceOrRetry(u64),
+               ForInvoice(u64),
                Spontaneous(u64),
+               OnRetry(u64),
        }
 
        impl TestPayer {
@@ -1231,11 +1252,6 @@ mod tests {
                        }
                }
 
-               fn expect_value_msat(self, value_msat: u64) -> Self {
-                       self.expectations.borrow_mut().push_back(Amount::ForInvoiceOrRetry(value_msat));
-                       self
-               }
-
                fn expect_send(self, value_msat: Amount) -> Self {
                        self.expectations.borrow_mut().push_back(value_msat);
                        self
@@ -1249,13 +1265,13 @@ mod tests {
                        }
                }
 
-               fn check_attempts(&self) -> bool {
+               fn check_attempts(&self) -> Result<PaymentId, PaymentSendFailure> {
                        let mut attempts = self.attempts.borrow_mut();
                        *attempts += 1;
                        match self.failing_on_attempt {
-                               None => true,
-                               Some(attempt) if attempt != *attempts => true,
-                               Some(_) => false,
+                               None => Ok(PaymentId([1; 32])),
+                               Some(attempt) if attempt != *attempts => Ok(PaymentId([1; 32])),
+                               Some(_) => Err(PaymentSendFailure::ParameterError(APIError::MonitorUpdateFailed)),
                        }
                }
 
@@ -1290,41 +1306,25 @@ mod tests {
                }
 
                fn send_payment(
-                       &self,
-                       route: &Route,
-                       _payment_hash: PaymentHash,
+                       &self, route: &Route, _payment_hash: PaymentHash,
                        _payment_secret: &Option<PaymentSecret>
                ) -> Result<PaymentId, PaymentSendFailure> {
-                       if self.check_attempts() {
-                               self.check_value_msats(Amount::ForInvoiceOrRetry(route.get_total_amount()));
-                               Ok(PaymentId([1; 32]))
-                       } else {
-                               Err(PaymentSendFailure::ParameterError(APIError::MonitorUpdateFailed))
-                       }
+                       self.check_value_msats(Amount::ForInvoice(route.get_total_amount()));
+                       self.check_attempts()
                }
 
                fn send_spontaneous_payment(
-                       &self,
-                       route: &Route,
-                       _payment_preimage: PaymentPreimage,
+                       &self, route: &Route, _payment_preimage: PaymentPreimage,
                ) -> Result<PaymentId, PaymentSendFailure> {
-                       if self.check_attempts() {
-                               self.check_value_msats(Amount::Spontaneous(route.get_total_amount()));
-                               Ok(PaymentId([1; 32]))
-                       } else {
-                               Err(PaymentSendFailure::ParameterError(APIError::MonitorUpdateFailed))
-                       }
+                       self.check_value_msats(Amount::Spontaneous(route.get_total_amount()));
+                       self.check_attempts()
                }
 
                fn retry_payment(
                        &self, route: &Route, _payment_id: PaymentId
                ) -> Result<(), PaymentSendFailure> {
-                       if self.check_attempts() {
-                               self.check_value_msats(Amount::ForInvoiceOrRetry(route.get_total_amount()));
-                               Ok(())
-                       } else {
-                               Err(PaymentSendFailure::ParameterError(APIError::MonitorUpdateFailed))
-                       }
+                       self.check_value_msats(Amount::OnRetry(route.get_total_amount()));
+                       self.check_attempts().map(|_| ())
                }
        }