From a78ab1934f872c6c914d83d6cf3193c142cf1a49 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Thu, 5 Aug 2021 17:00:21 +0000 Subject: [PATCH] WIP tokio 1 conversion --- Cargo.toml | 2 +- src/bgp_client.rs | 175 +++++++++++++++++++++++++++------------------- src/main.rs | 4 +- 3 files changed, 107 insertions(+), 74 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ab04e15..a8ffcbd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ edition = "2018" [dependencies] bitcoin = "0.26" bgp-rs = "0.6" -tokio = "0.1" +tokio = { version = "1", features = ["full"] } bytes = "0.4" futures = "0.1" rand = "0.8" diff --git a/src/bgp_client.rs b/src/bgp_client.rs index 78c2f88..17b6598 100644 --- a/src/bgp_client.rs +++ b/src/bgp_client.rs @@ -11,16 +11,10 @@ use bgp_rs::Segment; use bgp_rs::Message; use bgp_rs::Reader; -use tokio::prelude::*; -use tokio::codec; -use tokio::codec::Framed; use tokio::net::TcpStream; -use tokio::timer::Delay; - -use futures::sync::mpsc; +use tokio::time; use crate::printer::{Printer, Stat}; -use crate::timeout_stream::TimeoutStream; const PATH_SUFFIX_LEN: usize = 3; #[derive(Clone)] @@ -202,7 +196,7 @@ impl RoutingTable { } } -struct BytesCoder<'a>(&'a mut bytes::BytesMut); +/*struct BytesCoder<'a>(&'a mut bytes::BytesMut); impl<'a> std::io::Write for BytesCoder<'a> { fn write(&mut self, b: &[u8]) -> Result { self.0.extend_from_slice(&b); @@ -244,7 +238,6 @@ impl codec::Decoder for MsgCoder { Ok((_header, msg)) => { decoder.buf.advance(decoder.pos); if let Message::Open(ref o) = &msg { - self.0 = Some(Capabilities::from_parameters(o.parameters.clone())); } Ok(Some(msg)) }, @@ -263,7 +256,7 @@ impl codec::Encoder for MsgCoder { msg.encode(&mut BytesCoder(res))?; Ok(()) } -} +}*/ pub struct BGPClient { routes: Mutex, @@ -380,26 +373,101 @@ impl BGPClient { } else { None } } + async fn handle_peer(open_msg: Message, stream: TcpStream, timeout: Duration, printer: &'static Printer, client: Arc) -> Result<(), std::io::Error> { + let mut open_bytes = [0; 64]; + let len = { + let mut write_cursor = std::io::Cursor::new(&mut open_bytes); + open_msg.encode(&mut write_cursor); + write_cursor.position() + }; + stream.write_all(&open_bytes[..len]).await?; + let mut cap = Default::default(); + + let mut readpending = Vec::new(); + let mut readbuf = [0; 8192]; + let mut msg_timeout = time::sleep(timeout); + 'read_loop: loop { + if client.shutdown.load(Ordering::Relaxed) { + return std::io::Error::new(std::io::ErrorKind::Other, "Shutting Down"); + } + tokio::select! { + _ = msg_timeout => { + return Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "Keepalive expired")); + }, + res = stream.read(&mut readbuf) => { + let mut msg_opt = None; + let bytecnt = res?; + if readpending.is_empty() { + let mut cursor = std::io::Cursor::new(&readbuf[..bytecnt]); + let mut reader = Reader { stream: &mut cursor, capabilities: &cap }; + match reader.read() { + Ok((_header, newmsg)) => { readpending.append(&readbuf[cursor.position()..bytecnt]); newmsg = Some(msg_opt) }, + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + readpending.append(&readbuf[..bytecnt]); + continue 'read_loop; + }, + Err(e) => return Err(e), + } + } else { readpending.append(&readbuf[..bytecnt]); } + loop { + if msg_opt.is_none() { + let mut cursor = std::io::Cursor::new(&readpending); + let mut reader = Reader { stream: &mut cursor, capabilities: &cap }; + match reader.read() { + Ok((_header, newmsg)) => { newmsg = Some(msg_opt) }, + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {}, + Err(e) => return Err(e), + } + readpending = readpending.split_off(cursor.position()); + } + if let Some(bgp_msg) = msg_opt.take() { + match bgp_msg { + Message::Open(o) => { + cap = Capabilities::from_parameters(o.parameters.clone()); + client.routes.lock().unwrap().v4_table.clear(); + client.routes.lock().unwrap().v6_table.clear(); + printer.add_line("Connected to BGP route provider".to_string(), false); + }, + Message::KeepAlive => { + msg_timeout = time::sleep(timeout); + //XXX: let _ = sender.try_send(Message::KeepAlive); + }, + Message::Update(mut upd) => { + upd.normalize(); + let mut route_table = client.routes.lock().unwrap(); + for r in upd.withdrawn_routes { + route_table.withdraw(r); + } + if let Some(path) = Self::map_attrs(upd.attributes) { + for r in upd.announced_routes { + route_table.announce(r, path.clone()); + } + } + printer.set_stat(Stat::V4RoutingTableSize(route_table.v4_table.len())); + printer.set_stat(Stat::V6RoutingTableSize(route_table.v6_table.len())); + printer.set_stat(Stat::RoutingTablePaths(route_table.max_paths)); + }, + _ => {} + } + } else { break; } + } + } + }; + } + } + fn connect_given_client(remote_asn: u32, addr: SocketAddr, timeout: Duration, printer: &'static Printer, client: Arc) { - tokio::spawn(Delay::new(Instant::now() + timeout / 4).then(move |_| { - let connect_timeout = Delay::new(Instant::now() + timeout.clone()).then(|_| { - future::err(std::io::Error::new(std::io::ErrorKind::TimedOut, "timeout reached")) - }); + tokio::spawn(async move { + time::sleep(timeout / 4).await; + let client_reconn = Arc::clone(&client); - TcpStream::connect(&addr).select(connect_timeout) - .or_else(move |_| { - Delay::new(Instant::now() + timeout / 2).then(|_| { - future::err(()) - }) - }).and_then(move |stream| { - let (write, read) = Framed::new(stream.0, MsgCoder(None)).split(); - let (mut sender, receiver) = mpsc::channel(10); // We never really should send more than 10 messages unless they're dumb - tokio::spawn(write.sink_map_err(|_| { () }).send_all(receiver) - .then(|_| { - future::err(()) - })); + tokio::select! { + _ = time::sleep(timeout) => { + time::sleep(timeout / 2).await; + }, + mut stream = TcpStream::connect(&addr) => { let peer_asn = if remote_asn > u16::max_value() as u32 { 23456 } else { remote_asn as u16 }; - let _ = sender.try_send(Message::Open(Open { + let open_msg = Message::Open(Open { version: 4, peer_asn, hold_timer: timeout.as_secs() as u16, @@ -413,50 +481,15 @@ impl BGPClient { (AFI::IPV4, SAFI::Unicast, AddPathDirection::ReceivePaths), (AFI::IPV6, SAFI::Unicast, AddPathDirection::ReceivePaths)]), ])], - })); - TimeoutStream::new_persistent(read, timeout).for_each(move |bgp_msg| { - if client.shutdown.load(Ordering::Relaxed) { - return future::err(std::io::Error::new(std::io::ErrorKind::Other, "Shutting Down")); - } - match bgp_msg { - Message::Open(_) => { - client.routes.lock().unwrap().v4_table.clear(); - client.routes.lock().unwrap().v6_table.clear(); - printer.add_line("Connected to BGP route provider".to_string(), false); - }, - Message::KeepAlive => { - let _ = sender.try_send(Message::KeepAlive); - }, - Message::Update(mut upd) => { - upd.normalize(); - let mut route_table = client.routes.lock().unwrap(); - for r in upd.withdrawn_routes { - route_table.withdraw(r); - } - if let Some(path) = Self::map_attrs(upd.attributes) { - for r in upd.announced_routes { - route_table.announce(r, path.clone()); - } - } - printer.set_stat(Stat::V4RoutingTableSize(route_table.v4_table.len())); - printer.set_stat(Stat::V6RoutingTableSize(route_table.v6_table.len())); - printer.set_stat(Stat::RoutingTablePaths(route_table.max_paths)); - }, - _ => {} - } - future::ok(()) - }).or_else(move |e| { - printer.add_line(format!("Got error from BGP stream: {:?}", e), true); - future::ok(()) - }) - }).then(move |_| { - if !client_reconn.shutdown.load(Ordering::Relaxed) { - BGPClient::connect_given_client(remote_asn, addr, timeout, printer, client_reconn); - } - future::ok(()) - }) - }) - ); + }); + let e = Self::handle_peer(open_msg, stream, printer, client); + printer.add_line(format!("Got error from BGP stream: {:?}", e), true); + } + }; + if !client_reconn.shutdown.load(Ordering::Relaxed) { + BGPClient::connect_given_client(remote_asn, addr, timeout, printer, client_reconn); + } + }); } pub fn new(remote_asn: u32, addr: SocketAddr, timeout: Duration, printer: &'static Printer) -> Arc { diff --git a/src/main.rs b/src/main.rs index d887934..dfb7aa8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ mod printer; mod reader; mod peer; mod bgp_client; -mod timeout_stream; +//mod timeout_stream; mod datastore; use std::env; @@ -24,7 +24,7 @@ use bitcoin::network::message_blockdata::{GetHeadersMessage, Inventory}; use printer::{Printer, Stat}; use peer::Peer; use datastore::{AddressState, Store, U64Setting, RegexSetting}; -use timeout_stream::TimeoutStream; +//use timeout_stream::TimeoutStream; use rand::Rng; use bgp_client::BGPClient; -- 2.30.2