From ac4869b989e27b4500e00e901ea5851604315345 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Sun, 25 Oct 2020 00:36:05 -0400 Subject: [PATCH] Fix returning traits in trait calls, move towards clone on return --- genbindings.py | 27 ++-- src/main/java/org/ldk/impl/bindings.java | 4 + .../java/org/ldk/structs/ChannelKeys.java | 8 ++ .../java/org/ldk/structs/KeysInterface.java | 1 - .../org/ldk/structs/SocketDescriptor.java | 8 ++ src/main/jni/bindings.c | 19 ++- src/main/jni/org_ldk_impl_bindings.h | 16 +++ .../java/org/ldk/HumanObjectPeerTest.java | 133 +++++++++++++++--- 8 files changed, 184 insertions(+), 32 deletions(-) diff --git a/genbindings.py b/genbindings.py index 56d36da4..3f7a12ee 100755 --- a/genbindings.py +++ b/genbindings.py @@ -230,9 +230,6 @@ trait_structs = set() result_types = set() tuple_types = {} -def is_common_base_ext(struct_name): - return struct_name in complex_enums or struct_name in opaque_structs or struct_name in trait_structs or struct_name in result_types - var_is_arr_regex = re.compile("\(\*([A-za-z0-9_]*)\)\[([a-z0-9]*)\]") var_ty_regex = re.compile("([A-za-z_0-9]*)(.*)") java_c_types_none_allowed = True # Unset when we do the real pass that populates the above sets @@ -737,12 +734,19 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java: arg_conv_cleanup = None, ret_conv = ("jclass " + ty_info.var_name + "_conv = " + ty_info.rust_obj + "_to_java(_env, ", ");"), ret_conv_name = ty_info.var_name + "_conv", to_hu_conv = None, to_hu_conv_name = None, from_hu_conv = None) - base_conv = ty_info.rust_obj + " " + ty_info.var_name + "_conv = *(" + ty_info.rust_obj + "*)" + ty_info.var_name + ";"; + base_conv = ty_info.rust_obj + " " + ty_info.var_name + "_conv = *(" + ty_info.rust_obj + "*)" + ty_info.var_name + ";" if ty_info.rust_obj in trait_structs: if not is_free: - base_conv = base_conv + "\nif (" + ty_info.var_name + "_conv.free == " + ty_info.rust_obj + "_JCalls_free) {\n" - base_conv = base_conv + "\t// If this_arg is a JCalls struct, then we need to increment the refcnt in it.\n" - base_conv = base_conv + "\t" + ty_info.rust_obj + "_JCalls_clone(" + ty_info.var_name + "_conv.this_arg);\n}" + needs_full_clone = not is_free and (not ty_info.is_ptr and not holds_ref or ty_info.requires_clone == True) and ty_info.requires_clone != False + if needs_full_clone and (ty_info.java_hu_ty + "_clone") in clone_fns: + base_conv = base_conv + "\n" + ty_info.var_name + "_conv = " + ty_info.java_hu_ty + "_clone(" + ty_info.var_name + ");" + else: + base_conv = base_conv + "\nif (" + ty_info.var_name + "_conv.free == " + ty_info.rust_obj + "_JCalls_free) {\n" + base_conv = base_conv + "\t// If this_arg is a JCalls struct, then we need to increment the refcnt in it.\n" + base_conv = base_conv + "\t" + ty_info.rust_obj + "_JCalls_clone(" + ty_info.var_name + "_conv.this_arg);\n}" + if needs_full_clone: + base_conv = base_conv + "// Warning: we may need a move here but can't do a full clone!\n" + else: base_conv = base_conv + "\n" + "FREE((void*)" + ty_info.var_name + ");" return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name, @@ -992,7 +996,10 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java: out_java_struct.write(info.arg_name) out_java_struct.write(");\n") if ret_info.to_hu_conv is not None: - out_java_struct.write("\t\t" + ret_info.to_hu_conv.replace("\n", "\n\t\t") + "\n") + if ret_info.rust_obj == "LDK" + struct_meth: + out_java_struct.write("\t\t" + ret_info.to_hu_conv.replace("\n", "\n\t\t").replace("this", ret_info.to_hu_conv_name) + "\n") + else: + out_java_struct.write("\t\t" + ret_info.to_hu_conv.replace("\n", "\n\t\t") + "\n") for info in arg_names: if info.arg_name == "this_ptr" or info.arg_name == "this_arg": @@ -1326,7 +1333,8 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java: java_trait_constr = java_trait_constr + "\t\t\t\t" + ret_ty_info.java_ty + " result = " + ret_ty_info.from_hu_conv[0] + ";\n" if ret_ty_info.from_hu_conv[1] != "": java_trait_constr = java_trait_constr + "\t\t\t\t" + ret_ty_info.from_hu_conv[1].replace("this", "impl_holder.held") + ";\n" - if is_common_base_ext(ret_ty_info.rust_obj): + if ret_ty_info.rust_obj in result_types: + # Avoid double-free by breaking the result - we should learn to clone these and then we can be safe instead java_trait_constr = java_trait_constr + "\t\t\t\tret.ptr = 0;\n" java_trait_constr = java_trait_constr + "\t\t\t\treturn result;\n" else: @@ -1395,7 +1403,6 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java: elif fn_line.group(2) == "free": write_c("\t\t.free = " + struct_name + "_JCalls_free,\n") else: - clone_fns.add(struct_name + "_clone") write_c("\t\t.clone = " + struct_name + "_JCalls_clone,\n") for idx, var_line in enumerate(field_var_lines): if var_line.group(1) in trait_structs: diff --git a/src/main/java/org/ldk/impl/bindings.java b/src/main/java/org/ldk/impl/bindings.java index 966995a8..3d0dfd99 100644 --- a/src/main/java/org/ldk/impl/bindings.java +++ b/src/main/java/org/ldk/impl/bindings.java @@ -1024,6 +1024,8 @@ public class bindings { public static native void SpendableOutputDescriptor_free(long this_ptr); // LDKSpendableOutputDescriptor SpendableOutputDescriptor_clone(const LDKSpendableOutputDescriptor *orig); public static native long SpendableOutputDescriptor_clone(long orig); + // LDKChannelKeys ChannelKeys_clone(const LDKChannelKeys *orig); + public static native long ChannelKeys_clone(long orig); // void ChannelKeys_free(LDKChannelKeys this_ptr); public static native void ChannelKeys_free(long this_ptr); // void KeysInterface_free(LDKKeysInterface this_ptr); @@ -2040,6 +2042,8 @@ public class bindings { public static native void MessageHandler_set_route_handler(long this_ptr, long val); // MUST_USE_RES LDKMessageHandler MessageHandler_new(LDKChannelMessageHandler chan_handler_arg, LDKRoutingMessageHandler route_handler_arg); public static native long MessageHandler_new(long chan_handler_arg, long route_handler_arg); + // LDKSocketDescriptor SocketDescriptor_clone(const LDKSocketDescriptor *orig); + public static native long SocketDescriptor_clone(long orig); // void SocketDescriptor_free(LDKSocketDescriptor this_ptr); public static native void SocketDescriptor_free(long this_ptr); // void PeerHandleError_free(LDKPeerHandleError this_ptr); diff --git a/src/main/java/org/ldk/structs/ChannelKeys.java b/src/main/java/org/ldk/structs/ChannelKeys.java index 9a37199b..609e3f85 100644 --- a/src/main/java/org/ldk/structs/ChannelKeys.java +++ b/src/main/java/org/ldk/structs/ChannelKeys.java @@ -189,4 +189,12 @@ public class ChannelKeys extends CommonBase { return ret_hu_conv; } + public static ChannelKeys constructor_clone(ChannelKeys orig) { + long ret = bindings.ChannelKeys_clone(orig == null ? 0 : orig.ptr); + ChannelKeys ret_hu_conv = new ChannelKeys(null, ret); + ret_hu_conv.ptrs_to.add(ret_hu_conv); + ret_hu_conv.ptrs_to.add(orig); + return ret_hu_conv; + } + } diff --git a/src/main/java/org/ldk/structs/KeysInterface.java b/src/main/java/org/ldk/structs/KeysInterface.java index f93c7a0e..cc934aef 100644 --- a/src/main/java/org/ldk/structs/KeysInterface.java +++ b/src/main/java/org/ldk/structs/KeysInterface.java @@ -46,7 +46,6 @@ public class KeysInterface extends CommonBase { ChannelKeys ret = arg.get_channel_keys(inbound, channel_value_satoshis); long result = ret == null ? 0 : ret.ptr; impl_holder.held.ptrs_to.add(ret); - ret.ptr = 0; return result; } @Override public byte[] get_secure_random_bytes() { diff --git a/src/main/java/org/ldk/structs/SocketDescriptor.java b/src/main/java/org/ldk/structs/SocketDescriptor.java index 0784a608..f4956896 100644 --- a/src/main/java/org/ldk/structs/SocketDescriptor.java +++ b/src/main/java/org/ldk/structs/SocketDescriptor.java @@ -61,4 +61,12 @@ public class SocketDescriptor extends CommonBase { return ret; } + public static SocketDescriptor constructor_clone(SocketDescriptor orig) { + long ret = bindings.SocketDescriptor_clone(orig == null ? 0 : orig.ptr); + SocketDescriptor ret_hu_conv = new SocketDescriptor(null, ret); + ret_hu_conv.ptrs_to.add(ret_hu_conv); + ret_hu_conv.ptrs_to.add(orig); + return ret_hu_conv; + } + } diff --git a/src/main/jni/bindings.c b/src/main/jni/bindings.c index f3d3cca1..93ad2e57 100644 --- a/src/main/jni/bindings.c +++ b/src/main/jni/bindings.c @@ -2867,10 +2867,7 @@ LDKChannelKeys get_channel_keys_jcall(const void* this_arg, bool inbound, uint64 CHECK(obj != NULL); LDKChannelKeys* ret = (LDKChannelKeys*)(*_env)->CallLongMethod(_env, obj, j_calls->get_channel_keys_meth, inbound, channel_value_satoshis); LDKChannelKeys ret_conv = *(LDKChannelKeys*)ret; - if (ret_conv.free == LDKChannelKeys_JCalls_free) { - // If this_arg is a JCalls struct, then we need to increment the refcnt in it. - LDKChannelKeys_JCalls_clone(ret_conv.this_arg); - } + ret_conv = ChannelKeys_clone(ret); return ret_conv; } LDKThirtyTwoBytes get_secure_random_bytes_jcall(const void* this_arg) { @@ -6800,6 +6797,13 @@ JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_SpendableOutputDescriptor_1cl return ret_ref; } +JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_ChannelKeys_1clone(JNIEnv * _env, jclass _b, jlong orig) { + LDKChannelKeys* orig_conv = (LDKChannelKeys*)orig; + LDKChannelKeys* ret = MALLOC(sizeof(LDKChannelKeys), "LDKChannelKeys"); + *ret = ChannelKeys_clone(orig_conv); + return (long)ret; +} + JNIEXPORT void JNICALL Java_org_ldk_impl_bindings_ChannelKeys_1free(JNIEnv * _env, jclass _b, jlong this_ptr) { LDKChannelKeys this_ptr_conv = *(LDKChannelKeys*)this_ptr; FREE((void*)this_ptr); @@ -12328,6 +12332,13 @@ JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_MessageHandler_1new(JNIEnv * return ret_ref; } +JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_SocketDescriptor_1clone(JNIEnv * _env, jclass _b, jlong orig) { + LDKSocketDescriptor* orig_conv = (LDKSocketDescriptor*)orig; + LDKSocketDescriptor* ret = MALLOC(sizeof(LDKSocketDescriptor), "LDKSocketDescriptor"); + *ret = SocketDescriptor_clone(orig_conv); + return (long)ret; +} + JNIEXPORT void JNICALL Java_org_ldk_impl_bindings_SocketDescriptor_1free(JNIEnv * _env, jclass _b, jlong this_ptr) { LDKSocketDescriptor this_ptr_conv = *(LDKSocketDescriptor*)this_ptr; FREE((void*)this_ptr); diff --git a/src/main/jni/org_ldk_impl_bindings.h b/src/main/jni/org_ldk_impl_bindings.h index d2594975..07711151 100644 --- a/src/main/jni/org_ldk_impl_bindings.h +++ b/src/main/jni/org_ldk_impl_bindings.h @@ -3407,6 +3407,14 @@ JNIEXPORT void JNICALL Java_org_ldk_impl_bindings_SpendableOutputDescriptor_1fre JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_SpendableOutputDescriptor_1clone (JNIEnv *, jclass, jlong); +/* + * Class: org_ldk_impl_bindings + * Method: ChannelKeys_clone + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_ChannelKeys_1clone + (JNIEnv *, jclass, jlong); + /* * Class: org_ldk_impl_bindings * Method: ChannelKeys_free @@ -7471,6 +7479,14 @@ JNIEXPORT void JNICALL Java_org_ldk_impl_bindings_MessageHandler_1set_1route_1ha JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_MessageHandler_1new (JNIEnv *, jclass, jlong, jlong); +/* + * Class: org_ldk_impl_bindings + * Method: SocketDescriptor_clone + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_SocketDescriptor_1clone + (JNIEnv *, jclass, jlong); + /* * Class: org_ldk_impl_bindings * Method: SocketDescriptor_free diff --git a/src/test/java/org/ldk/HumanObjectPeerTest.java b/src/test/java/org/ldk/HumanObjectPeerTest.java index c66b4dd8..26f20d34 100644 --- a/src/test/java/org/ldk/HumanObjectPeerTest.java +++ b/src/test/java/org/ldk/HumanObjectPeerTest.java @@ -17,6 +17,96 @@ import java.util.concurrent.ConcurrentLinkedQueue; class HumanObjectPeerTestInstance { class Peer { + KeysInterface manual_keysif(KeysInterface underlying_if) { + return KeysInterface.new_impl(new KeysInterface.KeysInterfaceInterface() { + @Override + public byte[] get_node_secret() { + return underlying_if.get_node_secret(); + } + + @Override + public byte[] get_destination_script() { + return underlying_if.get_destination_script(); + } + + @Override + public byte[] get_shutdown_pubkey() { + return underlying_if.get_shutdown_pubkey(); + } + + @Override + public ChannelKeys get_channel_keys(boolean inbound, long channel_value_satoshis) { + ChannelKeys underlying_ck = underlying_if.get_channel_keys(inbound, channel_value_satoshis); + ChannelKeys.ChannelKeysInterface cki = new ChannelKeys.ChannelKeysInterface() { + @Override + public byte[] get_per_commitment_point(long idx) { + return underlying_ck.get_per_commitment_point(idx); + } + + @Override + public byte[] release_commitment_secret(long idx) { + return underlying_ck.release_commitment_secret(idx); + } + + @Override + public TwoTuple key_derivation_params() { + return new TwoTuple((long)0, (long)1); + } + + @Override + public Result_C2Tuple_SignatureCVec_SignatureZZNoneZ sign_counterparty_commitment(int feerate_per_kw, byte[] commitment_tx, PreCalculatedTxCreationKeys keys, HTLCOutputInCommitment[] htlcs) { + return underlying_ck.sign_counterparty_commitment(feerate_per_kw, commitment_tx, keys, htlcs); + } + + @Override + public Result_SignatureNoneZ sign_holder_commitment(HolderCommitmentTransaction holder_commitment_tx) { + return underlying_ck.sign_holder_commitment(holder_commitment_tx); + } + + @Override + public Result_CVec_SignatureZNoneZ sign_holder_commitment_htlc_transactions(HolderCommitmentTransaction holder_commitment_tx) { + return underlying_ck.sign_holder_commitment_htlc_transactions(holder_commitment_tx); + } + + @Override + public Result_SignatureNoneZ sign_justice_transaction(byte[] justice_tx, long input, long amount, byte[] per_commitment_key, HTLCOutputInCommitment htlc) { + return underlying_ck.sign_justice_transaction(justice_tx, input, amount, per_commitment_key, htlc); + } + + @Override + public Result_SignatureNoneZ sign_counterparty_htlc_transaction(byte[] htlc_tx, long input, long amount, byte[] per_commitment_point, HTLCOutputInCommitment htlc) { + return underlying_ck.sign_counterparty_htlc_transaction(htlc_tx, input, amount, per_commitment_point, htlc); + } + + @Override + public Result_SignatureNoneZ sign_closing_transaction(byte[] closing_tx) { + return underlying_ck.sign_closing_transaction(closing_tx); + } + + @Override + public Result_SignatureNoneZ sign_channel_announcement(UnsignedChannelAnnouncement msg) { + return underlying_ck.sign_channel_announcement(msg); + } + + @Override + public void on_accept(ChannelPublicKeys channel_points, short counterparty_selected_contest_delay, short holder_selected_contest_delay) { + underlying_ck.on_accept(channel_points, counterparty_selected_contest_delay, holder_selected_contest_delay); + } + }; + ChannelKeys resp = ChannelKeys.new_impl(cki, underlying_ck.get_pubkeys()); + must_free_objs.add(new WeakReference<>(cki)); + must_free_objs.add(new WeakReference<>(resp)); + must_free_objs.add(new WeakReference<>(underlying_ck)); + return resp; + } + + @Override + public byte[] get_secure_random_bytes() { + return underlying_if.get_secure_random_bytes(); + } + }); + } + final Logger logger; final FeeEstimator fee_estimator; final BroadcasterInterface tx_broadcaster; @@ -29,7 +119,7 @@ class HumanObjectPeerTestInstance { byte[] node_id; final LinkedList broadcast_set = new LinkedList<>(); - Peer(byte seed) { + Peer(byte seed, boolean use_km_wrapper) { logger = Logger.new_impl((String arg) -> System.out.println(seed + ": " + arg)); fee_estimator = FeeEstimator.new_impl((confirmation_target -> 253)); tx_broadcaster = BroadcasterInterface.new_impl(tx -> { @@ -70,8 +160,13 @@ class HumanObjectPeerTestInstance { for (byte i = 0; i < 32; i++) { key_seed[i] = (byte) (i ^ seed); } - KeysManager keys = KeysManager.constructor_new(key_seed, LDKNetwork.LDKNetwork_Bitcoin, System.currentTimeMillis() / 1000, (int) (System.currentTimeMillis() * 1000) & 0xffffffff); - this.keys_interface = keys.as_KeysInterface(); + if (use_km_wrapper) { + KeysManager underlying = KeysManager.constructor_new(key_seed, LDKNetwork.LDKNetwork_Bitcoin, System.currentTimeMillis() / 1000, (int) (System.currentTimeMillis() * 1000) & 0xffffffff); + this.keys_interface = manual_keysif(underlying.as_KeysInterface()); + } else { + KeysManager keys = KeysManager.constructor_new(key_seed, LDKNetwork.LDKNetwork_Bitcoin, System.currentTimeMillis() / 1000, (int) (System.currentTimeMillis() * 1000) & 0xffffffff); + this.keys_interface = keys.as_KeysInterface(); + } this.chan_manager = ChannelManager.constructor_new(LDKNetwork.LDKNetwork_Bitcoin, FeeEstimator.new_impl(confirmation_target -> 0), chain_monitor, tx_broadcaster, logger, this.keys_interface, UserConfig.constructor_default(), 1); this.node_id = chan_manager.get_our_node_id(); this.chan_manager_events = chan_manager.as_EventsProvider(); @@ -145,10 +240,10 @@ class HumanObjectPeerTestInstance { } } - void do_test_message_handler(boolean nice_close) throws InterruptedException { + void do_test_message_handler(boolean nice_close, boolean use_km_wrapper) throws InterruptedException { GcCheck obj = new GcCheck(); - Peer peer1 = new Peer((byte) 1); - Peer peer2 = new Peer((byte) 2); + Peer peer1 = new Peer((byte) 1, use_km_wrapper); + Peer peer2 = new Peer((byte) 2, use_km_wrapper); ConcurrentLinkedQueue list = new ConcurrentLinkedQueue(); LongHolder descriptor1 = new LongHolder(); @@ -331,10 +426,9 @@ class HumanObjectPeerTestInstance { } public class HumanObjectPeerTest { - @Test - public void test_message_handler_force_close() throws InterruptedException { + void do_test(boolean nice_close, boolean use_km_wrapper) throws InterruptedException { HumanObjectPeerTestInstance instance = new HumanObjectPeerTestInstance(); - instance.do_test_message_handler(false); + instance.do_test_message_handler(nice_close, use_km_wrapper); while (!instance.gc_ran) { System.gc(); System.runFinalization(); @@ -343,14 +437,19 @@ public class HumanObjectPeerTest { assert o.get() == null; } @Test + public void test_message_handler_force_close() throws InterruptedException { + do_test(false, false); + } + @Test public void test_message_handler_nice_close() throws InterruptedException { - HumanObjectPeerTestInstance instance = new HumanObjectPeerTestInstance(); - instance.do_test_message_handler(true); - while (!instance.gc_ran) { - System.gc(); - System.runFinalization(); - } - for (WeakReference o : instance.must_free_objs) - assert o.get() == null; + do_test(true, false); + } + @Test + public void test_message_handler_nice_close_wrapper() throws InterruptedException { + do_test(true, true); + } + @Test + public void test_message_handler_force_close_wrapper() throws InterruptedException { + do_test(false, true); } } -- 2.39.5