WIP tokio 1 conversion 2021-08-tokio-1
authorMatt Corallo <git@bluematt.me>
Thu, 5 Aug 2021 17:00:21 +0000 (17:00 +0000)
committerMatt Corallo <git@bluematt.me>
Thu, 5 Aug 2021 17:00:21 +0000 (17:00 +0000)
Cargo.toml
src/bgp_client.rs
src/main.rs

index ab04e15f4ef968cd59c5dbab9be0815a2b3d41f9..a8ffcbdc608657852382e8739b60709c8b0051b4 100644 (file)
@@ -7,7 +7,7 @@ edition = "2018"
 [dependencies]
 bitcoin = "0.26"
 bgp-rs = "0.6"
-tokio = "0.1"
+tokio = { version = "1", features = ["full"] }
 bytes = "0.4"
 futures = "0.1"
 rand = "0.8"
index 78c2f8829b96181b5f65038d954ff4a549ff420d..17b6598936f7dc31e02fc696e000bd6b02eea4ce 100644 (file)
@@ -11,16 +11,10 @@ use bgp_rs::Segment;
 use bgp_rs::Message;
 use bgp_rs::Reader;
 
-use tokio::prelude::*;
-use tokio::codec;
-use tokio::codec::Framed;
 use tokio::net::TcpStream;
-use tokio::timer::Delay;
-
-use futures::sync::mpsc;
+use tokio::time;
 
 use crate::printer::{Printer, Stat};
-use crate::timeout_stream::TimeoutStream;
 
 const PATH_SUFFIX_LEN: usize = 3;
 #[derive(Clone)]
@@ -202,7 +196,7 @@ impl RoutingTable {
        }
 }
 
-struct BytesCoder<'a>(&'a mut bytes::BytesMut);
+/*struct BytesCoder<'a>(&'a mut bytes::BytesMut);
 impl<'a> std::io::Write for BytesCoder<'a> {
        fn write(&mut self, b: &[u8]) -> Result<usize, std::io::Error> {
                self.0.extend_from_slice(&b);
@@ -244,7 +238,6 @@ impl codec::Decoder for MsgCoder {
                        Ok((_header, msg)) => {
                                decoder.buf.advance(decoder.pos);
                                if let Message::Open(ref o) = &msg {
-                                       self.0 = Some(Capabilities::from_parameters(o.parameters.clone()));
                                }
                                Ok(Some(msg))
                        },
@@ -263,7 +256,7 @@ impl codec::Encoder for MsgCoder {
                msg.encode(&mut BytesCoder(res))?;
                Ok(())
        }
-}
+}*/
 
 pub struct BGPClient {
        routes: Mutex<RoutingTable>,
@@ -380,26 +373,101 @@ impl BGPClient {
                } else { None }
        }
 
