Test the `RouteParameters` passed to `TestRouter`
[rust-lightning] / lightning / src / util / test_utils.rs
index 1dae61ab3ef5b1841241b069612f71c7e5cc8f38..f9a2eafb82d51739011dc1e80d922655a0f329bb 100644 (file)
@@ -77,7 +77,7 @@ impl chaininterface::FeeEstimator for TestFeeEstimator {
 
 pub struct TestRouter<'a> {
        pub network_graph: Arc<NetworkGraph<&'a TestLogger>>,
-       pub next_routes: Mutex<VecDeque<Result<Route, LightningError>>>,
+       pub next_routes: Mutex<VecDeque<(RouteParameters, Result<Route, LightningError>)>>,
 }
 
 impl<'a> TestRouter<'a> {
@@ -85,9 +85,9 @@ impl<'a> TestRouter<'a> {
                Self { network_graph, next_routes: Mutex::new(VecDeque::new()), }
        }
 
-       pub fn expect_find_route(&self, result: Result<Route, LightningError>) {
+       pub fn expect_find_route(&self, query: RouteParameters, result: Result<Route, LightningError>) {
                let mut expected_routes = self.next_routes.lock().unwrap();
-               expected_routes.push_back(result);
+               expected_routes.push_back((query, result));
        }
 }
 
@@ -96,8 +96,9 @@ impl<'a> Router for TestRouter<'a> {
                &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&channelmanager::ChannelDetails]>,
                inflight_htlcs: &InFlightHtlcs
        ) -> Result<Route, msgs::LightningError> {
-               if let Some(find_route_res) = self.next_routes.lock().unwrap().pop_front() {
-                       return find_route_res
+               if let Some((find_route_query, find_route_res)) = self.next_routes.lock().unwrap().pop_front() {
+                       assert_eq!(find_route_query, *params);
+                       return find_route_res;
                }
                let logger = TestLogger::new();
                find_route(