Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,13 +383,17 @@ GroundingDINO is a model for text-prompted object detection and segmentation, of
- `models/RMBG/SDMatte/SDMatte.safetensors` (standard) or `SDMatte_plus.safetensors` (plus)
- Components (config files) are auto-downloaded; if needed, mirror the structure from the Hugging Face repo to `models/RMBG/SDMatte/` (`scheduler/`, `text_encoder/`, `tokenizer/`, `unet/`, `vae/`)

## Troubleshooting (short)
- 401 error when initializing GroundingDINO / missing `models/sam2`:
- Delete `%USERPROFILE%\.cache\huggingface\token` (and `%USERPROFILE%\.huggingface\token` if present)
- Ensure no `HF_TOKEN`/`HUGGINGFACE_TOKEN` env vars are set
- Re-run; public repos download anonymously (no login required)
- Preview shows "Required input is missing: images":
- Ensure image outputs are connected and upstream nodes ran successfully
## Troubleshooting (short)
- 401 error when initializing GroundingDINO / missing `models/sam2`:
- Delete `%USERPROFILE%\.cache\huggingface\token` (and `%USERPROFILE%\.huggingface\token` if present)
- Ensure no `HF_TOKEN`/`HUGGINGFACE_TOKEN` env vars are set
- Re-run; public repos download anonymously (no login required)
- Windows: ComfyUI exits when running `RMBG-2.0`:
- This can be a native crash inside PyTorch when importing/executing the bundled model code.
- The node can run `RMBG-2.0` in a subprocess (default on Windows). Set `COMFYUI_RMBG_RMBG2_SUBPROCESS=0` to disable.
- Set `COMFYUI_RMBG_DEBUG_PROGRESS=1` to write a debug log to `ComfyUI/user/rmbg_progress.log`.
- Preview shows "Required input is missing: images":
- Ensure image outputs are connected and upstream nodes ran successfully

## Credits
- RMBG-2.0: https://huggingface.co/briaai/RMBG-2.0
Expand Down
219 changes: 179 additions & 40 deletions py/AILab_RMBG.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
# Source: https://github.com/1038lab/ComfyUI-RMBG

import os
import platform
import subprocess
import time
import uuid
import shutil
import torch
from PIL import Image
from torchvision import transforms
import numpy as np
import folder_paths
from PIL import ImageFilter
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
import shutil
import sys
import importlib.util
from transformers import AutoModelForImageSegmentation
Expand All @@ -35,6 +38,109 @@

folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG"))


def _rmbg_progress_path():
try:
user_dir = folder_paths.get_user_directory()
os.makedirs(user_dir, exist_ok=True)
return os.path.join(user_dir, "rmbg_progress.log")
except Exception:
return os.path.join(os.getcwd(), "rmbg_progress.log")


def _rmbg_progress(msg: str):
"""
Crash-resilient marker logging to help locate native crashes.
Enabled only when COMFYUI_RMBG_DEBUG_PROGRESS=1.
"""
if not os.environ.get("COMFYUI_RMBG_DEBUG_PROGRESS"):
return
try:
path = _rmbg_progress_path()
ts = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
line = f"[{ts}] pid={os.getpid()} {msg}\r\n".encode("utf-8", errors="replace")
fd = os.open(path, os.O_CREAT | os.O_APPEND | os.O_WRONLY, 0o666)
try:
os.write(fd, line)
finally:
os.close(fd)
except Exception:
pass


def _run_rmbg2_subprocess(images, model_name: str, params: dict):
"""
Run RMBG-2.0 inference in a separate process to avoid native crashes taking down ComfyUI.
Returns a list of PIL 'L' masks.
"""
tmp_root = folder_paths.get_temp_directory()
run_id = f"rmbg2_{os.getpid()}_{uuid.uuid4().hex}"
run_dir = os.path.join(tmp_root, run_id)
os.makedirs(run_dir, exist_ok=True)

