]> git.bitcoin.ninja Git - ldk-java/commitdiff
override CommonBase file, and make private constructor work for mapped traits
authorArik Sosman <git@arik.io>
Wed, 13 Jan 2021 10:43:04 +0000 (02:43 -0800)
committerArik Sosman <git@arik.io>
Wed, 13 Jan 2021 10:43:04 +0000 (02:43 -0800)
genbindings.py
typescript_strings.py

index d0e080a8b3557efe369dbb396347430d351a9aa1..ecdb6bd7a59aa479cd0b697047ffbe9ca68d30d2 100755 (executable)
@@ -924,7 +924,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
             write_c(c_out)
             out_java_enum.write(native_file_out)
             out_java.write(native_out)
+
     def map_complex_enum(struct_name, union_enum_items):
         java_hu_type = struct_name.replace("LDK", "")
         complex_enums.add(struct_name)
@@ -1232,7 +1232,7 @@ public class bindings {
 
 """)
 
-    with open(f"{sys.argv[3]}/structs/CommonBase{consts.file_ext}", "a") as out_java_struct:
+    with open(f"{sys.argv[3]}/structs/CommonBase{consts.file_ext}", "w") as out_java_struct:
         out_java_struct.write(consts.common_base)
 
     in_block_comment = False
index edd1cfc795092f70a1cfa63ccd3b794b81e8426f..3c44daea0e8d996bf308069ed86bbf75629f0f24 100644 (file)
@@ -1,11 +1,22 @@
+from bindingstypes import ConvInfo
+
+
+def first_to_lower(string: str) -> str:
+    first = string[0]
+    return first.lower() + string[1:]
+
+
 class Consts:
     def __init__(self, DEBUG):
         self.common_base = """
             export default class CommonBase {
                 ptr: number;
-                ptrs_to: object[] = new Array(); // new LinkedList(); TODO: build linked list implementation
+                ptrs_to: object[] = []; // new LinkedList(); TODO: build linked list implementation
                 protected constructor(ptr: number) { this.ptr = ptr; }
                 public _test_only_get_ptr(): number { return this.ptr; }
+                protected finalize() {
+                    // TODO: finalize myself
+                }
             }
         """
 
@@ -237,8 +248,306 @@ import * as bindings from '../bindings' // TODO: figure out location
             ret = ret + "; (void) " + param
         return ret
 
-    def native_c_map_trait(self, struct_name, field_var_convs, field_fn_lines):
-        return ("", "", "")
+    def native_c_map_trait(self, struct_name, field_var_conversions, field_function_lines):
+        out_java = "out_java:native_c_map_trait"
+        out_java_trait = "out_java_trait:native_c_map_trait"
+        out_c = "out_c:native_c_map_trait"
+
+
+        out_java_trait = ""
+        out_java = ""
+
+        constructor_arguments = ""
+        super_instantiator = ""
+        pointer_to_adder = ""
+        for var in field_var_conversions:
+            if isinstance(var, ConvInfo):
+                constructor_arguments += f", {first_to_lower(var.arg_name)}?: {var.java_hu_ty}"
+                if var.from_hu_conv is not None:
+                    super_instantiator += ", " + var.from_hu_conv[0]
+                    if var.from_hu_conv[1] != "":
+                        pointer_to_adder += var.from_hu_conv[1] + ";\n"
+                else:
+                    super_instantiator += ", " + first_to_lower(var.arg_name)
+            else:
+                constructor_arguments += f", {first_to_lower(var[1])}?: bindings.{var[0]}"
+                super_instantiator += ", " + first_to_lower(var[1])
+                pointer_to_adder += "this.ptrs_to.push(" + first_to_lower(var[1]) + ");\n"
+
+        out_java_trait = f"""
+            {self.hu_struct_file_prefix}
+            
+            export class {struct_name.replace("LDK","")} extends CommonBase {{
+            
+                bindings_instance?: bindings.{struct_name};
+                
+                constructor(ptr?: number, arg?: bindings.{struct_name}{constructor_arguments}) {{
+                    if (Number.isFinite(ptr)) {{
+                                       super(ptr);
+                                       this.bindings_instance = null;
+                                   }} else {{
+                                       // TODO: private constructor instantiation
+                                       super(bindings.{struct_name}_new(arg{super_instantiator}));
+                                       this.ptrs_to.push(arg);
+                                       {pointer_to_adder}
+                                   }}
+                }}
+                
+                protected finalize() {{
+                    if (this.ptr != 0) {{ 
+                        bindings.{struct_name.replace("LDK","")}_free(this.ptr); 
+                    }} 
+                    super.finalize();
+                }}
+                
+            }}
+        """
+
+
+
+        java_trait_constr = "\tprivate static class " + struct_name + "Holder { " + struct_name.replace("LDK", "") + " held; }\n"
+        java_trait_constr = java_trait_constr + "\tpublic static " + struct_name.replace("LDK", "") + " new_impl(" + struct_name.replace("LDK", "") + "Interface arg"
+        for var in field_var_conversions:
+            if isinstance(var, ConvInfo):
+                java_trait_constr = java_trait_constr + ", " + var.java_hu_ty + " " + var.arg_name
+            else:
+                # Ideally we'd be able to take any instance of the interface, but our C code can only represent
+                # Java-implemented version, so we require users pass a Java implementation here :/
+                java_trait_constr = java_trait_constr + ", " + var[0].replace("LDK", "") + "." + var[0].replace("LDK", "") + "Interface " + var[1] + "_impl"
+        java_trait_constr = java_trait_constr + ") {\n\t\tfinal " + struct_name + "Holder impl_holder = new " + struct_name + "Holder();\n"
+        java_trait_constr = java_trait_constr + "\t\timpl_holder.held = new " + struct_name.replace("LDK", "") + "(new bindings." + struct_name + "() {\n"
+        out_java_trait = out_java_trait + "\tpublic static interface " + struct_name.replace("LDK", "") + "Interface {\n"
+        out_java = out_java + "\tpublic interface " + struct_name + " {\n"
+        java_meths = []
+        for fn_line in field_function_lines:
+            java_meth_descr = "("
+            if fn_line.fn_name != "free" and fn_line.fn_name != "clone":
+                out_java = out_java + "\t\t " + fn_line.ret_ty_info.java_ty + " " + fn_line.fn_name + "("
+                java_trait_constr = java_trait_constr + "\t\t\t@Override public " + fn_line.ret_ty_info.java_ty + " " + fn_line.fn_name + "("
+                out_java_trait = out_java_trait + "\t\t" + fn_line.ret_ty_info.java_hu_ty + " " + fn_line.fn_name + "("
+
+                for idx, arg_conv_info in enumerate(fn_line.args_ty):
+                    if idx >= 1:
+                        out_java = out_java + ", "
+                        java_trait_constr = java_trait_constr + ", "
+                        out_java_trait = out_java_trait + ", "
+                    out_java = out_java + arg_conv_info.java_ty + " " + arg_conv_info.arg_name
+                    out_java_trait = out_java_trait + arg_conv_info.java_hu_ty + " " + arg_conv_info.arg_name
+                    java_trait_constr = java_trait_constr + arg_conv_info.java_ty + " " + arg_conv_info.arg_name
+                    java_meth_descr = java_meth_descr + arg_conv_info.java_fn_ty_arg
+                java_meth_descr = java_meth_descr + ")" + fn_line.ret_ty_info.java_fn_ty_arg
+                java_meths.append((fn_line.fn_name, java_meth_descr))
+
+                out_java = out_java + ");\n"
+                out_java_trait = out_java_trait + ");\n"
+                java_trait_constr = java_trait_constr + ") {\n"
+
+                for arg_info in fn_line.args_ty:
+                    if arg_info.to_hu_conv is not None:
+                        java_trait_constr = java_trait_constr + "\t\t\t\t" + arg_info.to_hu_conv.replace("\n", "\n\t\t\t\t") + "\n"
+
+                if fn_line.ret_ty_info.java_ty != "void":
+                    java_trait_constr = java_trait_constr + "\t\t\t\t" + fn_line.ret_ty_info.java_hu_ty + " ret = arg." + fn_line.fn_name + "("
+                else:
+                    java_trait_constr = java_trait_constr + "\t\t\t\targ." + fn_line.fn_name + "("
+
+                for idx, arg_info in enumerate(fn_line.args_ty):
+                    if idx != 0:
+                        java_trait_constr = java_trait_constr + ", "
+                    if arg_info.to_hu_conv_name is not None:
+                        java_trait_constr = java_trait_constr + arg_info.to_hu_conv_name
+                    else:
+                        java_trait_constr = java_trait_constr + arg_info.arg_name
+
+                java_trait_constr = java_trait_constr + ");\n"
+                if fn_line.ret_ty_info.java_ty != "void":
+                    if fn_line.ret_ty_info.from_hu_conv is not None:
+                        java_trait_constr = java_trait_constr + "\t\t\t\t" + fn_line.ret_ty_info.java_ty + " result = " + fn_line.ret_ty_info.from_hu_conv[0] + ";\n"
+                        if fn_line.ret_ty_info.from_hu_conv[1] != "":
+                            java_trait_constr = java_trait_constr + "\t\t\t\t" + fn_line.ret_ty_info.from_hu_conv[1].replace("this", "impl_holder.held") + ";\n"
+                        #if fn_line.ret_ty_info.rust_obj in result_types:
+                        # XXX: We need to handle this in conversion logic so that its cross-language!
+                        # Avoid double-free by breaking the result - we should learn to clone these and then we can be safe instead
+                        #    java_trait_constr = java_trait_constr + "\t\t\t\tret.ptr = 0;\n"
+                        java_trait_constr = java_trait_constr + "\t\t\t\treturn result;\n"
+                    else:
+                        java_trait_constr = java_trait_constr + "\t\t\t\treturn ret;\n"
+                java_trait_constr = java_trait_constr + "\t\t\t}\n"
+        java_trait_constr = java_trait_constr + "\t\t}"
+        for var in field_var_conversions:
+            if isinstance(var, ConvInfo):
+                java_trait_constr = java_trait_constr + ", " + var.arg_name
+            else:
+                java_trait_constr = java_trait_constr + ", " + var[1] + ".new_impl(" + var[1] + "_impl).bindings_instance"
+        out_java_trait = out_java_trait + "\t}\n"
+        out_java_trait = out_java_trait + java_trait_constr + ");\n\t\treturn impl_holder.held;\n\t}\n"
+
+        out_java = out_java + "\t}\n"
+
+        out_java = out_java + "\tpublic static native long " + struct_name + "_new(" + struct_name + " impl"
+        for var in field_var_conversions:
+            if isinstance(var, ConvInfo):
+                out_java = out_java + ", " + var.java_ty + " " + var.arg_name
+            else:
+                out_java = out_java + ", " + var[0] + " " + var[1]
+        out_java = out_java + ");\n"
+        out_java = out_java + "\tpublic static native " + struct_name + " " + struct_name + "_get_obj_from_jcalls(long val);\n"
+
+        # Now that we've written out our java code (and created java_meths), generate C
+        out_c = "typedef struct " + struct_name + "_JCalls {\n"
+        out_c = out_c + "\tatomic_size_t refcnt;\n"
+        out_c = out_c + "\tJavaVM *vm;\n"
+        out_c = out_c + "\tjweak o;\n"
+        for var in field_var_conversions:
+            if isinstance(var, ConvInfo):
+                # We're a regular ol' field
+                pass
+            else:
+                # We're a supertrait
+                out_c = out_c + "\t" + var[0] + "_JCalls* " + var[1] + ";\n"
+        for fn in field_function_lines:
+            if fn.fn_name != "free" and fn.fn_name != "clone":
+                out_c = out_c + "\tjmethodID " + fn.fn_name + "_meth;\n"
+        out_c = out_c + "} " + struct_name + "_JCalls;\n"
+
+        for fn_line in field_function_lines:
+            if fn_line.fn_name == "free":
+                out_c = out_c + "static void " + struct_name + "_JCalls_free(void* this_arg) {\n"
+                out_c = out_c + "\t" + struct_name + "_JCalls *j_calls = (" + struct_name + "_JCalls*) this_arg;\n"
+                out_c = out_c + "\tif (atomic_fetch_sub_explicit(&j_calls->refcnt, 1, memory_order_acquire) == 1) {\n"
+                out_c = out_c + "\t\tJNIEnv *env;\n"
+                out_c = out_c + "\t\tDO_ASSERT((*j_calls->vm)->GetEnv(j_calls->vm, (void**)&env, JNI_VERSION_1_8) == JNI_OK);\n"
+                out_c = out_c + "\t\t(*env)->DeleteWeakGlobalRef(env, j_calls->o);\n"
+                out_c = out_c + "\t\tFREE(j_calls);\n"
+                out_c = out_c + "\t}\n}\n"
+
+        for idx, fn_line in enumerate(field_function_lines):
+            if fn_line.fn_name != "free" and fn_line.fn_name != "clone":
+                assert fn_line.ret_ty_info.ty_info.get_full_rust_ty()[1] == ""
+                out_c = out_c + fn_line.ret_ty_info.ty_info.get_full_rust_ty()[0] + " " + fn_line.fn_name + "_jcall("
+                if fn_line.self_is_const:
+                    out_c = out_c + "const void* this_arg"
+                else:
+                    out_c = out_c + "void* this_arg"
+
+                for idx, arg in enumerate(fn_line.args_ty):
+                    out_c = out_c + ", " + arg.ty_info.get_full_rust_ty()[0] + " " + arg.arg_name + arg.ty_info.get_full_rust_ty()[1]
+
+                out_c = out_c + ") {\n"
+                out_c = out_c + "\t" + struct_name + "_JCalls *j_calls = (" + struct_name + "_JCalls*) this_arg;\n"
+                out_c = out_c + "\tJNIEnv *env;\n"
+                out_c = out_c + "\tDO_ASSERT((*j_calls->vm)->GetEnv(j_calls->vm, (void**)&env, JNI_VERSION_1_8) == JNI_OK);\n"
+
+                for arg_info in fn_line.args_ty:
+                    if arg_info.ret_conv is not None:
+                        out_c = out_c + "\t" + arg_info.ret_conv[0].replace('\n', '\n\t')
+                        out_c = out_c + arg_info.arg_name
+                        out_c = out_c + arg_info.ret_conv[1].replace('\n', '\n\t') + "\n"
+
+                out_c = out_c + "\tjobject obj = (*env)->NewLocalRef(env, j_calls->o);\n\tCHECK(obj != NULL);\n"
+                if fn_line.ret_ty_info.c_ty.endswith("Array"):
+                    out_c = out_c + "\t" + fn_line.ret_ty_info.c_ty + " arg = (*env)->CallObjectMethod(env, obj, j_calls->" + fn_line.fn_name + "_meth"
+                elif not fn_line.ret_ty_info.passed_as_ptr:
+                    out_c = out_c + "\treturn (*env)->Call" + fn_line.ret_ty_info.java_ty.title() + "Method(env, obj, j_calls->" + fn_line.fn_name + "_meth"
+                else:
+                    out_c = out_c + "\t" + fn_line.ret_ty_info.rust_obj + "* ret = (" + fn_line.ret_ty_info.rust_obj + "*)(*env)->CallLongMethod(env, obj, j_calls->" + fn_line.fn_name + "_meth"
+
+                for idx, arg_info in enumerate(fn_line.args_ty):
+                    if arg_info.ret_conv is not None:
+                        out_c = out_c + ", " + arg_info.ret_conv_name
+                    else:
+                        out_c = out_c + ", " + arg_info.arg_name
+                out_c = out_c + ");\n"
+                if fn_line.ret_ty_info.arg_conv is not None:
+                    out_c = out_c + "\t" + fn_line.ret_ty_info.arg_conv.replace("\n", "\n\t") + "\n\treturn " + fn_line.ret_ty_info.arg_conv_name + ";\n"
+
+                out_c = out_c + "}\n"
+
+        # Write out a clone function whether we need one or not, as we use them in moving to rust
+        out_c = out_c + "static void* " + struct_name + "_JCalls_clone(const void* this_arg) {\n"
+        out_c = out_c + "\t" + struct_name + "_JCalls *j_calls = (" + struct_name + "_JCalls*) this_arg;\n"
+        out_c = out_c + "\tatomic_fetch_add_explicit(&j_calls->refcnt, 1, memory_order_release);\n"
+        for var in field_var_conversions:
+            if not isinstance(var, ConvInfo):
+                out_c = out_c + "\tatomic_fetch_add_explicit(&j_calls->" + var[1] + "->refcnt, 1, memory_order_release);\n"
+        out_c = out_c + "\treturn (void*) this_arg;\n"
+        out_c = out_c + "}\n"
+
+        out_c = out_c + "static inline " + struct_name + " " + struct_name + "_init (" + self.c_fn_args_pfx + ", jobject o"
+        for var in field_var_conversions:
+            if isinstance(var, ConvInfo):
+                out_c = out_c + ", " + var.c_ty + " " + var.arg_name
+            else:
+                out_c = out_c + ", jobject " + var[1]
+        out_c = out_c + ") {\n"
+
+        out_c = out_c + "\tjclass c = (*env)->GetObjectClass(env, o);\n"
+        out_c = out_c + "\tCHECK(c != NULL);\n"
+        out_c = out_c + "\t" + struct_name + "_JCalls *calls = MALLOC(sizeof(" + struct_name + "_JCalls), \"" + struct_name + "_JCalls\");\n"
+        out_c = out_c + "\tatomic_init(&calls->refcnt, 1);\n"
+        out_c = out_c + "\tDO_ASSERT((*env)->GetJavaVM(env, &calls->vm) == 0);\n"
+        out_c = out_c + "\tcalls->o = (*env)->NewWeakGlobalRef(env, o);\n"
+
+        for (fn_name, java_meth_descr) in java_meths:
+            if fn_name != "free" and fn_name != "clone":
+                out_c = out_c + "\tcalls->" + fn_name + "_meth = (*env)->GetMethodID(env, c, \"" + fn_name + "\", \"" + java_meth_descr + "\");\n"
+                out_c = out_c + "\tCHECK(calls->" + fn_name + "_meth != NULL);\n"
+
+        for var in field_var_conversions:
+            if isinstance(var, ConvInfo) and var.arg_conv is not None:
+                out_c = out_c + "\n\t" + var.arg_conv.replace("\n", "\n\t") +"\n"
+        out_c = out_c + "\n\t" + struct_name + " ret = {\n"
+        out_c = out_c + "\t\t.this_arg = (void*) calls,\n"
+        for fn_line in field_function_lines:
+            if fn_line.fn_name != "free" and fn_line.fn_name != "clone":
+                out_c = out_c + "\t\t." + fn_line.fn_name + " = " + fn_line.fn_name + "_jcall,\n"
+            elif fn_line.fn_name == "free":
+                out_c = out_c + "\t\t.free = " + struct_name + "_JCalls_free,\n"
+            else:
+                out_c = out_c + "\t\t.clone = " + struct_name + "_JCalls_clone,\n"
+        for var in field_var_conversions:
+            if isinstance(var, ConvInfo):
+                if var.arg_conv_name is not None:
+                    out_c = out_c + "\t\t." + var.arg_name + " = " + var.arg_conv_name + ",\n"
+                    out_c = out_c + "\t\t.set_" + var.arg_name + " = NULL,\n"
+                else:
+                    out_c = out_c + "\t\t." + var.var_name + " = " + var.var_name + ",\n"
+                    out_c = out_c + "\t\t.set_" + var.var_name + " = NULL,\n"
+            else:
+                out_c = out_c + "\t\t." + var[1] + " = " + var[0] + "_init(env, clz, " + var[1] + "),\n"
+        out_c = out_c + "\t};\n"
+        for var in field_var_conversions:
+            if not isinstance(var, ConvInfo):
+                out_c = out_c + "\tcalls->" + var[1] + " = ret." + var[1] + ".this_arg;\n"
+        out_c = out_c + "\treturn ret;\n"
+        out_c = out_c + "}\n"
+
+        out_c = out_c + self.c_fn_ty_pfx + "long " + self.c_fn_name_pfx + struct_name.replace("_", "_1") + "_1new (" + self.c_fn_args_pfx + ", jobject o"
+        for var in field_var_conversions:
+            if isinstance(var, ConvInfo):
+                out_c = out_c + ", " + var.c_ty + " " + var.arg_name
+            else:
+                out_c = out_c + ", jobject " + var[1]
+        out_c = out_c + ") {\n"
+        out_c = out_c + "\t" + struct_name + " *res_ptr = MALLOC(sizeof(" + struct_name + "), \"" + struct_name + "\");\n"
+        out_c = out_c + "\t*res_ptr = " + struct_name + "_init(env, clz, o"
+        for var in field_var_conversions:
+            if isinstance(var, ConvInfo):
+                out_c = out_c + ", " + var.arg_name
+            else:
+                out_c = out_c + ", " + var[1]
+        out_c = out_c + ");\n"
+        out_c = out_c + "\treturn (long)res_ptr;\n"
+        out_c = out_c + "}\n"
+
+        out_c = out_c + self.c_fn_ty_pfx + "jobject " + self.c_fn_name_pfx + struct_name.replace("_", "_1") + "_1get_1obj_1from_1jcalls (" + self.c_fn_args_pfx + ", " + self.ptr_c_ty + " val) {\n"
+        out_c = out_c + "\tjobject ret = (*env)->NewLocalRef(env, ((" + struct_name + "_JCalls*)val)->o);\n"
+        out_c = out_c + "\tCHECK(ret != NULL);\n"
+        out_c = out_c + "\treturn ret;\n"
+        out_c = out_c + "}\n"
+
+
+        return (out_java, out_java_trait, out_c)
 
     def map_complex_enum(self, struct_name, variant_list, camel_to_snake):
         java_hu_type = struct_name.replace("LDK", "")