Fix returning traits in trait calls, move towards clone on return
authorMatt Corallo <git@bluematt.me>
Sun, 25 Oct 2020 04:36:05 +0000 (00:36 -0400)
committerMatt Corallo <git@bluematt.me>
Sun, 25 Oct 2020 04:41:05 +0000 (00:41 -0400)
genbindings.py
src/main/java/org/ldk/impl/bindings.java
src/main/java/org/ldk/structs/ChannelKeys.java
src/main/java/org/ldk/structs/KeysInterface.java
src/main/java/org/ldk/structs/SocketDescriptor.java
src/main/jni/bindings.c
src/main/jni/org_ldk_impl_bindings.h
src/test/java/org/ldk/HumanObjectPeerTest.java

index 56d36da4b653c2222155052bd072810a0cd417e6..3f7a12ee23d402b00044cd967b4c27c308af2006 100755 (executable)
@@ -230,9 +230,6 @@ trait_structs = set()
 result_types = set()
 tuple_types = {}
 
-def is_common_base_ext(struct_name):
-    return struct_name in complex_enums or struct_name in opaque_structs or struct_name in trait_structs or struct_name in result_types
-
 var_is_arr_regex = re.compile("\(\*([A-za-z0-9_]*)\)\[([a-z0-9]*)\]")
 var_ty_regex = re.compile("([A-za-z_0-9]*)(.*)")
 java_c_types_none_allowed = True # Unset when we do the real pass that populates the above sets
@@ -737,12 +734,19 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                         arg_conv_cleanup = None,
                         ret_conv = ("jclass " + ty_info.var_name + "_conv = " + ty_info.rust_obj + "_to_java(_env, ", ");"),
                         ret_conv_name = ty_info.var_name + "_conv", to_hu_conv = None, to_hu_conv_name = None, from_hu_conv = None)
-                base_conv = ty_info.rust_obj + " " + ty_info.var_name + "_conv = *(" + ty_info.rust_obj + "*)" + ty_info.var_name + ";";
+                base_conv = ty_info.rust_obj + " " + ty_info.var_name + "_conv = *(" + ty_info.rust_obj + "*)" + ty_info.var_name + ";"
                 if ty_info.rust_obj in trait_structs:
                     if not is_free:
-                        base_conv = base_conv + "\nif (" + ty_info.var_name + "_conv.free == " + ty_info.rust_obj + "_JCalls_free) {\n"
-                        base_conv = base_conv + "\t// If this_arg is a JCalls struct, then we need to increment the refcnt in it.\n"
-                        base_conv = base_conv + "\t" + ty_info.rust_obj + "_JCalls_clone(" + ty_info.var_name + "_conv.this_arg);\n}"
+                        needs_full_clone = not is_free and (not ty_info.is_ptr and not holds_ref or ty_info.requires_clone == True) and ty_info.requires_clone != False
+                        if needs_full_clone and (ty_info.java_hu_ty + "_clone") in clone_fns:
+                            base_conv = base_conv + "\n" + ty_info.var_name + "_conv = " + ty_info.java_hu_ty + "_clone(" + ty_info.var_name + ");"
+                        else:
+                            base_conv = base_conv + "\nif (" + ty_info.var_name + "_conv.free == " + ty_info.rust_obj + "_JCalls_free) {\n"
+                            base_conv = base_conv + "\t// If this_arg is a JCalls struct, then we need to increment the refcnt in it.\n"
+                            base_conv = base_conv + "\t" + ty_info.rust_obj + "_JCalls_clone(" + ty_info.var_name + "_conv.this_arg);\n}"
+                            if needs_full_clone:
+                                base_conv = base_conv + "// Warning: we may need a move here but can't do a full clone!\n"
+
                     else:
                         base_conv = base_conv + "\n" + "FREE((void*)" + ty_info.var_name + ");"
                     return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
@@ -992,7 +996,10 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                     out_java_struct.write(info.arg_name)
             out_java_struct.write(");\n")
             if ret_info.to_hu_conv is not None:
-                out_java_struct.write("\t\t" + ret_info.to_hu_conv.replace("\n", "\n\t\t") + "\n")
+                if ret_info.rust_obj == "LDK" + struct_meth:
+                    out_java_struct.write("\t\t" + ret_info.to_hu_conv.replace("\n", "\n\t\t").replace("this", ret_info.to_hu_conv_name) + "\n")
+                else:
+                    out_java_struct.write("\t\t" + ret_info.to_hu_conv.replace("\n", "\n\t\t") + "\n")
 
             for info in arg_names:
                 if info.arg_name == "this_ptr" or info.arg_name == "this_arg":
@@ -1326,7 +1333,8 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                             java_trait_constr = java_trait_constr + "\t\t\t\t" + ret_ty_info.java_ty + " result = " + ret_ty_info.from_hu_conv[0] + ";\n"
                             if ret_ty_info.from_hu_conv[1] != "":
                                 java_trait_constr = java_trait_constr + "\t\t\t\t" + ret_ty_info.from_hu_conv[1].replace("this", "impl_holder.held") + ";\n"
-                            if is_common_base_ext(ret_ty_info.rust_obj):
+                            if ret_ty_info.rust_obj in result_types:
+                                # Avoid double-free by breaking the result - we should learn to clone these and then we can be safe instead
                                 java_trait_constr = java_trait_constr + "\t\t\t\tret.ptr = 0;\n"
                             java_trait_constr = java_trait_constr + "\t\t\t\treturn result;\n"
                         else:
@@ -1395,7 +1403,6 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                 elif fn_line.group(2) == "free":
                     write_c("\t\t.free = " + struct_name + "_JCalls_free,\n")
                 else:
-                    clone_fns.add(struct_name + "_clone")
                     write_c("\t\t.clone = " + struct_name + "_JCalls_clone,\n")
             for idx, var_line in enumerate(field_var_lines):
                 if var_line.group(1) in trait_structs:
