Merge pull request #3039 from valentinewallace/2024-04-invoice-amt-msats-overflow
[rust-lightning] / lightning-invoice / src / lib.rs
index 690bf57640862daae0adb55019ab48d90a2384a0..fb34240bee1b32ee397f9acce1de769db65bf9a7 100644 (file)
@@ -31,7 +31,6 @@ pub mod utils;
 
 extern crate bech32;
 #[macro_use] extern crate lightning;
-extern crate num_traits;
 extern crate secp256k1;
 extern crate alloc;
 #[cfg(any(test, feature = "std"))]
@@ -66,7 +65,7 @@ use core::str;
 use serde::{Deserialize, Deserializer,Serialize, Serializer, de::Error};
 
 #[doc(no_inline)]
-pub use lightning::ln::PaymentSecret;
+pub use lightning::ln::types::PaymentSecret;
 #[doc(no_inline)]
 pub use lightning::routing::router::{RouteHint, RouteHintHop};
 #[doc(no_inline)]
@@ -77,25 +76,15 @@ mod de;
 mod ser;
 mod tb;
 
+#[allow(unused_imports)]
 mod prelude {
-       #[cfg(feature = "hashbrown")]
-       extern crate hashbrown;
-
        pub use alloc::{vec, vec::Vec, string::String};
-       #[cfg(not(feature = "hashbrown"))]
-       pub use std::collections::{HashMap, hash_map};
-       #[cfg(feature = "hashbrown")]
-       pub use self::hashbrown::{HashMap, HashSet, hash_map};
 
        pub use alloc::string::ToString;
 }
 
 use crate::prelude::*;
 
-/// Sync compat for std/no_std
-#[cfg(not(feature = "std"))]
-mod sync;
-
 /// Errors that indicate what is wrong with the invoice. They have some granularity for debug
 /// reasons, but should generally result in an "invalid BOLT11 invoice" message for the user.
 #[allow(missing_docs)]
@@ -173,7 +162,7 @@ pub const DEFAULT_MIN_FINAL_CLTV_EXPIRY_DELTA: u64 = 18;
 /// use secp256k1::Secp256k1;
 /// use secp256k1::SecretKey;
 ///
-/// use lightning::ln::PaymentSecret;
+/// use lightning::ln::types::PaymentSecret;
 ///
 /// use lightning_invoice::{Currency, InvoiceBuilder};
 ///
@@ -553,7 +542,7 @@ impl InvoiceBuilder<tb::False, tb::False, tb::False, tb::False, tb::False, tb::F
                        amount: None,
                        si_prefix: None,
                        timestamp: None,
-                       tagged_fields: Vec::new(),
+                       tagged_fields: Vec::with_capacity(8),
                        error: None,
 
                        phantom_d: core::marker::PhantomData,
@@ -588,7 +577,13 @@ impl<D: tb::Bool, H: tb::Bool, T: tb::Bool, C: tb::Bool, S: tb::Bool, M: tb::Boo
 
        /// Sets the amount in millisatoshis. The optimal SI prefix is chosen automatically.
        pub fn amount_milli_satoshis(mut self, amount_msat: u64) -> Self {
-               let amount = amount_msat * 10; // Invoices are denominated in "pico BTC"
+               let amount = match amount_msat.checked_mul(10) { // Invoices are denominated in "pico BTC"
+                       Some(amt) => amt,
+                       None => {
+                               self.error = Some(CreationError::InvalidAmount);
+                               return self
+                       }
+               };
                let biggest_possible_si_prefix = SiPrefix::values_desc()
                        .iter()
                        .find(|prefix| amount % prefix.multiplier() == 0)
@@ -1079,9 +1074,10 @@ impl RawBolt11Invoice {
                find_all_extract!(self.known_tagged_fields(), TaggedField::PrivateRoute(ref x), x).collect()
        }
 
+       /// Returns `None` if no amount is set or on overflow.
        pub fn amount_pico_btc(&self) -> Option<u64> {
-               self.hrp.raw_amount.map(|v| {
-                       v * self.hrp.si_prefix.as_ref().map_or(1_000_000_000_000, |si| { si.multiplier() })
+               self.hrp.raw_amount.and_then(|v| {
+                       v.checked_mul(self.hrp.si_prefix.as_ref().map_or(1_000_000_000_000, |si| { si.multiplier() }))
                })
        }
 
@@ -1358,6 +1354,15 @@ impl Bolt11Invoice {
                self.signed_invoice.recover_payee_pub_key().expect("was checked by constructor").0
        }
 
+       /// Recover the payee's public key if one was included in the invoice, otherwise return the
+       /// recovered public key from the signature
+       pub fn get_payee_pub_key(&self) -> PublicKey {
+               match self.payee_pub_key() {
+                       Some(pk) => *pk,
+                       None => self.recover_payee_pub_key()
+               }
+       }
+
        /// Returns the Duration since the Unix epoch at which the invoice expires.
        /// Returning None if overflow occurred.
        pub fn expires_at(&self) -> Option<Duration> {
@@ -1878,7 +1883,7 @@ mod test {
                         Bolt11SemanticError};
 
                let private_key = SecretKey::from_slice(&[42; 32]).unwrap();
-               let payment_secret = lightning::ln::PaymentSecret([21; 32]);
+               let payment_secret = lightning::ln::types::PaymentSecret([21; 32]);
                let invoice_template = RawBolt11Invoice {
                        hrp: RawHrp {
                                currency: Currency::Bitcoin,
@@ -2049,7 +2054,7 @@ mod test {
                use lightning::routing::router::RouteHintHop;
                use secp256k1::Secp256k1;
                use secp256k1::{SecretKey, PublicKey};
-               use std::time::{UNIX_EPOCH, Duration};
+               use std::time::Duration;
 
                let secp_ctx = Secp256k1::new();
 
@@ -2065,7 +2070,7 @@ mod test {
                let route_1 = RouteHint(vec![
                        RouteHintHop {
                                src_node_id: public_key,
-                               short_channel_id: de::parse_int_be(&[123; 8], 256).expect("short chan ID slice too big?"),
+                               short_channel_id: u64::from_be_bytes([123; 8]),
                                fees: RoutingFees {
                                        base_msat: 2,
                                        proportional_millionths: 1,
@@ -2076,7 +2081,7 @@ mod test {
                        },
                        RouteHintHop {
                                src_node_id: public_key,
-                               short_channel_id: de::parse_int_be(&[42; 8], 256).expect("short chan ID slice too big?"),
+                               short_channel_id: u64::from_be_bytes([42; 8]),
                                fees: RoutingFees {
                                        base_msat: 3,
                                        proportional_millionths: 2,
@@ -2101,7 +2106,7 @@ mod test {
                        },
                        RouteHintHop {
                                src_node_id: public_key,
-                               short_channel_id: de::parse_int_be(&[1; 8], 256).expect("short chan ID slice too big?"),
+                               short_channel_id: u64::from_be_bytes([1; 8]),
                                fees: RoutingFees {
                                        base_msat: 5,
                                        proportional_millionths: 4,
@@ -2138,7 +2143,7 @@ mod test {
                assert_eq!(invoice.currency(), Currency::BitcoinTestnet);
                #[cfg(feature = "std")]
                assert_eq!(
-                       invoice.timestamp().duration_since(UNIX_EPOCH).unwrap().as_secs(),
+                       invoice.timestamp().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs(),
                        1234567
                );
                assert_eq!(invoice.payee_pub_key(), Some(&public_key));