input_paths = []
output_paths = []
for i, img in enumerate(images):
in_path = os.path.join(run_dir, f"in_{i}.png")
out_path = os.path.join(run_dir, f"out_{i}.png")
tensor2pil(img).convert("RGB").save(in_path)
input_paths.append(in_path)
output_paths.append(out_path)

node_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
worker_path = os.path.join(os.path.dirname(__file__), "rmbg_worker.py")

cache_dir = os.path.join(folder_paths.models_dir, "RMBG", AVAILABLE_MODELS[model_name]["cache_dir"])

cmd = [
sys.executable,
"-s",
worker_path,
"--node-root",
node_root,
"--model",
model_name,
"--cache-dir",
cache_dir,
"--process-res",
str(int(params.get("process_res", 1024))),
"--sensitivity",
str(float(params.get("sensitivity", 1.0))),
"--inputs",
*input_paths,
"--outputs",
*output_paths,
]

debug_keep = bool(os.environ.get("COMFYUI_RMBG_DEBUG_KEEP_TEMP"))
try:
_rmbg_progress(f"RMBG-2.0 subprocess spawn: {worker_path} inputs={len(input_paths)}")
comfy_root = os.path.dirname(folder_paths.__file__)
result = subprocess.run(cmd, cwd=comfy_root, capture_output=True, text=True)
_rmbg_progress(f"RMBG-2.0 subprocess exit={result.returncode}")

if result.returncode != 0:
stderr_tail = (result.stderr or "").strip().splitlines()[-40:]
stdout_tail = (result.stdout or "").strip().splitlines()[-40:]
_rmbg_progress(f"RMBG-2.0 subprocess stdout_tail={stdout_tail}")
_rmbg_progress(f"RMBG-2.0 subprocess stderr_tail={stderr_tail}")

print("[RMBG ERROR] RMBG-2.0 subprocess failed. Last output:")
if stdout_tail:
print("\n".join(stdout_tail))
if stderr_tail:
print("\n".join(stderr_tail))
raise RuntimeError("RMBG-2.0 subprocess failed")

masks = []
for out_path in output_paths:
mask = Image.open(out_path).convert("L")
masks.append(mask)
return masks
finally:
if not debug_keep:
shutil.rmtree(run_dir, ignore_errors=True)

