diff --git a/src/dist/download.rs b/src/dist/download.rs index e10f027854..7e63d34439 100644 --- a/src/dist/download.rs +++ b/src/dist/download.rs @@ -129,6 +129,11 @@ impl<'a> DownloadCfg<'a> { } } + #[allow(dead_code)] + pub(crate) async fn content_length(&self, url: &Url) -> Result> { + crate::download::content_length(url, self.process).await + } + pub(crate) fn clean(&self, hashes: &[impl AsRef]) -> Result<()> { for hash in hashes.iter() { let used_file = self.download_dir.join(hash); diff --git a/src/dist/manifestation.rs b/src/dist/manifestation.rs index 1995572ec8..41f3d5d50f 100644 --- a/src/dist/manifestation.rs +++ b/src/dist/manifestation.rs @@ -240,6 +240,50 @@ impl Manifestation { ); }; + // TODO: enable this for windows + // disk_free() is only available for cfg(unix) for now + #[cfg(unix)] + { + let estimated_download_size = futures_util::future::join_all( + components + .iter() + .map(|component| component.required_size_via_head()), + ) + .await + .iter() + .try_fold(0u64, |acc, elem| { + match elem { + Ok(Some(size)) => Some(acc + size), + // This component size's unknown (None or Err) + _ => { + warn!("failed to fetch component size, continueing..."); + None + } + } + }); + + let path = &components + .first() + .unwrap() + .download_cfg + .tmp_cx + .root_directory; + let disk_free = match utils::disk_free(path) { + Ok(size) => Some(size), + Err(e) => { + warn!("failed to acquire disk free space: {e}\ncontinueing... "); + None + } + }; + + if let (Some(df), Some(est)) = (disk_free, estimated_download_size) + && df < est + && !force_update + { + bail!("insufficient storage: {est} bytes required, {df} bytes available") + } + }; + let mut stream = InstallEvents::new(components.into_iter(), Arc::new(self)); let mut transaction = Some(tx); tx = loop { @@ -767,6 +811,12 @@ impl<'a> ComponentBinary<'a> { })) } + #[allow(dead_code)] + async fn required_size_via_head(&self) -> Result> { + let url = self.download_cfg.url(&self.binary.url)?; + self.download_cfg.content_length(&url).await + } + async fn download(self, max_retries: usize) -> Result<(ComponentInstall, &'a str)> { use tokio_retry::{RetryIf, strategy::FixedInterval}; diff --git a/src/dist/temp.rs b/src/dist/temp.rs index 0eacbb44f6..7e232e03b7 100644 --- a/src/dist/temp.rs +++ b/src/dist/temp.rs @@ -73,7 +73,7 @@ impl Drop for File { } pub struct Context { - root_directory: PathBuf, + pub root_directory: PathBuf, pub dist_server: String, } diff --git a/src/download/mod.rs b/src/download/mod.rs index 31f7c83624..66f90b3553 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -84,50 +84,7 @@ pub(crate) fn is_network_failure(err: &anyhow::Error) -> bool { } } -async fn download_file_( - url: &Url, - path: &Path, - hasher: Option<&mut Sha256>, - resume_from_partial: bool, - status: Option<&DownloadStatus>, - process: &Process, -) -> anyhow::Result<()> { - #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] - use crate::download::{Backend, Event, TlsBackend}; - use sha2::Digest; - use std::cell::RefCell; - - debug!(url = %url, "downloading file"); - let hasher = RefCell::new(hasher); - - // This callback will write the download to disk and optionally - // hash the contents, then forward the notification up the stack - let callback: &dyn Fn(Event<'_>) -> anyhow::Result<()> = &|msg| { - if let Event::DownloadDataReceived(data) = msg - && let Some(h) = hasher.borrow_mut().as_mut() - { - h.update(data); - } - - match msg { - Event::DownloadContentLengthReceived(len) => { - if let Some(status) = status { - status.received_length(len) - } - } - Event::DownloadDataReceived(data) => { - if let Some(status) = status { - status.received_data(data.len()) - } - } - Event::ResumingPartialDownload => debug!("resuming partial download"), - } - - Ok(()) - }; - - // Download the file - +fn select_backend(process: &Process) -> anyhow::Result { // Keep the curl env var around for a bit let use_curl_backend = process.var_os("RUSTUP_USE_CURL").map(|it| it != "0"); if use_curl_backend == Some(true) { @@ -199,8 +156,12 @@ async fn download_file_( _ => Backend::Curl, }; + Ok(backend) +} + +fn timeout(process: &Process) -> anyhow::Result { let timeout = Duration::from_secs(match process.var("RUSTUP_DOWNLOAD_TIMEOUT") { - Ok(s) => NonZero::from_str(&s) + Ok(s) => NonZero::from_str(&s) .context( "invalid value in RUSTUP_DOWNLOAD_TIMEOUT -- must be a natural number greater than zero", )? @@ -208,6 +169,55 @@ async fn download_file_( Err(_) => 180, }); + Ok(timeout) +} + +async fn download_file_( + url: &Url, + path: &Path, + hasher: Option<&mut Sha256>, + resume_from_partial: bool, + status: Option<&DownloadStatus>, + process: &Process, +) -> anyhow::Result<()> { + #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] + use crate::download::{Backend, Event}; + use sha2::Digest; + use std::cell::RefCell; + + debug!(url = %url, "downloading file"); + let hasher = RefCell::new(hasher); + + // This callback will write the download to disk and optionally + // hash the contents, then forward the notification up the stack + let callback: &dyn Fn(Event<'_>) -> anyhow::Result<()> = &|msg| { + if let Event::DownloadDataReceived(data) = msg + && let Some(h) = hasher.borrow_mut().as_mut() + { + h.update(data); + } + + match msg { + Event::DownloadContentLengthReceived(len) => { + if let Some(status) = status { + status.received_length(len) + } + } + Event::DownloadDataReceived(data) => { + if let Some(status) = status { + status.received_data(data.len()) + } + } + Event::ResumingPartialDownload => debug!("resuming partial download"), + } + + Ok(()) + }; + + // Download the file + let backend = select_backend(process)?; + let timeout = timeout(process)?; + match backend { #[cfg(feature = "curl-backend")] Backend::Curl => debug!("downloading with curl"), @@ -230,6 +240,28 @@ async fn download_file_( res } +#[allow(dead_code)] +pub(crate) async fn content_length(url: &Url, process: &Process) -> anyhow::Result> { + if url.scheme() == "file" { + let path = url + .to_file_path() + .map_err(|_| anyhow::anyhow!("bogus file url: '{url}'"))?; + return Ok(Some(std::fs::metadata(path)?.len())); + } + + let backend = select_backend(process)?; + let timeout = timeout(process)?; + + match backend { + #[cfg(feature = "curl-backend")] + Backend::Curl => debug!(url = %url, "fetching content-length with curl"), + #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] + Backend::Reqwest(_) => debug!(url = %url, "fetching content-length with reqwest"), + }; + + backend.content_length(url, timeout).await +} + /// User agent header value for HTTP request. /// See: https://github.com/rust-lang/rustup/issues/2860. #[cfg(feature = "curl-backend")] @@ -392,6 +424,16 @@ impl Backend { Self::Reqwest(tls) => tls.download(url, resume_from, callback, timeout).await, } } + + #[allow(dead_code)] + async fn content_length(self, url: &Url, timeout: Duration) -> anyhow::Result> { + match self { + #[cfg(feature = "curl-backend")] + Self::Curl => curl::content_length(url, timeout), + #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] + Self::Reqwest(tls) => tls.content_length(url, timeout).await, + } + } } #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] @@ -421,6 +463,18 @@ impl TlsBackend { reqwest_be::download(url, resume_from, callback, client).await } + + #[allow(dead_code)] + async fn content_length(self, url: &Url, timeout: Duration) -> anyhow::Result> { + let client = match self { + #[cfg(feature = "reqwest-rustls-tls")] + Self::Rustls => reqwest_be::rustls_client(timeout)?, + #[cfg(feature = "reqwest-native-tls")] + Self::NativeTls => reqwest_be::native_tls_client(timeout)?, + }; + + reqwest_be::content_length(url, client).await + } } #[derive(Debug, Copy, Clone)] @@ -448,6 +502,47 @@ mod curl { use super::{DownloadError, Event}; + #[allow(dead_code)] + pub(super) fn content_length(url: &Url, timeout: Duration) -> Result> { + let mut handle = Easy::new(); + handle.url(url.as_ref())?; + handle.follow_location(true)?; + handle.useragent(super::CURL_USER_AGENT)?; + handle.nobody(true)?; + handle.connect_timeout(timeout)?; + + let length = std::cell::Cell::new(None); + { + let mut transfer = handle.transfer(); + transfer.header_function(|header| { + let Ok(data) = str::from_utf8(header) else { + return true; + }; + let prefix = "content-length: "; + let Some((dp, ds)) = data.split_at_checked(prefix.len()) else { + return true; + }; + if !dp.eq_ignore_ascii_case(prefix) { + return true; + } + if let Ok(s) = ds.trim().parse::() { + length.set(Some(s)); + } + true + })?; + + transfer.perform()?; + } + + let code = handle.response_code()?; + match code { + 0 | 200..=299 => Ok(length.get()), + // Some servers do not support HEAD for file assets + 405 => Ok(None), + _ => Err(DownloadError::HttpStatus(code).into()), + } + } + pub(super) fn download( url: &Url, resume_from: u64, @@ -610,6 +705,23 @@ mod reqwest_be { Ok(()) } + #[allow(dead_code)] + pub(super) async fn content_length(url: &Url, client: &Client) -> anyhow::Result> { + let res = client + .head(url.as_str()) + .send() + .await + .context("error fetching content length")?; + + let status = res.status().into(); + match status { + 200..=299 => Ok(res.content_length()), + // Some servers do not support HEAD for file assets + 405 => Ok(None), + _ => Err(DownloadError::HttpStatus(u32::from(status)).into()), + } + } + fn client_generic() -> ClientBuilder { Client::builder() // HACK: set `pool_max_idle_per_host` to `0` to avoid an issue in the underlying diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 649cfb9758..179b55054b 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -516,6 +516,30 @@ pub(crate) fn home_dir_from_passwd() -> Option { } } +#[cfg(unix)] +pub(crate) fn disk_free(path: impl AsRef) -> Result { + use libc::statvfs; + use std::mem::MaybeUninit; + use std::os::unix::ffi::OsStrExt; + + let mut os_path = path.as_ref().as_bytes().to_vec(); + os_path.push(0); + + let mut stat = MaybeUninit::::uninit(); + match unsafe { statvfs(os_path.as_ptr() as *const _, stat.as_mut_ptr()) } { + // bit width of f_bavail and f_bsize may differ on platforms and sometimes u32 + #[allow(clippy::useless_conversion)] + 0 => { + let stat = unsafe { stat.assume_init() }; + let available_blocks: u64 = stat.f_bavail.into(); + let block_size: u64 = stat.f_bsize.into(); + + Ok(available_blocks.saturating_mul(block_size)) + } + _ => anyhow::bail!("failed to acquire block size"), + } +} + #[cfg(test)] mod tests { use super::*;