diff --git a/examples/http_get.rs b/examples/http_get.rs index a54274c..e9c7036 100644 --- a/examples/http_get.rs +++ b/examples/http_get.rs @@ -17,10 +17,8 @@ async fn main() -> Result<(), Box> { .ok_or_else(|| "response expected to have Content-Type header")?; assert_eq!(content_type, "application/json; charset=utf-8"); - // Would much prefer read_to_end here: - let mut body_buf = vec![0; 4096]; - let body_len = response.body().read(&mut body_buf).await?; - body_buf.truncate(body_len); + let mut body_buf = Vec::new(); + let _body_len = response.body().read_to_end(&mut body_buf).await?; let val: serde_json::Value = serde_json::from_slice(&body_buf)?; let body_url = val diff --git a/src/http/response.rs b/src/http/response.rs index dec56b1..c8693dc 100644 --- a/src/http/response.rs +++ b/src/http/response.rs @@ -101,8 +101,8 @@ pub struct IncomingBody { // How many bytes have we already read from the buf? buf_offset: usize, - // IMPORTANT: the order of these fields here matters. `incoming_body` must - // be dropped before `body_stream`. + // IMPORTANT: the order of these fields here matters. `body_stream` must + // be dropped before `_incoming_body`. body_stream: InputStream, _incoming_body: WasiIncomingBody, } @@ -117,12 +117,16 @@ impl AsyncRead for IncomingBody { Reactor::current().wait_for(pollable).await; // Read the bytes from the body stream - let buf = self.body_stream.read(CHUNK_SIZE).map_err(|err| match err { - StreamError::LastOperationFailed(err) => { - std::io::Error::other(format!("{}", err.to_debug_string())) + let buf = match self.body_stream.read(CHUNK_SIZE) { + Ok(buf) => buf, + Err(StreamError::Closed) => return Ok(0), + Err(StreamError::LastOperationFailed(err)) => { + return Err(std::io::Error::other(format!( + "last operation failed: {}", + err.to_debug_string() + ))) } - StreamError::Closed => std::io::Error::other("Connection closed"), - })?; + }; self.buf.insert(buf) } }; diff --git a/src/io/copy.rs b/src/io/copy.rs index 7212524..6b896c7 100644 --- a/src/io/copy.rs +++ b/src/io/copy.rs @@ -12,14 +12,6 @@ where if bytes_read == 0 { break 'read Ok(()); } - let mut slice = &buf[0..bytes_read]; - - 'write: loop { - let bytes_written = writer.write(slice).await?; - slice = &slice[bytes_written..]; - if slice.is_empty() { - break 'write; - } - } + writer.write_all(&buf[0..bytes_read]).await?; } } diff --git a/src/io/read.rs b/src/io/read.rs index be54fcb..82ae5a0 100644 --- a/src/io/read.rs +++ b/src/io/read.rs @@ -1,6 +1,27 @@ use crate::io; +const CHUNK_SIZE: usize = 2048; + /// Read bytes from a source. pub trait AsyncRead { async fn read(&mut self, buf: &mut [u8]) -> io::Result; + async fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + // total bytes written to buf + let mut n = 0; + + loop { + // grow buf if empty + if buf.len() == n { + buf.resize(n + CHUNK_SIZE, 0u8); + } + + let len = self.read(&mut buf[n..]).await?; + if len == 0 { + buf.truncate(n); + return Ok(n); + } + + n += len; + } + } } diff --git a/src/io/write.rs b/src/io/write.rs index 493c79b..b775ca7 100644 --- a/src/io/write.rs +++ b/src/io/write.rs @@ -5,4 +5,15 @@ pub trait AsyncWrite { // Required methods async fn write(&mut self, buf: &[u8]) -> io::Result; async fn flush(&mut self) -> io::Result<()>; + + async fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { + let mut to_write = &buf[0..]; + loop { + let bytes_written = self.write(to_write).await?; + to_write = &to_write[bytes_written..]; + if to_write.is_empty() { + return Ok(()); + } + } + } } diff --git a/src/net/tcp_stream.rs b/src/net/tcp_stream.rs index 8293369..5c6ab8a 100644 --- a/src/net/tcp_stream.rs +++ b/src/net/tcp_stream.rs @@ -31,7 +31,11 @@ impl TcpStream { impl AsyncRead for TcpStream { async fn read(&mut self, buf: &mut [u8]) -> io::Result { Reactor::current().wait_for(self.input.subscribe()).await; - let slice = self.input.read(buf.len() as u64).map_err(to_io_err)?; + let slice = match self.input.read(buf.len() as u64) { + Ok(slice) => slice, + Err(StreamError::Closed) => return Ok(0), + Err(e) => return Err(to_io_err(e)), + }; let bytes_read = slice.len(); buf[..bytes_read].clone_from_slice(&slice); Ok(bytes_read) @@ -41,7 +45,11 @@ impl AsyncRead for TcpStream { impl AsyncRead for &TcpStream { async fn read(&mut self, buf: &mut [u8]) -> io::Result { Reactor::current().wait_for(self.input.subscribe()).await; - let slice = self.input.read(buf.len() as u64).map_err(to_io_err)?; + let slice = match self.input.read(buf.len() as u64) { + Ok(slice) => slice, + Err(StreamError::Closed) => return Ok(0), + Err(e) => return Err(to_io_err(e)), + }; let bytes_read = slice.len(); buf[..bytes_read].clone_from_slice(&slice); Ok(bytes_read)