index 966995a8f66b84d2c4ce8d49bdaf2b03ec7ef487..3d0dfd99e729e84df27dc2363442df183af9bd41 100644 (file)
@@ -1024,6 +1024,8 @@ public class bindings {
        public static native void SpendableOutputDescriptor_free(long this_ptr);
        // LDKSpendableOutputDescriptor SpendableOutputDescriptor_clone(const LDKSpendableOutputDescriptor *orig);
        public static native long SpendableOutputDescriptor_clone(long orig);
+       // LDKChannelKeys ChannelKeys_clone(const LDKChannelKeys *orig);
+       public static native long ChannelKeys_clone(long orig);
        // void ChannelKeys_free(LDKChannelKeys this_ptr);
        public static native void ChannelKeys_free(long this_ptr);
        // void KeysInterface_free(LDKKeysInterface this_ptr);
@@ -2040,6 +2042,8 @@ public class bindings {
        public static native void MessageHandler_set_route_handler(long this_ptr, long val);
        // MUST_USE_RES LDKMessageHandler MessageHandler_new(LDKChannelMessageHandler chan_handler_arg, LDKRoutingMessageHandler route_handler_arg);
        public static native long MessageHandler_new(long chan_handler_arg, long route_handler_arg);
+       // LDKSocketDescriptor SocketDescriptor_clone(const LDKSocketDescriptor *orig);
+       public static native long SocketDescriptor_clone(long orig);
        // void SocketDescriptor_free(LDKSocketDescriptor this_ptr);
        public static native void SocketDescriptor_free(long this_ptr);
        // void PeerHandleError_free(LDKPeerHandleError this_ptr);
index 9a37199b247faee2ec2b4940b25c3eefe50846ec..609e3f85d6191b76d2829e15a7d09b4ed21d4a48 100644 (file)
@@ -189,4 +189,12 @@ public class ChannelKeys extends CommonBase {
                return ret_hu_conv;
        }
 
+       public static ChannelKeys constructor_clone(ChannelKeys orig) {
+               long ret = bindings.ChannelKeys_clone(orig == null ? 0 : orig.ptr);
+               ChannelKeys ret_hu_conv = new ChannelKeys(null, ret);
+               ret_hu_conv.ptrs_to.add(ret_hu_conv);
+               ret_hu_conv.ptrs_to.add(orig);
+               return ret_hu_conv;
+       }
+
 }
index f93c7a0e776a0f53d82e066ad56bc8e8192bdb87..cc934aef83befc9bbb7c865b9a4ef1e68af21a6e 100644 (file)
@@ -46,7 +46,6 @@ public class KeysInterface extends CommonBase {
                                ChannelKeys ret = arg.get_channel_keys(inbound, channel_value_satoshis);
                                long result = ret == null ? 0 : ret.ptr;
                                impl_holder.held.ptrs_to.add(ret);
-                               ret.ptr = 0;
                                return result;
                        }
                        @Override public byte[] get_secure_random_bytes() {
index 0784a608532103f2ce560ce8dc2f67cc7b8a812a..f4956896545c94a77f9d217ccdb5d3505b802c5b 100644 (file)
@@ -61,4 +61,12 @@ public class SocketDescriptor extends CommonBase {
                return ret;
        }
 
+       public static SocketDescriptor constructor_clone(SocketDescriptor orig) {
+               long ret = bindings.SocketDescriptor_clone(orig == null ? 0 : orig.ptr);
+               SocketDescriptor ret_hu_conv = new SocketDescriptor(null, ret);
+               ret_hu_conv.ptrs_to.add(ret_hu_conv);
+               ret_hu_conv.ptrs_to.add(orig);
+               return ret_hu_conv;
+       }
+
 }
index f3d3cca176a22081e38e1859a22f29ecb4e159a4..93ad2e57aae3c48107f54cc73ba52fb6a7df89b4 100644 (file)
@@ -2867,10 +2867,7 @@ LDKChannelKeys get_channel_keys_jcall(const void* this_arg, bool inbound, uint64
        CHECK(obj != NULL);
        LDKChannelKeys* ret = (LDKChannelKeys*)(*_env)->CallLongMethod(_env, obj, j_calls->get_channel_keys_meth, inbound, channel_value_satoshis);
        LDKChannelKeys ret_conv = *(LDKChannelKeys*)ret;
-       if (ret_conv.free == LDKChannelKeys_JCalls_free) {
-               // If this_arg is a JCalls struct, then we need to increment the refcnt in it.
-               LDKChannelKeys_JCalls_clone(ret_conv.this_arg);
-       }
+       ret_conv = ChannelKeys_clone(ret);
        return ret_conv;
 }
 LDKThirtyTwoBytes get_secure_random_bytes_jcall(const void* this_arg) {
@@ -6800,6 +6797,13 @@ JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_SpendableOutputDescriptor_1cl
        return ret_ref;
 }
 
+JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_ChannelKeys_1clone(JNIEnv * _env, jclass _b, jlong orig) {
+       LDKChannelKeys* orig_conv = (LDKChannelKeys*)orig;
+       LDKChannelKeys* ret = MALLOC(sizeof(LDKChannelKeys), "LDKChannelKeys");
+       *ret = ChannelKeys_clone(orig_conv);
+       return (long)ret;
+}
+
 JNIEXPORT void JNICALL Java_org_ldk_impl_bindings_ChannelKeys_1free(JNIEnv * _env, jclass _b, jlong this_ptr) {
        LDKChannelKeys this_ptr_conv = *(LDKChannelKeys*)this_ptr;
        FREE((void*)this_ptr);
@@ -12328,6 +12332,13 @@ JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_MessageHandler_1new(JNIEnv *
        return ret_ref;
 }
 
+JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_SocketDescriptor_1clone(JNIEnv * _env, jclass _b, jlong orig) {
+       LDKSocketDescriptor* orig_conv = (LDKSocketDescriptor*)orig;
+       LDKSocketDescriptor* ret = MALLOC(sizeof(LDKSocketDescriptor), "LDKSocketDescriptor");
+       *ret = SocketDescriptor_clone(orig_conv);
+       return (long)ret;
+}
+
 JNIEXPORT void JNICALL Java_org_ldk_impl_bindings_SocketDescriptor_1free(JNIEnv * _env, jclass _b, jlong this_ptr) {
        LDKSocketDescriptor this_ptr_conv = *(LDKSocketDescriptor*)this_ptr;
        FREE((void*)this_ptr);
index d25949759a41e904d49ad6e3ecddb56a05103daf..07711151cabe4d749a05b919dc67615decd15551 100644 (file)
@@ -3407,6 +3407,14 @@ JNIEXPORT void JNICALL Java_org_ldk_impl_bindings_SpendableOutputDescriptor_1fre
 JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_SpendableOutputDescriptor_1clone
   (JNIEnv *, jclass, jlong);
 
+/*
+ * Class:     org_ldk_impl_bindings
+ * Method:    ChannelKeys_clone
+ * Signature: (J)J
+ */
+JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_ChannelKeys_1clone
+  (JNIEnv *, jclass, jlong);
+
 /*
  * Class:     org_ldk_impl_bindings
  * Method:    ChannelKeys_free
@@ -7471,6 +7479,14 @@ JNIEXPORT void JNICALL Java_org_ldk_impl_bindings_MessageHandler_1set_1route_1ha
 JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_MessageHandler_1new
   (JNIEnv *, jclass, jlong, jlong);
 
+/*
+ * Class:     org_ldk_impl_bindings
+ * Method:    SocketDescriptor_clone
+ * Signature: (J)J
+ */
+JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_SocketDescriptor_1clone
+  (JNIEnv *, jclass, jlong);
+
 /*
  * Class:     org_ldk_impl_bindings
  * Method:    SocketDescriptor_free
index c66b4dd84d3e82d61b2c7c8091637f5ca902d09e..26f20d34b5e70b8d088c2ebe8574061d0d847fac 100644 (file)
@@ -17,6 +17,96 @@ import java.util.concurrent.ConcurrentLinkedQueue;
 
 class HumanObjectPeerTestInstance {
     class Peer {
+        KeysInterface manual_keysif(KeysInterface underlying_if) {
+            return KeysInterface.new_impl(new KeysInterface.KeysInterfaceInterface() {
+                @Override
+                public byte[] get_node_secret() {
+                    return underlying_if.get_node_secret();
+                }
+
+                @Override
+                public byte[] get_destination_script() {
+                    return underlying_if.get_destination_script();
+                }
+
+                @Override
+                public byte[] get_shutdown_pubkey() {
+                    return underlying_if.get_shutdown_pubkey();
+                }
+
+                @Override
+                public ChannelKeys get_channel_keys(boolean inbound, long channel_value_satoshis) {
+                    ChannelKeys underlying_ck = underlying_if.get_channel_keys(inbound, channel_value_satoshis);
+                    ChannelKeys.ChannelKeysInterface cki = new ChannelKeys.ChannelKeysInterface() {
+                        @Override
+                        public byte[] get_per_commitment_point(long idx) {
+                            return underlying_ck.get_per_commitment_point(idx);
+                        }
+
+                        @Override
+                        public byte[] release_commitment_secret(long idx) {
+                            return underlying_ck.release_commitment_secret(idx);
+                        }
+
+                        @Override
+                        public TwoTuple<Long, Long> key_derivation_params() {
+                            return new TwoTuple<Long, Long>((long)0, (long)1);
+                        }
+
+                        @Override
+                        public Result_C2Tuple_SignatureCVec_SignatureZZNoneZ sign_counterparty_commitment(int feerate_per_kw, byte[] commitment_tx, PreCalculatedTxCreationKeys keys, HTLCOutputInCommitment[] htlcs) {
+                            return underlying_ck.sign_counterparty_commitment(feerate_per_kw, commitment_tx, keys, htlcs);
+                        }
+
+                        @Override
+                        public Result_SignatureNoneZ sign_holder_commitment(HolderCommitmentTransaction holder_commitment_tx) {
+                            return underlying_ck.sign_holder_commitment(holder_commitment_tx);
+                        }
+
+                        @Override
+                        public Result_CVec_SignatureZNoneZ sign_holder_commitment_htlc_transactions(HolderCommitmentTransaction holder_commitment_tx) {
+                            return underlying_ck.sign_holder_commitment_htlc_transactions(holder_commitment_tx);
+                        }
+
+                        @Override
+                        public Result_SignatureNoneZ sign_justice_transaction(byte[] justice_tx, long input, long amount, byte[] per_commitment_key, HTLCOutputInCommitment htlc) {
+                            return underlying_ck.sign_justice_transaction(justice_tx, input, amount, per_commitment_key, htlc);
+                        }
+
+                        @Override
+                        public Result_SignatureNoneZ sign_counterparty_htlc_transaction(byte[] htlc_tx, long input, long amount, byte[] per_commitment_point, HTLCOutputInCommitment htlc) {
+                            return underlying_ck.sign_counterparty_htlc_transaction(htlc_tx, input, amount, per_commitment_point, htlc);
+                        }
+
+                        @Override
+                        public Result_SignatureNoneZ sign_closing_transaction(byte[] closing_tx) {
+                            return underlying_ck.sign_closing_transaction(closing_tx);
+                        }
+
+                        @Override
+                        public Result_SignatureNoneZ sign_channel_announcement(UnsignedChannelAnnouncement msg) {
+                            return underlying_ck.sign_channel_announcement(msg);
+                        }
+
+                        @Override
+                        public void on_accept(ChannelPublicKeys channel_points, short counterparty_selected_contest_delay, short holder_selected_contest_delay) {
+                            underlying_ck.on_accept(channel_points, counterparty_selected_contest_delay, holder_selected_contest_delay);
+                        }
+                    };
+                    ChannelKeys resp = ChannelKeys.new_impl(cki, underlying_ck.get_pubkeys());
+                    must_free_objs.add(new WeakReference<>(cki));
+                    must_free_objs.add(new WeakReference<>(resp));
+                    must_free_objs.add(new WeakReference<>(underlying_ck));
+                    return resp;
+                }
+
+                @Override
+                public byte[] get_secure_random_bytes() {
+                    return underlying_if.get_secure_random_bytes();
+                }
+            });
+        }
+
         final Logger logger;
         final FeeEstimator fee_estimator;
         final BroadcasterInterface tx_broadcaster;
@@ -29,7 +119,7 @@ class HumanObjectPeerTestInstance {
         byte[] node_id;
         final LinkedList<byte[]> broadcast_set = new LinkedList<>();
 
-        Peer(byte seed) {
+        Peer(byte seed, boolean use_km_wrapper) {
             logger = Logger.new_impl((String arg) -> System.out.println(seed + ": " + arg));
             fee_estimator = FeeEstimator.new_impl((confirmation_target -> 253));
             tx_broadcaster = BroadcasterInterface.new_impl(tx -> {
@@ -70,8 +160,13 @@ class HumanObjectPeerTestInstance {
             for (byte i = 0; i < 32; i++) {
                 key_seed[i] = (byte) (i ^ seed);
             }
-            KeysManager keys = KeysManager.constructor_new(key_seed, LDKNetwork.LDKNetwork_Bitcoin, System.currentTimeMillis() / 1000, (int) (System.currentTimeMillis() * 1000) & 0xffffffff);
-            this.keys_interface = keys.as_KeysInterface();
+            if (use_km_wrapper) {
+                KeysManager underlying = KeysManager.constructor_new(key_seed, LDKNetwork.LDKNetwork_Bitcoin, System.currentTimeMillis() / 1000, (int) (System.currentTimeMillis() * 1000) & 0xffffffff);
+                this.keys_interface = manual_keysif(underlying.as_KeysInterface());
+            } else {
+                KeysManager keys = KeysManager.constructor_new(key_seed, LDKNetwork.LDKNetwork_Bitcoin, System.currentTimeMillis() / 1000, (int) (System.currentTimeMillis() * 1000) & 0xffffffff);
+                this.keys_interface = keys.as_KeysInterface();
+            }
             this.chan_manager = ChannelManager.constructor_new(LDKNetwork.LDKNetwork_Bitcoin, FeeEstimator.new_impl(confirmation_target -> 0), chain_monitor, tx_broadcaster, logger, this.keys_interface, UserConfig.constructor_default(), 1);
             this.node_id = chan_manager.get_our_node_id();
             this.chan_manager_events = chan_manager.as_EventsProvider();
@@ -145,10 +240,10 @@ class HumanObjectPeerTestInstance {
         }
     }
 
-    void do_test_message_handler(boolean nice_close) throws InterruptedException {
+    void do_test_message_handler(boolean nice_close, boolean use_km_wrapper) throws InterruptedException {
         GcCheck obj = new GcCheck();
-        Peer peer1 = new Peer((byte) 1);
-        Peer peer2 = new Peer((byte) 2);
+        Peer peer1 = new Peer((byte) 1, use_km_wrapper);
+        Peer peer2 = new Peer((byte) 2, use_km_wrapper);
 
         ConcurrentLinkedQueue<Thread> list = new ConcurrentLinkedQueue<Thread>();
         LongHolder descriptor1 = new LongHolder();
@@ -331,10 +426,9 @@ class HumanObjectPeerTestInstance {
 
 }
 public class HumanObjectPeerTest {
-    @Test
-    public void test_message_handler_force_close() throws InterruptedException {
+    void do_test(boolean nice_close, boolean use_km_wrapper) throws InterruptedException {
         HumanObjectPeerTestInstance instance = new HumanObjectPeerTestInstance();
-        instance.do_test_message_handler(false);
+        instance.do_test_message_handler(nice_close, use_km_wrapper);
         while (!instance.gc_ran) {
             System.gc();
             System.runFinalization();
@@ -343,14 +437,19 @@ public class HumanObjectPeerTest {
             assert o.get() == null;
     }
     @Test
+    public void test_message_handler_force_close() throws InterruptedException {
+        do_test(false, false);
+    }
+    @Test
     public void test_message_handler_nice_close() throws InterruptedException {
-        HumanObjectPeerTestInstance instance = new HumanObjectPeerTestInstance();
-        instance.do_test_message_handler(true);
-        while (!instance.gc_ran) {
-            System.gc();
-            System.runFinalization();
-        }
-        for (WeakReference<Object> o : instance.must_free_objs)
-            assert o.get() == null;
+        do_test(true, false);
+    }
+    @Test
+    public void test_message_handler_nice_close_wrapper() throws InterruptedException {
+        do_test(true, true);
+    }
+    @Test
+    public void test_message_handler_force_close_wrapper() throws InterruptedException {
+        do_test(false, true);
     }
 }