Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions monai/apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,17 +220,16 @@ def download_url(
HTTPError: See urllib.request.urlretrieve.
ContentTooShortError: See urllib.request.urlretrieve.
IOError: See urllib.request.urlretrieve.
RuntimeError: When the hash validation of the ``url`` downloaded file fails.

ValueError: When the hash validation of the ``url`` downloaded file fails.
"""
if not filepath:
filepath = Path(".", _basename(url)).resolve()
logger.info(f"Default downloading to '{filepath}'")
filepath = Path(filepath)
if filepath.exists():
if not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
f"{hash_type} check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}."
raise ValueError(
f"{hash_type} hash check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}."
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
logger.info(f"File exists: {filepath}, skipped downloading.")
return
Expand Down Expand Up @@ -260,18 +259,20 @@ def download_url(
raise RuntimeError(
f"Download of file from {url} to {filepath} failed due to network issue or denied permission."
)
if not check_hash(tmp_name, hash_val, hash_type):
raise ValueError(
f"{hash_type} hash check of downloaded file failed: URL={url}, "
f"filepath={filepath}, expected {hash_type}={hash_val}, "
f"The file may be corrupted or tampered with. "
"Please retry the download or verify the source."
)
file_dir = filepath.parent
if file_dir:
os.makedirs(file_dir, exist_ok=True)
shutil.move(f"{tmp_name}", f"{filepath}") # copy the downloaded to a user-specified cache.
except (PermissionError, NotADirectoryError): # project-monai/monai issue #3613 #3757 for windows
pass
logger.info(f"Downloaded: {filepath}")
if not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
f"{hash_type} check of downloaded file failed: URL={url}, "
f"filepath={filepath}, expected {hash_type}={hash_val}."
)


def _extract_zip(filepath, output_dir):
Expand Down Expand Up @@ -325,10 +326,16 @@ def extractall(
be False.

Raises:
RuntimeError: When the hash validation of the ``filepath`` compressed file fails.
ValueError: When the hash validation of the ``filepath`` compressed file fails.
NotImplementedError: When the ``filepath`` file extension is not one of [zip", "tar.gz", "tar"].

"""
filepath = Path(filepath)
if hash_val and not check_hash(filepath, hash_val, hash_type):
raise ValueError(
f"{hash_type} hash check of compressed file failed: "
f"filepath={filepath}, expected {hash_type}={hash_val}."
)
if has_base:
# the extracted files will be in this folder
cache_dir = Path(output_dir, _basename(filepath).split(".")[0])
Expand All @@ -337,11 +344,6 @@ def extractall(
if cache_dir.exists() and next(cache_dir.iterdir(), None) is not None:
logger.info(f"Non-empty folder exists in {cache_dir}, skipped extracting.")
return
filepath = Path(filepath)
if hash_val and not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
f"{hash_type} check of compressed file failed: " f"filepath={filepath}, expected {hash_type}={hash_val}."
)
logger.info(f"Writing into directory: {output_dir}.")
_file_type = file_type.lower().strip()
if filepath.name.endswith("zip") or _file_type == "zip":
Expand Down
78 changes: 51 additions & 27 deletions tests/apps/test_download_and_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import unittest
import zipfile
from pathlib import Path
from urllib.error import ContentTooShortError, HTTPError

from parameterized import parameterized

Expand All @@ -26,39 +25,62 @@

@SkipIfNoModule("requests")
class TestDownloadAndExtract(unittest.TestCase):
def setUp(self):
self.testing_dir = Path(__file__).parents[1] / "testing_data"
self.config = testing_data_config("images", "mednist")
self.url = self.config["url"]
self.hash_val = self.config["hash_val"]
self.hash_type = self.config["hash_type"]

@skip_if_quick
def test_actions(self):
testing_dir = Path(__file__).parents[1] / "testing_data"
config_dict = testing_data_config("images", "mednist")
url = config_dict["url"]
filepath = Path(testing_dir) / "MedNIST.tar.gz"
output_dir = Path(testing_dir)
hash_val, hash_type = config_dict["hash_val"], config_dict["hash_type"]
def test_download_and_extract_success(self):
"""End-to-end: download and extract should succeed with correct hash."""
filepath = self.testing_dir / "MedNIST.tar.gz"
output_dir = self.testing_dir

with skip_if_downloading_fails():
download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type)
download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type)
download_and_extract(self.url, filepath, output_dir, hash_val=self.hash_val, hash_type=self.hash_type)

