Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import uuid
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Sequence
from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Optional, Sequence

from autogen_core import (
AgentId,
Expand Down Expand Up @@ -75,6 +75,7 @@ def __init__(
runtime: AgentRuntime | None = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None,
):
self._name = name
self._description = description
Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(
self._output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination] = (
asyncio.Queue()
)
self._source_verifier = source_verifier

# Create a runtime for the team.
if runtime is not None:
Expand Down Expand Up @@ -173,6 +175,7 @@ def _create_group_chat_manager_factory(
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None,
) -> Callable[[], SequentialRoutedAgent]: ...

def _create_participant_factory(
Expand Down Expand Up @@ -224,6 +227,7 @@ async def _init(self, runtime: AgentRuntime) -> None:
termination_condition=self._termination_condition,
max_turns=self._max_turns,
message_factory=self._message_factory,
source_verifier=self._source_verifier,
),
)
# Add subscriptions for the group chat manager.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Any, List, Sequence
from typing import Any, Callable, List, Optional, Sequence

from autogen_core import CancellationToken, DefaultTopicId, MessageContext, event, rpc

Expand Down Expand Up @@ -47,7 +47,18 @@ def __init__(
max_turns: int | None,
message_factory: MessageFactory,
emit_team_events: bool = False,
source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None,
):
"""
Args:
source_verifier: Optional callback to verify the source of agent responses.
Called when a GroupChatAgentResponse or GroupChatTeamResponse is received,
before processing the response. The callback receives the claimed source name
and the message thread. It may return None to accept the message, or a
string describing the reason for rejection, which will cause the group chat
to terminate with an error. This is useful for identity verification,
content pinning, or policy enforcement. Defaults to None (no verification).
"""
super().__init__(
description="Group chat manager",
sequential_message_types=[
Expand Down Expand Up @@ -82,6 +93,7 @@ def __init__(
self._message_factory = message_factory
self._emit_team_events = emit_team_events
self._active_speakers: List[str] = []
self._source_verifier = source_verifier

@rpc
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
Expand Down Expand Up @@ -136,6 +148,16 @@ async def handle_agent_response(
self, message: GroupChatAgentResponse | GroupChatTeamResponse, ctx: MessageContext
) -> None:
try:
# Verify the source of the agent response if a verifier is registered.
if self._source_verifier is not None:
rejection_reason = self._source_verifier(message.name, self._message_thread)
if rejection_reason is not None:
error = SerializableException(
error_type="SourceVerificationError",
error_message=rejection_reason,
)
await self._signal_termination_with_error(error)
return
# Construct the detla from the agent response.
delta: List[BaseAgentEvent | BaseChatMessage] = []
if isinstance(message, GroupChatAgentResponse):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from collections import Counter, deque
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set, Union
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Optional, Sequence, Set, Union

from autogen_core import AgentRuntime, Component, ComponentModel
from pydantic import BaseModel, Field, model_validator
Expand Down Expand Up @@ -322,6 +322,7 @@ def __init__(
max_turns: int | None,
message_factory: MessageFactory,
graph: DiGraph,
source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None,
) -> None:
"""Initialize the graph-based execution manager."""
super().__init__(
Expand All @@ -335,6 +336,7 @@ def __init__(
termination_condition=termination_condition,
max_turns=max_turns,
message_factory=message_factory,
source_verifier=source_verifier,
)
graph.graph_validate()
if graph.get_has_cycles() and self._termination_condition is None and self._max_turns is None:
Expand Down Expand Up @@ -826,6 +828,7 @@ def _create_group_chat_manager_factory(
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None,
) -> Callable[[], GraphFlowManager]:
"""Creates the factory method for initializing the DiGraph-based chat manager."""

Expand All @@ -842,6 +845,7 @@ def _factory() -> GraphFlowManager:
max_turns=max_turns,
message_factory=message_factory,
graph=self._graph,
source_verifier=source_verifier,
)

return _factory
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Any, Callable, List, Mapping, Sequence
from typing import Any, Callable, List, Mapping, Optional, Sequence

from autogen_core import AgentRuntime, Component, ComponentModel
from pydantic import BaseModel
Expand Down Expand Up @@ -29,6 +29,7 @@ def __init__(
max_turns: int | None,
message_factory: MessageFactory,
emit_team_events: bool,
source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None,
) -> None:
super().__init__(
name,
Expand All @@ -42,6 +43,7 @@ def __init__(
max_turns,
message_factory,
emit_team_events,
source_verifier,
)
self._next_speaker_index = 0

Expand Down Expand Up @@ -276,6 +278,7 @@ def _create_group_chat_manager_factory(
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None,
) -> Callable[[], RoundRobinGroupChatManager]:
def _factory() -> RoundRobinGroupChatManager:
return RoundRobinGroupChatManager(
Expand All @@ -290,6 +293,7 @@ def _factory() -> RoundRobinGroupChatManager:
max_turns,
message_factory,
self._emit_team_events,
source_verifier,
)

return _factory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
emit_team_events: bool,
model_context: ChatCompletionContext | None,
model_client_streaming: bool = False,
source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None,
) -> None:
super().__init__(
name,
Expand All @@ -85,6 +86,7 @@ def __init__(
max_turns,
message_factory,
emit_team_events,
source_verifier,
)
self._model_client = model_client
self._selector_prompt = selector_prompt
Expand Down Expand Up @@ -620,6 +622,7 @@ def __init__(
emit_team_events: bool = False,
model_client_streaming: bool = False,
model_context: ChatCompletionContext | None = None,
source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None,
):
super().__init__(
name=name or self.DEFAULT_NAME,
Expand All @@ -632,6 +635,7 @@ def __init__(
runtime=runtime,
custom_message_types=custom_message_types,
emit_team_events=emit_team_events,
source_verifier=source_verifier,
)
# Validate the participants.
if len(participants) < 2:
Expand All @@ -657,6 +661,7 @@ def _create_group_chat_manager_factory(
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None,
) -> Callable[[], BaseGroupChatManager]:
return lambda: SelectorGroupChatManager(
name,
Expand All @@ -678,6 +683,7 @@ def _create_group_chat_manager_factory(
self._emit_team_events,
self._model_context,
self._model_client_streaming,
source_verifier,
)

def _to_config(self) -> SelectorGroupChatConfig:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Any, Callable, List, Mapping, Sequence
from typing import Any, Callable, List, Mapping, Optional, Sequence

from autogen_core import AgentRuntime, Component, ComponentModel
from pydantic import BaseModel
Expand Down Expand Up @@ -28,6 +28,7 @@ def __init__(
max_turns: int | None,
message_factory: MessageFactory,
emit_team_events: bool,
source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None,
) -> None:
super().__init__(
name,
Expand All @@ -41,6 +42,7 @@ def __init__(
max_turns,
message_factory,
emit_team_events,
source_verifier,
)
self._current_speaker = self._participant_names[0]

Expand Down Expand Up @@ -241,6 +243,7 @@ def __init__(
runtime: AgentRuntime | None = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None = None,
) -> None:
for participant in participants:
if not isinstance(participant, ChatAgent):
Expand All @@ -256,6 +259,7 @@ def __init__(
runtime=runtime,
custom_message_types=custom_message_types,
emit_team_events=emit_team_events,
source_verifier=source_verifier,
)
# The first participant must be able to produce handoff messages.
first_participant = self._participants[0]
Expand All @@ -275,6 +279,7 @@ def _create_group_chat_manager_factory(
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
source_verifier: Callable[[str, Sequence[BaseAgentEvent | BaseChatMessage]], Optional[str]] | None,
) -> Callable[[], SwarmGroupChatManager]:
def _factory() -> SwarmGroupChatManager:
return SwarmGroupChatManager(
Expand All @@ -289,6 +294,7 @@ def _factory() -> SwarmGroupChatManager:
max_turns,
message_factory,
self._emit_team_events,
source_verifier,
)

return _factory
Expand Down