diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 60f222912387..5a383df51f2d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -1,7 +1,7 @@ import asyncio import uuid from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Sequence +from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Optional, Sequence from autogen_core import ( AgentId, @@ -75,6 +75,7 @@ def __init__( runtime: AgentRuntime | None = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, emit_team_events: bool = False, + source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None, ): self._name = name self._description = description @@ -130,6 +131,7 @@ def __init__( self._output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination] = ( asyncio.Queue() ) + self._source_verifier = source_verifier # Create a runtime for the team. if runtime is not None: @@ -173,6 +175,7 @@ def _create_group_chat_manager_factory( termination_condition: TerminationCondition | None, max_turns: int | None, message_factory: MessageFactory, + source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None, ) -> Callable[[], SequentialRoutedAgent]: ... def _create_participant_factory( @@ -224,6 +227,7 @@ async def _init(self, runtime: AgentRuntime) -> None: termination_condition=self._termination_condition, max_turns=self._max_turns, message_factory=self._message_factory, + source_verifier=self._source_verifier, ), ) # Add subscriptions for the group chat manager. diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index b0a0c1d55fc4..74a45d20105c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -1,6 +1,6 @@ import asyncio from abc import ABC, abstractmethod -from typing import Any, List, Sequence +from typing import Any, Callable, List, Optional, Sequence from autogen_core import CancellationToken, DefaultTopicId, MessageContext, event, rpc @@ -47,7 +47,18 @@ def __init__( max_turns: int | None, message_factory: MessageFactory, emit_team_events: bool = False, + source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None, ): + """ + Args: + source_verifier: Optional callback to verify the source of agent responses. + Called when a GroupChatAgentResponse or GroupChatTeamResponse is received, + before processing the response. The callback receives the claimed source name + and the message thread. It may return None to accept the message, or a + string describing the reason for rejection, which will cause the group chat + to terminate with an error. This is useful for identity verification, + content pinning, or policy enforcement. Defaults to None (no verification). + """ super().__init__( description="Group chat manager", sequential_message_types=[ @@ -82,6 +93,7 @@ def __init__( self._message_factory = message_factory self._emit_team_events = emit_team_events self._active_speakers: List[str] = [] + self._source_verifier = source_verifier @rpc async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None: @@ -136,6 +148,16 @@ async def handle_agent_response( self, message: GroupChatAgentResponse | GroupChatTeamResponse, ctx: MessageContext ) -> None: try: + # Verify the source of the agent response if a verifier is registered. + if self._source_verifier is not None: + rejection_reason = self._source_verifier(message.name, self._message_thread) + if rejection_reason is not None: + error = SerializableException( + error_type="SourceVerificationError", + error_message=rejection_reason, + ) + await self._signal_termination_with_error(error) + return # Construct the detla from the agent response. delta: List[BaseAgentEvent | BaseChatMessage] = [] if isinstance(message, GroupChatAgentResponse): diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py index d77b42dd17f2..f34597d91c9a 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py @@ -1,6 +1,6 @@ import asyncio from collections import Counter, deque -from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set, Union +from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Optional, Sequence, Set, Union from autogen_core import AgentRuntime, Component, ComponentModel from pydantic import BaseModel, Field, model_validator @@ -322,6 +322,7 @@ def __init__( max_turns: int | None, message_factory: MessageFactory, graph: DiGraph, + source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None, ) -> None: """Initialize the graph-based execution manager.""" super().__init__( @@ -335,6 +336,7 @@ def __init__( termination_condition=termination_condition, max_turns=max_turns, message_factory=message_factory, + source_verifier=source_verifier, ) graph.graph_validate() if graph.get_has_cycles() and self._termination_condition is None and self._max_turns is None: @@ -826,6 +828,7 @@ def _create_group_chat_manager_factory( termination_condition: TerminationCondition | None, max_turns: int | None, message_factory: MessageFactory, + source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None, ) -> Callable[[], GraphFlowManager]: """Creates the factory method for initializing the DiGraph-based chat manager.""" @@ -842,6 +845,7 @@ def _factory() -> GraphFlowManager: max_turns=max_turns, message_factory=message_factory, graph=self._graph, + source_verifier=source_verifier, ) return _factory diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py index 3f529f0c4474..3e2b3502db5b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Callable, List, Mapping, Sequence +from typing import Any, Callable, List, Mapping, Optional, Sequence from autogen_core import AgentRuntime, Component, ComponentModel from pydantic import BaseModel @@ -29,6 +29,7 @@ def __init__( max_turns: int | None, message_factory: MessageFactory, emit_team_events: bool, + source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None, ) -> None: super().__init__( name, @@ -42,6 +43,7 @@ def __init__( max_turns, message_factory, emit_team_events, + source_verifier, ) self._next_speaker_index = 0 @@ -276,6 +278,7 @@ def _create_group_chat_manager_factory( termination_condition: TerminationCondition | None, max_turns: int | None, message_factory: MessageFactory, + source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None, ) -> Callable[[], RoundRobinGroupChatManager]: def _factory() -> RoundRobinGroupChatManager: return RoundRobinGroupChatManager( @@ -290,6 +293,7 @@ def _factory() -> RoundRobinGroupChatManager: max_turns, message_factory, self._emit_team_events, + source_verifier, ) return _factory diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 480dc6b71641..afdc4e05444c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -72,6 +72,7 @@ def __init__( emit_team_events: bool, model_context: ChatCompletionContext | None, model_client_streaming: bool = False, + source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None, ) -> None: super().__init__( name, @@ -85,6 +86,7 @@ def __init__( max_turns, message_factory, emit_team_events, + source_verifier, ) self._model_client = model_client self._selector_prompt = selector_prompt @@ -620,6 +622,7 @@ def __init__( emit_team_events: bool = False, model_client_streaming: bool = False, model_context: ChatCompletionContext | None = None, + source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None, ): super().__init__( name=name or self.DEFAULT_NAME, @@ -632,6 +635,7 @@ def __init__( runtime=runtime, custom_message_types=custom_message_types, emit_team_events=emit_team_events, + source_verifier=source_verifier, ) # Validate the participants. if len(participants) < 2: @@ -657,6 +661,7 @@ def _create_group_chat_manager_factory( termination_condition: TerminationCondition | None, max_turns: int | None, message_factory: MessageFactory, + source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None, ) -> Callable[[], BaseGroupChatManager]: return lambda: SelectorGroupChatManager( name, @@ -678,6 +683,7 @@ def _create_group_chat_manager_factory( self._emit_team_events, self._model_context, self._model_client_streaming, + source_verifier, ) def _to_config(self) -> SelectorGroupChatConfig: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py index c9b495083939..8b03c0b015e5 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Callable, List, Mapping, Sequence +from typing import Any, Callable, List, Mapping, Optional, Sequence from autogen_core import AgentRuntime, Component, ComponentModel from pydantic import BaseModel @@ -28,6 +28,7 @@ def __init__( max_turns: int | None, message_factory: MessageFactory, emit_team_events: bool, + source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None, ) -> None: super().__init__( name, @@ -41,6 +42,7 @@ def __init__( max_turns, message_factory, emit_team_events, + source_verifier, ) self._current_speaker = self._participant_names[0] @@ -241,6 +243,7 @@ def __init__( runtime: AgentRuntime | None = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, emit_team_events: bool = False, + source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None, ) -> None: for participant in participants: if not isinstance(participant, ChatAgent): @@ -256,6 +259,7 @@ def __init__( runtime=runtime, custom_message_types=custom_message_types, emit_team_events=emit_team_events, + source_verifier=source_verifier, ) # The first participant must be able to produce handoff messages. first_participant = self._participants[0] @@ -275,6 +279,7 @@ def _create_group_chat_manager_factory( termination_condition: TerminationCondition | None, max_turns: int | None, message_factory: MessageFactory, + source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None, ) -> Callable[[], SwarmGroupChatManager]: def _factory() -> SwarmGroupChatManager: return SwarmGroupChatManager( @@ -289,6 +294,7 @@ def _factory() -> SwarmGroupChatManager: max_turns, message_factory, self._emit_team_events, + source_verifier, ) return _factory