+       async fn handle_peer(open_msg: Message, stream: TcpStream, timeout: Duration, printer: &'static Printer, client: Arc<BGPClient>) -> Result<(), std::io::Error> {
+               let mut open_bytes = [0; 64];
+                       let len = {
+                               let mut write_cursor = std::io::Cursor::new(&mut open_bytes);
+                               open_msg.encode(&mut write_cursor);
+                               write_cursor.position()
+                       };
+                       stream.write_all(&open_bytes[..len]).await?;
+                       let mut cap = Default::default();
+
+                       let mut readpending = Vec::new();
+                       let mut readbuf = [0; 8192];
+                       let mut msg_timeout = time::sleep(timeout);
+                       'read_loop: loop {
+                               if client.shutdown.load(Ordering::Relaxed) {
+                                       return std::io::Error::new(std::io::ErrorKind::Other, "Shutting Down");
+                               }
+                               tokio::select! {
+                                       _ = msg_timeout => {
+                                               return Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "Keepalive expired"));
+                                       },
+                                       res = stream.read(&mut readbuf) => {
+                                               let mut msg_opt = None;
+                                               let bytecnt = res?;
+                                               if readpending.is_empty() {
+                                                       let mut cursor = std::io::Cursor::new(&readbuf[..bytecnt]);
+                                                       let mut reader = Reader { stream: &mut cursor, capabilities: &cap };
+                                                       match reader.read() {
+                                                               Ok((_header, newmsg)) => { readpending.append(&readbuf[cursor.position()..bytecnt]); newmsg = Some(msg_opt) },
+                                                               Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
+                                                                       readpending.append(&readbuf[..bytecnt]);
+                                                                       continue 'read_loop;
+                                                               },
+                                                               Err(e) => return Err(e),
+                                                       }
+                                               } else { readpending.append(&readbuf[..bytecnt]); }
+                                               loop {
+                                                       if msg_opt.is_none() {
+                                                               let mut cursor = std::io::Cursor::new(&readpending);
+                                                               let mut reader = Reader { stream: &mut cursor, capabilities: &cap };
+                                                               match reader.read() {
+                                                                       Ok((_header, newmsg)) => { newmsg = Some(msg_opt) },
+                                                                       Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {},
+                                                                       Err(e) => return Err(e),
+                                                               }
+                                                               readpending = readpending.split_off(cursor.position());
+                                                       }
+                                                       if let Some(bgp_msg) = msg_opt.take() {
+                                                               match bgp_msg {
+                                                                       Message::Open(o) => {
+                                                                               cap = Capabilities::from_parameters(o.parameters.clone());
+                                                                               client.routes.lock().unwrap().v4_table.clear();
+                                                                               client.routes.lock().unwrap().v6_table.clear();
+                                                                               printer.add_line("Connected to BGP route provider".to_string(), false);
+                                                                       },
+                                                                       Message::KeepAlive => {
+                                                                               msg_timeout = time::sleep(timeout);
+                                                                               //XXX: let _ = sender.try_send(Message::KeepAlive);
+                                                                       },
+                                                                       Message::Update(mut upd) => {
+                                                                               upd.normalize();
+                                                                               let mut route_table = client.routes.lock().unwrap();
+                                                                               for r in upd.withdrawn_routes {
+                                                                                       route_table.withdraw(r);
+                                                                               }
+                                                                               if let Some(path) = Self::map_attrs(upd.attributes) {
+                                                                                       for r in upd.announced_routes {
+                                                                                               route_table.announce(r, path.clone());
+                                                                                       }
+                                                                               }
+                                                                               printer.set_stat(Stat::V4RoutingTableSize(route_table.v4_table.len()));
+                                                                               printer.set_stat(Stat::V6RoutingTableSize(route_table.v6_table.len()));
+                                                                               printer.set_stat(Stat::RoutingTablePaths(route_table.max_paths));
+                                                                       },
+                                                                       _ => {}
+                                                               }
+                                                       } else { break; }
+                                               }
+                                       }
+                               };
+               }
+       }
+
        fn connect_given_client(remote_asn: u32, addr: SocketAddr, timeout: Duration, printer: &'static Printer, client: Arc<BGPClient>) {
-               tokio::spawn(Delay::new(Instant::now() + timeout / 4).then(move |_| {
-                       let connect_timeout = Delay::new(Instant::now() + timeout.clone()).then(|_| {
-                               future::err(std::io::Error::new(std::io::ErrorKind::TimedOut, "timeout reached"))
-                       });
+               tokio::spawn(async move {
+                       time::sleep(timeout / 4).await;
+
                        let client_reconn = Arc::clone(&client);
-                       TcpStream::connect(&addr).select(connect_timeout)
-                               .or_else(move |_| {
-                                       Delay::new(Instant::now() + timeout / 2).then(|_| {
-                                               future::err(())
-                                       })
-                               }).and_then(move |stream| {
-                                       let (write, read) = Framed::new(stream.0, MsgCoder(None)).split();
-                                       let (mut sender, receiver) = mpsc::channel(10); // We never really should send more than 10 messages unless they're dumb
-                                       tokio::spawn(write.sink_map_err(|_| { () }).send_all(receiver)
-                                               .then(|_| {
-                                                       future::err(())
-                                               }));
+                       tokio::select! {
+                               _ = time::sleep(timeout) => {
+                                       time::sleep(timeout / 2).await;
+                               },
+                               mut stream = TcpStream::connect(&addr) => {
                                        let peer_asn = if remote_asn > u16::max_value() as u32 { 23456 } else { remote_asn as u16 };
-                                       let _ = sender.try_send(Message::Open(Open {
+                                       let open_msg = Message::Open(Open {
                                                version: 4,
                                                peer_asn,
                                                hold_timer: timeout.as_secs() as u16,
@@ -413,50 +481,15 @@ impl BGPClient {
                                                                (AFI::IPV4, SAFI::Unicast, AddPathDirection::ReceivePaths),
                                                                (AFI::IPV6, SAFI::Unicast, AddPathDirection::ReceivePaths)]),
                                                ])],
-                                       }));
-                                       TimeoutStream::new_persistent(read, timeout).for_each(move |bgp_msg| {
-                                               if client.shutdown.load(Ordering::Relaxed) {
-                                                       return future::err(std::io::Error::new(std::io::ErrorKind::Other, "Shutting Down"));
-                                               }
-                                               match bgp_msg {
-                                                       Message::Open(_) => {
-                                                               client.routes.lock().unwrap().v4_table.clear();
-                                                               client.routes.lock().unwrap().v6_table.clear();
-                                                               printer.add_line("Connected to BGP route provider".to_string(), false);
-                                                       },
-                                                       Message::KeepAlive => {
-                                                               let _ = sender.try_send(Message::KeepAlive);
-                                                       },
-                                                       Message::Update(mut upd) => {
-                                                               upd.normalize();
-                                                               let mut route_table = client.routes.lock().unwrap();
-                                                               for r in upd.withdrawn_routes {
-                                                                       route_table.withdraw(r);
-                                                               }
-                                                               if let Some(path) = Self::map_attrs(upd.attributes) {
-                                                                       for r in upd.announced_routes {
-                                                                               route_table.announce(r, path.clone());
-                                                                       }
-                                                               }
-                                                               printer.set_stat(Stat::V4RoutingTableSize(route_table.v4_table.len()));
-                                                               printer.set_stat(Stat::V6RoutingTableSize(route_table.v6_table.len()));
-                                                               printer.set_stat(Stat::RoutingTablePaths(route_table.max_paths));
-                                                       },
-                                                       _ => {}
-                                               }
-                                               future::ok(())
-                                       }).or_else(move |e| {
-                                               printer.add_line(format!("Got error from BGP stream: {:?}", e), true);
-                                               future::ok(())
-                                       })
-                               }).then(move |_| {
-                                       if !client_reconn.shutdown.load(Ordering::Relaxed) {
-                                               BGPClient::connect_given_client(remote_asn, addr, timeout, printer, client_reconn);
-                                       }
-                                       future::ok(())
-                               })
-                       })
-               );
+                                       });
+                                       let e = Self::handle_peer(open_msg, stream, printer, client);
+                                       printer.add_line(format!("Got error from BGP stream: {:?}", e), true);
+                               }
+                       };
+                       if !client_reconn.shutdown.load(Ordering::Relaxed) {
+                               BGPClient::connect_given_client(remote_asn, addr, timeout, printer, client_reconn);
+                       }
+               });
        }
 
        pub fn new(remote_asn: u32, addr: SocketAddr, timeout: Duration, printer: &'static Printer) -> Arc<BGPClient> {
index d887934d435e847af43d34a9c60c8d50f9d5e881..dfb7aa819ba43d12b3018327fccbb25bd36d30dd 100644 (file)
@@ -3,7 +3,7 @@ mod printer;
 mod reader;
 mod peer;
 mod bgp_client;
-mod timeout_stream;
+//mod timeout_stream;
 mod datastore;
 
 use std::env;
@@ -24,7 +24,7 @@ use bitcoin::network::message_blockdata::{GetHeadersMessage, Inventory};
 use printer::{Printer, Stat};
 use peer::Peer;
 use datastore::{AddressState, Store, U64Setting, RegexSetting};
-use timeout_stream::TimeoutStream;
+//use timeout_stream::TimeoutStream;
 use rand::Rng;
 use bgp_client::BGPClient;