]> git.bitcoin.ninja Git - ldk-java/commitdiff
[Java] Expose the ProbabilisticScorer from the CMC correctly
authorMatt Corallo <git@bluematt.me>
Wed, 8 Mar 2023 05:04:14 +0000 (05:04 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 8 Mar 2023 05:13:01 +0000 (05:13 +0000)
We'd previously exposed a race-y version of the ProbabilisticScorer
where its held within a lock but also made public with no threading
requirements placed on it.

This resolves that issue by exposing it via a wrapper that holds
the score lock.

src/main/java/org/ldk/batteries/ChannelManagerConstructor.java
src/test/java/org/ldk/HumanObjectPeerTest.java

index 029fac792e0b407be50b8eeeb7ffe625ec4be2d0..db37118dfd71cb10f43dda02d21789d6facad658 100644 (file)
@@ -76,10 +76,32 @@ public class ChannelManagerConstructor {
      * we want to expose underlying details of the scorer itself. Thus, we expose a safe version that takes the lock
      * then returns a reference to this scorer.
      */
-    @Nullable private final ProbabilisticScorer prob_scorer;
+    private final ProbabilisticScorer prob_scorer;
     private final Logger logger;
     private final KeysManager keys_manager;
 
+    /**
+     * Exposes the `ProbabilisticScorer` wrapped inside a lock. Don't forget to `close` this lock when you're done with
+     * it so normal scoring operation can continue.
+     */
+    public class ScorerWrapper implements AutoCloseable {
+        private final Score lock;
+        public final ProbabilisticScorer prob_scorer;
+        private ScorerWrapper(Score lock, ProbabilisticScorer prob_scorer) {
+            this.lock = lock; this.prob_scorer = prob_scorer;
+        }
+        @Override public void close() throws Exception {
+            lock.destroy();
+        }
+    }
+    /**
+     * Gets the `ProbabilisticScorer` which backs the public lockable `scorer`. Don't forget to `close` the lock when
+     * you're done with it.
+     */
+    public ScorerWrapper get_locked_scorer() {
+        return new ScorerWrapper(this.scorer.as_LockableScore().lock(), this.prob_scorer);
+    }
+
     /**
      * Deserializes a channel manager and a set of channel monitors from the given serialized copies and interface implementations
      *
index 113eae9035d32401597bc566776bc6e0bdef0cb2..0fa0c35fe5c36fd3f76dadea7ad50868ab52b89c 100644 (file)
@@ -946,7 +946,7 @@ class HumanObjectPeerTestInstance {
             this.best_blockhash = best_blockhash;
         }
     }
-    void do_test_message_handler_b(TestState state) throws InterruptedException {
+    void do_test_message_handler_b(TestState state) throws Exception {
         GcCheck obj = new GcCheck();
         if (state.ref_block != null) {
             // Ensure the original peers get freed before we move on. Note that we have to be in a different function
@@ -1128,6 +1128,16 @@ class HumanObjectPeerTestInstance {
         assert upd_msg instanceof Result_ChannelUpdateDecodeErrorZ.Result_ChannelUpdateDecodeErrorZ_OK;
         assert ((Result_ChannelUpdateDecodeErrorZ.Result_ChannelUpdateDecodeErrorZ_OK) upd_msg).res.get_contents().get_htlc_maximum_msat() == 0xdeadbeef42424242L;
         Option_NetworkUpdateZ upd = Option_NetworkUpdateZ.some(NetworkUpdate.channel_update_message(((Result_ChannelUpdateDecodeErrorZ.Result_ChannelUpdateDecodeErrorZ_OK) upd_msg).res));
+
+        if (use_chan_manager_constructor) {
+            // Lock the scorer twice back-to-back to check that the try-with-resources AutoCloseable on the scorer works.
+            try (ChannelManagerConstructor.ScorerWrapper score = state.peer1.constructor.get_locked_scorer()) {
+                score.prob_scorer.debug_log_liquidity_stats();
+            }
+            try (ChannelManagerConstructor.ScorerWrapper score = state.peer1.constructor.get_locked_scorer()) {
+                score.prob_scorer.clear_manual_penalties();
+            }
+        }
     }
 
     java.util.LinkedList<WeakReference<Object>> must_free_objs = new java.util.LinkedList();
@@ -1143,13 +1153,13 @@ class HumanObjectPeerTestInstance {
     }
 }
 public class HumanObjectPeerTest {
-    static HumanObjectPeerTestInstance do_test_run(boolean nice_close, boolean use_km_wrapper, boolean use_manual_watch, boolean reload_peers, boolean break_cross_peer_refs, boolean nio_peer_handler, boolean use_ignoring_routing_handler, boolean use_chan_manager_constructor, boolean use_invoice_payer) throws InterruptedException {
+    static HumanObjectPeerTestInstance do_test_run(boolean nice_close, boolean use_km_wrapper, boolean use_manual_watch, boolean reload_peers, boolean break_cross_peer_refs, boolean nio_peer_handler, boolean use_ignoring_routing_handler, boolean use_chan_manager_constructor, boolean use_invoice_payer) throws Exception {
         HumanObjectPeerTestInstance instance = new HumanObjectPeerTestInstance(nice_close, use_km_wrapper, use_manual_watch, reload_peers, break_cross_peer_refs, nio_peer_handler, !nio_peer_handler, use_ignoring_routing_handler, use_chan_manager_constructor, use_invoice_payer);
         HumanObjectPeerTestInstance.TestState state = instance.do_test_message_handler();
         instance.do_test_message_handler_b(state);
         return instance;
     }
-    static void do_test(boolean nice_close, boolean use_km_wrapper, boolean use_manual_watch, boolean reload_peers, boolean break_cross_peer_refs, boolean nio_peer_handler, boolean use_ignoring_routing_handler, boolean use_chan_manager_constructor, boolean use_invoice_payer) throws InterruptedException {
+    static void do_test(boolean nice_close, boolean use_km_wrapper, boolean use_manual_watch, boolean reload_peers, boolean break_cross_peer_refs, boolean nio_peer_handler, boolean use_ignoring_routing_handler, boolean use_chan_manager_constructor, boolean use_invoice_payer) throws Exception {
         HumanObjectPeerTestInstance instance = do_test_run(nice_close, use_km_wrapper, use_manual_watch, reload_peers, break_cross_peer_refs, nio_peer_handler, use_ignoring_routing_handler, use_chan_manager_constructor, use_invoice_payer);
         while (instance.gc_count != instance.gc_exp_count) {
             System.gc();
@@ -1159,7 +1169,7 @@ public class HumanObjectPeerTest {
             assert o.get() == null;
     }
     public static final int TEST_ITERATIONS = (1 << 9);
-    public static void do_test_message_handler(IntConsumer statusFunction) throws InterruptedException {
+    public static void do_test_message_handler(IntConsumer statusFunction) throws Exception {
         Thread gc_thread = new Thread(() -> {
             while (true) {
                 System.gc();
@@ -1209,7 +1219,7 @@ public class HumanObjectPeerTest {
         gc_thread.join();
     }
     @Test
-    public void test_message_handler() throws InterruptedException {
+    public void test_message_handler() throws Exception {
         do_test_message_handler(i -> System.err.println("Running test with flags " + i));
     }