Merge pull request #110 from TheBlueMatt/main
[ldk-java] / src / test / java / org / ldk / PeerTest.java
1 package org.ldk;
2
3 import org.bitcoinj.core.*;
4 import org.bitcoinj.script.Script;
5 import org.junit.jupiter.api.Test;
6 import org.ldk.enums.Network;
7 import org.ldk.enums.Recipient;
8 import org.ldk.impl.bindings;
9
10 import java.lang.ref.Reference;
11 import java.util.ArrayList;
12 import java.util.Arrays;
13 import java.util.HashMap;
14 import java.util.LinkedList;
15 import java.util.concurrent.ConcurrentHashMap;
16 import java.util.concurrent.ConcurrentLinkedQueue;
17
18 public class PeerTest {
19     class Peer {
20         final long logger;
21         final long fee_estimator;
22         final long tx_broadcaster;
23         final long chain_monitor;
24         final long keys;
25         final long keys_interface;
26         final long config;
27         final long chan_manager;
28         final long chan_manager_events;
29         final long chan_handler;
30         final long router;
31         final long router_wrapper;
32         final long route_handler;
33         final long message_handler;
34         final long custom_message_handler;
35         final long peer_manager;
36         HashMap<String, Long> monitors; // Wow I forgot just how terrible Java is - we can't put a byte array here.
37         byte[] node_id;
38         bindings.LDKFeeEstimator fee_est;
39         bindings.LDKBroadcasterInterface broad_trait;
40         bindings.LDKLogger log_trait;
41         bindings.LDKWatch watcher;
42
43         Peer(byte seed) {
44             this.log_trait = (long arg) -> {
45                 System.out.println(seed + ": " + bindings.Record_get_args(arg));
46                 bindings.Record_free(arg);
47             };
48             logger = bindings.LDKLogger_new(this.log_trait);
49             this.fee_est = confirmation_target -> 0;
50             this.fee_estimator = bindings.LDKFeeEstimator_new(this.fee_est);
51             this.broad_trait = tx -> {
52                 // We should broadcast
53             };
54             this.tx_broadcaster = bindings.LDKBroadcasterInterface_new(this.broad_trait);
55             this.monitors = new HashMap<>();
56             this.watcher = new bindings.LDKWatch() {
57                 @Override
58                 public long watch_channel(long funding_txo, long monitor) {
59                     synchronized (monitors) {
60                         assert monitors.put(Arrays.toString(bindings.OutPoint_get_txid(funding_txo)), monitor) == null;
61                     }
62                     bindings.OutPoint_free(funding_txo);
63                     return bindings.CResult_NoneChannelMonitorUpdateErrZ_ok();
64                 }
65
66                 @Override
67                 public long update_channel(long funding_txo, long update) {
68                     synchronized (monitors) {
69                         String txid = Arrays.toString(bindings.OutPoint_get_txid(funding_txo));
70                         assert monitors.containsKey(txid);
71                         long update_res = bindings.ChannelMonitor_update_monitor(monitors.get(txid), update, tx_broadcaster, fee_estimator, logger);
72                         assert bindings.CResult_NoneNoneZ_is_ok(update_res);
73                         bindings.CResult_NoneNoneZ_free(update_res);
74                     }
75                     bindings.OutPoint_free(funding_txo);
76                     bindings.ChannelMonitorUpdate_free(update);
77                     return bindings.CResult_NoneChannelMonitorUpdateErrZ_ok();
78                 }
79
80                 @Override
81                 public long[] release_pending_monitor_events() {
82                     synchronized (monitors) {
83                         assert monitors.size() <= 1;
84                         for (Long mon : monitors.values()) {
85                             long funding_info = bindings.ChannelMonitor_get_funding_txo(mon);
86                             long funding_txo = bindings.C2Tuple_OutPointScriptZ_get_a(funding_info);
87                             long[] mon_events = bindings.ChannelMonitor_get_and_clear_pending_monitor_events(mon);
88                             byte[] counterparty_pk = bindings.ChannelMonitor_get_counterparty_node_id(mon);
89                             long funding_mon_tuple = bindings.C3Tuple_OutPointCVec_MonitorEventZPublicKeyZ_new(funding_txo, mon_events, counterparty_pk);
90                             bindings.C2Tuple_OutPointScriptZ_free(funding_info);
91                             return new long[] {funding_mon_tuple};
92                         }
93                     }
94                     return new long[0];
95                 }
96             };
97             this.chain_monitor = bindings.LDKWatch_new(this.watcher);
98
99             byte[] key_seed = new byte[32];
100             for (byte i = 0; i < 32; i++) { key_seed[i] = (byte) (i ^ seed); }
101             this.keys = bindings.KeysManager_new(key_seed, System.currentTimeMillis() / 1000, (int)(System.currentTimeMillis() * 1000) & 0xffffffff);
102             this.keys_interface = bindings.KeysManager_as_KeysInterface(keys);
103             this.config = bindings.UserConfig_default();
104             long params = bindings.ChainParameters_new(Network.LDKNetwork_Bitcoin, bindings.BestBlock_new(new byte[32], 0));
105             this.chan_manager = bindings.ChannelManager_new(fee_estimator, chain_monitor, tx_broadcaster, logger, keys_interface, config, params);
106             this.node_id = bindings.ChannelManager_get_our_node_id(chan_manager);
107             this.chan_manager_events = bindings.ChannelManager_as_EventsProvider(chan_manager);
108
109             this.chan_handler = bindings.ChannelManager_as_ChannelMessageHandler(chan_manager);
110             this.router = bindings.NetworkGraph_new(new byte[32], logger);
111             this.router_wrapper = bindings.P2PGossipSync_new(router, bindings.COption_AccessZ_none(), logger);
112             this.route_handler = bindings.P2PGossipSync_as_RoutingMessageHandler(router_wrapper);
113             this.message_handler = bindings.MessageHandler_new(chan_handler, route_handler);
114             this.custom_message_handler = bindings.IgnoringMessageHandler_new();
115
116             byte[] random_data = new byte[32];
117             for (byte i = 0; i < 32; i++) { random_data[i] = (byte) ((i ^ seed) ^ 0xf0); }
118
119             long node_id_result = bindings.KeysInterface_get_node_secret(keys_interface, Recipient.LDKRecipient_Node);
120             assert bindings.CResult_SecretKeyNoneZ_is_ok(node_id_result);
121             this.peer_manager = bindings.PeerManager_new(message_handler, bindings.CResult_SecretKeyNoneZ_get_ok(node_id_result),
122                     random_data, logger, bindings.IgnoringMessageHandler_as_CustomMessageHandler(this.custom_message_handler));
123             bindings.CResult_SecretKeyNoneZ_free(node_id_result);
124         }
125
126         void connect_block(Block b, Transaction t, int height) {
127             long listen = bindings.ChannelManager_as_Listen(chan_manager);
128             bindings.Listen_block_connected(listen, b.bitcoinSerialize(), height);
129             bindings.Listen_free(listen);
130             synchronized (monitors) {
131                 for (Long mon : monitors.values()) {
132                     long[] txn;
133                     if (t != null)
134                         txn = new long[]{bindings.C2Tuple_usizeTransactionZ_new(1, t.bitcoinSerialize())};
135                     else
136                         txn = new long[0];
137                     byte[] header = Arrays.copyOfRange(b.bitcoinSerialize(), 0, 80);
138                     long[] ret = bindings.ChannelMonitor_block_connected(mon, header, txn, height, tx_broadcaster, fee_estimator, logger);
139                     for (long r : ret) {
140                         bindings.C2Tuple_TxidCVec_C2Tuple_u32TxOutZZZ_free(r);
141                     }
142                 }
143             }
144         }
145
146         void free() {
147             // Note that we can't rely on finalizer order, so don't bother trying to rely on it here
148             bindings.Logger_free(logger);
149             bindings.FeeEstimator_free(fee_estimator);
150             bindings.BroadcasterInterface_free(tx_broadcaster);
151             bindings.Watch_free(chain_monitor);
152             bindings.KeysManager_free(keys);
153             bindings.KeysInterface_free(keys_interface);
154             bindings.UserConfig_free(config);
155             bindings.ChannelManager_free(chan_manager);
156             bindings.EventsProvider_free(chan_manager_events);
157             bindings.ChannelMessageHandler_free(chan_handler);
158             bindings.NetworkGraph_free(router);
159             bindings.P2PGossipSync_free(router_wrapper);
160             bindings.RoutingMessageHandler_free(route_handler);
161             //MessageHandler was actually moved into the route_handler!: bindings.MessageHandler_free(message_handler);
162             bindings.PeerManager_free(peer_manager);
163             synchronized (monitors) {
164                 for (Long mon : monitors.values()) {
165                     bindings.ChannelMonitor_free(mon);
166                 }
167             }
168         }
169     }
170
171     class LongHolder { long val; }
172     class PendingWrite {
173         long pm;
174         long descriptor;
175         byte[] array;
176         PendingWrite(long pm, long descriptor, byte[] array) { this.pm = pm; this.descriptor = descriptor; this.array = array; }
177         void process() {
178             long res = bindings.PeerManager_read_event(pm, descriptor, array);
179             assert bindings.CResult_boolPeerHandleErrorZ_is_ok(res);
180             assert !bindings.CResult_boolPeerHandleErrorZ_get_ok(res);
181             bindings.CResult_boolPeerHandleErrorZ_free(res);
182         }
183     }
184     ConcurrentLinkedQueue<PendingWrite> pending_writes = new ConcurrentLinkedQueue<>();
185
186     void do_read_event(ConcurrentLinkedQueue<Thread> list, long pm, long descriptor, byte[] arr) {
187         pending_writes.add(new PendingWrite(pm, descriptor, arr));
188         Thread thread = new Thread(() -> {
189             synchronized (pending_writes) {
190                 while (true) {
191                     PendingWrite write = pending_writes.poll();
192                     if (write == null) break;
193                     write.process();
194                 }
195             }
196         });
197         thread.start();
198         list.add(thread);
199     }
200
201     void deliver_peer_messages(ConcurrentLinkedQueue<Thread> list, long peer1, long peer2) throws InterruptedException {
202         bindings.PeerManager_process_events(peer1);
203         bindings.PeerManager_process_events(peer2);
204         while (!list.isEmpty()) {
205             list.poll().join();
206             bindings.PeerManager_process_events(peer1);
207             bindings.PeerManager_process_events(peer2);
208         }
209     }
210
211     @Test
212     void test_message_handler() throws InterruptedException {
213         Peer peer1 = new Peer((byte) 1);
214         Peer peer2 = new Peer((byte) 2);
215
216         ConcurrentLinkedQueue<Thread> list = new ConcurrentLinkedQueue<Thread>();
217         LongHolder descriptor1 = new LongHolder();
218         LongHolder descriptor1ref = descriptor1;
219         bindings.LDKSocketDescriptor sock1 = new bindings.LDKSocketDescriptor() {
220             @Override
221             public long send_data(byte[] data, boolean resume_read) {
222                 do_read_event(list, peer1.peer_manager, descriptor1ref.val, data);
223                 return data.length;
224             }
225
226             @Override public void disconnect_socket() { assert false; }
227             @Override public boolean eq(long other_arg) { boolean ret = bindings.SocketDescriptor_hash(other_arg) == 2; bindings.SocketDescriptor_free(other_arg); return ret; }
228             @Override public long hash() { return 2; }
229         };
230         long descriptor2 = bindings.LDKSocketDescriptor_new(sock1);
231
232         bindings.LDKSocketDescriptor sock2 = new bindings.LDKSocketDescriptor() {
233             @Override
234             public long send_data(byte[] data, boolean resume_read) {
235                 do_read_event(list, peer2.peer_manager, descriptor2, data);
236                 return data.length;
237             }
238
239             @Override public void disconnect_socket() { assert false; }
240             @Override public boolean eq(long other_arg) { boolean ret = bindings.SocketDescriptor_hash(other_arg) == 1; bindings.SocketDescriptor_free(other_arg); return ret; }
241             @Override public long hash() { return 1; }
242         };
243         descriptor1.val = bindings.LDKSocketDescriptor_new(sock2);
244
245         long no_netaddr = bindings.COption_NetAddressZ_none();
246         long init_vec = bindings.PeerManager_new_outbound_connection(peer1.peer_manager, peer2.node_id, descriptor1.val, no_netaddr);
247         assert(bindings.CResult_CVec_u8ZPeerHandleErrorZ_is_ok(init_vec));
248
249         long con_res = bindings.PeerManager_new_inbound_connection(peer2.peer_manager, descriptor2, no_netaddr);
250         assert(bindings.CResult_NonePeerHandleErrorZ_is_ok(con_res));
251         bindings.CResult_NonePeerHandleErrorZ_free(con_res);
252         do_read_event(list, peer2.peer_manager, descriptor2, bindings.CResult_CVec_u8ZPeerHandleErrorZ_get_ok(init_vec));
253         bindings.CResult_CVec_u8ZPeerHandleErrorZ_free(init_vec);
254         bindings.COption_NetAddressZ_free(no_netaddr);
255
256         deliver_peer_messages(list, peer1.peer_manager, peer2.peer_manager);
257
258         long cc_res = bindings.ChannelManager_create_channel(peer1.chan_manager, peer2.node_id, 10000, 1000, 42, 0);
259         assert bindings.CResult_NoneAPIErrorZ_is_ok(cc_res);
260         bindings.CResult_NoneAPIErrorZ_free(cc_res);
261
262         deliver_peer_messages(list, peer1.peer_manager, peer2.peer_manager);
263
264         ArrayList<Long> events = new ArrayList();
265         bindings.LDKEventHandler events_adder = events::add;
266         long handler = bindings.LDKEventHandler_new(events_adder);
267
268         bindings.EventsProvider_process_pending_events(peer1.chan_manager_events, handler);
269         assert events.size() == 1;
270         bindings.LDKEvent event = bindings.LDKEvent_ref_from_ptr(events.get(0));
271         assert event instanceof bindings.LDKEvent.FundingGenerationReady;
272         assert ((bindings.LDKEvent.FundingGenerationReady)event).channel_value_satoshis == 10000;
273         assert ((bindings.LDKEvent.FundingGenerationReady)event).user_channel_id == 42;
274         byte[] funding_spk = ((bindings.LDKEvent.FundingGenerationReady)event).output_script;
275         assert funding_spk.length == 34 && funding_spk[0] == 0 && funding_spk[1] == 32; // P2WSH
276         byte[] chan_id = ((bindings.LDKEvent.FundingGenerationReady)event).temporary_channel_id;
277         bindings.Event_free(events.remove(0));
278
279         Transaction funding = new Transaction(NetworkParameters.fromID(NetworkParameters.ID_MAINNET));
280         funding.addInput(new TransactionInput(NetworkParameters.fromID(NetworkParameters.ID_MAINNET), funding, new byte[0]));
281         funding.getInputs().get(0).setWitness(new TransactionWitness(2)); // Make sure we don't complain about lack of witness
282         funding.getInput(0).getWitness().setPush(0, new byte[] {0x1});
283         funding.addOutput(Coin.SATOSHI.multiply(10000), new Script(funding_spk));
284         bindings.ChannelManager_funding_transaction_generated(peer1.chan_manager, chan_id, peer2.node_id, funding.bitcoinSerialize());
285
286         deliver_peer_messages(list, peer1.peer_manager, peer2.peer_manager);
287
288         Block b = new Block(NetworkParameters.fromID(NetworkParameters.ID_MAINNET), 2, Sha256Hash.ZERO_HASH, Sha256Hash.ZERO_HASH, 42, 0, 0, Arrays.asList(new Transaction[]{funding}));
289         peer1.connect_block(b, funding, 1);
290         peer2.connect_block(b, funding, 1);
291
292         for (int height = 2; height < 10; height++) {
293             b = new Block(NetworkParameters.fromID(NetworkParameters.ID_MAINNET), 2, b.getHash(), Sha256Hash.ZERO_HASH, 42, 0, 0, Arrays.asList(new Transaction[0]));
294             peer1.connect_block(b, null, height);
295             peer2.connect_block(b, null, height);
296         }
297
298         deliver_peer_messages(list, peer1.peer_manager, peer2.peer_manager);
299
300         long[] peer1_chans = bindings.ChannelManager_list_channels(peer1.chan_manager);
301         long[] peer2_chans = bindings.ChannelManager_list_channels(peer2.chan_manager);
302         assert peer1_chans.length == 1;
303         assert peer2_chans.length == 1;
304         assert bindings.ChannelDetails_get_channel_value_satoshis(peer1_chans[0]) == 10000;
305         assert bindings.ChannelDetails_get_is_usable(peer1_chans[0]);
306         assert Arrays.equals(bindings.ChannelDetails_get_channel_id(peer1_chans[0]), funding.getTxId().getReversedBytes());
307         assert Arrays.equals(bindings.ChannelDetails_get_channel_id(peer2_chans[0]), funding.getTxId().getReversedBytes());
308         for (long chan : peer2_chans) bindings.ChannelDetails_free(chan);
309
310         long no_min_val = bindings.COption_u64Z_none();
311         long inbound_payment = bindings.ChannelManager_create_inbound_payment(peer2.chan_manager, no_min_val, 7200);
312         assert bindings.CResult_C2Tuple_PaymentHashPaymentSecretZNoneZ_is_ok(inbound_payment);
313         long payment_tuple = bindings.CResult_C2Tuple_PaymentHashPaymentSecretZNoneZ_get_ok(inbound_payment);
314         bindings.COption_u64Z_free(no_min_val);
315
316         long scorer = bindings.ProbabilisticScorer_new(bindings.ProbabilisticScoringParameters_default(), peer1.router, peer1.logger);
317         long scorer_interface = bindings.ProbabilisticScorer_as_Score(scorer);
318
319         long no_u64 = bindings.COption_u64Z_none();
320         long invoice_features = bindings.InvoiceFeatures_known();
321         long payee = bindings.PaymentParameters_new(peer2.node_id, invoice_features, new long[0], no_u64, 6*24*14, (byte)1, (byte)1, new long[0]);
322         bindings.InvoiceFeatures_free(invoice_features);
323         bindings.COption_u64Z_free(no_u64);
324         long route_params = bindings.RouteParameters_new(payee, 1000, 42);
325         long route = bindings.find_route(peer1.node_id, route_params, peer1.router, peer1_chans, peer1.logger,
326                 scorer_interface, new byte[32]);
327         bindings.RouteParameters_free(route_params);
328         bindings.PaymentParameters_free(payee);
329         bindings.Score_free(scorer_interface);
330         bindings.ProbabilisticScorer_free(scorer);
331
332         for (long chan : peer1_chans) bindings.ChannelDetails_free(chan);
333         assert bindings.CResult_RouteLightningErrorZ_is_ok(route);
334         long payment_res = bindings.ChannelManager_send_payment(peer1.chan_manager, bindings.CResult_RouteLightningErrorZ_get_ok(route),
335                 bindings.C2Tuple_PaymentHashPaymentSecretZ_get_a(payment_tuple), bindings.C2Tuple_PaymentHashPaymentSecretZ_get_b(payment_tuple));
336         bindings.CResult_RouteLightningErrorZ_free(route);
337         bindings.CResult_C2Tuple_PaymentHashPaymentSecretZNoneZ_is_ok(inbound_payment);
338         assert bindings.CResult_NonePaymentSendFailureZ_is_ok(payment_res);
339         bindings.CResult_NonePaymentSendFailureZ_free(payment_res);
340
341         deliver_peer_messages(list, peer1.peer_manager, peer2.peer_manager);
342
343         bindings.EventsProvider_process_pending_events(peer2.chan_manager_events, handler);
344         assert events.size() == 1;
345         bindings.LDKEvent forwardable = bindings.LDKEvent_ref_from_ptr(events.get(0));
346         assert forwardable instanceof bindings.LDKEvent.PendingHTLCsForwardable;
347         bindings.Event_free(events.remove(0));
348         bindings.ChannelManager_process_pending_htlc_forwards(peer2.chan_manager);
349
350         bindings.EventsProvider_process_pending_events(peer2.chan_manager_events, handler);
351         assert events.size() == 1;
352         bindings.LDKEvent payment_recvd = bindings.LDKEvent_ref_from_ptr(events.get(0));
353         assert payment_recvd instanceof bindings.LDKEvent.PaymentReceived;
354         bindings.LDKPaymentPurpose purpose = bindings.LDKPaymentPurpose_ref_from_ptr(((bindings.LDKEvent.PaymentReceived) payment_recvd).purpose);
355         assert purpose instanceof bindings.LDKPaymentPurpose.InvoicePayment;
356         bindings.ChannelManager_claim_funds(peer2.chan_manager, ((bindings.LDKPaymentPurpose.InvoicePayment) purpose).payment_preimage);
357         bindings.Event_free(events.remove(0));
358
359         bindings.EventsProvider_process_pending_events(peer2.chan_manager_events, handler);
360         assert events.size() == 1;
361         bindings.LDKEvent payment_claimed = bindings.LDKEvent_ref_from_ptr(events.get(0));
362         assert payment_claimed instanceof bindings.LDKEvent.PaymentClaimed;
363         bindings.Event_free(events.remove(0));
364
365         deliver_peer_messages(list, peer1.peer_manager, peer2.peer_manager);
366
367         bindings.EventsProvider_process_pending_events(peer1.chan_manager_events, handler);
368         assert events.size() == 2;
369         bindings.LDKEvent sent = bindings.LDKEvent_ref_from_ptr(events.get(0));
370         assert sent instanceof bindings.LDKEvent.PaymentSent;
371         bindings.Event_free(events.remove(0));
372         bindings.LDKEvent sent_path = bindings.LDKEvent_ref_from_ptr(events.get(0));
373         assert sent_path instanceof bindings.LDKEvent.PaymentPathSuccessful;
374         bindings.Event_free(events.remove(0));
375         Reference.reachabilityFence(events_adder);
376
377         bindings.EventHandler_free(handler);
378
379         peer1.free();
380         peer2.free();
381         bindings.SocketDescriptor_free(descriptor2);
382         bindings.SocketDescriptor_free(descriptor1.val);
383     }
384 }