]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Accept multi-hop route hints in get_route
authorJeffrey Czyz <jkczyz@gmail.com>
Thu, 10 Jun 2021 22:49:14 +0000 (15:49 -0700)
committerJeffrey Czyz <jkczyz@gmail.com>
Fri, 11 Jun 2021 15:44:32 +0000 (08:44 -0700)
Lightning invoices allow for zero or more multi-hop route hints. Update
get_route's interface to accept such hints, although only the last hop
from each is used for the time being.

Moves RouteHint from lightning-invoice crate to lightning crate. Adds a
PrivateRoute wrapper around RouteHint for use in lightning-invoice.

fuzz/src/router.rs
lightning-invoice/src/de.rs
lightning-invoice/src/lib.rs
lightning-invoice/src/ser.rs
lightning-invoice/src/utils.rs
lightning/src/routing/router.rs

index d0f7b13219fb3adb29b22715dee471cfc8b85332..e8e025e247b2ec7c7442ce05c72363946d9bc85b 100644 (file)
@@ -16,7 +16,7 @@ use lightning::chain::transaction::OutPoint;
 use lightning::ln::channelmanager::ChannelDetails;
 use lightning::ln::features::InitFeatures;
 use lightning::ln::msgs;
 use lightning::ln::channelmanager::ChannelDetails;
 use lightning::ln::features::InitFeatures;
 use lightning::ln::msgs;
-use lightning::routing::router::{get_route, RouteHintHop};
+use lightning::routing::router::{get_route, RouteHint, RouteHintHop};
 use lightning::util::logger::Logger;
 use lightning::util::ser::Readable;
 use lightning::routing::network_graph::{NetworkGraph, RoutingFees};
 use lightning::util::logger::Logger;
 use lightning::util::ser::Readable;
 use lightning::routing::network_graph::{NetworkGraph, RoutingFees};
@@ -225,13 +225,13 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
                                                Some(&first_hops_vec[..])
                                        },
                                };
                                                Some(&first_hops_vec[..])
                                        },
                                };
