Fix race condition in NioPeerHandler on `socket_disconnected`
authorMatt Corallo <git@bluematt.me>
Wed, 3 Nov 2021 21:10:12 +0000 (21:10 +0000)
committerMatt Corallo <git@bluematt.me>
Thu, 4 Nov 2021 05:03:28 +0000 (05:03 +0000)
src/main/java/org/ldk/batteries/NioPeerHandler.java

index bf111729115492ba7e3c94c26645290b000794d3..afe728b059c482cc35983051ab3f72ed998b0e99 100644 (file)
@@ -174,10 +174,11 @@ public class NioPeerHandler {
                                 if (chan == null) continue;
                                 try {
                                     Peer peer = setup_socket(chan);
+                                    peer.key = chan.register(this.selector, SelectionKey.OP_READ, peer);
                                     Result_NonePeerHandleErrorZ res = this.peer_manager.new_inbound_connection(peer.descriptor);
-                                    if (res instanceof Result_NonePeerHandleErrorZ.Result_NonePeerHandleErrorZ_OK) {
-                                        peer.key = chan.register(this.selector, SelectionKey.OP_READ, peer);
-                                    }
+                                    if (res instanceof Result_NonePeerHandleErrorZ.Result_NonePeerHandleErrorZ_Err) {
+                                                                               peer.descriptor.disconnect_socket();
+                                                                       }
                                 } catch (IOException ignored) { }
                             }
                             continue; // There is no attachment so the rest of the loop is useless
@@ -273,14 +274,17 @@ public class NioPeerHandler {
             throw new IOException("Timed out");
         }
         Peer peer = setup_socket(chan);
+        do_selector_action(() -> peer.key = chan.register(this.selector, SelectionKey.OP_READ, peer));
         Result_CVec_u8ZPeerHandleErrorZ res = this.peer_manager.new_outbound_connection(their_node_id, peer.descriptor);
         if (res instanceof  Result_CVec_u8ZPeerHandleErrorZ.Result_CVec_u8ZPeerHandleErrorZ_OK) {
             byte[] initial_bytes = ((Result_CVec_u8ZPeerHandleErrorZ.Result_CVec_u8ZPeerHandleErrorZ_OK) res).res;
             if (chan.write(ByteBuffer.wrap(initial_bytes)) != initial_bytes.length) {
+                peer.descriptor.disconnect_socket();
+                this.peer_manager.socket_disconnected(peer.descriptor);
                 throw new IOException("We assume TCP socket buffer is at least a single packet in length");
             }
-            do_selector_action(() -> peer.key = chan.register(this.selector, SelectionKey.OP_READ, peer));
         } else {
+            peer.descriptor.disconnect_socket();
             throw new IOException("LDK rejected outbound connection. This likely shouldn't ever happen.");
         }
     }