# Model configuration
AVAILABLE_MODELS = {
"RMBG-2.0": {
Expand Down Expand Up @@ -243,57 +349,78 @@ def process_image(self, images, model_name, params):
try:
self.load_model(model_name)

# Prepare batch processing
transform_image = transforms.Compose([
transforms.Resize((params["process_res"], params["process_res"])),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

if isinstance(images, torch.Tensor):
if len(images.shape) == 3:
images = [images]
else:
images = [img for img in images]

original_sizes = [tensor2pil(img).size for img in images]

input_tensors = [transform_image(tensor2pil(img)).unsqueeze(0) for img in images]
input_batch = torch.cat(input_tensors, dim=0).to(device)
# Avoid torchvision CPU transforms on Windows (can crash in torch_cpu.dll on some setups).
# ComfyUI IMAGE tensors are HWC in [0, 1] on CPU; move to GPU and do resize/normalize there.
original_sizes = [(int(img.shape[1]), int(img.shape[0])) for img in images]

input_tensors = []
for img in images:
if img.ndim != 3:
handle_model_error(f"Unexpected image tensor shape: {tuple(img.shape)}")
chw = img.permute(2, 0, 1)
if chw.shape[0] > 3:
chw = chw[:3, :, :]
input_tensors.append(chw)

input_batch = torch.stack(input_tensors, dim=0).to(device=device, dtype=torch.float32, non_blocking=True)
input_batch = F.interpolate(
input_batch,
size=(params["process_res"], params["process_res"]),
mode="bilinear",
align_corners=False,
)
mean = torch.tensor([0.485, 0.456, 0.406], device=device, dtype=input_batch.dtype).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=device, dtype=input_batch.dtype).view(1, 3, 1, 1)
input_batch = (input_batch - mean) / std

with torch.no_grad():
outputs = self.model(input_batch)

results = None
if isinstance(outputs, list) and len(outputs) > 0:
results = outputs[-1].sigmoid().cpu()
elif isinstance(outputs, dict) and 'logits' in outputs:
results = outputs['logits'].sigmoid().cpu()
results = outputs[-1]
elif isinstance(outputs, dict) and "logits" in outputs:
results = outputs["logits"]
elif isinstance(outputs, torch.Tensor):
results = outputs.sigmoid().cpu()
results = outputs
else:
try:
if hasattr(outputs, 'last_hidden_state'):
results = outputs.last_hidden_state.sigmoid().cpu()
if hasattr(outputs, "last_hidden_state"):
results = outputs.last_hidden_state
else:
for k, v in outputs.items():
for _, v in outputs.items():
if isinstance(v, torch.Tensor):
results = v.sigmoid().cpu()
results = v
break
except:
handle_model_error("Unable to recognize model output format")
except Exception:
results = None

if results is None:
handle_model_error("Unable to recognize model output format")

results = results.sigmoid()

masks = []

for i, (result, (orig_w, orig_h)) in enumerate(zip(results, original_sizes)):
result = result.squeeze()
result = result * (1 + (1 - params["sensitivity"]))
result = torch.clamp(result, 0, 1)

result = F.interpolate(result.unsqueeze(0).unsqueeze(0),
size=(orig_h, orig_w),
mode='bilinear').squeeze()

masks.append(tensor2pil(result))
for result, (orig_w, orig_h) in zip(results, original_sizes):
mask = result.squeeze()
mask = mask * (1 + (1 - params["sensitivity"]))
mask = torch.clamp(mask, 0, 1)

mask = F.interpolate(
mask.unsqueeze(0).unsqueeze(0),
size=(orig_h, orig_w),
mode="bilinear",
align_corners=False,
).squeeze()

masks.append(tensor2pil(mask.detach().float().cpu()))

return masks

Expand Down Expand Up @@ -558,6 +685,10 @@ def INPUT_TYPES(s):

def process_image(self, image, model, **params):
try:
_rmbg_progress(
f"RMBG node start model={model} batch={getattr(image, 'shape', None)} "
f"refine={params.get('refine_foreground', False)} res={params.get('process_res', None)}"
)
processed_images = []
processed_masks = []

Expand All @@ -583,10 +714,11 @@ def _process_pair(img, mask):
else:
mask_local = mask

mask_tensor_local = pil2tensor(mask_local)
mask_tensor_local = mask_tensor_local * (1 + (1 - params["sensitivity"]))
mask_tensor_local = torch.clamp(mask_tensor_local, 0, 1)
mask_img_local = tensor2pil(mask_tensor_local)
# Avoid CPU torch ops here (some Windows setups crash in torch_cpu.dll during clamp/interpolate).
mask_arr = np.array(mask_local, dtype=np.float32) / 255.0
mask_arr = mask_arr * (1 + (1 - params["sensitivity"]))
mask_arr = np.clip(mask_arr, 0.0, 1.0)
mask_img_local = Image.fromarray((mask_arr * 255.0).astype(np.uint8), mode="L")

if params["mask_blur"] > 0:
mask_img_local = mask_img_local.filter(ImageFilter.GaussianBlur(radius=params["mask_blur"]))
Expand Down Expand Up @@ -643,7 +775,13 @@ def hex_to_rgba(hex_color):
chunk_size = 4
for start in range(0, len(images_list), chunk_size):
batch_imgs = images_list[start:start + chunk_size]
masks = model_instance.process_image(batch_imgs, model, params)
_rmbg_progress(f"RMBG node calling model_instance.process_image type={model_type} chunk_start={start} chunk_len={len(batch_imgs)}")
use_subprocess = os.environ.get("COMFYUI_RMBG_RMBG2_SUBPROCESS", "1") != "0"
if use_subprocess and model == "RMBG-2.0" and platform.system() == "Windows":
masks = _run_rmbg2_subprocess(batch_imgs, model, params)
else:
masks = model_instance.process_image(batch_imgs, model, params)
_rmbg_progress("RMBG node returned from model_instance.process_image")
if isinstance(masks, Image.Image):
masks = [masks]
for img_item, mask_item in zip(batch_imgs, masks):
Expand All @@ -659,10 +797,11 @@ def hex_to_rgba(hex_color):
mask_images.append(mask_image)

mask_image_output = torch.cat(mask_images, dim=0)

_rmbg_progress(f"RMBG node finish processed_images={len(processed_images)} processed_masks={len(processed_masks)}")
return (torch.cat(processed_images, dim=0), torch.cat(processed_masks, dim=0), mask_image_output)

except Exception as e:
_rmbg_progress(f"RMBG node exception: {type(e).__name__}: {e}")
handle_model_error(f"Error in image processing: {str(e)}")
empty_mask = torch.zeros((image.shape[0], image.shape[2], image.shape[3]))
empty_mask_image = empty_mask.reshape((-1, 1, empty_mask.shape[-2], empty_mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
Expand All @@ -674,4 +813,4 @@ def hex_to_rgba(hex_color):

NODE_DISPLAY_NAME_MAPPINGS = {
"RMBG": "Remove Background (RMBG)"
}
}
27 changes: 17 additions & 10 deletions py/AILab_SAM3Segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
if str(MODELS_ROOT) not in sys.path:
sys.path.insert(0, str(MODELS_ROOT))

SAM3_BPE_PATH = SAM3_LOCAL_DIR / "assets" / "bpe_simple_vocab_16e6.txt.gz"
if not os.path.isfile(SAM3_BPE_PATH):
raise RuntimeError("SAM3 assets missing; ensure sam3/assets/bpe_simple_vocab_16e6.txt.gz exists.")
SAM3_BPE_PATH = SAM3_LOCAL_DIR / "assets" / "bpe_simple_vocab_16e6.txt.gz"

_DEFAULT_PT_ENTRY = {
"model_url": "https://huggingface.co/1038lab/sam3/resolve/main/sam3.pt",
Expand Down Expand Up @@ -104,11 +102,19 @@ def _resolve_device(user_choice):
return auto_device


from sam3.model_builder import build_sam3_image_model # noqa: E402
from sam3.model.sam3_image_processor import Sam3Processor # noqa: E402

def _lazy_import_sam3():
if not os.path.isfile(SAM3_BPE_PATH):
raise RuntimeError("SAM3 assets missing; ensure models/sam3/assets/bpe_simple_vocab_16e6.txt.gz exists.")

# Import only when node is used to avoid breaking ComfyUI startup if optional deps are missing.
from sam3.model_builder import build_sam3_image_model # noqa: E402
from sam3.model.sam3_image_processor import Sam3Processor # noqa: E402

return build_sam3_image_model, Sam3Processor


class SAM3Segment:
class SAM3Segment:
@classmethod
def INPUT_TYPES(cls):
return {
Expand Down Expand Up @@ -139,10 +145,11 @@ def INPUT_TYPES(cls):
def __init__(self):
self.processor_cache = {}

def _load_processor(self, device_choice):
torch_device = _resolve_device(device_choice)
device_str = "cuda" if torch_device.type == "cuda" else "cpu"
cache_key = ("sam3", device_str)
def _load_processor(self, device_choice):
build_sam3_image_model, Sam3Processor = _lazy_import_sam3()
torch_device = _resolve_device(device_choice)
device_str = "cuda" if torch_device.type == "cuda" else "cpu"
cache_key = ("sam3", device_str)
if cache_key not in self.processor_cache:
model_info = SAM3_MODELS["sam3"]
ckpt_path = get_or_download_model_file(model_info["filename"], model_info["model_url"])
Expand Down
Loading