-                               let mut last_hops_vec = Vec::new();
+                               let mut last_hops = Vec::new();
                                {
                                        let count = get_slice!(1)[0];
                                        for _ in 0..count {
                                                scid += 1;
                                                let rnid = node_pks.iter().skip(slice_to_be16(get_slice!(2))as usize % node_pks.len()).next().unwrap();
                                {
                                        let count = get_slice!(1)[0];
                                        for _ in 0..count {
                                                scid += 1;
                                                let rnid = node_pks.iter().skip(slice_to_be16(get_slice!(2))as usize % node_pks.len()).next().unwrap();
-                                               last_hops_vec.push(RouteHintHop {
+                                               last_hops.push(RouteHint(vec![RouteHintHop {
                                                        src_node_id: *rnid,
                                                        short_channel_id: scid,
                                                        fees: RoutingFees {
                                                        src_node_id: *rnid,
                                                        short_channel_id: scid,
                                                        fees: RoutingFees {
@@ -241,10 +241,9 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
                                                        cltv_expiry_delta: slice_to_be16(get_slice!(2)),
                                                        htlc_minimum_msat: Some(slice_to_be64(get_slice!(8))),
                                                        htlc_maximum_msat: None,
                                                        cltv_expiry_delta: slice_to_be16(get_slice!(2)),
                                                        htlc_minimum_msat: Some(slice_to_be64(get_slice!(8))),
                                                        htlc_maximum_msat: None,
-                                               });
+                                               }]));
                                        }
                                }
                                        }
                                }
-                               let last_hops = &last_hops_vec[..];
                                for target in node_pks.iter() {
                                        let _ = get_route(&our_pubkey, &net_graph, target, None,
                                                first_hops.map(|c| c.iter().collect::<Vec<_>>()).as_ref().map(|a| a.as_slice()),
                                for target in node_pks.iter() {
                                        let _ = get_route(&our_pubkey, &net_graph, target, None,
                                                first_hops.map(|c| c.iter().collect::<Vec<_>>()).as_ref().map(|a| a.as_slice()),
index 47557665b4bf4b473a90657280c9b807ebadbfdc..9c5120e4ad67cb2e9ccdc492b0757a8479c380f0 100644 (file)
@@ -12,7 +12,7 @@ use bitcoin_hashes::Hash;
 use bitcoin_hashes::sha256;
 use lightning::ln::PaymentSecret;
 use lightning::routing::network_graph::RoutingFees;
 use bitcoin_hashes::sha256;
 use lightning::ln::PaymentSecret;
 use lightning::routing::network_graph::RoutingFees;
-use lightning::routing::router::RouteHintHop;
+use lightning::routing::router::{RouteHint, RouteHintHop};
 
 use num_traits::{CheckedAdd, CheckedMul};
 
 
 use num_traits::{CheckedAdd, CheckedMul};
 
@@ -21,7 +21,7 @@ use secp256k1::recovery::{RecoveryId, RecoverableSignature};
 use secp256k1::key::PublicKey;
 
 use super::{Invoice, Sha256, TaggedField, ExpiryTime, MinFinalCltvExpiry, Fallback, PayeePubKey, InvoiceSignature, PositiveTimestamp,
 use secp256k1::key::PublicKey;
 
 use super::{Invoice, Sha256, TaggedField, ExpiryTime, MinFinalCltvExpiry, Fallback, PayeePubKey, InvoiceSignature, PositiveTimestamp,
-       SemanticError, RouteHint, Description, RawTaggedField, Currency, RawHrp, SiPrefix, RawInvoice, constants, SignedRawInvoice,
+       SemanticError, PrivateRoute, Description, RawTaggedField, Currency, RawHrp, SiPrefix, RawInvoice, constants, SignedRawInvoice,
        RawDataPart, CreationError, InvoiceFeatures};
 
 use self::hrp_sm::parse_hrp;
        RawDataPart, CreationError, InvoiceFeatures};
 
 use self::hrp_sm::parse_hrp;
@@ -433,8 +433,8 @@ impl FromBase32 for TaggedField {
                                Ok(TaggedField::MinFinalCltvExpiry(MinFinalCltvExpiry::from_base32(field_data)?)),
                        constants::TAG_FALLBACK =>
                                Ok(TaggedField::Fallback(Fallback::from_base32(field_data)?)),
                                Ok(TaggedField::MinFinalCltvExpiry(MinFinalCltvExpiry::from_base32(field_data)?)),
                        constants::TAG_FALLBACK =>
                                Ok(TaggedField::Fallback(Fallback::from_base32(field_data)?)),
-                       constants::TAG_ROUTE =>
-                               Ok(TaggedField::Route(RouteHint::from_base32(field_data)?)),
+                       constants::TAG_PRIVATE_ROUTE =>
+                               Ok(TaggedField::PrivateRoute(PrivateRoute::from_base32(field_data)?)),
                        constants::TAG_PAYMENT_SECRET =>
                                Ok(TaggedField::PaymentSecret(PaymentSecret::from_base32(field_data)?)),
                        constants::TAG_FEATURES =>
                        constants::TAG_PAYMENT_SECRET =>
                                Ok(TaggedField::PaymentSecret(PaymentSecret::from_base32(field_data)?)),
                        constants::TAG_FEATURES =>
@@ -558,10 +558,10 @@ impl FromBase32 for Fallback {
        }
 }
 
        }
 }
 
-impl FromBase32 for RouteHint {
+impl FromBase32 for PrivateRoute {
        type Err = ParseError;
 
        type Err = ParseError;
 
-       fn from_base32(field_data: &[u5]) -> Result<RouteHint, ParseError> {
+       fn from_base32(field_data: &[u5]) -> Result<PrivateRoute, ParseError> {
                let bytes = Vec::<u8>::from_base32(field_data)?;
 
                if bytes.len() % 51 != 0 {
                let bytes = Vec::<u8>::from_base32(field_data)?;
 
                if bytes.len() % 51 != 0 {
@@ -593,7 +593,7 @@ impl FromBase32 for RouteHint {
                        route_hops.push(hop);
                }
 
                        route_hops.push(hop);
                }
 
-               Ok(RouteHint(route_hops))
+               Ok(PrivateRoute(RouteHint(route_hops)))
        }
 }
 
        }
 }
 
@@ -930,8 +930,8 @@ mod test {
        #[test]
        fn test_parse_route() {
                use lightning::routing::network_graph::RoutingFees;
        #[test]
        fn test_parse_route() {
                use lightning::routing::network_graph::RoutingFees;
-               use lightning::routing::router::RouteHintHop;
-               use ::RouteHint;
+               use lightning::routing::router::{RouteHint, RouteHintHop};
+               use ::PrivateRoute;
                use bech32::FromBase32;
                use de::parse_int_be;
 
                use bech32::FromBase32;
                use de::parse_int_be;
 
@@ -976,10 +976,10 @@ mod test {
                        htlc_maximum_msat: None
                });
 
                        htlc_maximum_msat: None
                });
 
-               assert_eq!(RouteHint::from_base32(&input), Ok(RouteHint(expected)));
+               assert_eq!(PrivateRoute::from_base32(&input), Ok(PrivateRoute(RouteHint(expected))));
 
                assert_eq!(
 
                assert_eq!(
-                       RouteHint::from_base32(&[u5::try_from_u8(0).unwrap(); 40][..]),
+                       PrivateRoute::from_base32(&[u5::try_from_u8(0).unwrap(); 40][..]),
                        Err(ParseError::UnexpectedEndOfTaggedFields)
                );
        }
                        Err(ParseError::UnexpectedEndOfTaggedFields)
                );
        }
index 1d5bb4ffb53e533cb6c06373f8e8803e14b67475..19d53b5456bc3107e75cb7bf85ce364983d5438b 100644 (file)
@@ -30,7 +30,7 @@ use lightning::ln::PaymentSecret;
 use lightning::ln::features::InvoiceFeatures;
 #[cfg(any(doc, test))]
 use lightning::routing::network_graph::RoutingFees;
 use lightning::ln::features::InvoiceFeatures;
 #[cfg(any(doc, test))]
 use lightning::routing::network_graph::RoutingFees;
-use lightning::routing::router::RouteHintHop;
+use lightning::routing::router::RouteHint;
 
 use secp256k1::key::PublicKey;
 use secp256k1::{Message, Secp256k1};
 
 use secp256k1::key::PublicKey;
 use secp256k1::{Message, Secp256k1};
@@ -362,7 +362,7 @@ pub enum TaggedField {
        ExpiryTime(ExpiryTime),
        MinFinalCltvExpiry(MinFinalCltvExpiry),
        Fallback(Fallback),
        ExpiryTime(ExpiryTime),
        MinFinalCltvExpiry(MinFinalCltvExpiry),
        Fallback(Fallback),
-       Route(RouteHint),
+       PrivateRoute(PrivateRoute),
        PaymentSecret(PaymentSecret),
        Features(InvoiceFeatures),
 }
        PaymentSecret(PaymentSecret),
        Features(InvoiceFeatures),
 }
@@ -419,7 +419,7 @@ pub struct InvoiceSignature(pub RecoverableSignature);
 /// The encoded route has to be <1024 5bit characters long (<=639 bytes or <=12 hops)
 ///
 #[derive(Eq, PartialEq, Debug, Clone)]
 /// The encoded route has to be <1024 5bit characters long (<=639 bytes or <=12 hops)
 ///
 #[derive(Eq, PartialEq, Debug, Clone)]
-pub struct RouteHint(Vec<RouteHintHop>);
+pub struct PrivateRoute(RouteHint);
 
 /// Tag constants as specified in BOLT11
 #[allow(missing_docs)]
 
 /// Tag constants as specified in BOLT11
 #[allow(missing_docs)]
@@ -431,7 +431,7 @@ pub mod constants {
        pub const TAG_EXPIRY_TIME: u8 = 6;
        pub const TAG_MIN_FINAL_CLTV_EXPIRY: u8 = 24;
        pub const TAG_FALLBACK: u8 = 9;
        pub const TAG_EXPIRY_TIME: u8 = 6;
        pub const TAG_MIN_FINAL_CLTV_EXPIRY: u8 = 24;
        pub const TAG_FALLBACK: u8 = 9;
-       pub const TAG_ROUTE: u8 = 3;
+       pub const TAG_PRIVATE_ROUTE: u8 = 3;
        pub const TAG_PAYMENT_SECRET: u8 = 16;
        pub const TAG_FEATURES: u8 = 5;
 }
        pub const TAG_PAYMENT_SECRET: u8 = 16;
        pub const TAG_FEATURES: u8 = 5;
 }
@@ -509,9 +509,9 @@ impl<D: tb::Bool, H: tb::Bool, T: tb::Bool, C: tb::Bool, S: tb::Bool> InvoiceBui
        }
 
        /// Adds a private route.
        }
 
        /// Adds a private route.
