]> git.bitcoin.ninja Git - ldk-java/commitdiff
Make String types language-specific and fix TS string conversion
authorMatt Corallo <git@bluematt.me>
Sun, 9 Jan 2022 06:24:30 +0000 (06:24 +0000)
committerMatt Corallo <git@bluematt.me>
Mon, 10 Jan 2022 06:33:14 +0000 (06:33 +0000)
gen_type_mapping.py
genbindings.py
java_strings.py
typescript_strings.py

index c7f75484daa660529eb9b4bb7da81806e577d7d0..b285ab34477554b4f60f3822b7b6093bd444ded0 100644 (file)
@@ -200,7 +200,7 @@ class TypeMappingGenerator:
                 return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                     arg_conv = arg_conv, arg_conv_name = arg_conv_name, arg_conv_cleanup = arg_conv_cleanup,
                     ret_conv = ret_conv, ret_conv_name = arr_name + "_arr", to_hu_conv = to_hu_conv, to_hu_conv_name = to_hu_conv_name, from_hu_conv = from_hu_conv)
-        elif ty_info.java_ty == "String":
+        elif ty_info.java_fn_ty_arg == "Ljava/lang/String;":
             assert not is_nullable
             if not is_free:
                 arg_conv = "LDKStr " + ty_info.var_name + "_conv = " + self.consts.str_ref_to_c_call(ty_info.var_name) + ";"
@@ -209,6 +209,7 @@ class TypeMappingGenerator:
                 arg_conv = "LDKStr dummy = { .chars = NULL, .len = 0, .chars_is_owned = false };"
                 arg_conv_name = "dummy"
             if ty_info.arr_access is None:
+                assert False
                 return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                     arg_conv = arg_conv, arg_conv_name = arg_conv_name, arg_conv_cleanup = None,
                     ret_conv = ("const char* " + ty_info.var_name + "_str = ",
@@ -219,11 +220,17 @@ class TypeMappingGenerator:
                 free_str = ""
                 if not holds_ref:
                     free_str = "\nStr_free(" + ty_info.var_name + "_str);"
+                to_hu_conv = self.consts.str_to_hu_conv(ty_info.var_name)
+                to_hu_conv_name = None
+                if to_hu_conv is not None:
+                    to_hu_conv_name = ty_info.var_name + "_conv"
                 return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                     arg_conv = arg_conv, arg_conv_name = arg_conv_name, arg_conv_cleanup = None,
                     ret_conv = ("LDKStr " + ty_info.var_name + "_str = ",
                         ";\njstring " + ty_info.var_name + "_conv = " + self.consts.str_ref_to_native_call(ty_info.var_name + "_str." + ty_info.arr_access, ty_info.var_name + "_str." + ty_info.arr_len) + ";" + free_str),
-                    ret_conv_name = ty_info.var_name + "_conv", to_hu_conv = None, to_hu_conv_name = None, from_hu_conv = None)
+                    ret_conv_name = ty_info.var_name + "_conv",
+                    to_hu_conv=to_hu_conv, to_hu_conv_name=to_hu_conv_name,
+                    from_hu_conv = self.consts.str_from_hu_conv(ty_info.var_name))
         elif ty_info.var_name == "" and not print_void:
             # We don't have a parameter name, and want one, just call it arg
             if not ty_info.is_native_primitive:
index 4a495a1ac51c5d414a1da3af6ac9d1d468b40613..34434b2ee33e4327b29073bb58d20b3861a11e33 100755 (executable)
@@ -285,13 +285,15 @@ def java_c_types(fn_arg, ret_arr_len):
             fn_arg = fn_arg[9:].strip()
         is_primitive = True
     elif is_const and fn_arg.startswith("char *"):
-        java_ty = "String"
+        java_ty = consts.java_type_map["String"]
+        java_hu_ty = consts.java_hu_type_map["String"]
         c_ty = "const char*"
         fn_ty_arg = "Ljava/lang/String;"
         fn_arg = fn_arg[6:].strip()
     elif fn_arg.startswith("LDKStr"):
         rust_obj = "LDKStr"
-        java_ty = "String"
+        java_ty = consts.java_type_map["String"]
+        java_hu_ty = consts.java_hu_type_map["String"]
         c_ty = "jstring"
         fn_ty_arg = "Ljava/lang/String;"
         fn_arg = fn_arg[6:].strip()
index c0eddca110ad91fdae9271ca41fd825e3e21a848..d477aa31575a93252ade4d6aa41ddebb3bee47f1 100644 (file)
@@ -16,6 +16,12 @@ class Consts:
             uint32_t = ['int'],
             uint64_t = ['long'],
         )
