Return, instead of writing, form map_type
[ldk-java] / genbindings.py
1 #!/usr/bin/env python3
2 import sys, re
3
4 if len(sys.argv) != 4:
5     print("USAGE: /path/to/lightning.h /path/to/bindings/output.java /path/to/bindings/output.c")
6     sys.exit(1)
7
8 with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.argv[3], "w") as out_c:
9     opaque_structs = set()
10
11     var_is_arr_regex = re.compile("\(\*([A-za-z_]*)\)\[([0-9]*)\]")
12     var_ty_regex = re.compile("([A-za-z_0-9]*)(.*)")
13     def java_c_types(fn_arg, ret_arr_len):
14         fn_arg = fn_arg.strip()
15         if fn_arg.startswith("MUST_USE_RES "):
16             fn_arg = fn_arg[13:]
17         if fn_arg.startswith("const "):
18             fn_arg = fn_arg[6:]
19
20         is_ptr = False
21         take_by_ptr = False
22         if fn_arg.startswith("void"):
23             java_ty = "void"
24             c_ty = "void"
25             fn_arg = fn_arg[4:].strip()
26         elif fn_arg.startswith("bool"):
27             java_ty = "boolean"
28             c_ty = "jboolean"
29             fn_arg = fn_arg[4:].strip()
30         elif fn_arg.startswith("uint8_t"):
31             java_ty = "byte"
32             c_ty = "jbyte"
33             fn_arg = fn_arg[7:].strip()
34         elif fn_arg.startswith("uint16_t"):
35             java_ty = "short"
36             c_ty = "jshort"
37             fn_arg = fn_arg[8:].strip()
38         elif fn_arg.startswith("uint32_t"):
39             java_ty = "int"
40             c_ty = "jint"
41             fn_arg = fn_arg[8:].strip()
42         elif fn_arg.startswith("uint64_t"):
43             java_ty = "long"
44             c_ty = "jlong"
45             fn_arg = fn_arg[8:].strip()
46         else:
47             ma = var_ty_regex.match(fn_arg)
48             java_ty = "long"
49             c_ty = "jlong"
50             fn_arg = ma.group(2).strip()
51             take_by_ptr = True
52
53         if fn_arg.startswith(" *") or fn_arg.startswith("*"):
54             fn_arg = fn_arg.replace("*", "").strip()
55             is_ptr = True
56             c_ty = "jlong"
57             java_ty = "long"
58
59         var_is_arr = var_is_arr_regex.match(fn_arg)
60         if var_is_arr is not None or ret_arr_len is not None:
61             assert(not take_by_ptr)
62             java_ty = java_ty + "[]"
63             c_ty = c_ty + "Array"
64             if var_is_arr is not None:
65                 return (java_ty, c_ty, is_ptr, False, var_is_arr.group(1))
66         return (java_ty, c_ty, is_ptr or take_by_ptr, is_ptr, fn_arg)
67
68     class TypeInfo:
69         def __init__(self, c_ty, java_ty, arg_name, arg_conv, ret_conv, arg_conv_name):
70             assert(c_ty is not None)
71             assert(java_ty is not None)
72             assert(arg_name is not None)
73             self.c_ty = c_ty
74             self.java_ty = java_ty
75             self.arg_name = arg_name
76             self.arg_conv = arg_conv
77             self.ret_conv = ret_conv
78             self.arg_conv_name = arg_conv_name
79
80         def print_ty(self):
81             out_c.write(self.c_ty)
82             out_java.write(self.java_ty)
83
84         def print_name(self):
85             if self.arg_name != "":
86                 out_java.write(" " + self.arg_name)
87                 out_c.write(" " + self.arg_name)
88             else:
89                 out_java.write(" arg")
90                 out_c.write(" arg")
91
92     def map_type(fn_arg, print_void, ret_arr_len, is_free):
93         fn_arg = fn_arg.strip()
94         if fn_arg.startswith("MUST_USE_RES "):
95             fn_arg = fn_arg[13:]
96         if fn_arg.startswith("const "):
97             fn_arg = fn_arg[6:]
98
99         (java_ty, c_ty, is_ptr, rust_takes_ptr, var_name) = java_c_types(fn_arg, ret_arr_len)
100         is_ptr_to_obj = None
101         if fn_arg.startswith("void"):
102             if not print_void:
103                 return TypeInfo(c_ty = c_ty, java_ty = java_ty, arg_name = var_name,
104                     arg_conv = None, ret_conv = None, arg_conv_name = None)
105             fn_arg = fn_arg.strip("void ")
106         elif not is_ptr:
107             split = fn_arg.split(" ", 2)
108             if len(split) > 1:
109                 fn_arg = split[1]
110             else:
111                 fn_arg = ""
112         else:
113             ma = var_ty_regex.match(fn_arg)
114             is_ptr_to_obj = ma.group(1)
115             fn_arg = ma.group(2)
116
117         var_is_arr = var_is_arr_regex.match(fn_arg)
118         if c_ty.endswith("Array"):
119             if var_is_arr is not None:
120                 arr_name = var_name
121                 arr_len = var_is_arr.group(2)
122             else:
123                 arr_name = "ret"
124                 arr_len = ret_arr_len
125             assert(c_ty == "jbyteArray")
126             return TypeInfo(c_ty = c_ty, java_ty = java_ty, arg_name = var_name,
127                 arg_conv = "unsigned char " + arr_name + "_arr[" + arr_len + "];\n" +
128                     "(*_env)->GetByteArrayRegion (_env, """ + arr_name + ", 0, " + arr_len + ", " + arr_name + "_arr);\n" +
129                     "unsigned char (*""" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;",
130                 ret_conv = (c_ty + " " + arr_name + "_arr = (*_env)->NewByteArray(_env, " + arr_len + ");\n" +
131                     "(*_env)->SetByteArrayRegion(_env, " + arr_name + "_arr, 0, " + arr_len + ", *",
132                     ");\nreturn ret_arr;"),
133                 arg_conv_name = arr_name + "_ref")
134         elif var_name != "":
135             # If we have a parameter name, print it (noting that it may indicate its a pointer)
136             if is_ptr_to_obj is not None:
137                 assert(is_ptr)
138                 if not rust_takes_ptr:
139                     base_conv = is_ptr_to_obj + " " + var_name + "_conv = *(" + is_ptr_to_obj + "*)" + var_name + ";\nfree((void*)" + var_name + ");";
140                     if is_ptr_to_obj in opaque_structs:
141                         return TypeInfo(c_ty = c_ty, java_ty = java_ty, arg_name = var_name,
142                             arg_conv = base_conv + "\n" + var_name + "_conv._underlying_ref = false;",
143                             ret_conv = None, arg_conv_name = var_name + "_conv")
144                     return TypeInfo(c_ty = c_ty, java_ty = java_ty, arg_name = var_name,
145                         arg_conv = base_conv, ret_conv = None, arg_conv_name = var_name + "_conv")
146                 else:
147                     assert(not is_free)
148                     return TypeInfo(c_ty = c_ty, java_ty = java_ty, arg_name = var_name,
149                         arg_conv = is_ptr_to_obj + "* " + var_name + "_conv = (" + is_ptr_to_obj + "*)" + var_name + ";",
150                             ret_conv = None, arg_conv_name = var_name + "_conv")
151             elif rust_takes_ptr:
152                 return TypeInfo(c_ty = c_ty, java_ty = java_ty, arg_name = var_name,
153                     arg_conv = None, ret_conv = None, arg_conv_name = var_name)
154             else:
155                 return TypeInfo(c_ty = c_ty, java_ty = java_ty, arg_name = var_name,
156                     arg_conv = None, ret_conv = None, arg_conv_name = var_name)
157         elif not print_void:
158             # We don't have a parameter name, and want one, just call it arg
159             if is_ptr_to_obj is not None:
160                 assert(not is_free or is_ptr_to_obj not in opaque_structs);
161                 return TypeInfo(c_ty = c_ty, java_ty = java_ty, arg_name = var_name,
162                     arg_conv = is_ptr_to_obj + " arg_conv = *(" + is_ptr_to_obj + "*)arg;\nfree((void*)arg);",
163                     ret_conv = None, arg_conv_name = "arg_conv")
164             else:
165                 assert(not is_free)
166                 return TypeInfo(c_ty = c_ty, java_ty = java_ty, arg_name = var_name,
167                     arg_conv = None, ret_conv = None, arg_conv_name = "arg")
168         else:
169             # We don't have a parameter name, and don't want one (cause we're returning)
170             if is_ptr_to_obj is not None:
171                 if not rust_takes_ptr:
172                     if is_ptr_to_obj in opaque_structs:
173                         # If we're returning a newly-allocated struct, we don't want Rust to ever
174                         # free, instead relying on the Java GC to lose the ref. We undo this in
175                         # any _free function.
176                         # To avoid any issues, we first assert that the incoming object is non-ref.
177                         return TypeInfo(c_ty = c_ty, java_ty = java_ty, arg_name = var_name,
178                             ret_conv = (is_ptr_to_obj + "* ret = malloc(sizeof(" + is_ptr_to_obj + "));\n*ret = ", ";\nassert(!ret->_underlying_ref);\nret->_underlying_ref = true;\nreturn (long)ret;"),
179                             arg_conv = None, arg_conv_name = None)
180                     else:
181                         return TypeInfo(c_ty = c_ty, java_ty = java_ty, arg_name = var_name,
182                             ret_conv = (is_ptr_to_obj + "* ret = malloc(sizeof(" + is_ptr_to_obj + "));\n*ret = ", ";\nreturn (long)ret;"),
183                             arg_conv = None, arg_conv_name = None)
184                 else:
185                     return TypeInfo(c_ty = c_ty, java_ty = java_ty, arg_name = var_name,
186                         ret_conv = ("return (long) ", ";"),
187                         arg_conv = None, arg_conv_name = None)
188             else:
189                 return TypeInfo(c_ty = c_ty, java_ty = java_ty, arg_name = var_name,
190                     arg_conv = None, ret_conv = None, arg_conv_name = None)
191
192     def map_fn(re_match, ret_arr_len):
193         out_java.write("\t/// " + line)
194         out_java.write("\tpublic static native ")
195         out_c.write("JNIEXPORT ")
196
197         ret_info = map_type(re_match.group(1), True, ret_arr_len, False)
198         ret_info.print_ty()
199         if ret_info.ret_conv is not None:
200             ret_conv_pfx, ret_conv_sfx = ret_info.ret_conv
201
202         out_java.write(" " + re_match.group(2) + "(")
203         out_c.write(" JNICALL Java_org_ldk_impl_bindings_" + re_match.group(2).replace('_', '_1') + "(JNIEnv * _env, jclass _b")
204
205         arg_names = []
206         for idx, arg in enumerate(re_match.group(3).split(',')):
207             if idx != 0:
208                 out_java.write(", ")
209             if arg != "void":
210                 out_c.write(", ")
211             arg_conv_info = map_type(arg, False, None, re_match.group(2).endswith("_free"))
212             if arg_conv_info.c_ty != "void":
213                 arg_conv_info.print_ty()
214                 arg_conv_info.print_name()
215             arg_names.append(arg_conv_info)
216
217         out_java.write(");\n")
218         out_c.write(") {\n")
219
220         for info in arg_names:
221             if info.arg_conv is not None:
222                 out_c.write("\t" + info.arg_conv.replace('\n', "\n\t") + "\n");
223
224         if ret_info.ret_conv is not None:
225             out_c.write("\t" + ret_conv_pfx.replace('\n', '\n\t'));
226         else:
227             out_c.write("\treturn ");
228
229         out_c.write(re_match.group(2) + "(")
230         for idx, info in enumerate(arg_names):
231             if info.arg_conv_name is not None:
232                 if idx != 0:
233                     out_c.write(", ")
234                 out_c.write(info.arg_conv_name)
235         out_c.write(")")
236         if ret_info.ret_conv is not None:
237             out_c.write(ret_conv_sfx.replace('\n', '\n\t'))
238         else:
239             out_c.write(";")
240         out_c.write("\n}\n\n")
241
242     out_java.write("""package org.ldk.impl;
243
244 public class bindings {
245         static {
246                 System.loadLibrary(\"lightningjni\");
247         }
248
249 """)
250     out_c.write("#include \"org_ldk_impl_bindings.h\"\n")
251     out_c.write("#include <rust_types.h>\n")
252     out_c.write("#include <lightning.h>\n")
253     out_c.write("#include <assert.h>\n\n")
254     out_c.write("#include <string.h>\n\n")
255
256     in_block_comment = False
257     in_block_enum = False
258     cur_block_struct = None
259     in_block_union = False
260
261     fn_ptr_regex = re.compile("^extern const ([A-Za-z_0-9\* ]*) \(\*(.*)\)\((.*)\);$")
262     fn_ret_arr_regex = re.compile("(.*) \(\*(.*)\((.*)\)\)\[([0-9]*)\];$")
263     reg_fn_regex = re.compile("([A-Za-z_0-9\* ]* \*?)([a-zA-Z_0-9]*)\((.*)\);$")
264     const_val_regex = re.compile("^extern const ([A-Za-z_0-9]*) ([A-Za-z_0-9]*);$")
265
266     line_indicates_opaque_regex = re.compile("^   bool _underlying_ref;$")
267     line_indicates_trait_regex = re.compile("^   ([A-Za-z_0-9]* \*?)\(\*([A-Za-z_0-9]*)\)\((const )?void \*this_arg(.*)\);$")
268     assert(line_indicates_trait_regex.match("   uintptr_t (*send_data)(void *this_arg, LDKu8slice data, bool resume_read);"))
269     assert(line_indicates_trait_regex.match("   LDKCVec_MessageSendEventZ (*get_and_clear_pending_msg_events)(const void *this_arg);"))
270     assert(line_indicates_trait_regex.match("   void *(*clone)(const void *this_arg);"))
271     struct_name_regex = re.compile("^typedef struct (MUST_USE_STRUCT )?(LDK[A-Za-z_0-9]*) {$")
272     assert(struct_name_regex.match("typedef struct LDKCVecTempl_u8 {"))
273
274     for line in in_h:
275         if in_block_comment:
276             #out_java.write("\t" + line)
277             if line.endswith("*/\n"):
278                 in_block_comment = False
279         elif cur_block_struct is not None:
280             cur_block_struct  = cur_block_struct + line
281             if line.startswith("} "):
282                 field_lines = []
283                 struct_name = None
284                 struct_lines = cur_block_struct.split("\n")
285                 is_opaque = False
286                 trait_fn_lines = []
287
288                 for idx, struct_line in enumerate(struct_lines):
289                     if struct_line.strip().startswith("/*"):
290                         in_block_comment = True
291                     if in_block_comment:
292                         if struct_line.endswith("*/"):
293                             in_block_comment = False
294                     else:
295                         struct_name_match = struct_name_regex.match(struct_line)
296                         if struct_name_match is not None:
297                             struct_name = struct_name_match.group(2)
298                         if line_indicates_opaque_regex.match(struct_line):
299                             is_opaque = True
300                         trait_fn_match = line_indicates_trait_regex.match(struct_line)
301                         if trait_fn_match is not None:
302                             trait_fn_lines.append(trait_fn_match)
303                         field_lines.append(struct_line)
304
305                 assert(struct_name is not None)
306                 assert(len(trait_fn_lines) == 0 or not is_opaque)
307                 if is_opaque:
308                     opaque_structs.add(struct_name)
309                 if len(trait_fn_lines) > 0:
310                     out_c.write("typedef struct " + struct_name + "_JCalls {\n")
311                     out_c.write("\tJNIEnv *env;\n")
312                     out_c.write("\tjobject o;\n")
313                     for fn_line in trait_fn_lines:
314                         if fn_line.group(2) != "free" and fn_line.group(2) != "clone":
315                             out_c.write("\tjmethodID " + fn_line.group(2) + "_meth;\n")
316                     out_c.write("} " + struct_name + "_JCalls;\n")
317
318                     out_java.write("\tpublic interface " + struct_name + " {\n")
319                     for fn_line in trait_fn_lines:
320                         if fn_line.group(2) != "free" and fn_line.group(2) != "clone":
321                             (java_ty, c_ty, is_ptr, _, _) = java_c_types(fn_line.group(1), None)
322
323                             out_java.write("\t\t " + java_ty + " " + fn_line.group(2) + "(")
324                             is_const = fn_line.group(3) is not None
325                             out_c.write(fn_line.group(1) + fn_line.group(2) + "_jcall(")
326                             if is_const:
327                                 out_c.write("const void* this_arg")
328                             else:
329                                 out_c.write("void* this_arg")
330
331                             arg_names = []
332                             for idx, arg in enumerate(fn_line.group(4).split(',')):
333                                 if arg == "":
334                                     continue
335                                 if idx >= 2:
336                                     out_java.write(", ")
337                                 out_c.write(", ")
338                                 arg_conv_info = map_type(arg, True, None, False)
339                                 out_c.write(arg.strip())
340                                 out_java.write(arg_conv_info.java_ty + " " + arg_conv_info.arg_name)
341                                 arg_names.append(arg_conv_info)
342
343                             out_java.write(");\n")
344                             out_c.write(") {\n")
345                             out_c.write("\t" + struct_name + "_JCalls *j_calls = (" + struct_name + "_JCalls*) this_arg;\n")
346
347                             if not is_ptr:
348                                 out_c.write("\treturn (*j_calls->env)->Call" + java_ty.title() + "Method(j_calls->env, j_calls->o, j_calls->" + fn_line.group(2) + "_meth")
349                             else:
350                                 out_c.write("\t" + fn_line.group(1).strip() + "* ret = (" + fn_line.group(1).strip() + "*)(*j_calls->env)->CallLongMethod(j_calls->env, j_calls->o, j_calls->" + fn_line.group(2) + "_meth");
351                             for arg in fn_line.group(4).split(','):
352                                 if arg == "":
353                                     continue
354                                 (arg_java_ty, arg_c_ty, arg_is_ptr, _, arg_name) = java_c_types(arg, None)
355                                 # TODO: Run conversion here!
356                                 out_c.write(", " + arg_name)
357                             out_c.write(");\n");
358
359                             if is_ptr:
360                                 out_c.write("\t" + fn_line.group(1).strip() + " res = *ret;\n")
361                                 out_c.write("\tfree(ret);\n")
362                                 out_c.write("\treturn res;\n")
363                             out_c.write("}\n")
364                         elif fn_line.group(2) == "free":
365                             out_c.write("void " + struct_name + "_JCalls_free(void* this_arg) {\n")
366                             out_c.write("\t" + struct_name + "_JCalls *j_calls = (" + struct_name + "_JCalls*) this_arg;\n")
367                             out_c.write("\t(*j_calls->env)->DeleteGlobalRef(j_calls->env, j_calls->o);\n")
368                             out_c.write("\tfree(j_calls);\n")
369                             out_c.write("}\n")
370                         elif fn_line.group(2) == "clone":
371                             out_c.write("void* " + struct_name + "_JCalls_clone(const void* this_arg) {\n")
372                             out_c.write("\t" + struct_name + "_JCalls *ret = malloc(sizeof(" + struct_name + "_JCalls));\n")
373                             out_c.write("\tmemcpy(ret, this_arg, sizeof(" + struct_name + "_JCalls));\n")
374                             out_c.write("\treturn ret;\n")
375                             out_c.write("}\n")
376                     out_java.write("\t}\n")
377                     out_java.write("\tpublic static native long " + struct_name + "_new(" + struct_name + " impl);\n")
378
379                     out_c.write("JNIEXPORT long JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1new (JNIEnv * env, jclass _a, jobject o) {\n")
380                     out_c.write("\tjclass c = (*env)->GetObjectClass(env, o);\n")
381                     out_c.write("\tassert(c != NULL);\n")
382                     out_c.write("\t" + struct_name + "_JCalls *calls = malloc(sizeof(" + struct_name + "_JCalls));\n")
383                     out_c.write("\tcalls->env = env;\n")
384                     out_c.write("\tcalls->o = (*env)->NewGlobalRef(env, o);\n")
385                     for fn_line in trait_fn_lines:
386                         if fn_line.group(2) != "free" and fn_line.group(2) != "clone":
387                             out_c.write("\tcalls->" + fn_line.group(2) + "_meth = (*env)->GetMethodID(env, c, \"" + fn_line.group(2) + "\", \"" + "TODO" + "\");\n")
388                             out_c.write("\tassert(calls->" + fn_line.group(2) + "_meth != NULL);\n")
389                     out_c.write("\n\t" + struct_name + " *ret = malloc(sizeof(" + struct_name + "));\n")
390                     out_c.write("\tret->this_arg = (void*) calls;\n")
391                     for fn_line in trait_fn_lines:
392                         if fn_line.group(2) != "free" and fn_line.group(2) != "clone":
393                             out_c.write("\tret->" + fn_line.group(2) + " = " + fn_line.group(2) + "_jcall;\n")
394                         elif fn_line.group(2) == "free":
395                             out_c.write("\tret->free = " + struct_name + "_JCalls_free;\n")
396                         else:
397                             out_c.write("\tret->clone = " + struct_name + "_JCalls_clone;\n")
398                     out_c.write("\treturn (long)ret;\n")
399                     out_c.write("}\n\n")
400
401                     #out_java.write("/* " + "\n".join(field_lines) + "*/\n")
402                 cur_block_struct = None
403         elif in_block_union:
404             if line.startswith("} "):
405                 in_block_union = False
406         elif in_block_enum:
407             if line.startswith("} "):
408                 in_block_enum = False
409         else:
410             fn_ptr = fn_ptr_regex.match(line)
411             fn_ret_arr = fn_ret_arr_regex.match(line)
412             reg_fn = reg_fn_regex.match(line)
413             const_val = const_val_regex.match(line)
414
415             if line.startswith("#include <"):
416                 pass
417             elif line.startswith("/*"):
418                 #out_java.write("\t" + line)
419                 if not line.endswith("*/\n"):
420                     in_block_comment = True
421             elif line.startswith("typedef enum "):
422                 in_block_enum = True
423             elif line.startswith("typedef struct "):
424                 cur_block_struct = line
425             elif line.startswith("typedef union "):
426                 in_block_union = True
427             elif line.startswith("typedef "):
428                 pass
429             elif fn_ptr is not None:
430                 map_fn(fn_ptr, None)
431             elif fn_ret_arr is not None:
432                 map_fn(fn_ret_arr, fn_ret_arr.group(4))
433             elif reg_fn is not None:
434                 map_fn(reg_fn, None)
435             elif const_val_regex is not None:
436                 # TODO Map const variables
437                 pass
438             else:
439                 assert(line == "\n")
440
441     out_java.write("}\n")