-       pub fn route(mut self, route: Vec<RouteHintHop>) -> Self {
-               match RouteHint::new(route) {
-                       Ok(r) => self.tagged_fields.push(TaggedField::Route(r)),
+       pub fn private_route(mut self, hint: RouteHint) -> Self {
+               match PrivateRoute::new(hint) {
+                       Ok(r) => self.tagged_fields.push(TaggedField::PrivateRoute(r)),
                        Err(e) => self.error = Some(e),
                }
                self
                        Err(e) => self.error = Some(e),
                }
                self
@@ -913,8 +913,8 @@ impl RawInvoice {
                find_all_extract!(self.known_tagged_fields(), TaggedField::Fallback(ref x), x).collect()
        }
 
                find_all_extract!(self.known_tagged_fields(), TaggedField::Fallback(ref x), x).collect()
        }
 
-       pub fn routes(&self) -> Vec<&RouteHint> {
-               find_all_extract!(self.known_tagged_fields(), TaggedField::Route(ref x), x).collect()
+       pub fn private_routes(&self) -> Vec<&PrivateRoute> {
+               find_all_extract!(self.known_tagged_fields(), TaggedField::PrivateRoute(ref x), x).collect()
        }
 
        pub fn amount_pico_btc(&self) -> Option<u64> {
        }
 
        pub fn amount_pico_btc(&self) -> Option<u64> {
@@ -1163,8 +1163,15 @@ impl Invoice {
        }
 
        /// Returns a list of all routes included in the invoice
        }
 
        /// Returns a list of all routes included in the invoice
-       pub fn routes(&self) -> Vec<&RouteHint> {
-               self.signed_invoice.routes()
+       pub fn private_routes(&self) -> Vec<&PrivateRoute> {
+               self.signed_invoice.private_routes()
+       }
+
+       /// Returns a list of all routes included in the invoice as the underlying hints
+       pub fn route_hints(&self) -> Vec<&RouteHint> {
+               find_all_extract!(
+                       self.signed_invoice.known_tagged_fields(), TaggedField::PrivateRoute(ref x), x
+               ).map(|route| &**route).collect()
        }
 
        /// Returns the currency for which the invoice was issued
        }
 
        /// Returns the currency for which the invoice was issued
@@ -1195,7 +1202,7 @@ impl TaggedField {
                        TaggedField::ExpiryTime(_) => constants::TAG_EXPIRY_TIME,
                        TaggedField::MinFinalCltvExpiry(_) => constants::TAG_MIN_FINAL_CLTV_EXPIRY,
                        TaggedField::Fallback(_) => constants::TAG_FALLBACK,
                        TaggedField::ExpiryTime(_) => constants::TAG_EXPIRY_TIME,
                        TaggedField::MinFinalCltvExpiry(_) => constants::TAG_MIN_FINAL_CLTV_EXPIRY,
                        TaggedField::Fallback(_) => constants::TAG_FALLBACK,
-                       TaggedField::Route(_) => constants::TAG_ROUTE,
+                       TaggedField::PrivateRoute(_) => constants::TAG_PRIVATE_ROUTE,
                        TaggedField::PaymentSecret(_) => constants::TAG_PAYMENT_SECRET,
                        TaggedField::Features(_) => constants::TAG_FEATURES,
                };
                        TaggedField::PaymentSecret(_) => constants::TAG_PAYMENT_SECRET,
                        TaggedField::Features(_) => constants::TAG_FEATURES,
                };
@@ -1286,32 +1293,32 @@ impl ExpiryTime {
        }
 }
 
        }
 }
 
-impl RouteHint {
-       /// Create a new (partial) route from a list of hops
-       pub fn new(hops: Vec<RouteHintHop>) -> Result<RouteHint, CreationError> {
-               if hops.len() <= 12 {
-                       Ok(RouteHint(hops))
+impl PrivateRoute {
+       /// Creates a new (partial) route from a list of hops
+       pub fn new(hops: RouteHint) -> Result<PrivateRoute, CreationError> {
+               if hops.0.len() <= 12 {
+                       Ok(PrivateRoute(hops))
                } else {
                        Err(CreationError::RouteTooLong)
                }
        }
 
                } else {
                        Err(CreationError::RouteTooLong)
                }
        }
 
-       /// Returrn the underlying vector of hops
-       pub fn into_inner(self) -> Vec<RouteHintHop> {
+       /// Returns the underlying list of hops
+       pub fn into_inner(self) -> RouteHint {
                self.0
        }
 }
 
                self.0
        }
 }
 
-impl Into<Vec<RouteHintHop>> for RouteHint {
-       fn into(self) -> Vec<RouteHintHop> {
+impl Into<RouteHint> for PrivateRoute {
+       fn into(self) -> RouteHint {
                self.into_inner()
        }
 }
 
                self.into_inner()
        }
 }
 
-impl Deref for RouteHint {
-       type Target = Vec<RouteHintHop>;
+impl Deref for PrivateRoute {
+       type Target = RouteHint;
 
 
-       fn deref(&self) -> &Vec<RouteHintHop> {
+       fn deref(&self) -> &RouteHint {
                &self.0
        }
 }
                &self.0
        }
 }
