return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
arg_conv = base_conv,
arg_conv_name = ty_info.var_name + "_conv",
- ret_conv = None, ret_conv_name = None)
+ ret_conv = ("CANT PASS TRAIT TO Java?", ""), ret_conv_name = "NO CONV POSSIBLE")
if ty_info.rust_obj != "LDKu8slice":
# Don't bother free'ing slices passed in - we often pass them Rust -> Rust
base_conv = base_conv + "\nFREE((void*)" + ty_info.var_name + ");";
return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
arg_conv = base_conv + "\n" + ty_info.var_name + "_conv.is_owned = true;",
arg_conv_name = ty_info.var_name + "_conv",
- ret_conv = None, ret_conv_name = None)
+ ret_conv = ("long " + ty_info.var_name + "_ref = (long)&", ";"), ret_conv_name = ty_info.var_name + "_ref")
return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
arg_conv = base_conv, arg_conv_name = ty_info.var_name + "_conv",
- ret_conv = None, ret_conv_name = None)
+ ret_conv = ("long " + ty_info.var_name + "_ref = (long)&", ";"), ret_conv_name = ty_info.var_name + "_ref")
else:
assert(not is_free)
return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
arg_conv = ty_info.rust_obj + "* " + ty_info.var_name + "_conv = (" + ty_info.rust_obj + "*)" + ty_info.var_name + ";",
arg_conv_name = ty_info.var_name + "_conv",
- ret_conv = None, ret_conv_name = None)
+ ret_conv = None, ret_conv_name = None) # its a pointer, no conv needed
elif ty_info.is_ptr:
return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
arg_conv = None, arg_conv_name = ty_info.var_name, ret_conv = None, ret_conv_name = None)
# any _free function.
# To avoid any issues, we first assert that the incoming object is non-ref.
return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
- ret_conv = (ty_info.rust_obj + "* ret = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n*ret = ", ";\nassert(ret->is_owned);\nret->is_owned = false;"),
+ ret_conv = (ty_info.rust_obj + "* ret = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n*ret = ", ";\nDO_ASSERT(ret->is_owned);\nret->is_owned = false;"),
ret_conv_name = "(long)ret",
arg_conv = None, arg_conv_name = None)
else:
def map_trait(struct_name, field_var_lines, trait_fn_lines):
out_c.write("typedef struct " + struct_name + "_JCalls {\n")
out_c.write("\tatomic_size_t refcnt;\n")
- out_c.write("\tJNIEnv *env;\n")
+ out_c.write("\tJavaVM *vm;\n")
out_c.write("\tjobject o;\n")
for var_line in field_var_lines:
if var_line.group(1) in trait_structs:
out_java.write(");\n")
out_c.write(") {\n")
out_c.write("\t" + struct_name + "_JCalls *j_calls = (" + struct_name + "_JCalls*) this_arg;\n")
+ out_c.write("\tJNIEnv *env;\n")
+ out_c.write("\tDO_ASSERT((*j_calls->vm)->GetEnv(j_calls->vm, (void**)&env, JNI_VERSION_1_8) == JNI_OK);\n")
for arg_info in arg_names:
if arg_info.ret_conv is not None:
- out_c.write("\t" + arg_info.ret_conv[0].replace('\n', '\n\t').replace("_env", "j_calls->env"));
+ out_c.write("\t" + arg_info.ret_conv[0].replace('\n', '\n\t').replace("_env", "env"));
out_c.write(arg_info.arg_name)
- out_c.write(arg_info.ret_conv[1].replace('\n', '\n\t').replace("_env", "j_calls->env") + "\n")
+ out_c.write(arg_info.ret_conv[1].replace('\n', '\n\t').replace("_env", "env") + "\n")
if not ret_ty_info.passed_as_ptr:
- out_c.write("\treturn (*j_calls->env)->Call" + ret_ty_info.java_ty.title() + "Method(j_calls->env, j_calls->o, j_calls->" + fn_line.group(2) + "_meth")
+ out_c.write("\treturn (*env)->Call" + ret_ty_info.java_ty.title() + "Method(env, j_calls->o, j_calls->" + fn_line.group(2) + "_meth")
else:
- out_c.write("\t" + fn_line.group(1).strip() + "* ret = (" + fn_line.group(1).strip() + "*)(*j_calls->env)->CallLongMethod(j_calls->env, j_calls->o, j_calls->" + fn_line.group(2) + "_meth");
+ out_c.write("\t" + fn_line.group(1).strip() + "* ret = (" + fn_line.group(1).strip() + "*)(*env)->CallLongMethod(env, j_calls->o, j_calls->" + fn_line.group(2) + "_meth");
for arg_info in arg_names:
if arg_info.ret_conv is not None:
out_c.write("static void " + struct_name + "_JCalls_free(void* this_arg) {\n")
out_c.write("\t" + struct_name + "_JCalls *j_calls = (" + struct_name + "_JCalls*) this_arg;\n")
out_c.write("\tif (atomic_fetch_sub_explicit(&j_calls->refcnt, 1, memory_order_acquire) == 1) {\n")
- out_c.write("\t\t(*j_calls->env)->DeleteGlobalRef(j_calls->env, j_calls->o);\n")
+ out_c.write("\t\tJNIEnv *env;\n")
+ out_c.write("\t\tDO_ASSERT((*j_calls->vm)->GetEnv(j_calls->vm, (void**)&env, JNI_VERSION_1_8) == JNI_OK);\n")
+ out_c.write("\t\t(*env)->DeleteGlobalRef(env, j_calls->o);\n")
out_c.write("\t\tFREE(j_calls);\n")
out_c.write("\t}\n}\n")
out_c.write(") {\n")
out_c.write("\tjclass c = (*env)->GetObjectClass(env, o);\n")
- out_c.write("\tassert(c != NULL);\n")
+ out_c.write("\tDO_ASSERT(c != NULL);\n")
out_c.write("\t" + struct_name + "_JCalls *calls = MALLOC(sizeof(" + struct_name + "_JCalls), \"" + struct_name + "_JCalls\");\n")
out_c.write("\tatomic_init(&calls->refcnt, 1);\n")
- out_c.write("\tcalls->env = env;\n")
+ out_c.write("\tDO_ASSERT((*env)->GetJavaVM(env, &calls->vm) == 0);\n")
out_c.write("\tcalls->o = (*env)->NewGlobalRef(env, o);\n")
for (fn_line, java_meth_descr) in zip(trait_fn_lines, java_meths):
if fn_line.group(2) != "free" and fn_line.group(2) != "clone":
out_c.write("\tcalls->" + fn_line.group(2) + "_meth = (*env)->GetMethodID(env, c, \"" + fn_line.group(2) + "\", \"" + java_meth_descr + "\");\n")
- out_c.write("\tassert(calls->" + fn_line.group(2) + "_meth != NULL);\n")
+ out_c.write("\tDO_ASSERT(calls->" + fn_line.group(2) + "_meth != NULL);\n")
out_c.write("\n\t" + struct_name + " ret = {\n")
out_c.write("\t\t.this_arg = (void*) calls,\n")
for fn_line in trait_fn_lines:
out_c.write("""#include \"org_ldk_impl_bindings.h\"
#include <rust_types.h>
#include <lightning.h>
-#include <assert.h>
#include <string.h>
#include <stdatomic.h>
""")
if sys.argv[4] == "false":
out_c.write("#define MALLOC(a, _) malloc(a)\n")
out_c.write("#define FREE free\n")
+ out_c.write("#define DO_ASSERT(a) (void)(a)\n")
else:
- out_c.write("""
+ out_c.write("""#include <assert.h>
+#define DO_ASSERT(a) do { bool _assert_val = (a); assert(_assert_val); } while(0)
+
#include <threads.h>
static mtx_t allocation_mtx;
void __attribute__((constructor)) init_mtx() {
- assert(mtx_init(&allocation_mtx, mtx_plain) == thrd_success);
+ DO_ASSERT(mtx_init(&allocation_mtx, mtx_plain) == thrd_success);
}
typedef struct allocation {
allocation* new_alloc = malloc(sizeof(allocation));
new_alloc->ptr = res;
new_alloc->struct_name = struct_name;
- assert(mtx_lock(&allocation_mtx) == thrd_success);
+ DO_ASSERT(mtx_lock(&allocation_mtx) == thrd_success);
new_alloc->next = allocation_ll;
allocation_ll = new_alloc;
- assert(mtx_unlock(&allocation_mtx) == thrd_success);
+ DO_ASSERT(mtx_unlock(&allocation_mtx) == thrd_success);
return res;
}
void FREE(void* ptr) {
allocation* p = NULL;
- assert(mtx_lock(&allocation_mtx) == thrd_success);
+ DO_ASSERT(mtx_lock(&allocation_mtx) == thrd_success);
allocation* it = allocation_ll;
while (it->ptr != ptr) { p = it; it = it->next; }
if (p) { p->next = it->next; } else { allocation_ll = it->next; }
- assert(mtx_unlock(&allocation_mtx) == thrd_success);
- assert(it->ptr == ptr);
+ DO_ASSERT(mtx_unlock(&allocation_mtx) == thrd_success);
+ DO_ASSERT(it->ptr == ptr);
free(it);
free(ptr);
}
void __attribute__((destructor)) check_leaks() {
for (allocation* a = allocation_ll; a != NULL; a = a->next) { fprintf(stderr, "%s %p remains\\n", a->struct_name, a->ptr); }
- assert(allocation_ll == NULL);
+ DO_ASSERT(allocation_ll == NULL);
}
""")
public static native boolean deref_bool(long ptr);
public static native long deref_long(long ptr);
public static native void free_heap_ptr(long ptr);
- public static native long u8_vec_len(long vec);
+ public static native byte[] get_u8_slice_bytes(long slice_ptr);
+ public static native long bytes_to_u8_vec(byte[] bytes);
+ public static native long vec_slice_len(long vec);
+ public static native long new_empty_slice_vec();
""")
out_c.write("""
jmethodID ordinal_meth = NULL;
JNIEXPORT void Java_org_ldk_impl_bindings_init(JNIEnv * env, jclass _b, jclass enum_class) {
ordinal_meth = (*env)->GetMethodID(env, enum_class, "ordinal", "()I");
- assert(ordinal_meth != NULL);
+ DO_ASSERT(ordinal_meth != NULL);
}
JNIEXPORT jboolean JNICALL Java_org_ldk_impl_bindings_deref_1bool (JNIEnv * env, jclass _a, jlong ptr) {
JNIEXPORT void JNICALL Java_org_ldk_impl_bindings_free_1heap_1ptr (JNIEnv * env, jclass _a, jlong ptr) {
FREE((void*)ptr);
}
-JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_u8_1vec_1len (JNIEnv * env, jclass _a, jlong ptr) {
+JNIEXPORT jbyteArray JNICALL Java_org_ldk_impl_bindings_get_1u8_1slice_1bytes (JNIEnv * _env, jclass _b, jlong slice_ptr) {
+ LDKu8slice *slice = (LDKu8slice*)slice_ptr;
+ jbyteArray ret_arr = (*_env)->NewByteArray(_env, slice->datalen);
+ (*_env)->SetByteArrayRegion(_env, ret_arr, 0, slice->datalen, slice->data);
+ return ret_arr;
+}
+JNIEXPORT long JNICALL Java_org_ldk_impl_bindings_bytes_1to_1u8_1vec (JNIEnv * _env, jclass _b, jbyteArray bytes) {
+ LDKCVec_u8Z *vec = (LDKCVec_u8Z*)MALLOC(sizeof(LDKCVec_u8Z), "LDKCVec_u8");
+ vec->datalen = (*_env)->GetArrayLength(_env, bytes);
+ vec->data = (uint8_t*)malloc(vec->datalen); // May be freed by rust, so don't track allocation
+ (*_env)->GetByteArrayRegion (_env, bytes, 0, vec->datalen, vec->data);
+ return (long)vec;
+}
+JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_vec_1slice_1len (JNIEnv * env, jclass _a, jlong ptr) {
+ // Check offsets of a few Vec types are all consistent as we're meant to be generic across types
+ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKCVec_SignatureZ, datalen), "Vec<*> needs to be mapped identically");
+ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKCVec_MessageSendEventZ, datalen), "Vec<*> needs to be mapped identically");
+ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKCVec_EventZ, datalen), "Vec<*> needs to be mapped identically");
+ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKCVec_C2Tuple_usizeTransactionZZ, datalen), "Vec<*> needs to be mapped identically");
LDKCVec_u8Z *vec = (LDKCVec_u8Z*)ptr;
return (long)vec->datalen;
}
+JNIEXPORT long JNICALL Java_org_ldk_impl_bindings_new_1empty_1slice_1vec (JNIEnv * _env, jclass _b) {
+ // Check sizes of a few Vec types are all consistent as we're meant to be generic across types
+ _Static_assert(sizeof(LDKCVec_u8Z) == sizeof(LDKCVec_SignatureZ), "Vec<*> needs to be mapped identically");
+ _Static_assert(sizeof(LDKCVec_u8Z) == sizeof(LDKCVec_MessageSendEventZ), "Vec<*> needs to be mapped identically");
+ _Static_assert(sizeof(LDKCVec_u8Z) == sizeof(LDKCVec_EventZ), "Vec<*> needs to be mapped identically");
+ _Static_assert(sizeof(LDKCVec_u8Z) == sizeof(LDKCVec_C2Tuple_usizeTransactionZZ), "Vec<*> needs to be mapped identically");
+ LDKCVec_u8Z *vec = (LDKCVec_u8Z*)MALLOC(sizeof(LDKCVec_u8Z), "Empty LDKCVec");
+ vec->data = NULL;
+ vec->datalen = 0;
+ return (long)vec;
+}
// We assume that CVec_u8Z and u8slice are the same size and layout (and thus pointers to the two can be mixed)
_Static_assert(sizeof(LDKCVec_u8Z) == sizeof(LDKu8slice), "Vec<u8> and [u8] need to have been mapped identically");
assert(not is_union or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque))
if is_opaque:
opaque_structs.add(struct_name)
+ out_java.write("\tpublic static native long " + struct_name + "_optional_none();\n")
+ out_c.write("JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1optional_1none (JNIEnv * env, jclass _a) {\n")
+ out_c.write("\t" + struct_name + " *ret = MALLOC(sizeof(" + struct_name + "), \"" + struct_name + "\");\n")
+ out_c.write("\tret->inner = NULL;\n")
+ out_c.write("\treturn (long)ret;\n")
+ out_c.write("}\n")
elif is_result:
result_templ_structs.add(struct_name)
elif is_unitary_enum:
out_c.write("\t\tcase %d: return %s;\n" % (ord_v, struct_line.strip().strip(",")))
ord_v = ord_v + 1
out_c.write("\t}\n")
- out_c.write("\tassert(false);\n")
+ out_c.write("\tabort();\n")
out_c.write("}\n")
ord_v = 0
out_c.write("static inline jclass " + struct_name + "_to_java(JNIEnv *env, " + struct_name + " val) {\n")
out_c.write("\t// TODO: This is pretty inefficient, we really need to cache the field IDs and class\n")
out_c.write("\tjclass enum_class = (*env)->FindClass(env, \"Lorg/ldk/impl/bindings$" + struct_name + ";\");\n")
- out_c.write("\tassert(enum_class != NULL);\n")
+ out_c.write("\tDO_ASSERT(enum_class != NULL);\n")
out_c.write("\tswitch (val) {\n")
for idx, struct_line in enumerate(field_lines):
if idx > 0 and idx < len(field_lines) - 3:
variant = struct_line.strip().strip(",")
out_c.write("\t\tcase " + variant + ": {\n")
out_c.write("\t\t\tjfieldID field = (*env)->GetStaticFieldID(env, enum_class, \"" + variant + "\", \"Lorg/ldk/impl/bindings$" + struct_name + ";\");\n")
- out_c.write("\t\t\tassert(field != NULL);\n")
+ out_c.write("\t\t\tDO_ASSERT(field != NULL);\n")
out_c.write("\t\t\treturn (*env)->GetStaticObjectField(env, enum_class, field);\n")
out_c.write("\t\t}\n")
ord_v = ord_v + 1
- out_c.write("\t\tdefault: assert(false);\n")
+ out_c.write("\t\tdefault: abort();\n")
out_c.write("\t}\n")
out_c.write("}\n\n")
elif len(trait_fn_lines) > 0: