diff --git a/README.md b/README.md index 85e9c6e..f8de42d 100644 --- a/README.md +++ b/README.md @@ -383,13 +383,17 @@ GroundingDINO is a model for text-prompted object detection and segmentation, of - `models/RMBG/SDMatte/SDMatte.safetensors` (standard) or `SDMatte_plus.safetensors` (plus) - Components (config files) are auto-downloaded; if needed, mirror the structure from the Hugging Face repo to `models/RMBG/SDMatte/` (`scheduler/`, `text_encoder/`, `tokenizer/`, `unet/`, `vae/`) -## Troubleshooting (short) -- 401 error when initializing GroundingDINO / missing `models/sam2`: - - Delete `%USERPROFILE%\.cache\huggingface\token` (and `%USERPROFILE%\.huggingface\token` if present) - - Ensure no `HF_TOKEN`/`HUGGINGFACE_TOKEN` env vars are set - - Re-run; public repos download anonymously (no login required) -- Preview shows "Required input is missing: images": - - Ensure image outputs are connected and upstream nodes ran successfully +## Troubleshooting (short) +- 401 error when initializing GroundingDINO / missing `models/sam2`: + - Delete `%USERPROFILE%\.cache\huggingface\token` (and `%USERPROFILE%\.huggingface\token` if present) + - Ensure no `HF_TOKEN`/`HUGGINGFACE_TOKEN` env vars are set + - Re-run; public repos download anonymously (no login required) +- Windows: ComfyUI exits when running `RMBG-2.0`: + - This can be a native crash inside PyTorch when importing/executing the bundled model code. + - The node can run `RMBG-2.0` in a subprocess (default on Windows). Set `COMFYUI_RMBG_RMBG2_SUBPROCESS=0` to disable. + - Set `COMFYUI_RMBG_DEBUG_PROGRESS=1` to write a debug log to `ComfyUI/user/rmbg_progress.log`. +- Preview shows "Required input is missing: images": + - Ensure image outputs are connected and upstream nodes ran successfully ## Credits - RMBG-2.0: https://huggingface.co/briaai/RMBG-2.0 diff --git a/py/AILab_RMBG.py b/py/AILab_RMBG.py index 06c5a63..91938b8 100644 --- a/py/AILab_RMBG.py +++ b/py/AILab_RMBG.py @@ -16,15 +16,18 @@ # Source: https://github.com/1038lab/ComfyUI-RMBG import os +import platform +import subprocess +import time +import uuid +import shutil import torch from PIL import Image -from torchvision import transforms import numpy as np import folder_paths from PIL import ImageFilter import torch.nn.functional as F from huggingface_hub import hf_hub_download -import shutil import sys import importlib.util from transformers import AutoModelForImageSegmentation @@ -35,6 +38,109 @@ folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG")) + +def _rmbg_progress_path(): + try: + user_dir = folder_paths.get_user_directory() + os.makedirs(user_dir, exist_ok=True) + return os.path.join(user_dir, "rmbg_progress.log") + except Exception: + return os.path.join(os.getcwd(), "rmbg_progress.log") + + +def _rmbg_progress(msg: str): + """ + Crash-resilient marker logging to help locate native crashes. + Enabled only when COMFYUI_RMBG_DEBUG_PROGRESS=1. + """ + if not os.environ.get("COMFYUI_RMBG_DEBUG_PROGRESS"): + return + try: + path = _rmbg_progress_path() + ts = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + line = f"[{ts}] pid={os.getpid()} {msg}\r\n".encode("utf-8", errors="replace") + fd = os.open(path, os.O_CREAT | os.O_APPEND | os.O_WRONLY, 0o666) + try: + os.write(fd, line) + finally: + os.close(fd) + except Exception: + pass + + +def _run_rmbg2_subprocess(images, model_name: str, params: dict): + """ + Run RMBG-2.0 inference in a separate process to avoid native crashes taking down ComfyUI. + Returns a list of PIL 'L' masks. + """ + tmp_root = folder_paths.get_temp_directory() + run_id = f"rmbg2_{os.getpid()}_{uuid.uuid4().hex}" + run_dir = os.path.join(tmp_root, run_id) + os.makedirs(run_dir, exist_ok=True) + + input_paths = [] + output_paths = [] + for i, img in enumerate(images): + in_path = os.path.join(run_dir, f"in_{i}.png") + out_path = os.path.join(run_dir, f"out_{i}.png") + tensor2pil(img).convert("RGB").save(in_path) + input_paths.append(in_path) + output_paths.append(out_path) + + node_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + worker_path = os.path.join(os.path.dirname(__file__), "rmbg_worker.py") + + cache_dir = os.path.join(folder_paths.models_dir, "RMBG", AVAILABLE_MODELS[model_name]["cache_dir"]) + + cmd = [ + sys.executable, + "-s", + worker_path, + "--node-root", + node_root, + "--model", + model_name, + "--cache-dir", + cache_dir, + "--process-res", + str(int(params.get("process_res", 1024))), + "--sensitivity", + str(float(params.get("sensitivity", 1.0))), + "--inputs", + *input_paths, + "--outputs", + *output_paths, + ] + + debug_keep = bool(os.environ.get("COMFYUI_RMBG_DEBUG_KEEP_TEMP")) + try: + _rmbg_progress(f"RMBG-2.0 subprocess spawn: {worker_path} inputs={len(input_paths)}") + comfy_root = os.path.dirname(folder_paths.__file__) + result = subprocess.run(cmd, cwd=comfy_root, capture_output=True, text=True) + _rmbg_progress(f"RMBG-2.0 subprocess exit={result.returncode}") + + if result.returncode != 0: + stderr_tail = (result.stderr or "").strip().splitlines()[-40:] + stdout_tail = (result.stdout or "").strip().splitlines()[-40:] + _rmbg_progress(f"RMBG-2.0 subprocess stdout_tail={stdout_tail}") + _rmbg_progress(f"RMBG-2.0 subprocess stderr_tail={stderr_tail}") + + print("[RMBG ERROR] RMBG-2.0 subprocess failed. Last output:") + if stdout_tail: + print("\n".join(stdout_tail)) + if stderr_tail: + print("\n".join(stderr_tail)) + raise RuntimeError("RMBG-2.0 subprocess failed") + + masks = [] + for out_path in output_paths: + mask = Image.open(out_path).convert("L") + masks.append(mask) + return masks + finally: + if not debug_keep: + shutil.rmtree(run_dir, ignore_errors=True) + # Model configuration AVAILABLE_MODELS = { "RMBG-2.0": { @@ -243,57 +349,78 @@ def process_image(self, images, model_name, params): try: self.load_model(model_name) - # Prepare batch processing - transform_image = transforms.Compose([ - transforms.Resize((params["process_res"], params["process_res"])), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ]) - if isinstance(images, torch.Tensor): if len(images.shape) == 3: images = [images] else: images = [img for img in images] - original_sizes = [tensor2pil(img).size for img in images] - - input_tensors = [transform_image(tensor2pil(img)).unsqueeze(0) for img in images] - input_batch = torch.cat(input_tensors, dim=0).to(device) + # Avoid torchvision CPU transforms on Windows (can crash in torch_cpu.dll on some setups). + # ComfyUI IMAGE tensors are HWC in [0, 1] on CPU; move to GPU and do resize/normalize there. + original_sizes = [(int(img.shape[1]), int(img.shape[0])) for img in images] + + input_tensors = [] + for img in images: + if img.ndim != 3: + handle_model_error(f"Unexpected image tensor shape: {tuple(img.shape)}") + chw = img.permute(2, 0, 1) + if chw.shape[0] > 3: + chw = chw[:3, :, :] + input_tensors.append(chw) + + input_batch = torch.stack(input_tensors, dim=0).to(device=device, dtype=torch.float32, non_blocking=True) + input_batch = F.interpolate( + input_batch, + size=(params["process_res"], params["process_res"]), + mode="bilinear", + align_corners=False, + ) + mean = torch.tensor([0.485, 0.456, 0.406], device=device, dtype=input_batch.dtype).view(1, 3, 1, 1) + std = torch.tensor([0.229, 0.224, 0.225], device=device, dtype=input_batch.dtype).view(1, 3, 1, 1) + input_batch = (input_batch - mean) / std with torch.no_grad(): outputs = self.model(input_batch) + results = None if isinstance(outputs, list) and len(outputs) > 0: - results = outputs[-1].sigmoid().cpu() - elif isinstance(outputs, dict) and 'logits' in outputs: - results = outputs['logits'].sigmoid().cpu() + results = outputs[-1] + elif isinstance(outputs, dict) and "logits" in outputs: + results = outputs["logits"] elif isinstance(outputs, torch.Tensor): - results = outputs.sigmoid().cpu() + results = outputs else: try: - if hasattr(outputs, 'last_hidden_state'): - results = outputs.last_hidden_state.sigmoid().cpu() + if hasattr(outputs, "last_hidden_state"): + results = outputs.last_hidden_state else: - for k, v in outputs.items(): + for _, v in outputs.items(): if isinstance(v, torch.Tensor): - results = v.sigmoid().cpu() + results = v break - except: - handle_model_error("Unable to recognize model output format") + except Exception: + results = None + + if results is None: + handle_model_error("Unable to recognize model output format") + + results = results.sigmoid() masks = [] - for i, (result, (orig_w, orig_h)) in enumerate(zip(results, original_sizes)): - result = result.squeeze() - result = result * (1 + (1 - params["sensitivity"])) - result = torch.clamp(result, 0, 1) - - result = F.interpolate(result.unsqueeze(0).unsqueeze(0), - size=(orig_h, orig_w), - mode='bilinear').squeeze() - - masks.append(tensor2pil(result)) + for result, (orig_w, orig_h) in zip(results, original_sizes): + mask = result.squeeze() + mask = mask * (1 + (1 - params["sensitivity"])) + mask = torch.clamp(mask, 0, 1) + + mask = F.interpolate( + mask.unsqueeze(0).unsqueeze(0), + size=(orig_h, orig_w), + mode="bilinear", + align_corners=False, + ).squeeze() + + masks.append(tensor2pil(mask.detach().float().cpu())) return masks @@ -558,6 +685,10 @@ def INPUT_TYPES(s): def process_image(self, image, model, **params): try: + _rmbg_progress( + f"RMBG node start model={model} batch={getattr(image, 'shape', None)} " + f"refine={params.get('refine_foreground', False)} res={params.get('process_res', None)}" + ) processed_images = [] processed_masks = [] @@ -583,10 +714,11 @@ def _process_pair(img, mask): else: mask_local = mask - mask_tensor_local = pil2tensor(mask_local) - mask_tensor_local = mask_tensor_local * (1 + (1 - params["sensitivity"])) - mask_tensor_local = torch.clamp(mask_tensor_local, 0, 1) - mask_img_local = tensor2pil(mask_tensor_local) + # Avoid CPU torch ops here (some Windows setups crash in torch_cpu.dll during clamp/interpolate). + mask_arr = np.array(mask_local, dtype=np.float32) / 255.0 + mask_arr = mask_arr * (1 + (1 - params["sensitivity"])) + mask_arr = np.clip(mask_arr, 0.0, 1.0) + mask_img_local = Image.fromarray((mask_arr * 255.0).astype(np.uint8), mode="L") if params["mask_blur"] > 0: mask_img_local = mask_img_local.filter(ImageFilter.GaussianBlur(radius=params["mask_blur"])) @@ -643,7 +775,13 @@ def hex_to_rgba(hex_color): chunk_size = 4 for start in range(0, len(images_list), chunk_size): batch_imgs = images_list[start:start + chunk_size] - masks = model_instance.process_image(batch_imgs, model, params) + _rmbg_progress(f"RMBG node calling model_instance.process_image type={model_type} chunk_start={start} chunk_len={len(batch_imgs)}") + use_subprocess = os.environ.get("COMFYUI_RMBG_RMBG2_SUBPROCESS", "1") != "0" + if use_subprocess and model == "RMBG-2.0" and platform.system() == "Windows": + masks = _run_rmbg2_subprocess(batch_imgs, model, params) + else: + masks = model_instance.process_image(batch_imgs, model, params) + _rmbg_progress("RMBG node returned from model_instance.process_image") if isinstance(masks, Image.Image): masks = [masks] for img_item, mask_item in zip(batch_imgs, masks): @@ -659,10 +797,11 @@ def hex_to_rgba(hex_color): mask_images.append(mask_image) mask_image_output = torch.cat(mask_images, dim=0) - + _rmbg_progress(f"RMBG node finish processed_images={len(processed_images)} processed_masks={len(processed_masks)}") return (torch.cat(processed_images, dim=0), torch.cat(processed_masks, dim=0), mask_image_output) except Exception as e: + _rmbg_progress(f"RMBG node exception: {type(e).__name__}: {e}") handle_model_error(f"Error in image processing: {str(e)}") empty_mask = torch.zeros((image.shape[0], image.shape[2], image.shape[3])) empty_mask_image = empty_mask.reshape((-1, 1, empty_mask.shape[-2], empty_mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) @@ -674,4 +813,4 @@ def hex_to_rgba(hex_color): NODE_DISPLAY_NAME_MAPPINGS = { "RMBG": "Remove Background (RMBG)" -} \ No newline at end of file +} diff --git a/py/AILab_SAM3Segment.py b/py/AILab_SAM3Segment.py index 6f42573..45d0f31 100644 --- a/py/AILab_SAM3Segment.py +++ b/py/AILab_SAM3Segment.py @@ -22,9 +22,7 @@ if str(MODELS_ROOT) not in sys.path: sys.path.insert(0, str(MODELS_ROOT)) -SAM3_BPE_PATH = SAM3_LOCAL_DIR / "assets" / "bpe_simple_vocab_16e6.txt.gz" -if not os.path.isfile(SAM3_BPE_PATH): - raise RuntimeError("SAM3 assets missing; ensure sam3/assets/bpe_simple_vocab_16e6.txt.gz exists.") +SAM3_BPE_PATH = SAM3_LOCAL_DIR / "assets" / "bpe_simple_vocab_16e6.txt.gz" _DEFAULT_PT_ENTRY = { "model_url": "https://huggingface.co/1038lab/sam3/resolve/main/sam3.pt", @@ -104,11 +102,19 @@ def _resolve_device(user_choice): return auto_device -from sam3.model_builder import build_sam3_image_model # noqa: E402 -from sam3.model.sam3_image_processor import Sam3Processor # noqa: E402 + +def _lazy_import_sam3(): + if not os.path.isfile(SAM3_BPE_PATH): + raise RuntimeError("SAM3 assets missing; ensure models/sam3/assets/bpe_simple_vocab_16e6.txt.gz exists.") + + # Import only when node is used to avoid breaking ComfyUI startup if optional deps are missing. + from sam3.model_builder import build_sam3_image_model # noqa: E402 + from sam3.model.sam3_image_processor import Sam3Processor # noqa: E402 + + return build_sam3_image_model, Sam3Processor -class SAM3Segment: +class SAM3Segment: @classmethod def INPUT_TYPES(cls): return { @@ -139,10 +145,11 @@ def INPUT_TYPES(cls): def __init__(self): self.processor_cache = {} - def _load_processor(self, device_choice): - torch_device = _resolve_device(device_choice) - device_str = "cuda" if torch_device.type == "cuda" else "cpu" - cache_key = ("sam3", device_str) + def _load_processor(self, device_choice): + build_sam3_image_model, Sam3Processor = _lazy_import_sam3() + torch_device = _resolve_device(device_choice) + device_str = "cuda" if torch_device.type == "cuda" else "cpu" + cache_key = ("sam3", device_str) if cache_key not in self.processor_cache: model_info = SAM3_MODELS["sam3"] ckpt_path = get_or_download_model_file(model_info["filename"], model_info["model_url"]) diff --git a/py/rmbg_worker.py b/py/rmbg_worker.py new file mode 100644 index 0000000..39c440b --- /dev/null +++ b/py/rmbg_worker.py @@ -0,0 +1,61 @@ +import argparse +import os +import sys +from pathlib import Path + + +def _parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--node-root", required=True, help="Path to the comfyui-rmbg node root") + p.add_argument("--model", required=True) + p.add_argument("--cache-dir", required=True) + p.add_argument("--process-res", type=int, required=True) + p.add_argument("--sensitivity", type=float, required=True) + p.add_argument("--inputs", nargs="+", required=True) + p.add_argument("--outputs", nargs="+", required=True) + return p.parse_args() + + +def main(): + args = _parse_args() + node_root = Path(args.node_root).resolve() + node_py = node_root / "py" + sys.path.insert(0, str(node_root)) + sys.path.insert(0, str(node_py)) + + from PIL import Image + import numpy as np + import torch + + from AILab_RMBG import RMBGModel + + if not os.path.isdir(args.cache_dir): + raise RuntimeError(f"Cache dir does not exist: {args.cache_dir}") + + model = RMBGModel() + model.load_model(args.model) + + images = [] + for path in args.inputs: + img = Image.open(path).convert("RGB") + t = torch.from_numpy(np.array(img).astype("float32") / 255.0) + images.append(t) + + params = {"process_res": args.process_res, "sensitivity": args.sensitivity} + masks = model.process_image(images, args.model, params) + + if len(masks) != len(args.outputs): + raise RuntimeError(f"Expected {len(args.outputs)} masks, got {len(masks)}") + + for mask, out_path in zip(masks, args.outputs): + if isinstance(mask, Image.Image): + mask_img = mask.convert("L") + else: + mask_img = mask + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + mask_img.save(out_path) + + +if __name__ == "__main__": + main() + diff --git a/pyproject.toml b/pyproject.toml index 9912758..1d79051 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ description = "A sophisticated ComfyUI custom node engineered for advanced image version = "2.9.6" license = {file = "LICENSE"} -dependencies = ["huggingface-hub>=0.19.0", "transparent-background>=1.1.2", "segment-anything>=1.0", "groundingdino-py>=0.4.0", "opencv-python>=4.7.0", "onnxruntime>=1.15.0", "onnxruntime-gpu>=1.15.0", "protobuf>=3.20.2,<6.0.0", "hydra-core>=1.3.0", "omegaconf>=2.3.0", "iopath>=0.1.9"] +dependencies = ["huggingface-hub>=0.19.0", "transparent-background>=1.1.2", "segment-anything>=1.0", "groundingdino-py>=0.4.0; platform_system != 'Windows'", "opencv-python>=4.7.0", "onnxruntime>=1.15.0", "onnxruntime-gpu>=1.15.0", "protobuf>=3.20.2,<6.0.0", "hydra-core>=1.3.0", "omegaconf>=2.3.0", "iopath>=0.1.9"] [project.urls] Repository = "https://github.com/1038lab/ComfyUI-RMBG" diff --git a/requirements.txt b/requirements.txt index f122f0f..5952774 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,9 @@ huggingface-hub>=0.19.0 transparent-background>=1.1.2 segment-anything>=1.0 -groundingdino-py>=0.4.0 +# Optional: often fails to build on Windows (esp. Python 3.12); nodes fall back when unavailable. +groundingdino-py>=0.4.0; platform_system != "Windows" + opencv-python>=4.7.0 onnxruntime>=1.15.0 onnxruntime-gpu>=1.15.0