@@ -1670,6 +1677,7 @@ mod test {
        #[test]
        fn test_builder_fail() {
                use ::*;
        #[test]
        fn test_builder_fail() {
                use ::*;
+               use lightning::routing::router::RouteHintHop;
                use std::iter::FromIterator;
                use secp256k1::key::PublicKey;
 
                use std::iter::FromIterator;
                use secp256k1::key::PublicKey;
 
@@ -1704,10 +1712,10 @@ mod test {
                        htlc_minimum_msat: None,
                        htlc_maximum_msat: None,
                };
                        htlc_minimum_msat: None,
                        htlc_maximum_msat: None,
                };
-               let too_long_route = vec![route_hop; 13];
+               let too_long_route = RouteHint(vec![route_hop; 13]);
                let long_route_res = builder.clone()
                        .description("Test".into())
                let long_route_res = builder.clone()
                        .description("Test".into())
-                       .route(too_long_route)
+                       .private_route(too_long_route)
                        .build_raw();
                assert_eq!(long_route_res, Err(CreationError::RouteTooLong));
 
                        .build_raw();
                assert_eq!(long_route_res, Err(CreationError::RouteTooLong));
 
@@ -1722,6 +1730,7 @@ mod test {
        #[test]
        fn test_builder_ok() {
                use ::*;
        #[test]
        fn test_builder_ok() {
                use ::*;
+               use lightning::routing::router::RouteHintHop;
                use secp256k1::Secp256k1;
                use secp256k1::key::{SecretKey, PublicKey};
                use std::time::{UNIX_EPOCH, Duration};
                use secp256k1::Secp256k1;
                use secp256k1::key::{SecretKey, PublicKey};
                use std::time::{UNIX_EPOCH, Duration};
@@ -1737,7 +1746,7 @@ mod test {
                ).unwrap();
                let public_key = PublicKey::from_secret_key(&secp_ctx, &private_key);
 
                ).unwrap();
                let public_key = PublicKey::from_secret_key(&secp_ctx, &private_key);
 
-               let route_1 = vec![
+               let route_1 = RouteHint(vec![
                        RouteHintHop {
                                src_node_id: public_key.clone(),
                                short_channel_id: de::parse_int_be(&[123; 8], 256).expect("short chan ID slice too big?"),
                        RouteHintHop {
                                src_node_id: public_key.clone(),
                                short_channel_id: de::parse_int_be(&[123; 8], 256).expect("short chan ID slice too big?"),
@@ -1760,9 +1769,9 @@ mod test {
                                htlc_minimum_msat: None,
                                htlc_maximum_msat: None,
                        }
                                htlc_minimum_msat: None,
                                htlc_maximum_msat: None,
                        }
-               ];
+               ]);
 
 
-               let route_2 = vec![
+               let route_2 = RouteHint(vec![
                        RouteHintHop {
                                src_node_id: public_key.clone(),
                                short_channel_id: 0,
                        RouteHintHop {
                                src_node_id: public_key.clone(),
                                short_channel_id: 0,
@@ -1785,7 +1794,7 @@ mod test {
                                htlc_minimum_msat: None,
                                htlc_maximum_msat: None,
                        }
                                htlc_minimum_msat: None,
                                htlc_maximum_msat: None,
                        }
-               ];
+               ]);
 
                let builder = InvoiceBuilder::new(Currency::BitcoinTestnet)
                        .amount_pico_btc(123)
 
                let builder = InvoiceBuilder::new(Currency::BitcoinTestnet)
                        .amount_pico_btc(123)
@@ -1794,8 +1803,8 @@ mod test {
                        .expiry_time(Duration::from_secs(54321))
                        .min_final_cltv_expiry(144)
                        .fallback(Fallback::PubKeyHash([0;20]))
                        .expiry_time(Duration::from_secs(54321))
                        .min_final_cltv_expiry(144)
                        .fallback(Fallback::PubKeyHash([0;20]))
-                       .route(route_1.clone())
-                       .route(route_2.clone())
+                       .private_route(route_1.clone())
+                       .private_route(route_2.clone())
                        .description_hash(sha256::Hash::from_slice(&[3;32][..]).unwrap())
                        .payment_hash(sha256::Hash::from_slice(&[21;32][..]).unwrap())
                        .payment_secret(PaymentSecret([42; 32]))
                        .description_hash(sha256::Hash::from_slice(&[3;32][..]).unwrap())
                        .payment_hash(sha256::Hash::from_slice(&[21;32][..]).unwrap())
                        .payment_secret(PaymentSecret([42; 32]))
@@ -1818,7 +1827,7 @@ mod test {
                assert_eq!(invoice.expiry_time(), Duration::from_secs(54321));
                assert_eq!(invoice.min_final_cltv_expiry(), 144);
                assert_eq!(invoice.fallbacks(), vec![&Fallback::PubKeyHash([0;20])]);
                assert_eq!(invoice.expiry_time(), Duration::from_secs(54321));
                assert_eq!(invoice.min_final_cltv_expiry(), 144);
                assert_eq!(invoice.fallbacks(), vec![&Fallback::PubKeyHash([0;20])]);
-               assert_eq!(invoice.routes(), vec![&RouteHint(route_1), &RouteHint(route_2)]);
+               assert_eq!(invoice.private_routes(), vec![&PrivateRoute(route_1), &PrivateRoute(route_2)]);
                assert_eq!(
                        invoice.description(),
                        InvoiceDescription::Hash(&Sha256(sha256::Hash::from_slice(&[3;32][..]).unwrap()))
                assert_eq!(
                        invoice.description(),
                        InvoiceDescription::Hash(&Sha256(sha256::Hash::from_slice(&[3;32][..]).unwrap()))
index 885ea2684f2295665ebd0ac8800b90e08f7685ed..5c7b4aa8978456aebda20014a43c4ea165a2effb 100644 (file)
@@ -3,7 +3,7 @@ use std::fmt::{Display, Formatter};
 use bech32::{ToBase32, u5, WriteBase32, Base32Len};
 
 use super::{Invoice, Sha256, TaggedField, ExpiryTime, MinFinalCltvExpiry, Fallback, PayeePubKey, InvoiceSignature, PositiveTimestamp,
 use bech32::{ToBase32, u5, WriteBase32, Base32Len};
 
 use super::{Invoice, Sha256, TaggedField, ExpiryTime, MinFinalCltvExpiry, Fallback, PayeePubKey, InvoiceSignature, PositiveTimestamp,
-       RouteHint, Description, RawTaggedField, Currency, RawHrp, SiPrefix, constants, SignedRawInvoice, RawDataPart};
+       PrivateRoute, Description, RawTaggedField, Currency, RawHrp, SiPrefix, constants, SignedRawInvoice, RawDataPart};
 
 /// Converts a stream of bytes written to it to base32. On finalization the according padding will
 /// be applied. That means the results of writing two data blocks with one or two `BytesToBase32`
 
 /// Converts a stream of bytes written to it to base32. On finalization the according padding will
 /// be applied. That means the results of writing two data blocks with one or two `BytesToBase32`
@@ -356,11 +356,11 @@ impl Base32Len for Fallback {
        }
 }
 
        }
 }
 
-impl ToBase32 for RouteHint {
+impl ToBase32 for PrivateRoute {
        fn write_base32<W: WriteBase32>(&self, writer: &mut W) -> Result<(), <W as WriteBase32>::Err> {
                let mut converter = BytesToBase32::new(writer);
 
        fn write_base32<W: WriteBase32>(&self, writer: &mut W) -> Result<(), <W as WriteBase32>::Err> {
                let mut converter = BytesToBase32::new(writer);
 
-               for hop in self.iter() {
+               for hop in (self.0).0.iter() {
                        converter.append(&hop.src_node_id.serialize()[..])?;
                        let short_channel_id = try_stretch(
                                encode_int_be_base256(hop.short_channel_id),
                        converter.append(&hop.src_node_id.serialize()[..])?;
                        let short_channel_id = try_stretch(
                                encode_int_be_base256(hop.short_channel_id),
@@ -392,9 +392,9 @@ impl ToBase32 for RouteHint {
        }
 }
 
        }
 }
 
-impl Base32Len for RouteHint {
+impl Base32Len for PrivateRoute {
        fn base32_len(&self) -> usize {
        fn base32_len(&self) -> usize {
-               bytes_size_to_base32_size(self.0.len() * 51)
+               bytes_size_to_base32_size((self.0).0.len() * 51)
        }
 }
 
        }
 }
 
@@ -439,8 +439,8 @@ impl ToBase32 for TaggedField {
                        TaggedField::Fallback(ref fallback_address) => {
                                write_tagged_field(writer, constants::TAG_FALLBACK, fallback_address)
                        },
                        TaggedField::Fallback(ref fallback_address) => {
                                write_tagged_field(writer, constants::TAG_FALLBACK, fallback_address)
                        },
-                       TaggedField::Route(ref route_hops) => {
-                               write_tagged_field(writer, constants::TAG_ROUTE, route_hops)
+                       TaggedField::PrivateRoute(ref route_hops) => {
+                               write_tagged_field(writer, constants::TAG_PRIVATE_ROUTE, route_hops)
                        },
                        TaggedField::PaymentSecret(ref payment_secret) => {
                                  write_tagged_field(writer, constants::TAG_PAYMENT_SECRET, payment_secret)
                        },
                        TaggedField::PaymentSecret(ref payment_secret) => {
                                  write_tagged_field(writer, constants::TAG_PAYMENT_SECRET, payment_secret)
index b75a215ee430341b64ca62bacfe4d278e86c02d0..9d41b928ddc6bc4b610039463ade719714a3844a 100644 (file)
@@ -7,7 +7,7 @@ use lightning::chain::chaininterface::{BroadcasterInterface, FeeEstimator};
 use lightning::chain::keysinterface::{Sign, KeysInterface};
 use lightning::ln::channelmanager::{ChannelManager, MIN_FINAL_CLTV_EXPIRY};
 use lightning::routing::network_graph::RoutingFees;
 use lightning::chain::keysinterface::{Sign, KeysInterface};
 use lightning::ln::channelmanager::{ChannelManager, MIN_FINAL_CLTV_EXPIRY};
 use lightning::routing::network_graph::RoutingFees;
-use lightning::routing::router::RouteHintHop;
+use lightning::routing::router::{RouteHint, RouteHintHop};
 use lightning::util::logger::Logger;
 use std::convert::TryInto;
 use std::ops::Deref;
 use lightning::util::logger::Logger;
 use std::convert::TryInto;
 use std::ops::Deref;
@@ -40,7 +40,7 @@ where
                        Some(info) => info,
                        None => continue,
                };
                        Some(info) => info,
                        None => continue,
                };
-               route_hints.push(vec![RouteHintHop {
+               route_hints.push(RouteHint(vec![RouteHintHop {
                        src_node_id: channel.remote_network_id,
                        short_channel_id,
                        fees: RoutingFees {
                        src_node_id: channel.remote_network_id,
                        short_channel_id,
                        fees: RoutingFees {
@@ -50,7 +50,7 @@ where
                        cltv_expiry_delta: forwarding_info.cltv_expiry_delta,
                        htlc_minimum_msat: None,
                        htlc_maximum_msat: None,
                        cltv_expiry_delta: forwarding_info.cltv_expiry_delta,
                        htlc_minimum_msat: None,
                        htlc_maximum_msat: None,
-               }]);
+               }]));
        }
 
        let (payment_hash, payment_secret) = channelmanager.create_inbound_payment(
        }
 
        let (payment_hash, payment_secret) = channelmanager.create_inbound_payment(
@@ -70,8 +70,8 @@ where
        if let Some(amt) = amt_msat {
                invoice = invoice.amount_pico_btc(amt * 10);
        }
        if let Some(amt) = amt_msat {
                invoice = invoice.amount_pico_btc(amt * 10);
        }
-       for hint in route_hints.drain(..) {
-               invoice = invoice.route(hint);
+       for hint in route_hints {
+               invoice = invoice.private_route(hint);
        }
 
        let raw_invoice = match invoice.build_raw() {
        }
 
        let raw_invoice = match invoice.build_raw() {
@@ -112,14 +112,9 @@ mod test {
                assert_eq!(invoice.min_final_cltv_expiry(), MIN_FINAL_CLTV_EXPIRY as u64);
                assert_eq!(invoice.description(), InvoiceDescription::Direct(&Description("test".to_string())));
 
                assert_eq!(invoice.min_final_cltv_expiry(), MIN_FINAL_CLTV_EXPIRY as u64);
                assert_eq!(invoice.description(), InvoiceDescription::Direct(&Description("test".to_string())));
 
-               let mut route_hints = invoice.routes().clone();
-               let mut last_hops = Vec::new();
-               for hint in route_hints.drain(..) {
-                       last_hops.push(hint[hint.len() - 1].clone());
-               }
                let amt_msat = invoice.amount_pico_btc().unwrap() / 10;
                let amt_msat = invoice.amount_pico_btc().unwrap() / 10;
-
                let first_hops = nodes[0].node.list_usable_channels();
                let first_hops = nodes[0].node.list_usable_channels();
+               let last_hops = invoice.route_hints();
                let network_graph = nodes[0].net_graph_msg_handler.network_graph.read().unwrap();
                let logger = test_utils::TestLogger::new();
                let route = router::get_route(
                let network_graph = nodes[0].net_graph_msg_handler.network_graph.read().unwrap();
                let logger = test_utils::TestLogger::new();
                let route = router::get_route(
@@ -128,7 +123,7 @@ mod test {
                        &invoice.recover_payee_pub_key(),
                        Some(invoice.features().unwrap().clone()),
                        Some(&first_hops.iter().collect::<Vec<_>>()),
                        &invoice.recover_payee_pub_key(),
                        Some(invoice.features().unwrap().clone()),
                        Some(&first_hops.iter().collect::<Vec<_>>()),
-                       &last_hops.iter().collect::<Vec<_>>(),
+                       &last_hops,
                        amt_msat,
                        invoice.min_final_cltv_expiry() as u32,
                        &logger,
                        amt_msat,
                        invoice.min_final_cltv_expiry() as u32,
                        &logger,
index 52ac9b77578c0586a6586c58ffff17f4b9e80e98..9f4b8ac3589a27ef7c713751eb1d98ad18596978 100644 (file)
@@ -107,7 +107,11 @@ impl Readable for Route {
        }
 }
 
        }
 }
 
-/// A channel descriptor which provides a last-hop route to get_route
+/// A list of hops along a payment path terminating with a channel to the recipient.
+#[derive(Eq, PartialEq, Debug, Clone)]
+pub struct RouteHint(pub Vec<RouteHintHop>);
+
+/// A channel descriptor for a hop along a payment path.
 #[derive(Eq, PartialEq, Debug, Clone)]
 pub struct RouteHintHop {
        /// The node_id of the non-target end of the route
 #[derive(Eq, PartialEq, Debug, Clone)]
 pub struct RouteHintHop {
        /// The node_id of the non-target end of the route
@@ -329,8 +333,8 @@ fn compute_fees(amount_msat: u64, channel_fees: RoutingFees) -> Option<u64> {
 /// If the payee provided features in their invoice, they should be provided via payee_features.
 /// Without this, MPP will only be used if the payee's features are available in the network graph.
 ///
 /// If the payee provided features in their invoice, they should be provided via payee_features.
 /// Without this, MPP will only be used if the payee's features are available in the network graph.
 ///
-/// Extra routing hops between known nodes and the target will be used if they are included in
-/// last_hops.
+/// Private routing paths between a public node and the target may be included in `last_hops`.
+/// Currently, only the last hop in each path is considered.
 ///
 /// If some channels aren't announced, it may be useful to fill in a first_hops with the
 /// results from a local ChannelManager::list_usable_channels() call. If it is filled in, our
 ///
 /// If some channels aren't announced, it may be useful to fill in a first_hops with the
 /// results from a local ChannelManager::list_usable_channels() call. If it is filled in, our
@@ -344,7 +348,7 @@ fn compute_fees(amount_msat: u64, channel_fees: RoutingFees) -> Option<u64> {
 /// equal), however the enabled/disabled bit on such channels as well as the
 /// htlc_minimum_msat/htlc_maximum_msat *are* checked as they may change based on the receiving node.
 pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, payee: &PublicKey, payee_features: Option<InvoiceFeatures>, first_hops: Option<&[&ChannelDetails]>,
 /// equal), however the enabled/disabled bit on such channels as well as the
 /// htlc_minimum_msat/htlc_maximum_msat *are* checked as they may change based on the receiving node.
 pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, payee: &PublicKey, payee_features: Option<InvoiceFeatures>, first_hops: Option<&[&ChannelDetails]>,
-       last_hops: &[&RouteHintHop], final_value_msat: u64, final_cltv: u32, logger: L) -> Result<Route, LightningError> where L::Target: Logger {
+       last_hops: &[&RouteHint], final_value_msat: u64, final_cltv: u32, logger: L) -> Result<Route, LightningError> where L::Target: Logger {
        // TODO: Obviously *only* using total fee cost sucks. We should consider weighting by
        // uptime/success in using a node in the past.
        if *payee == *our_node_id {
        // TODO: Obviously *only* using total fee cost sucks. We should consider weighting by
        // uptime/success in using a node in the past.
        if *payee == *our_node_id {
@@ -359,7 +363,8 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
                return Err(LightningError{err: "Cannot send a payment of 0 msat".to_owned(), action: ErrorAction::IgnoreError});
        }
 
                return Err(LightningError{err: "Cannot send a payment of 0 msat".to_owned(), action: ErrorAction::IgnoreError});
        }
 
-       for last_hop in last_hops {
+       let last_hops = last_hops.iter().filter_map(|hops| hops.0.last()).collect::<Vec<_>>();
+       for last_hop in last_hops.iter() {
                if last_hop.src_node_id == *payee {
                        return Err(LightningError{err: "Last hop cannot have a payee as a source.".to_owned(), action: ErrorAction::IgnoreError});
                }
                if last_hop.src_node_id == *payee {
                        return Err(LightningError{err: "Last hop cannot have a payee as a source.".to_owned(), action: ErrorAction::IgnoreError});
                }
@@ -1154,7 +1159,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
 
 #[cfg(test)]
 mod tests {
 
 #[cfg(test)]
 mod tests {
-       use routing::router::{get_route, RouteHintHop, RoutingFees};
+       use routing::router::{get_route, RouteHint, RouteHintHop, RoutingFees};
        use routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
        use chain::transaction::OutPoint;
        use ln::features::{ChannelFeatures, InitFeatures, InvoiceFeatures, NodeFeatures};
        use routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
        use chain::transaction::OutPoint;
        use ln::features::{ChannelFeatures, InitFeatures, InvoiceFeatures, NodeFeatures};
@@ -2085,19 +2090,19 @@ mod tests {
                assert_eq!(route.paths[0][1].channel_features.le_flags(), &id_to_feature_flags(13));
        }
 
                assert_eq!(route.paths[0][1].channel_features.le_flags(), &id_to_feature_flags(13));
        }
 
-       fn last_hops(nodes: &Vec<PublicKey>) -> Vec<RouteHintHop> {
+       fn last_hops(nodes: &Vec<PublicKey>) -> Vec<RouteHint> {
                let zero_fees = RoutingFees {
                        base_msat: 0,
                        proportional_millionths: 0,
                };
                let zero_fees = RoutingFees {
                        base_msat: 0,
                        proportional_millionths: 0,
                };
-               vec!(RouteHintHop {
+               vec![RouteHint(vec![RouteHintHop {
                        src_node_id: nodes[3].clone(),
                        short_channel_id: 8,
                        fees: zero_fees,
                        cltv_expiry_delta: (8 << 8) | 1,
                        htlc_minimum_msat: None,
                        htlc_maximum_msat: None,
                        src_node_id: nodes[3].clone(),
                        short_channel_id: 8,
                        fees: zero_fees,
                        cltv_expiry_delta: (8 << 8) | 1,
                        htlc_minimum_msat: None,
                        htlc_maximum_msat: None,
-               }RouteHintHop {
+               }]), RouteHint(vec![RouteHintHop {
                        src_node_id: nodes[4].clone(),
                        short_channel_id: 9,
                        fees: RoutingFees {
                        src_node_id: nodes[4].clone(),
                        short_channel_id: 9,
                        fees: RoutingFees {
@@ -2107,14 +2112,14 @@ mod tests {
                        cltv_expiry_delta: (9 << 8) | 1,
                        htlc_minimum_msat: None,
                        htlc_maximum_msat: None,
                        cltv_expiry_delta: (9 << 8) | 1,
                        htlc_minimum_msat: None,
                        htlc_maximum_msat: None,
-               }RouteHintHop {
+               }]), RouteHint(vec![RouteHintHop {
                        src_node_id: nodes[5].clone(),
                        short_channel_id: 10,
                        fees: zero_fees,
                        cltv_expiry_delta: (10 << 8) | 1,
                        htlc_minimum_msat: None,
                        htlc_maximum_msat: None,
                        src_node_id: nodes[5].clone(),
                        short_channel_id: 10,
                        fees: zero_fees,
                        cltv_expiry_delta: (10 << 8) | 1,
                        htlc_minimum_msat: None,
                        htlc_maximum_msat: None,
-               })
+               }])]
        }
 
        #[test]
        }
 
        #[test]
@@ -2124,8 +2129,8 @@ mod tests {
 
                // Simple test across 2, 3, 5, and 4 via a last_hop channel
 
 
                // Simple test across 2, 3, 5, and 4 via a last_hop channel
 
-               // First check that lst hop can't have its source as the payee.
-               let invalid_last_hop = RouteHintHop {
+               // First check that last hop can't have its source as the payee.
+               let invalid_last_hop = RouteHint(vec![RouteHintHop {
                        src_node_id: nodes[6],
                        short_channel_id: 8,
                        fees: RoutingFees {
                        src_node_id: nodes[6],
                        short_channel_id: 8,
                        fees: RoutingFees {
@@ -2135,7 +2140,7 @@ mod tests {
                        cltv_expiry_delta: (8 << 8) | 1,
                        htlc_minimum_msat: None,
                        htlc_maximum_msat: None,
                        cltv_expiry_delta: (8 << 8) | 1,
                        htlc_minimum_msat: None,
                        htlc_maximum_msat: None,
-               };
+               }]);
 
                let mut invalid_last_hops = last_hops(&nodes);
                invalid_last_hops.push(invalid_last_hop);
 
                let mut invalid_last_hops = last_hops(&nodes);
                invalid_last_hops.push(invalid_last_hop);
@@ -2224,7 +2229,7 @@ mod tests {
                assert_eq!(route.paths[0][1].node_features.le_flags(), &Vec::<u8>::new()); // We dont pass flags in from invoices yet
                assert_eq!(route.paths[0][1].channel_features.le_flags(), &Vec::<u8>::new()); // We can't learn any flags from invoices, sadly
 
                assert_eq!(route.paths[0][1].node_features.le_flags(), &Vec::<u8>::new()); // We dont pass flags in from invoices yet
                assert_eq!(route.paths[0][1].channel_features.le_flags(), &Vec::<u8>::new()); // We can't learn any flags from invoices, sadly
 
-               last_hops[0].fees.base_msat = 1000;
+               last_hops[0].0[0].fees.base_msat = 1000;
 
                // Revert to via 6 as the fee on 8 goes up
                let route = get_route(&our_id, &net_graph_msg_handler.network_graph.read().unwrap(), &nodes[6], None, None, &last_hops.iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger)).unwrap();
 
                // Revert to via 6 as the fee on 8 goes up
                let route = get_route(&our_id, &net_graph_msg_handler.network_graph.read().unwrap(), &nodes[6], None, None, &last_hops.iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger)).unwrap();
@@ -2312,7 +2317,7 @@ mod tests {
                let target_node_id = PublicKey::from_secret_key(&Secp256k1::new(), &SecretKey::from_slice(&hex::decode(format!("{:02}", 43).repeat(32)).unwrap()[..]).unwrap());
 
                // If we specify a channel to a middle hop, that overrides our local channel view and that gets used
                let target_node_id = PublicKey::from_secret_key(&Secp256k1::new(), &SecretKey::from_slice(&hex::decode(format!("{:02}", 43).repeat(32)).unwrap()[..]).unwrap());
 
                // If we specify a channel to a middle hop, that overrides our local channel view and that gets used
-               let last_hops = vec![RouteHintHop {
+               let last_hops = RouteHint(vec![RouteHintHop {
                        src_node_id: middle_node_id,
                        short_channel_id: 8,
                        fees: RoutingFees {
                        src_node_id: middle_node_id,
                        short_channel_id: 8,
                        fees: RoutingFees {
@@ -2322,7 +2327,7 @@ mod tests {
                        cltv_expiry_delta: (8 << 8) | 1,
                        htlc_minimum_msat: None,
                        htlc_maximum_msat: None,
                        cltv_expiry_delta: (8 << 8) | 1,
                        htlc_minimum_msat: None,
                        htlc_maximum_msat: None,
-               }];
+               }]);
                let our_chans = vec![channelmanager::ChannelDetails {
                        channel_id: [0; 32],
                        funding_txo: Some(OutPoint { txid: bitcoin::Txid::from_slice(&[0; 32]).unwrap(), index: 0 }),
                let our_chans = vec![channelmanager::ChannelDetails {
                        channel_id: [0; 32],
                        funding_txo: Some(OutPoint { txid: bitcoin::Txid::from_slice(&[0; 32]).unwrap(), index: 0 }),
@@ -2337,7 +2342,7 @@ mod tests {
                        is_usable: true, is_public: true,
                        counterparty_forwarding_info: None,
                }];
                        is_usable: true, is_public: true,
                        counterparty_forwarding_info: None,
                }];
-               let route = get_route(&source_node_id, &NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash()), &target_node_id, None, Some(&our_chans.iter().collect::<Vec<_>>()), &last_hops.iter().collect::<Vec<_>>(), 100, 42, Arc::new(test_utils::TestLogger::new())).unwrap();
+               let route = get_route(&source_node_id, &NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash()), &target_node_id, None, Some(&our_chans.iter().collect::<Vec<_>>()), &vec![&last_hops], 100, 42, Arc::new(test_utils::TestLogger::new())).unwrap();
 
                assert_eq!(route.paths[0].len(), 2);
 
 
                assert_eq!(route.paths[0].len(), 2);