wrong_md5 = "0"
with self.assertLogs(logger="monai.apps", level="ERROR"):
try:
download_url(url, filepath, wrong_md5)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
if isinstance(e, RuntimeError):
# FIXME: skip MD5 check as current downloading method may fail
self.assertTrue(str(e).startswith("md5 check"))
return # skipping this test due the network connection errors

try:
extractall(filepath, output_dir, wrong_md5)
except RuntimeError as e:
self.assertTrue(str(e).startswith("md5 check"))
self.assertTrue(filepath.exists(), "Downloaded file does not exist")
self.assertTrue(any(output_dir.iterdir()), "Extraction output is empty")

@skip_if_quick
def test_download_url_hash_mismatch(self):
"""download_url should raise ValueError on hash mismatch."""
filepath = self.testing_dir / "MedNIST.tar.gz"

with skip_if_downloading_fails():
# First ensure file is downloaded correctly
download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type)

# Now test incorrect hash
with self.assertRaises(ValueError) as ctx:
download_url(self.url, filepath, hash_val="0" * len(self.hash_val), hash_type=self.hash_type)

self.assertIn("hash check", str(ctx.exception).lower())

@skip_if_quick
@parameterized.expand((("icon", "tar"), ("favicon", "zip")))
def test_default(self, key, file_type):
def test_extractall_hash_mismatch(self):
"""extractall should raise ValueError when hash is incorrect."""
filepath = self.testing_dir / "MedNIST.tar.gz"
output_dir = self.testing_dir

with skip_if_downloading_fails():
download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type)

with self.assertRaises(ValueError) as ctx:
extractall(filepath, output_dir, hash_val="0" * len(self.hash_val), hash_type=self.hash_type)

self.assertIn("hash check", str(ctx.exception).lower())

@skip_if_quick
@parameterized.expand([("icon", "tar"), ("favicon", "zip")])
def test_download_and_extract_various_formats(self, key, file_type):
"""Verify different archive formats download and extract correctly."""
with tempfile.TemporaryDirectory() as tmp_dir:
img_spec = testing_data_config("images", key)

with skip_if_downloading_fails():
img_spec = testing_data_config("images", key)
download_and_extract(
img_spec["url"],
output_dir=tmp_dir,
Expand All @@ -67,6 +89,8 @@ def test_default(self, key, file_type):
file_type=file_type,
)

self.assertTrue(any(Path(tmp_dir).iterdir()), f"Extraction failed for format: {file_type}")


class TestPathTraversalProtection(unittest.TestCase):
"""Test cases for path traversal attack protection in extractall function."""
Expand Down
36 changes: 5 additions & 31 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
"unexpected EOF", # incomplete download
"network issue",
"gdown dependency", # gdown not installed
"md5 check",
"hash check", # check hash value of downloaded file
"limit", # HTTP Error 503: Egress is over the account limit
"authenticate",
"timed out", # urlopen error [Errno 110] Connection timed out
Expand Down Expand Up @@ -182,37 +182,11 @@ def skip_if_downloading_fails():
raise unittest.SkipTest(f"Error while downloading: {rt_e}") from rt_e # incomplete download

raise rt_e
except ValueError as v_e:
if "hash check" in str(v_e):
raise unittest.SkipTest(f"Hash value error while downloading: {v_e}") from v_e
Comment thread
coderabbitai[bot] marked this conversation as resolved.


SAMPLE_TIFF = "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/CMU-1.tiff"
SAMPLE_TIFF_HASH = "73a7e89bc15576587c3d68e55d9bf92f09690280166240b48ff4b48230b13bcd"
SAMPLE_TIFF_HASH_TYPE = "sha256"


class TestDownloadUrl(unittest.TestCase):
"""Exercise ``download_url`` success and hash-mismatch paths."""

def test_download_url(self):
"""Download a sample TIFF and validate hash handling.

Raises:
RuntimeError: When the downloaded file's hash does not match.
"""
with tempfile.TemporaryDirectory() as tempdir:
with skip_if_downloading_fails():
download_url(
url=SAMPLE_TIFF,
filepath=os.path.join(tempdir, "model.tiff"),
hash_val=SAMPLE_TIFF_HASH,
hash_type=SAMPLE_TIFF_HASH_TYPE,
)
with self.assertRaises(RuntimeError):
download_url(
url=SAMPLE_TIFF,
filepath=os.path.join(tempdir, "model_bad.tiff"),
hash_val="0" * 64,
hash_type=SAMPLE_TIFF_HASH_TYPE,
)
raise v_e


def test_pretrained_networks(network, input_param, device):
Expand Down