[Java] Allow the user to hook route-finding when using CMC
authorMatt Corallo <git@bluematt.me>
Thu, 9 Mar 2023 18:48:32 +0000 (18:48 +0000)
committerMatt Corallo <git@bluematt.me>
Thu, 9 Mar 2023 19:12:27 +0000 (19:12 +0000)
src/main/java/org/ldk/batteries/ChannelManagerConstructor.java
src/test/java/org/ldk/HumanObjectPeerTest.java

index db37118dfd71cb10f43dda02d21789d6facad658..920eb0523a537d6fd5bf0750588e9ea679c14860 100644 (file)
@@ -102,17 +102,40 @@ public class ChannelManagerConstructor {
         return new ScorerWrapper(this.scorer.as_LockableScore().lock(), this.prob_scorer);
     }
 
+    /**
+     * A simple interface to provide routes to LDK.
+     */
+    public interface RouterWrapper {
+        /**
+         * Gets a route for the given payment.
+         *
+         * @param payment_hash is non-null for this-node-originated payments, however in the future trampoline or other
+         *                     HTLC re-routing may cause it to be null as we find routes for payments which we did not
+         *                     originate.
+         * @param payment_id is non-null for this-node-originated payments, however in the future trampoline or other
+         *                   HTLC re-routing may cause it to be null as we find routes for payments which we did not
+         *                   originate.
+         * @param default_router Provides a router which uses the LDK route-finder and a ProbabilisticScorer using the
+         *                       provided ProbabilisticScoringParameters. You may use this to fetch a "default" route,
+         *                       modifying or storing it as you wish before returning the route to LDK.
+         */
+        Result_RouteLightningErrorZ find_route(byte[] payer_node_id, RouteParameters route_params, ChannelDetails[] first_hops,
+            InFlightHtlcs inflight_htlcs, @Nullable byte[] payment_hash, @Nullable byte[] payment_id, DefaultRouter default_router);
+    }
+
     /**
      * Deserializes a channel manager and a set of channel monitors from the given serialized copies and interface implementations
      *
      * @param filter If provided, the outputs which were previously registered to be monitored for will be loaded into the filter.
      *               Note that if the provided Watch is a ChainWatch and has an associated filter, the previously registered
      *               outputs will be loaded when chain_sync_completed is called.
+     * @param router_wrapper If provided, routes will be fetched by calling the given router rather than an LDK `DefaultRouter`.
      */
     public ChannelManagerConstructor(byte[] channel_manager_serialized, byte[][] channel_monitors_serialized, UserConfig config,
                                      KeysManager keys_manager, FeeEstimator fee_estimator, ChainMonitor chain_monitor,
                                      @Nullable Filter filter, byte[] net_graph_serialized,
                                      ProbabilisticScoringParameters scoring_params, byte[] probabilistic_scorer_bytes,
+                                     @Nullable RouterWrapper router_wrapper,
                                      BroadcasterInterface tx_broadcaster, Logger logger) throws InvalidSerializedDataException {
         this.keys_manager = keys_manager;
         EntropySource entropy_source = keys_manager.as_EntropySource();
@@ -130,7 +153,21 @@ public class ChannelManagerConstructor {
         }
         this.prob_scorer = ((Result_ProbabilisticScorerDecodeErrorZ.Result_ProbabilisticScorerDecodeErrorZ_OK)scorer_res).res;
         this.scorer = MultiThreadedLockableScore.of(this.prob_scorer.as_Score());
-        DefaultRouter router = DefaultRouter.of(this.net_graph, logger, entropy_source.get_secure_random_bytes(), scorer.as_LockableScore());
+
+        DefaultRouter default_router = DefaultRouter.of(this.net_graph, logger, entropy_source.get_secure_random_bytes(), scorer.as_LockableScore());
+        Router router;
+        if (router_wrapper != null) {
+            router = Router.new_impl(new Router.RouterInterface() {
+                @Override public Result_RouteLightningErrorZ find_route(byte[] payer, RouteParameters route_params, ChannelDetails[] first_hops, InFlightHtlcs inflight_htlcs) {
+                    return router_wrapper.find_route(payer, route_params, first_hops, inflight_htlcs, null, null, default_router);
+                }
+                @Override public Result_RouteLightningErrorZ find_route_with_id(byte[] payer, RouteParameters route_params, ChannelDetails[] first_hops, InFlightHtlcs inflight_htlcs, byte[] payment_hash, byte[] payment_id) {
+                    return router_wrapper.find_route(payer, route_params, first_hops, inflight_htlcs, payment_hash, payment_id, default_router);
+                }
+            });
+        } else {
+            router = default_router.as_Router();
+        }
 
         final ChannelMonitor[] monitors = new ChannelMonitor[channel_monitors_serialized.length];
         this.channel_monitors = new TwoTuple_BlockHashChannelMonitorZ[monitors.length];
@@ -149,7 +186,7 @@ public class ChannelManagerConstructor {
         Result_C2Tuple_BlockHashChannelManagerZDecodeErrorZ res =
                 UtilMethods.C2Tuple_BlockHashChannelManagerZ_read(channel_manager_serialized, keys_manager.as_EntropySource(),
                         keys_manager.as_NodeSigner(), keys_manager.as_SignerProvider(), fee_estimator, chain_monitor.as_Watch(),
-                        tx_broadcaster, router.as_Router(), logger, config, monitors);
+                        tx_broadcaster, router, logger, config, monitors);
         if (!res.is_ok()) {
             throw new InvalidSerializedDataException("Serialized ChannelManager was corrupt");
         }
@@ -166,10 +203,13 @@ public class ChannelManagerConstructor {
 
     /**
      * Constructs a channel manager from the given interface implementations
+     *
+     * @param router_wrapper If provided, routes will be fetched by calling the given router rather than an LDK `DefaultRouter`.
      */
     public ChannelManagerConstructor(Network network, UserConfig config, byte[] current_blockchain_tip_hash, int current_blockchain_tip_height,
                                      KeysManager keys_manager, FeeEstimator fee_estimator, ChainMonitor chain_monitor,
                                      NetworkGraph net_graph, ProbabilisticScoringParameters scoring_params,
+                                     @Nullable RouterWrapper router_wrapper,
                                      BroadcasterInterface tx_broadcaster, Logger logger) {
         this.keys_manager = keys_manager;
         EntropySource entropy_source = keys_manager.as_EntropySource();
@@ -178,14 +218,27 @@ public class ChannelManagerConstructor {
         assert(scoring_params != null);
         this.prob_scorer = ProbabilisticScorer.of(scoring_params, net_graph, logger);
         this.scorer = MultiThreadedLockableScore.of(this.prob_scorer.as_Score());
-        DefaultRouter router = DefaultRouter.of(this.net_graph, logger, entropy_source.get_secure_random_bytes(), scorer.as_LockableScore());
 
+        DefaultRouter default_router = DefaultRouter.of(this.net_graph, logger, entropy_source.get_secure_random_bytes(), scorer.as_LockableScore());
+        Router router;
+        if (router_wrapper != null) {
+            router = Router.new_impl(new Router.RouterInterface() {
+                @Override public Result_RouteLightningErrorZ find_route(byte[] payer, RouteParameters route_params, ChannelDetails[] first_hops, InFlightHtlcs inflight_htlcs) {
+                    return router_wrapper.find_route(payer, route_params, first_hops, inflight_htlcs, null, null, default_router);
+                }
+                @Override public Result_RouteLightningErrorZ find_route_with_id(byte[] payer, RouteParameters route_params, ChannelDetails[] first_hops, InFlightHtlcs inflight_htlcs, byte[] payment_hash, byte[] payment_id) {
+                    return router_wrapper.find_route(payer, route_params, first_hops, inflight_htlcs, payment_hash, payment_id, default_router);
+                }
+            });
+        } else {
+            router = default_router.as_Router();
+        }
         channel_monitors = new TwoTuple_BlockHashChannelMonitorZ[0];
         channel_manager_latest_block_hash = null;
         this.chain_monitor = chain_monitor;
         BestBlock block = BestBlock.of(current_blockchain_tip_hash, current_blockchain_tip_height);
         ChainParameters params = ChainParameters.of(network, block);
-        channel_manager = ChannelManager.of(fee_estimator, chain_monitor.as_Watch(), tx_broadcaster, router.as_Router(), logger,
+        channel_manager = ChannelManager.of(fee_estimator, chain_monitor.as_Watch(), tx_broadcaster, router, logger,
             keys_manager.as_EntropySource(), keys_manager.as_NodeSigner(), keys_manager.as_SignerProvider(), config, params);
         this.logger = logger;
     }
index 0fa0c35fe5c36fd3f76dadea7ad50868ab52b89c..a0ab3be3bc6e27b49e90e16b415b52ed2b800b54 100644 (file)
@@ -376,7 +376,14 @@ class HumanObjectPeerTestInstance {
             if (use_chan_manager_constructor) {
                 this.constructor = new ChannelManagerConstructor(Network.LDKNetwork_Bitcoin, get_config(), new byte[32], 0,
                         this.explicit_keys_manager, this.fee_estimator, this.chain_monitor, this.net_graph,
-                        ProbabilisticScoringParameters.with_default(), this.tx_broadcaster, this.logger);
+                        ProbabilisticScoringParameters.with_default(), (ChannelManagerConstructor.RouterWrapper)
+                            (payer_node_id, route_params, first_hops, inflight_htlcs, payment_hash, payment_id, default_router) -> {
+                                assert payment_hash != null && payment_id != null;
+                                Router r = default_router.as_Router();
+                                must_free_objs.add(new WeakReference<>(r));
+                                return r.find_route_with_id(payer_node_id, route_params, first_hops, inflight_htlcs, payment_hash, payment_id);
+                            },
+                        this.tx_broadcaster, this.logger);
                 constructor.chain_sync_completed(new ChannelManagerConstructor.EventHandler() {
                     @Override public void handle_event(Event event) {
                         synchronized (pending_manager_events) {
@@ -456,7 +463,7 @@ class HumanObjectPeerTestInstance {
                     }
                     this.constructor = new ChannelManagerConstructor(serialized, monitors, get_config(),
                             this.explicit_keys_manager, this.fee_estimator, this.chain_monitor, filter_nullable,
-                            serialized_graph, ProbabilisticScoringParameters.with_default(), serialized_scorer,
+                            serialized_graph, ProbabilisticScoringParameters.with_default(), serialized_scorer, null,
                             this.tx_broadcaster, this.logger);
                     try {
                         // Test that ChannelManagerConstructor correctly rejects duplicate ChannelMonitors
@@ -465,7 +472,7 @@ class HumanObjectPeerTestInstance {
                         monitors_dupd[1] = monitors[0];
                         ChannelManagerConstructor constr = new ChannelManagerConstructor(serialized, monitors_dupd, get_config(),
                                 this.explicit_keys_manager, this.fee_estimator, this.chain_monitor, filter_nullable,
-                                serialized_graph, ProbabilisticScoringParameters.with_default(), serialized_scorer,
+                                serialized_graph, ProbabilisticScoringParameters.with_default(), serialized_scorer, null,
                                 this.tx_broadcaster, this.logger);
                         assert false;
                     } catch (ChannelManagerConstructor.InvalidSerializedDataException e) {}