assert var_is_arr_regex.match(fn_arg[8:])
rust_obj = "LDKSecretKey"
arr_access = "bytes"
- #if fn_arg.startswith("LDKSignature"):
- # fn_arg = "uint8_t (*" + fn_arg[13:] + ")[64]"
- # assert var_is_arr_regex.match(fn_arg[8:])
- # rust_obj = "LDKSignature"
+ if fn_arg.startswith("LDKSignature"):
+ fn_arg = "uint8_t (*" + fn_arg[13:] + ")[64]"
+ assert var_is_arr_regex.match(fn_arg[8:])
+ rust_obj = "LDKSignature"
+ arr_access = "compact_form"
+ if fn_arg.startswith("LDKThreeBytes"):
+ fn_arg = "uint8_t (*" + fn_arg[14:] + ")[3]"
+ assert var_is_arr_regex.match(fn_arg[8:])
+ rust_obj = "LDKThreeBytes"
+ arr_access = "data"
if fn_arg.startswith("void"):
java_ty = "void"
arr_len = ret_arr_len
assert(ty_info.c_ty == "jbyteArray")
if ty_info.rust_obj is not None:
- arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n" + "(*_env)->GetByteArrayRegion (_env, " + arr_name + ", 0, " + arr_len + ", " + arr_name + "_ref." + ty_info.arr_access + ");"
+ arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n"
+ arg_conv = arg_conv + "CHECK((*_env)->GetArrayLength (_env, " + arr_name + ") == " + arr_len + ");\n"
+ arg_conv = arg_conv + "(*_env)->GetByteArrayRegion (_env, " + arr_name + ", 0, " + arr_len + ", " + arr_name + "_ref." + ty_info.arr_access + ");"
arr_access = ("", "." + ty_info.arr_access)
else:
- arg_conv = "unsigned char " + arr_name + "_arr[" + arr_len + "];\n" + "(*_env)->GetByteArrayRegion (_env, " + arr_name + ", 0, " + arr_len + ", " + arr_name + "_arr);\n" + "unsigned char (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;"
+ arg_conv = "unsigned char " + arr_name + "_arr[" + arr_len + "];\n"
+ arg_conv = arg_conv + "CHECK((*_env)->GetArrayLength (_env, " + arr_name + ") == " + arr_len + ");\n"
+ arg_conv = arg_conv + "(*_env)->GetByteArrayRegion (_env, " + arr_name + ", 0, " + arr_len + ", " + arr_name + "_arr);\n" + "unsigned char (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;"
arr_access = ("*", "")
return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
arg_conv = arg_conv,
opaque_arg_conv = ty_info.rust_obj + " " + ty_info.var_name + "_conv;\n"
opaque_arg_conv = opaque_arg_conv + ty_info.var_name + "_conv.inner = (void*)(" + ty_info.var_name + " & (~1));\n"
opaque_arg_conv = opaque_arg_conv + ty_info.var_name + "_conv.is_owned = (" + ty_info.var_name + " & 1) || (" + ty_info.var_name + " == 0);"
- if (ty_info.rust_obj.replace("LDK", "") + "_clone") in clone_fns and not ty_info.is_ptr and not is_free:
- # TODO: This is a bit too naive, even with the checks above, we really need to know if rust wants a ref or not, not just if its pass as a ptr.
- opaque_arg_conv = opaque_arg_conv + "\nif (" + ty_info.var_name + "_conv.inner != NULL)\n"
- opaque_arg_conv = opaque_arg_conv + "\t" + ty_info.var_name + "_conv = " + ty_info.rust_obj.replace("LDK", "") + "_clone(&" + ty_info.var_name + "_conv);"
+ if not ty_info.is_ptr and not is_free:
+ if (ty_info.rust_obj.replace("LDK", "") + "_clone") in clone_fns:
+ # TODO: This is a bit too naive, even with the checks above, we really need to know if rust wants a ref or not, not just if its pass as a ptr.
+ opaque_arg_conv = opaque_arg_conv + "\nif (" + ty_info.var_name + "_conv.inner != NULL)\n"
+ opaque_arg_conv = opaque_arg_conv + "\t" + ty_info.var_name + "_conv = " + ty_info.rust_obj.replace("LDK", "") + "_clone(&" + ty_info.var_name + "_conv);"
+ elif ty_info.passed_as_ptr:
+ opaque_arg_conv = opaque_arg_conv + "\n// Warning: we may need a move here but can't clone!"
if not ty_info.is_ptr:
if ty_info.rust_obj in unitary_enums:
return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
ret_conv = ("jclass " + ty_info.var_name + "_conv = " + ty_info.rust_obj + "_to_java(_env, ", ");"),
ret_conv_name = ty_info.var_name + "_conv")
if ty_info.rust_obj in opaque_structs:
- ret_conv_suf = ";\nDO_ASSERT((((long)" + ty_info.var_name + "_var.inner) & 1) == 0); // We rely on a free low bit, malloc guarantees this.\n"
- ret_conv_suf = ret_conv_suf + "DO_ASSERT((((long)&" + ty_info.var_name + "_var) & 1) == 0); // We rely on a free low bit, pointer alignment guarantees this.\n"
+ ret_conv_suf = ";\nCHECK((((long)" + ty_info.var_name + "_var.inner) & 1) == 0); // We rely on a free low bit, malloc guarantees this.\n"
+ ret_conv_suf = ret_conv_suf + "CHECK((((long)&" + ty_info.var_name + "_var) & 1) == 0); // We rely on a free low bit, pointer alignment guarantees this.\n"
ret_conv_suf = ret_conv_suf + "long " + ty_info.var_name + "_ref;\n"
ret_conv_suf = ret_conv_suf + "if (" + ty_info.var_name + "_var.is_owned) {\n"
ret_conv_suf = ret_conv_suf + "\t" + ty_info.var_name + "_ref = (long)" + ty_info.var_name + "_var.inner | 1;\n"
out_c.write("static jfieldID " + struct_name + "_" + variant + " = NULL;\n")
out_c.write("JNIEXPORT void JNICALL Java_org_ldk_enums_" + struct_name.replace("_", "_1") + "_init (JNIEnv * env, jclass clz) {\n")
out_c.write("\t" + struct_name + "_class = (*env)->NewGlobalRef(env, clz);\n")
- out_c.write("\tDO_ASSERT(" + struct_name + "_class != NULL);\n")
+ out_c.write("\tCHECK(" + struct_name + "_class != NULL);\n")
for idx, struct_line in enumerate(field_lines):
if idx > 0 and idx < len(field_lines) - 3:
variant = struct_line.strip().strip(",")
out_c.write("\t" + struct_name + "_" + variant + " = (*env)->GetStaticFieldID(env, " + struct_name + "_class, \"" + variant + "\", \"Lorg/ldk/enums/" + struct_name + ";\");\n")
- out_c.write("\tDO_ASSERT(" + struct_name + "_" + variant + " != NULL);\n")
+ out_c.write("\tCHECK(" + struct_name + "_" + variant + " != NULL);\n")
out_c.write("}\n")
out_c.write("static inline jclass " + struct_name + "_to_java(JNIEnv *env, " + struct_name + " val) {\n")
out_c.write("\tswitch (val) {\n")
var_name = struct_line.strip(' ,')[len(struct_name) + 1:]
out_c.write("\t" + struct_name + "_" + var_name + "_class =\n")
out_c.write("\t\t(*env)->NewGlobalRef(env, (*env)->FindClass(env, \"Lorg/ldk/impl/bindings$" + struct_name + "$" + var_name + ";\"));\n")
- out_c.write("\tDO_ASSERT(" + struct_name + "_" + var_name + "_class != NULL);\n")
+ out_c.write("\tCHECK(" + struct_name + "_" + var_name + "_class != NULL);\n")
out_c.write("\t" + struct_name + "_" + var_name + "_meth = (*env)->GetMethodID(env, " + struct_name + "_" + var_name + "_class, \"<init>\", \"(" + init_meth_jty_strs[var_name] + ")V\");\n")
- out_c.write("\tDO_ASSERT(" + struct_name + "_" + var_name + "_meth != NULL);\n")
+ out_c.write("\tCHECK(" + struct_name + "_" + var_name + "_meth != NULL);\n")
out_c.write("}\n")
out_c.write("JNIEXPORT jobject JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1ref_1from_1ptr (JNIEnv * env, jclass _c, jlong ptr) {\n")
out_c.write("\t" + struct_name + " *obj = (" + struct_name + "*)ptr;\n")
out_c.write(arg_info.arg_name)
out_c.write(arg_info.ret_conv[1].replace('\n', '\n\t').replace("_env", "env") + "\n")
- out_c.write("\tjobject obj = (*env)->NewLocalRef(env, j_calls->o);\n\tDO_ASSERT(obj != NULL);\n")
+ out_c.write("\tjobject obj = (*env)->NewLocalRef(env, j_calls->o);\n\tCHECK(obj != NULL);\n")
if ret_ty_info.c_ty.endswith("Array"):
assert(ret_ty_info.c_ty == "jbyteArray")
out_c.write("\tjbyteArray jret = (*env)->CallObjectMethod(env, obj, j_calls->" + fn_line.group(2) + "_meth")
out_c.write(");\n");
if ret_ty_info.c_ty.endswith("Array"):
out_c.write("\t" + ret_ty_info.rust_obj + " ret;\n")
+ out_c.write("\tCHECK((*env)->GetArrayLength(env, jret) == " + ret_ty_info.arr_len + ");\n")
out_c.write("\t(*env)->GetByteArrayRegion(env, jret, 0, " + ret_ty_info.arr_len + ", ret." + ret_ty_info.arr_access + ");\n")
out_c.write("\treturn ret;\n")
out_c.write(") {\n")
out_c.write("\tjclass c = (*env)->GetObjectClass(env, o);\n")
- out_c.write("\tDO_ASSERT(c != NULL);\n")
+ out_c.write("\tCHECK(c != NULL);\n")
out_c.write("\t" + struct_name + "_JCalls *calls = MALLOC(sizeof(" + struct_name + "_JCalls), \"" + struct_name + "_JCalls\");\n")
out_c.write("\tatomic_init(&calls->refcnt, 1);\n")
out_c.write("\tDO_ASSERT((*env)->GetJavaVM(env, &calls->vm) == 0);\n")
for (fn_line, java_meth_descr) in zip(trait_fn_lines, java_meths):
if fn_line.group(2) != "free" and fn_line.group(2) != "clone":
out_c.write("\tcalls->" + fn_line.group(2) + "_meth = (*env)->GetMethodID(env, c, \"" + fn_line.group(2) + "\", \"" + java_meth_descr + "\");\n")
- out_c.write("\tDO_ASSERT(calls->" + fn_line.group(2) + "_meth != NULL);\n")
+ out_c.write("\tCHECK(calls->" + fn_line.group(2) + "_meth != NULL);\n")
out_c.write("\n\t" + struct_name + " ret = {\n")
out_c.write("\t\t.this_arg = (void*) calls,\n")
for fn_line in trait_fn_lines:
out_java.write("\tpublic static native " + struct_name + " " + struct_name + "_get_obj_from_jcalls(long val);\n")
out_c.write("JNIEXPORT jobject JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1get_1obj_1from_1jcalls (JNIEnv * env, jclass _a, jlong val) {\n")
out_c.write("\tjobject ret = (*env)->NewLocalRef(env, ((" + struct_name + "_JCalls*)val)->o);\n")
- out_c.write("\tDO_ASSERT(ret != NULL);\n")
+ out_c.write("\tCHECK(ret != NULL);\n")
out_c.write("\treturn ret;\n")
out_c.write("}\n")
out_c.write("#define MALLOC(a, _) malloc(a)\n")
out_c.write("#define FREE free\n")
out_c.write("#define DO_ASSERT(a) (void)(a)\n")
+ out_c.write("#define CHECK(a)\n")
else:
out_c.write("""#include <assert.h>
+// Always run a, then assert it is true:
#define DO_ASSERT(a) do { bool _assert_val = (a); assert(_assert_val); } while(0)
+// Assert a is true or do nothing
+#define CHECK(a) DO_ASSERT(a)
// Running a leak check across all the allocations and frees of the JDK is a mess,
// so instead we implement our own naive leak checker here, relying on the -wrap
static jclass slicedef_cls = NULL;
JNIEXPORT void Java_org_ldk_impl_bindings_init(JNIEnv * env, jclass _b, jclass enum_class, jclass slicedef_class) {
ordinal_meth = (*env)->GetMethodID(env, enum_class, "ordinal", "()I");
- DO_ASSERT(ordinal_meth != NULL);
+ CHECK(ordinal_meth != NULL);
slicedef_meth = (*env)->GetMethodID(env, slicedef_class, "<init>", "(JJJ)V");
- DO_ASSERT(slicedef_meth != NULL);
+ CHECK(slicedef_meth != NULL);
slicedef_cls = (*env)->NewGlobalRef(env, slicedef_class);
- DO_ASSERT(slicedef_cls != NULL);
+ CHECK(slicedef_cls != NULL);
}
JNIEXPORT jboolean JNICALL Java_org_ldk_impl_bindings_deref_1bool (JNIEnv * env, jclass _a, jlong ptr) {
out_c.write("\tjlongArray ret = (*env)->NewLongArray(env, vec->datalen);\n")
out_c.write("\tjlong *ret_elems = (*env)->GetPrimitiveArrayCritical(env, ret, NULL);\n")
out_c.write("\tfor (size_t i = 0; i < vec->datalen; i++) {\n")
- out_c.write("\t\tDO_ASSERT((((long)vec->data[i].inner) & 1) == 0);\n")
+ out_c.write("\t\tCHECK((((long)vec->data[i].inner) & 1) == 0);\n")
out_c.write("\t\tret_elems[i] = (long)vec->data[i].inner | (vec->data[i].is_owned ? 1 : 0);\n")
out_c.write("\t}\n")
out_c.write("\t(*env)->ReleasePrimitiveArrayCritical(env, ret, ret_elems, 0);\n")