Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions backend/app/api/dingtalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,22 @@ async def process_dingtalk_message(
)
platform_user_id = platform_user.id

# Check for channel commands (/new, /reset)
from app.services.channel_commands import is_channel_command, handle_channel_command
if is_channel_command(user_text):
cmd_result = await handle_channel_command(
db=db, command=user_text, agent_id=agent_id,
user_id=platform_user_id, external_conv_id=conv_id,
source_channel="dingtalk",
)
await db.commit()
async with httpx.AsyncClient(timeout=10) as _cl_cmd:
await _cl_cmd.post(session_webhook, json={
"msgtype": "text",
"text": {"content": cmd_result["message"]},
})
return

# Find or create session
sess = await find_or_create_channel_session(
db=db,
Expand Down
28 changes: 28 additions & 0 deletions backend/app/api/feishu.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,13 @@ async def process_feishu_event(agent_id: uuid.UUID, body: dict, db: AsyncSession
from app.models.agent import DEFAULT_CONTEXT_WINDOW_SIZE
ctx_size = (agent_obj.context_window_size or DEFAULT_CONTEXT_WINDOW_SIZE) if agent_obj else DEFAULT_CONTEXT_WINDOW_SIZE

# Detect channel command early, but defer processing until we have
# resolved the real sender's platform_user_id (see below). Handling the
# command before resolve_channel_user() would attribute the new P2P
# session to the agent creator instead of the actual Feishu sender.
from app.services.channel_commands import is_channel_command, handle_channel_command
_is_cmd = is_channel_command(user_text)

# Pre-resolve session so history lookup uses the UUID (session created later if new)
_pre_sess_r = await db.execute(
select(__import__('app.models.chat_session', fromlist=['ChatSession']).ChatSession).where(
Expand Down Expand Up @@ -561,6 +568,27 @@ async def process_feishu_event(agent_id: uuid.UUID, body: dict, db: AsyncSession
)
platform_user_id = platform_user.id

# Now that the real sender is resolved, handle /new or /reset so the
# replacement P2P session is attributed to the sender (not creator_id).
# Mirrors the user_id rule used by find_or_create_channel_session below:
# group → creator_id (placeholder); P2P → platform_user_id.
if _is_cmd:
_is_group_cmd = (chat_type == "group")
_cmd_user_id = creator_id if _is_group_cmd else platform_user_id
_cmd_result = await handle_channel_command(
db=db, command=user_text, agent_id=agent_id,
user_id=_cmd_user_id, external_conv_id=conv_id,
source_channel="feishu",
)
await db.commit()
import json as _j_cmd
_cmd_reply = _j_cmd.dumps({"text": _cmd_result["message"]})
if _is_group_cmd and chat_id:
await feishu_service.send_message(config.app_id, config.app_secret, chat_id, "text", _cmd_reply, receive_id_type="chat_id")
else:
await feishu_service.send_message(config.app_id, config.app_secret, sender_open_id, "text", _cmd_reply)
return {"code": 0, "msg": "command handled"}

# ── Find-or-create a ChatSession via external_conv_id (DB-based, no cache needed) ──
from datetime import datetime as _dt, timezone as _tz
_is_group = (chat_type == "group")
Expand Down
69 changes: 69 additions & 0 deletions backend/app/services/channel_commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Channel command handler for external channels (DingTalk, Feishu, etc.)

Supports slash commands like /new to reset session context.
"""

import uuid
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from app.models.chat_session import ChatSession
from app.services.channel_session import find_or_create_channel_session


COMMANDS = {"/new", "/reset"}


def is_channel_command(text: str) -> bool:
"""Check if the message is a recognized channel command."""
stripped = text.strip().lower()
return stripped in COMMANDS


async def handle_channel_command(
db: AsyncSession,
command: str,
agent_id: uuid.UUID,
user_id: uuid.UUID,
external_conv_id: str,
source_channel: str,
) -> dict:
"""Handle a channel command and return response info.

Returns:
{"action": "new_session", "message": "..."}
"""
cmd = command.strip().lower()

if cmd in ("/new", "/reset"):
# Find current session. Scope by source_channel as well so we never
# accidentally archive a session from a different channel that happens
# to share the same external_conv_id (defensive against future changes
# to the per-channel ID prefix scheme).
result = await db.execute(
select(ChatSession).where(
ChatSession.agent_id == agent_id,
ChatSession.external_conv_id == external_conv_id,
ChatSession.source_channel == source_channel,
)
)
old_session = result.scalar_one_or_none()

if old_session:
# Rename old external_conv_id so find_or_create will make a new one
now = datetime.now(timezone.utc)
old_session.external_conv_id = (
f"{external_conv_id}__archived_{now.strftime('%Y%m%d_%H%M%S')}"
)
await db.flush()

# Defer session creation to the user's next message so its title
# auto-names from that message (via find_or_create_channel_session)
# instead of being locked to a hard-coded placeholder.
return {
"action": "new_session",
"message": "已开启新对话,之前的上下文已清除。",
}

return {"action": "unknown", "message": f"未知命令: {cmd}"}
194 changes: 194 additions & 0 deletions backend/tests/test_channel_commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Unit tests for app.services.channel_commands.

Covers:
1. `handle_channel_command()` scopes its archive lookup by source_channel
(no cross-channel collision on shared external_conv_id).
2. It archives the matching old session by renaming its external_conv_id.
3. It defers new-session creation to the next user message so the session
title auto-names from the first message — rather than being locked to
a hard-coded 'New Session' placeholder.
"""

from __future__ import annotations

import uuid
from types import SimpleNamespace
from typing import Any

import pytest

from app.services import channel_commands


class _ExecutedQuery:
"""Captures WHERE-clause state for an executed SQLAlchemy select()."""

def __init__(self, statement: Any) -> None:
self.statement = statement
# Extract column names referenced by equality comparisons in the WHERE
# clause. This lets tests assert that source_channel is part of the
# filter without depending on clause order.
self.filter_columns: set[str] = set()
self.filter_values: dict[str, Any] = {}
whereclause = getattr(statement, "whereclause", None)
self._collect(whereclause)

def _collect(self, clause: Any) -> None:
if clause is None:
return
# BooleanClauseList (AND/OR) has .clauses
sub_clauses = getattr(clause, "clauses", None)
if sub_clauses:
for c in sub_clauses:
self._collect(c)
return
left = getattr(clause, "left", None)
right = getattr(clause, "right", None)
if left is not None:
name = getattr(left, "key", None) or getattr(left, "name", None)
if name:
self.filter_columns.add(name)
if right is not None:
# BindParameter exposes .value
val = getattr(right, "value", None)
if val is not None:
self.filter_values[name] = val


class _FakeResult:
def __init__(self, value: Any) -> None:
self._value = value

def scalar_one_or_none(self) -> Any:
return self._value


class FakeDB:
"""Minimal AsyncSession stub that records executes / adds / flush / commit."""

def __init__(self, lookup_result: Any = None) -> None:
self._lookup_result = lookup_result
self.executed: list[_ExecutedQuery] = []
self.added: list[Any] = []
self.flushes = 0

async def execute(self, statement, _params=None): # noqa: D401
self.executed.append(_ExecutedQuery(statement))
return _FakeResult(self._lookup_result)

def add(self, obj) -> None:
# Assign an id so handle_channel_command can stringify it.
if getattr(obj, "id", None) is None:
try:
obj.id = uuid.uuid4()
except Exception:
pass
self.added.append(obj)

async def flush(self) -> None:
self.flushes += 1


@pytest.mark.asyncio
async def test_handle_channel_command_scopes_lookup_by_source_channel():
"""Regression test for review concern #2.

The session-archive lookup must include `source_channel` in its WHERE
clause so a /new command on one channel never archives a same-external-id
session on another channel.
"""
agent_id = uuid.uuid4()
user_id = uuid.uuid4()

db = FakeDB(lookup_result=None) # no pre-existing session

result = await channel_commands.handle_channel_command(
db=db,
command="/new",
agent_id=agent_id,
user_id=user_id,
external_conv_id="feishu_p2p_ou_xxx",
source_channel="feishu",
)

assert result["action"] == "new_session"
# Exactly one SELECT for the old-session lookup.
assert len(db.executed) == 1
q = db.executed[0]
# The WHERE clause must filter on all three columns.
assert "agent_id" in q.filter_columns
assert "external_conv_id" in q.filter_columns
assert "source_channel" in q.filter_columns, (
"handle_channel_command() must scope the archive lookup by source_channel "
"so it never archives a cross-channel session with a colliding external_conv_id"
)
assert q.filter_values.get("source_channel") == "feishu"


@pytest.mark.asyncio
async def test_handle_channel_command_does_not_preempt_session_creation():
"""New sessions must not be pre-created by /new — they're built by the
next user message via find_or_create_channel_session, so the first real
message content becomes the session title instead of "New Session".
"""
agent_id = uuid.uuid4()
user_id = uuid.uuid4()

# Simulate no pre-existing session (lookup miss).
db = FakeDB(lookup_result=None)

result = await channel_commands.handle_channel_command(
db=db,
command="/new",
agent_id=agent_id,
user_id=user_id,
external_conv_id="shared_conv_id_xxx",
source_channel="feishu",
)

assert result["action"] == "new_session"
# Nothing should be added to the DB — creation is deferred.
assert db.added == []
# And the response must not leak a session_id (there is no session yet).
assert "session_id" not in result


@pytest.mark.asyncio
async def test_handle_channel_command_archives_old_session():
"""When a session for the same (agent_id, external_conv_id, source_channel)
exists, /reset must archive it by renaming its external_conv_id, so the
next user message creates a fresh one.
"""
agent_id = uuid.uuid4()
user_id = uuid.uuid4()

# Existing session to be archived.
old_session = SimpleNamespace(external_conv_id="feishu_p2p_ou_zzz")
db = FakeDB(lookup_result=old_session)

result = await channel_commands.handle_channel_command(
db=db,
command="/reset",
agent_id=agent_id,
user_id=user_id,
external_conv_id="feishu_p2p_ou_zzz",
source_channel="feishu",
)

assert result["action"] == "new_session"
# Old session got its external_conv_id renamed to the archived form.
assert old_session.external_conv_id.startswith("feishu_p2p_ou_zzz__archived_")
# No new session pre-created (deferred to next user message).
assert db.added == []


@pytest.mark.asyncio
async def test_is_channel_command_recognises_slash_commands():
assert channel_commands.is_channel_command("/new") is True
assert channel_commands.is_channel_command("/reset") is True
assert channel_commands.is_channel_command(" /NEW ") is True
assert channel_commands.is_channel_command("/RESET") is True
# Non-commands
assert channel_commands.is_channel_command("hello") is False
assert channel_commands.is_channel_command("/newish") is False
assert channel_commands.is_channel_command("") is False