+/// Writer that only tracks the amount of data written - useful if you need to calculate the length
+/// of some data when serialized but don't yet need the full data.
+pub(crate) struct LengthCalculatingWriter(pub usize);
+impl Writer for LengthCalculatingWriter {
+ #[inline]
+ fn write_all(&mut self, buf: &[u8]) -> Result<(), io::Error> {
+ self.0 += buf.len();
+ Ok(())
+ }
+ #[inline]
+ fn size_hint(&mut self, _size: usize) {}
+}
+
+/// Essentially std::io::Take but a bit simpler and with a method to walk the underlying stream
+/// forward to ensure we always consume exactly the fixed length specified.
+pub(crate) struct FixedLengthReader<R: Read> {
+ read: R,
+ bytes_read: u64,
+ total_bytes: u64,
+}
+impl<R: Read> FixedLengthReader<R> {
+ pub fn new(read: R, total_bytes: u64) -> Self {
+ Self { read, bytes_read: 0, total_bytes }
+ }
+
+ #[inline]
+ pub fn bytes_remain(&mut self) -> bool {
+ self.bytes_read != self.total_bytes
+ }
+
+ #[inline]
+ pub fn eat_remaining(&mut self) -> Result<(), DecodeError> {
+ copy(self, &mut sink()).unwrap();
+ if self.bytes_read != self.total_bytes {
+ Err(DecodeError::ShortRead)
+ } else {
+ Ok(())
+ }
+ }
+}
+impl<R: Read> Read for FixedLengthReader<R> {
+ #[inline]
+ fn read(&mut self, dest: &mut [u8]) -> Result<usize, io::Error> {
+ if self.total_bytes == self.bytes_read {
+ Ok(0)
+ } else {
+ let read_len = cmp::min(dest.len() as u64, self.total_bytes - self.bytes_read);
+ match self.read.read(&mut dest[0..(read_len as usize)]) {
+ Ok(v) => {
+ self.bytes_read += v as u64;
+ Ok(v)
+ },
+ Err(e) => Err(e),
+ }
+ }
+ }
+}
+
+/// A Read which tracks whether any bytes have been read at all. This allows us to distinguish
+/// between "EOF reached before we started" and "EOF reached mid-read".
+pub(crate) struct ReadTrackingReader<R: Read> {
+ read: R,
+ pub have_read: bool,
+}
+impl<R: Read> ReadTrackingReader<R> {
+ pub fn new(read: R) -> Self {
+ Self { read, have_read: false }
+ }
+}
+impl<R: Read> Read for ReadTrackingReader<R> {
+ #[inline]
+ fn read(&mut self, dest: &mut [u8]) -> Result<usize, io::Error> {
+ match self.read.read(dest) {
+ Ok(0) => Ok(0),
+ Ok(len) => {
+ self.have_read = true;
+ Ok(len)
+ },
+ Err(e) => Err(e),
+ }
+ }
+}
+