From e80e8c80627e554034100af197e8ec173db4dd21 Mon Sep 17 00:00:00 2001 From: benthecarman Date: Tue, 14 Nov 2023 10:14:30 -0600 Subject: [PATCH] Have Invoice Description use UntrustedString --- lightning-invoice/src/lib.rs | 25 ++++++------------------- lightning-invoice/src/ser.rs | 4 ++-- lightning-invoice/src/utils.rs | 11 ++++++----- lightning/src/util/string.rs | 2 +- 4 files changed, 15 insertions(+), 27 deletions(-) diff --git a/lightning-invoice/src/lib.rs b/lightning-invoice/src/lib.rs index d283cba7a..42d0a337e 100644 --- a/lightning-invoice/src/lib.rs +++ b/lightning-invoice/src/lib.rs @@ -73,6 +73,7 @@ pub use lightning::ln::PaymentSecret; pub use lightning::routing::router::{RouteHint, RouteHintHop}; #[doc(no_inline)] pub use lightning::routing::gossip::RoutingFees; +use lightning::util::string::UntrustedString; mod de; mod ser; @@ -480,7 +481,7 @@ impl Sha256 { /// # Invariants /// The description can be at most 639 __bytes__ long #[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd, Default)] -pub struct Description(String); +pub struct Description(UntrustedString); /// Payee public key #[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] @@ -684,7 +685,7 @@ impl InvoiceBui pub fn invoice_description(self, description: Bolt11InvoiceDescription) -> InvoiceBuilder { match description { Bolt11InvoiceDescription::Direct(desc) => { - self.description(desc.clone().into_inner()) + self.description(desc.clone().into_inner().0) } Bolt11InvoiceDescription::Hash(hash) => { self.description_hash(hash.0) @@ -1517,30 +1518,16 @@ impl Description { if description.len() > 639 { Err(CreationError::DescriptionTooLong) } else { - Ok(Description(description)) + Ok(Description(UntrustedString(description))) } } - /// Returns the underlying description [`String`] - pub fn into_inner(self) -> String { + /// Returns the underlying description [`UntrustedString`] + pub fn into_inner(self) -> UntrustedString { self.0 } } -impl From for String { - fn from(val: Description) -> Self { - val.into_inner() - } -} - -impl Deref for Description { - type Target = str; - - fn deref(&self) -> &str { - &self.0 - } -} - impl Display for Description { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) diff --git a/lightning-invoice/src/ser.rs b/lightning-invoice/src/ser.rs index dc5dba45d..fe42f72b6 100644 --- a/lightning-invoice/src/ser.rs +++ b/lightning-invoice/src/ser.rs @@ -279,13 +279,13 @@ impl Base32Len for Sha256 { impl ToBase32 for Description { fn write_base32(&self, writer: &mut W) -> Result<(), ::Err> { - self.as_bytes().write_base32(writer) + self.0.0.as_bytes().write_base32(writer) } } impl Base32Len for Description { fn base32_len(&self) -> usize { - self.0.as_bytes().base32_len() + self.0.0.as_bytes().base32_len() } } diff --git a/lightning-invoice/src/utils.rs b/lightning-invoice/src/utils.rs index d551d248a..f3e642a2e 100644 --- a/lightning-invoice/src/utils.rs +++ b/lightning-invoice/src/utils.rs @@ -158,7 +158,7 @@ where let invoice = match description { Bolt11InvoiceDescription::Direct(description) => { - InvoiceBuilder::new(network).description(description.0.clone()) + InvoiceBuilder::new(network).description(description.0.0.clone()) } Bolt11InvoiceDescription::Hash(hash) => InvoiceBuilder::new(network).description_hash(hash.0), }; @@ -538,7 +538,7 @@ fn _create_invoice_from_channelmanager_and_duration_since_epoch_with_payment_has let invoice = match description { Bolt11InvoiceDescription::Direct(description) => { - InvoiceBuilder::new(network).description(description.0.clone()) + InvoiceBuilder::new(network).description(description.0.0.clone()) } Bolt11InvoiceDescription::Hash(hash) => InvoiceBuilder::new(network).description_hash(hash.0), }; @@ -808,6 +808,7 @@ mod test { use lightning::util::config::UserConfig; use crate::utils::{create_invoice_from_channelmanager_and_duration_since_epoch, rotate_through_iterators}; use std::collections::HashSet; + use lightning::util::string::UntrustedString; #[test] fn test_prefer_current_channel() { @@ -852,7 +853,7 @@ mod test { assert_eq!(invoice.amount_pico_btc(), Some(100_000)); // If no `min_final_cltv_expiry_delta` is specified, then it should be `MIN_FINAL_CLTV_EXPIRY_DELTA`. assert_eq!(invoice.min_final_cltv_expiry_delta(), MIN_FINAL_CLTV_EXPIRY_DELTA as u64); - assert_eq!(invoice.description(), Bolt11InvoiceDescription::Direct(&Description("test".to_string()))); + assert_eq!(invoice.description(), Bolt11InvoiceDescription::Direct(&Description(UntrustedString("test".to_string())))); assert_eq!(invoice.expiry_time(), Duration::from_secs(non_default_invoice_expiry_secs.into())); // Invoice SCIDs should always use inbound SCID aliases over the real channel ID, if one is @@ -963,7 +964,7 @@ mod test { ).unwrap(); assert_eq!(invoice.amount_pico_btc(), Some(100_000)); assert_eq!(invoice.min_final_cltv_expiry_delta(), MIN_FINAL_CLTV_EXPIRY_DELTA as u64); - assert_eq!(invoice.description(), Bolt11InvoiceDescription::Direct(&Description("test".to_string()))); + assert_eq!(invoice.description(), Bolt11InvoiceDescription::Direct(&Description(UntrustedString("test".to_string())))); assert_eq!(invoice.payment_hash(), &sha256::Hash::from_slice(&payment_hash.0[..]).unwrap()); } @@ -1315,7 +1316,7 @@ mod test { }; assert_eq!(invoice.min_final_cltv_expiry_delta(), MIN_FINAL_CLTV_EXPIRY_DELTA as u64); - assert_eq!(invoice.description(), Bolt11InvoiceDescription::Direct(&Description("test".to_string()))); + assert_eq!(invoice.description(), Bolt11InvoiceDescription::Direct(&Description(UntrustedString("test".to_string())))); assert_eq!(invoice.route_hints().len(), 2); assert_eq!(invoice.expiry_time(), Duration::from_secs(non_default_invoice_expiry_secs.into())); assert!(!invoice.features().unwrap().supports_basic_mpp()); diff --git a/lightning/src/util/string.rs b/lightning/src/util/string.rs index 3e5942f6f..6949c936e 100644 --- a/lightning/src/util/string.rs +++ b/lightning/src/util/string.rs @@ -16,7 +16,7 @@ use crate::ln::msgs; use crate::util::ser::{Writeable, Writer, Readable}; /// Struct to `Display` fields in a safe way using `PrintableString` -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default)] pub struct UntrustedString(pub String); impl Writeable for UntrustedString { -- 2.39.5