+        self.java_type_map = dict(
+            String = "String"
+        )
+        self.java_hu_type_map = dict(
+            String = "String"
+        )
 
         self.to_hu_conv_templates = dict(
             ptr = '{human_type} {var_name}_hu_conv = null; if ({var_name} < 0 || {var_name} > 4096) { {var_name}_hu_conv = new {human_type}(null, {var_name}); }',
@@ -631,6 +637,10 @@ import javax.annotation.Nullable;
         return "str_ref_to_java(env, " + var_name + ", " + str_len + ")"
     def str_ref_to_c_call(self, var_name):
         return "java_to_owned_str(env, " + var_name + ")"
+    def str_to_hu_conv(self, var_name):
+        return None
+    def str_from_hu_conv(self, var_name):
+        return None
 
     def c_fn_name_define_pfx(self, fn_name, has_args):
         if has_args:
@@ -950,7 +960,7 @@ import javax.annotation.Nullable;
                     out_c = out_c + "\t" + fn_line.ret_ty_info.c_ty + " ret = (*env)->CallObjectMethod(env, obj, j_calls->" + fn_line.fn_name + "_meth"
                 elif fn_line.ret_ty_info.c_ty == "void":
                     out_c += "\t(*env)->Call" + fn_line.ret_ty_info.java_ty.title() + "Method(env, obj, j_calls->" + fn_line.fn_name + "_meth"
-                elif fn_line.ret_ty_info.java_ty == "String":
+                elif fn_line.ret_ty_info.java_hu_ty == "String":
                     # Manually write out String methods as they're just an Object
                     out_c += "\t" + fn_line.ret_ty_info.c_ty + " ret = (*env)->CallObjectMethod(env, obj, j_calls->" + fn_line.fn_name + "_meth"
                 elif not fn_line.ret_ty_info.passed_as_ptr:
index 79976674bdcb295d2a1d6a2cdab3826bc60a89f3..892c63b802488a20b73b7101fa348e11b58afaee 100644 (file)
@@ -22,6 +22,12 @@ class Consts:
             uint32_t = ['number', 'Uint32Array'],
             uint64_t = ['BigInt'],
         )
+        self.java_type_map = dict(
+            String = "number"
+        )
+        self.java_hu_type_map = dict(
+            String = "string"
+        )
 
         self.wasm_decoding_map = dict(
             int8_tArray = 'decodeUint8Array'
@@ -353,6 +359,10 @@ import * as InternalUtils from '../InternalUtils.mjs'
         return "str_ref_to_ts(" + var_name + ", " + str_len + ")"
     def str_ref_to_c_call(self, var_name):
         return "str_ref_to_owned_c(" + var_name + ")"
+    def str_to_hu_conv(self, var_name):
+        return "const " + var_name + "_conv: string = bindings.decodeString(" + var_name + ");"
+    def str_from_hu_conv(self, var_name):
+        return ("bindings.encodeString(" + var_name + ")", "")
 
     def c_fn_name_define_pfx(self, fn_name, have_args):
         return " __attribute__((export_name(\"TS_" + fn_name + "\"))) TS_" + fn_name + "("
@@ -504,38 +514,22 @@ const decodeUint32Array = (arrayPointer: number, free = true) => {
        return actualArray;
 }
 
-const encodeString = (string: string) => {
-       // make malloc count divisible by 4
-       const memoryNeed = nextMultipleOfFour(string.length + 1);
-       const stringPointer = wasm.TS_malloc(memoryNeed);
-       const stringMemoryView = new Uint8Array(
-               wasm.memory.buffer, // value
-               stringPointer, // offset
-               string.length + 1 // length
-       );
-       for (let i = 0; i < string.length; i++) {
-               stringMemoryView[i] = string.charCodeAt(i);
-       }
-       stringMemoryView[string.length] = 0;
-       return stringPointer;
+export function encodeString(str: string): number {
+       const charArray = new TextEncoder().encode(str);
+       return encodeUint8Array(charArray);
 }
 
-const decodeString = (stringPointer: number, free = true) => {
-       const memoryView = new Uint8Array(wasm.memory.buffer, stringPointer);
-       let cursor = 0;
-       let result = '';
-
-       while (memoryView[cursor] !== 0) {
-               result += String.fromCharCode(memoryView[cursor]);
-               cursor++;
-       }
+export function decodeString(stringPointer: number, free = true): string {
+       const arraySize = getArrayLength(stringPointer);
+       const memoryView = new Uint8Array(wasm.memory.buffer, stringPointer + 4, arraySize);
+       const result = new TextDecoder("utf-8").decode(memoryView);
 
        if (free) {
-               wasm.wasm_free(stringPointer);
+               wasm.TS_free(stringPointer);
        }
 
        return result;
-};
+}
 """
 
     def init_str(self):
@@ -816,7 +810,7 @@ export class {struct_name.replace("LDK","")} extends CommonBase {{
                     out_c += "js_invoke_function_" + str(len(fn_line.args_ty)) + "(j_calls->instance_ptr, " + str(self.function_ptr_counter)
                 elif fn_line.ret_ty_info.java_ty == "void":
                     out_c = out_c + "\tjs_invoke_function_" + str(len(fn_line.args_ty)) + "(j_calls->instance_ptr, " + str(self.function_ptr_counter)
-                elif fn_line.ret_ty_info.java_ty == "String":
+                elif fn_line.ret_ty_info.java_hu_ty == "string":
                     out_c = out_c + "\tjstring ret = (jstring)js_invoke_function_" + str(len(fn_line.args_ty)) + "(j_calls->instance_ptr, " + str(self.function_ptr_counter)
                 elif not fn_line.ret_ty_info.passed_as_ptr:
                     out_c = out_c + "\treturn js_invoke_function_" + str(len(fn_line.args_ty)) + "(j_calls->instance_ptr, " + str(self.function_ptr_counter)