diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 3f3c8d545e..065d55089c 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -24,6 +24,7 @@ get_bundle_info, get_bundle_versions, init_bundle, + inspect_ckpt, load, onnx_export, push_to_hf_hub, diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py index 778c9ef2f0..edce1567df 100644 --- a/monai/bundle/__main__.py +++ b/monai/bundle/__main__.py @@ -16,6 +16,7 @@ download, download_large_files, init_bundle, + inspect_ckpt, onnx_export, run, run_workflow, diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index fa9ba27096..e74bfad312 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -2013,3 +2013,76 @@ def download_large_files(bundle_path: str | None = None, large_file_name: str | lf_data["filepath"] = os.path.join(bundle_path, lf_data["path"]) lf_data.pop("path") download_url(**lf_data) + +def inspect_ckpt( + path: str, + print_all_vars: bool = True, + compute_hash: bool = False, + hash_type: str = "md5", +) -> dict: + """ + Inspect the variables and shapes saved in a checkpoint file. + Prints a human-readable summary of the tensor names, shapes, and dtypes + stored in the checkpoint, similar to TensorFlow's inspect_checkpoint tool. + Optionally also computes the hash value of the file (useful when creating + a ``large_files.yml`` for model-zoo bundles). + + Typical usage examples: + + .. code-block:: bash + + # Display all tensor names, shapes, and dtypes: + python -m monai.bundle inspect_ckpt --path model.pt + + # Suppress individual variable printing (only show file-level info): + python -m monai.bundle inspect_ckpt --path model.pt --print_all_vars false + + # Also compute md5 hash of the checkpoint file: + python -m monai.bundle inspect_ckpt --path model.pt --compute_hash true + + # Use sha256 hash instead of md5: + python -m monai.bundle inspect_ckpt --path model.pt --compute_hash true --hash_type sha256 + + Args: + path: path to the checkpoint file to inspect. + print_all_vars: whether to print individual variable names, shapes, + and dtypes. Default to ``True``. + compute_hash: whether to compute and print the hash value of the + checkpoint file. Default to ``False``. + hash_type: the hash type to use when ``compute_hash`` is ``True``. + Should be ``"md5"`` or ``"sha256"``. Default to ``"md5"``. + + Returns: + A dictionary mapping variable names to a dict containing + ``"shape"`` (tuple) and ``"dtype"`` (str) for each tensor. + """ + import hashlib + + _log_input_summary(tag="inspect_ckpt", args={"path": path, "print_all_vars": print_all_vars, "compute_hash": compute_hash}) + + ckpt = torch.load(path, map_location="cpu", weights_only=True) + if not isinstance(ckpt, Mapping): + ckpt = get_state_dict(ckpt) + + var_info: dict = {} + for name, val in ckpt.items(): + if isinstance(val, torch.Tensor): + var_info[name] = {"shape": tuple(val.shape), "dtype": str(val.dtype)} + else: + var_info[name] = {"shape": None, "dtype": type(val).__name__} + + logger.info(f"checkpoint file: {path}") + logger.info(f"total variables: {len(var_info)}") + if print_all_vars: + for name, info in var_info.items(): + logger.info(f" {name}: shape={info['shape']}, dtype={info['dtype']}") + + if compute_hash: + h = hashlib.new(hash_type) + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(1 << 20), b""): + h.update(chunk) + digest = h.hexdigest() + logger.info(f"{hash_type} hash: {digest}") + + return var_info diff --git a/monai/losses/spectral_loss.py b/monai/losses/spectral_loss.py index 06714f3993..fcba03f132 100644 --- a/monai/losses/spectral_loss.py +++ b/monai/losses/spectral_loss.py @@ -55,8 +55,8 @@ def __init__( self.fft_norm = fft_norm def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - input_amplitude = self._get_fft_amplitude(target) - target_amplitude = self._get_fft_amplitude(input) + input_amplitude = self._get_fft_amplitude(input) + target_amplitude = self._get_fft_amplitude(target) # Compute distance between amplitude of frequency components # See Section 3.3 from https://arxiv.org/abs/2005.00341 diff --git a/monai/losses/ssim_loss.py b/monai/losses/ssim_loss.py index 8ee1da7267..3fa578da29 100644 --- a/monai/losses/ssim_loss.py +++ b/monai/losses/ssim_loss.py @@ -111,17 +111,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # 2D data x = torch.ones([1,1,10,10])/2 y = torch.ones([1,1,10,10])/2 - print(1-SSIMLoss(spatial_dims=2)(x,y)) + print(SSIMLoss(spatial_dims=2)(x,y)) # pseudo-3D data x = torch.ones([1,5,10,10])/2 # 5 could represent number of slices y = torch.ones([1,5,10,10])/2 - print(1-SSIMLoss(spatial_dims=2)(x,y)) + print(SSIMLoss(spatial_dims=2)(x,y)) # 3D data x = torch.ones([1,1,10,10,10])/2 y = torch.ones([1,1,10,10,10])/2 - print(1-SSIMLoss(spatial_dims=3)(x,y)) + print(SSIMLoss(spatial_dims=3)(x,y)) """ ssim_value = self.ssim_metric._compute_tensor(input, target).view(-1, 1) loss: torch.Tensor = 1 - ssim_value diff --git a/tests/bundle/test_bundle_inspect_ckpt.py b/tests/bundle/test_bundle_inspect_ckpt.py new file mode 100644 index 0000000000..ab569b234d --- /dev/null +++ b/tests/bundle/test_bundle_inspect_ckpt.py @@ -0,0 +1,70 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import tempfile +import unittest + +import torch + +from monai.bundle import inspect_ckpt + + +class TestInspectCkpt(unittest.TestCase): + def setUp(self): + # Create a temporary checkpoint file with a simple state dict + self.tmp_dir = tempfile.mkdtemp() + self.ckpt_path = os.path.join(self.tmp_dir, "model.pt") + state_dict = { + "layer1.weight": torch.randn(4, 3), + "layer1.bias": torch.zeros(4), + "layer2.weight": torch.randn(2, 4), + } + torch.save(state_dict, self.ckpt_path) + + def test_returns_dict_with_correct_keys(self): + result = inspect_ckpt(path=self.ckpt_path, print_all_vars=False) + self.assertIsInstance(result, dict) + self.assertIn("layer1.weight", result) + self.assertIn("layer1.bias", result) + self.assertIn("layer2.weight", result) + + def test_shapes_are_correct(self): + result = inspect_ckpt(path=self.ckpt_path, print_all_vars=False) + self.assertEqual(result["layer1.weight"]["shape"], (4, 3)) + self.assertEqual(result["layer1.bias"]["shape"], (4,)) + self.assertEqual(result["layer2.weight"]["shape"], (2, 4)) + + def test_dtype_is_reported(self): + result = inspect_ckpt(path=self.ckpt_path, print_all_vars=False) + self.assertIn("dtype", result["layer1.weight"]) + self.assertTrue(result["layer1.weight"]["dtype"].startswith("torch.")) + + def test_compute_hash_md5(self): + # Should not raise; hash value is logged but not returned in dict + result = inspect_ckpt(path=self.ckpt_path, print_all_vars=False, compute_hash=True, hash_type="md5") + self.assertIsInstance(result, dict) + + def test_compute_hash_sha256(self): + result = inspect_ckpt(path=self.ckpt_path, print_all_vars=False, compute_hash=True, hash_type="sha256") + self.assertIsInstance(result, dict) + + def test_print_all_vars_true_does_not_raise(self): + # Should log each variable without raising + try: + inspect_ckpt(path=self.ckpt_path, print_all_vars=True) + except Exception as e: + self.fail(f"inspect_ckpt raised an exception with print_all_vars=True: {e}") + + +if __name__ == "__main__": + unittest.main()