diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 856bc64c9e..cbeed9091c 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -220,8 +220,7 @@ def download_url( HTTPError: See urllib.request.urlretrieve. ContentTooShortError: See urllib.request.urlretrieve. IOError: See urllib.request.urlretrieve. - RuntimeError: When the hash validation of the ``url`` downloaded file fails. - + ValueError: When the hash validation of the ``url`` downloaded file fails. """ if not filepath: filepath = Path(".", _basename(url)).resolve() @@ -229,8 +228,8 @@ def download_url( filepath = Path(filepath) if filepath.exists(): if not check_hash(filepath, hash_val, hash_type): - raise RuntimeError( - f"{hash_type} check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}." + raise ValueError( + f"{hash_type} hash check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}." ) logger.info(f"File exists: {filepath}, skipped downloading.") return @@ -260,6 +259,13 @@ def download_url( raise RuntimeError( f"Download of file from {url} to {filepath} failed due to network issue or denied permission." ) + if not check_hash(tmp_name, hash_val, hash_type): + raise ValueError( + f"{hash_type} hash check of downloaded file failed: URL={url}, " + f"filepath={filepath}, expected {hash_type}={hash_val}, " + f"The file may be corrupted or tampered with. " + "Please retry the download or verify the source." + ) file_dir = filepath.parent if file_dir: os.makedirs(file_dir, exist_ok=True) @@ -267,11 +273,6 @@ def download_url( except (PermissionError, NotADirectoryError): # project-monai/monai issue #3613 #3757 for windows pass logger.info(f"Downloaded: {filepath}") - if not check_hash(filepath, hash_val, hash_type): - raise RuntimeError( - f"{hash_type} check of downloaded file failed: URL={url}, " - f"filepath={filepath}, expected {hash_type}={hash_val}." - ) def _extract_zip(filepath, output_dir): @@ -325,10 +326,16 @@ def extractall( be False. Raises: - RuntimeError: When the hash validation of the ``filepath`` compressed file fails. + ValueError: When the hash validation of the ``filepath`` compressed file fails. NotImplementedError: When the ``filepath`` file extension is not one of [zip", "tar.gz", "tar"]. """ + filepath = Path(filepath) + if hash_val and not check_hash(filepath, hash_val, hash_type): + raise ValueError( + f"{hash_type} hash check of compressed file failed: " + f"filepath={filepath}, expected {hash_type}={hash_val}." + ) if has_base: # the extracted files will be in this folder cache_dir = Path(output_dir, _basename(filepath).split(".")[0]) @@ -337,11 +344,6 @@ def extractall( if cache_dir.exists() and next(cache_dir.iterdir(), None) is not None: logger.info(f"Non-empty folder exists in {cache_dir}, skipped extracting.") return - filepath = Path(filepath) - if hash_val and not check_hash(filepath, hash_val, hash_type): - raise RuntimeError( - f"{hash_type} check of compressed file failed: " f"filepath={filepath}, expected {hash_type}={hash_val}." - ) logger.info(f"Writing into directory: {output_dir}.") _file_type = file_type.lower().strip() if filepath.name.endswith("zip") or _file_type == "zip": diff --git a/tests/apps/test_download_and_extract.py b/tests/apps/test_download_and_extract.py index 6d16a72735..a1e5381d90 100644 --- a/tests/apps/test_download_and_extract.py +++ b/tests/apps/test_download_and_extract.py @@ -16,7 +16,6 @@ import unittest import zipfile from pathlib import Path -from urllib.error import ContentTooShortError, HTTPError from parameterized import parameterized @@ -26,39 +25,62 @@ @SkipIfNoModule("requests") class TestDownloadAndExtract(unittest.TestCase): + def setUp(self): + self.testing_dir = Path(__file__).parents[1] / "testing_data" + self.config = testing_data_config("images", "mednist") + self.url = self.config["url"] + self.hash_val = self.config["hash_val"] + self.hash_type = self.config["hash_type"] + @skip_if_quick - def test_actions(self): - testing_dir = Path(__file__).parents[1] / "testing_data" - config_dict = testing_data_config("images", "mednist") - url = config_dict["url"] - filepath = Path(testing_dir) / "MedNIST.tar.gz" - output_dir = Path(testing_dir) - hash_val, hash_type = config_dict["hash_val"], config_dict["hash_type"] + def test_download_and_extract_success(self): + """End-to-end: download and extract should succeed with correct hash.""" + filepath = self.testing_dir / "MedNIST.tar.gz" + output_dir = self.testing_dir + with skip_if_downloading_fails(): - download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type) - download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type) + download_and_extract(self.url, filepath, output_dir, hash_val=self.hash_val, hash_type=self.hash_type) - wrong_md5 = "0" - with self.assertLogs(logger="monai.apps", level="ERROR"): - try: - download_url(url, filepath, wrong_md5) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors - - try: - extractall(filepath, output_dir, wrong_md5) - except RuntimeError as e: - self.assertTrue(str(e).startswith("md5 check")) + self.assertTrue(filepath.exists(), "Downloaded file does not exist") + self.assertTrue(any(output_dir.iterdir()), "Extraction output is empty") + + @skip_if_quick + def test_download_url_hash_mismatch(self): + """download_url should raise ValueError on hash mismatch.""" + filepath = self.testing_dir / "MedNIST.tar.gz" + + with skip_if_downloading_fails(): + # First ensure file is downloaded correctly + download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type) + + # Now test incorrect hash + with self.assertRaises(ValueError) as ctx: + download_url(self.url, filepath, hash_val="0" * len(self.hash_val), hash_type=self.hash_type) + + self.assertIn("hash check", str(ctx.exception).lower()) @skip_if_quick - @parameterized.expand((("icon", "tar"), ("favicon", "zip"))) - def test_default(self, key, file_type): + def test_extractall_hash_mismatch(self): + """extractall should raise ValueError when hash is incorrect.""" + filepath = self.testing_dir / "MedNIST.tar.gz" + output_dir = self.testing_dir + + with skip_if_downloading_fails(): + download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type) + + with self.assertRaises(ValueError) as ctx: + extractall(filepath, output_dir, hash_val="0" * len(self.hash_val), hash_type=self.hash_type) + + self.assertIn("hash check", str(ctx.exception).lower()) + + @skip_if_quick + @parameterized.expand([("icon", "tar"), ("favicon", "zip")]) + def test_download_and_extract_various_formats(self, key, file_type): + """Verify different archive formats download and extract correctly.""" with tempfile.TemporaryDirectory() as tmp_dir: + img_spec = testing_data_config("images", key) + with skip_if_downloading_fails(): - img_spec = testing_data_config("images", key) download_and_extract( img_spec["url"], output_dir=tmp_dir, @@ -67,6 +89,8 @@ def test_default(self, key, file_type): file_type=file_type, ) + self.assertTrue(any(Path(tmp_dir).iterdir()), f"Extraction failed for format: {file_type}") + class TestPathTraversalProtection(unittest.TestCase): """Test cases for path traversal attack protection in extractall function.""" diff --git a/tests/test_utils.py b/tests/test_utils.py index 03fa7abce3..02b6a18452 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -79,7 +79,7 @@ "unexpected EOF", # incomplete download "network issue", "gdown dependency", # gdown not installed - "md5 check", + "hash check", # check hash value of downloaded file "limit", # HTTP Error 503: Egress is over the account limit "authenticate", "timed out", # urlopen error [Errno 110] Connection timed out @@ -182,37 +182,11 @@ def skip_if_downloading_fails(): raise unittest.SkipTest(f"Error while downloading: {rt_e}") from rt_e # incomplete download raise rt_e + except ValueError as v_e: + if "hash check" in str(v_e): + raise unittest.SkipTest(f"Hash value error while downloading: {v_e}") from v_e - -SAMPLE_TIFF = "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/CMU-1.tiff" -SAMPLE_TIFF_HASH = "73a7e89bc15576587c3d68e55d9bf92f09690280166240b48ff4b48230b13bcd" -SAMPLE_TIFF_HASH_TYPE = "sha256" - - -class TestDownloadUrl(unittest.TestCase): - """Exercise ``download_url`` success and hash-mismatch paths.""" - - def test_download_url(self): - """Download a sample TIFF and validate hash handling. - - Raises: - RuntimeError: When the downloaded file's hash does not match. - """ - with tempfile.TemporaryDirectory() as tempdir: - with skip_if_downloading_fails(): - download_url( - url=SAMPLE_TIFF, - filepath=os.path.join(tempdir, "model.tiff"), - hash_val=SAMPLE_TIFF_HASH, - hash_type=SAMPLE_TIFF_HASH_TYPE, - ) - with self.assertRaises(RuntimeError): - download_url( - url=SAMPLE_TIFF, - filepath=os.path.join(tempdir, "model_bad.tiff"), - hash_val="0" * 64, - hash_type=SAMPLE_TIFF_HASH_TYPE, - ) + raise v_e def test_pretrained_networks(network, input_param, device):