Expand router fuzzer to test returned routes match the graph
authorMatt Corallo <git@bluematt.me>
Thu, 25 Feb 2021 16:47:25 +0000 (11:47 -0500)
committerMatt Corallo <git@bluematt.me>
Thu, 25 Feb 2021 23:22:35 +0000 (18:22 -0500)
fuzz/src/router.rs

index 4d06d9e3e2ef926fa6093a18f931ab236606288f..1ec5294b5f2b41fbf87934ada735e0a8cf03bb8f 100644 (file)
@@ -170,6 +170,8 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
        let mut node_pks: HashSet<_, NonRandomHash> = HashSet::default();
        let mut scid = 42;
 
+       let mut channel_limits = HashMap::new();
+
        loop {
                match get_slice!(1)[0] {
                        0 => {
@@ -195,11 +197,16 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
                                let _ = net_graph.update_channel_from_unsigned_announcement(&msg, &Some(&FuzzChainSource { input: Arc::clone(&input) }));
                        },
                        3 => {
-                               let _ = net_graph.update_channel_unsigned(&decode_msg!(msgs::UnsignedChannelUpdate, 72));
+                               let msg = decode_msg!(msgs::UnsignedChannelUpdate, 72);
+                               if net_graph.update_channel_unsigned(&msg).is_ok() {
+                                       channel_limits.insert((msg.short_channel_id, if msg.flags & 1 == 1 { true } else { false }), msg);
+                               }
                        },
                        4 => {
                                let short_channel_id = slice_to_be64(get_slice!(8));
                                net_graph.close_channel_from_update(short_channel_id, false);
+                               channel_limits.remove(&(short_channel_id, true));
+                               channel_limits.remove(&(short_channel_id, false));
                        },
                        _ if node_pks.is_empty() => {},
                        _ => {
@@ -209,7 +216,7 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
                                        count => {
                                                for _ in 0..count {
                                                        scid += 1;
-                                                       let rnid = node_pks.iter().skip(slice_to_be16(get_slice!(2))as usize % node_pks.len()).next().unwrap();
+                                                       let rnid = node_pks.iter().skip(slice_to_be16(get_slice!(2)) as usize % node_pks.len()).next().unwrap();
                                                        first_hops_vec.push(ChannelDetails {
                                                                channel_id: [0; 32],
                                                                short_channel_id: Some(scid),
@@ -230,7 +237,7 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
                                        let count = get_slice!(1)[0];
                                        for _ in 0..count {
                                                scid += 1;
-                                               let rnid = node_pks.iter().skip(slice_to_be16(get_slice!(2))as usize % node_pks.len()).next().unwrap();
+                                               let rnid = node_pks.iter().skip(slice_to_be16(get_slice!(2)) as usize % node_pks.len()).next().unwrap();
                                                last_hops_vec.push(RouteHint {
                                                        src_node_id: *rnid,
                                                        short_channel_id: scid,
@@ -246,10 +253,115 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
                                }
                                let last_hops = &last_hops_vec[..];
                                for target in node_pks.iter() {
-                                       let _ = get_route(&our_pubkey, &net_graph, target,
-                                               first_hops.map(|c| c.iter().collect::<Vec<_>>()).as_ref().map(|a| a.as_slice()),
-                                               &last_hops.iter().collect::<Vec<_>>(),
-                                               slice_to_be64(get_slice!(8)), slice_to_be32(get_slice!(4)), Arc::clone(&logger));
+                                       let value_msat = slice_to_be64(get_slice!(8));
+                                       let cltv = slice_to_be32(get_slice!(4));
+                                       if let Ok(route) = get_route(&our_pubkey, &net_graph, target,
+                                                       first_hops.map(|c| c.iter().collect::<Vec<_>>()).as_ref().map(|a| a.as_slice()),
+                                                       &last_hops.iter().collect::<Vec<_>>(),
+                                                       value_msat, cltv, Arc::clone(&logger)) {
+                                               let mut sent_msat = 0;
+                                               let mut a_last_hop_hit_min = false;
+                                               for (idxp, path) in route.paths.iter().enumerate() {
+                                                       macro_rules! lookup_fee_info {
+                                                               ($short_channel_id: expr, $check_first_hops: expr, $check_last_hops: expr) => { {
+                                                                       'find_chan_loop: loop {
+                                                                               if $check_first_hops {
+                                                                                       if let Some(hops) = first_hops {
+                                                                                               for first_hop in hops {
+                                                                                                       if first_hop.short_channel_id == Some($short_channel_id) {
+                                                                                                               break 'find_chan_loop Some((None, Some(first_hop.outbound_capacity_msat),
+                                                                                                                       0, RoutingFees { base_msat: 0, proportional_millionths: 0 }));
+                                                                                                       }
+                                                                                               }
+                                                                                       }
+                                                                               }
+                                                                               if $check_last_hops {
+                                                                                       for last_hop in last_hops {
+                                                                                               if last_hop.short_channel_id == $short_channel_id {
+                                                                                                       break 'find_chan_loop Some((last_hop.htlc_minimum_msat,
+                                                                                                               last_hop.htlc_maximum_msat, last_hop.cltv_expiry_delta,
+                                                                                                               last_hop.fees));
+                                                                                               }
+                                                                                       }
+                                                                               }
+                                                                               // We don't know by looking at a route whether the inbound or outbound
+                                                                               // direction is in use, so we only test if we only have one filled in.
+                                                                               let upd_a = channel_limits.get(&($short_channel_id, false));
+                                                                               let upd_b = channel_limits.get(&($short_channel_id, true));
+                                                                               if upd_a.is_some() && upd_b.is_some() { break 'find_chan_loop None; }
+                                                                               let upd = if let Some(u) = upd_a { u } else if let Some(u) = upd_b { u } else { panic!(); };
+                                                                               break 'find_chan_loop Some((Some(upd.htlc_minimum_msat),
+                                                                                       match upd.htlc_maximum_msat {
+                                                                                               msgs::OptionalField::Absent => None,
+                                                                                               msgs::OptionalField::Present(v) => Some(v),
+                                                                                       }, upd.cltv_expiry_delta,
+                                                                                       RoutingFees { base_msat: upd.fee_base_msat,
+                                                                                               proportional_millionths: upd.fee_proportional_millionths }));
+                                                                       }
+                                                               } }
+                                                       }
+
+                                                       sent_msat += path.last().unwrap().fee_msat;
+                                                       assert_eq!(path.last().unwrap().cltv_expiry_delta, cltv);
+
+                                                       let mut path_total_msat = path.last().unwrap().fee_msat;
+                                                       let mut hop_hit_min = false;
+                                                       let mut hop_overpaid_fees = false;
+
+                                                       for (idx, first_prev_hop) in path.windows(2).enumerate().rev() {
+                                                               let (prev_hop, hop) = (&first_prev_hop[0], &first_prev_hop[1]);
+                                                               if let Some((min, max, expiry, fees)) = lookup_fee_info!(hop.short_channel_id, idx == 0, idx == path.len() - 2) {
+
+                                                                       if let Some(v) = max {
+                                                                               assert!(path_total_msat <= v);
+                                                                       }
+                                                                       if let Some(v) = min {
+                                                                               assert!(path_total_msat >= v);
+                                                                               if path_total_msat == v {
+                                                                                       if idx == 0 { a_last_hop_hit_min = true; }
+                                                                                       hop_hit_min = true;
+                                                                               }
+                                                                       }
+                                                                       let expected_fees = fees.base_msat as u64 + fees.proportional_millionths as u64 * path_total_msat / 1_000_000;
+                                                                       assert!(prev_hop.fee_msat >= expected_fees);
+                                                                       if prev_hop.fee_msat > expected_fees { hop_overpaid_fees = true; }
+                                                                       assert!(prev_hop.fee_msat >= fees.base_msat as u64);
+                                                                       path_total_msat += prev_hop.fee_msat;
+                                                                       assert_eq!(prev_hop.cltv_expiry_delta, expiry as u32);
+                                                               } else {
+                                                                       // Assume the hop hit the minimum, as we failed to find the
+                                                                       // exact fee information
+                                                                       if idx == 0 { a_last_hop_hit_min = true; }
+                                                                       hop_hit_min = true;
+                                                               }
+                                                       }
+
+                                                       // Finally, handle the first hop - we don't pay fees on it, but we
+                                                       // still need to check the min/max and apply the minimum checks.
+                                                       if let Some((min, max, _, _)) = lookup_fee_info!(path.first().unwrap().short_channel_id, true, path.len() == 1) {
+                                                               if let Some(v) = max {
+                                                                       assert!(path_total_msat <= v);
+                                                               }
+                                                               if let Some(v) = min {
+                                                                       assert!(path_total_msat >= v);
+                                                                       if path_total_msat == v {
+                                                                               if path.len() == 1 { a_last_hop_hit_min = true; }
+                                                                               hop_hit_min = true;
+                                                                       }
+                                                               }
+                                                       } else {
+                                                               if path.len() == 1 { a_last_hop_hit_min = true; }
+                                                       }
+
+                                                       // If we overpaid on fees, it has to be because at least
+                                                       // one hop only *just* paid the htlc_minimum_msat value.
+                                                       assert!(hop_hit_min || !hop_overpaid_fees);
+                                               }
+                                               assert!(sent_msat >= value_msat);
+                                               if !a_last_hop_hit_min {
+                                                       assert_eq!(sent_msat, value_msat);
+                                               }
+                                       }
                                }
                        },
                }