Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
cancel-in-progress: true
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
os: [ubuntu-latest, windows-latest, macos-latest]
fail-fast: false
env:
Expand Down
38 changes: 32 additions & 6 deletions nonebot/adapters/telegram/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,16 @@
import anyio
from pydantic.main import BaseModel
from pydantic.json import pydantic_encoder
from nonebot.utils import escape_tag, logger_wrapper
from nonebot.drivers import URL, Driver, Request, Response, HTTPServerSetup
from nonebot.utils import UNSET, escape_tag, logger_wrapper
from nonebot.drivers import (
URL,
DEFAULT_TIMEOUT,
Driver,
Request,
Timeout,
Response,
HTTPServerSetup,
)

from nonebot.adapters import Adapter as BaseAdapter

Expand All @@ -31,7 +39,7 @@ class Adapter(BaseAdapter):
def __init__(self, driver: Driver, **kwargs: Any):
super().__init__(driver, **kwargs)
self.adapter_config = AdapterConfig(**self.config.model_dump())
self.tasks: list[asyncio.Task] = []
self.tasks: set[asyncio.Task] = set()
self.setup()

@classmethod
Expand Down Expand Up @@ -90,11 +98,13 @@ async def poll(self, bot: Bot):
if update_offset is not None:
for update in updates:
update_offset = update.update_id + 1
asyncio.create_task(
task = asyncio.create_task(
self.__handle_update(
bot, update.model_dump(by_alias=True, exclude_none=True)
)
)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
elif updates:
update_offset = updates[0].update_id
except Exception as e:
Expand All @@ -104,7 +114,9 @@ async def poll(self, bot: Bot):
def setup_polling(self, bot: Bot):
@self.on_ready
async def _():
self.tasks.append(asyncio.create_task(self.poll(bot)))
task = asyncio.create_task(self.poll(bot))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)

@self.driver.on_shutdown
async def _():
Expand All @@ -120,7 +132,9 @@ async def handle_http(self, request: Request) -> Response:
if bot.secret_token == token:
if request.content:
update: dict = json.loads(request.content)
asyncio.create_task(self.__handle_update(bot, update))
task = asyncio.create_task(self.__handle_update(bot, update))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
return Response(204)
return Response(401)

Expand Down Expand Up @@ -152,6 +166,17 @@ async def _call_api(self, bot: Bot, api: str, **data) -> Any:
s.capitalize() for s in api.split("_")[1:]
)
data = _escape_none(data)
request_timeout = UNSET
if api == "getUpdates":
timeout = data.get("timeout")
if not isinstance(timeout, bool) and isinstance(timeout, (int, float)):
# Telegram timeout is server-side long polling; the HTTP read
# timeout must be slightly longer.
request_timeout = Timeout(
total=DEFAULT_TIMEOUT.total,
connect=DEFAULT_TIMEOUT.connect,
read=float(timeout) + 5,
)

# 分离文件到 files
files: dict[str, tuple[str, bytes]] = {}
Expand Down Expand Up @@ -233,6 +258,7 @@ async def process_input_file(file: Union[InputFile, str]) -> Optional[str]:
data=data if files else None,
json=data if not files else None,
files=files, # type: ignore
timeout=request_timeout,
proxy=self.adapter_config.proxy,
)
try:
Expand Down
2 changes: 1 addition & 1 deletion nonebot/adapters/telegram/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def call_api(self, api: str, *args: Any, **kargs: Any) -> Any:
)
return await super().call_api(api, **kargs)

def __getattribute__(self, __name: str) -> Any:
def __getattribute__(self, __name: str, /) -> Any:
if not __name.startswith("__") and hasattr(API, __name):
return partial(self.call_api, __name)
return object.__getattribute__(self, __name)
Expand Down
15 changes: 11 additions & 4 deletions nonebot/adapters/telegram/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class EventWithChat(Protocol):


class Event(BaseEvent):
telegram_model: Update = Field(default=None)
telegram_model: Optional[Update] = Field(default=None)

@classmethod
def __parse_event(cls, obj: dict) -> "Event":
Expand Down Expand Up @@ -362,7 +362,7 @@ def get_event_description(self) -> str:


class GroupEditedMessageEvent(EditedMessageEvent):
from_: User = Field(default=None, alias="from")
from_: Optional[User] = Field(default=None, alias="from")
sender_chat: Optional[Chat] = None

@classmethod
Expand All @@ -380,14 +380,17 @@ def get_event_name(self) -> str:

@override
def get_user_id(self) -> str:
assert self.from_ is not None
return str(self.from_.id)

@override
def get_session_id(self) -> str:
assert self.from_ is not None
return f"group_{self.chat.id}_{self.from_.id}"

@override
def get_event_description(self) -> str:
assert self.from_ is not None
return (
f"EditedMessage {self.message_id} from {self.from_.id}"
f"@[Chat {self.chat.id}]: {self.get_message_description()}"
Expand All @@ -403,10 +406,12 @@ def get_event_name(self) -> str:

@override
def get_session_id(self) -> str:
assert self.from_ is not None
return f"group_{self.chat.id}_thread{self.message_thread_id}_{self.from_.id}"

@override
def get_event_description(self) -> str:
assert self.from_ is not None
return (
f"EditedMessage {self.message_id} from {self.from_.id}@[Chat {self.chat.id}"
f" Thread {self.message_thread_id}]: {self.get_message_description()}"
Expand Down Expand Up @@ -470,7 +475,7 @@ class PinnedMessageEvent(NoticeEvent):
sender_chat: Optional[Chat] = None
chat: Chat
date: int
pinned_message: MessageEvent = Field(default=None)
pinned_message: Optional[MessageEvent] = Field(default=None)

@classmethod
def __parse_event(cls, obj: dict):
Expand All @@ -485,10 +490,12 @@ def get_event_name(self) -> str:

@override
def get_message(self) -> Message:
assert self.pinned_message is not None
return self.pinned_message.get_message()

@override
def get_event_description(self) -> str:
assert self.pinned_message is not None
return (
f"PinnedMessage {self.pinned_message.message_id} "
f"@[Chat {self.pinned_message.chat.id}]: {self.get_message_description()}"
Expand Down Expand Up @@ -705,7 +712,7 @@ def get_event_description(self) -> str:


class CallbackQueryEvent(InlineEvent, CallbackQuery):
chat: Chat = Field(default=None)
chat: Optional[Chat] = Field(default=None)

@override
def get_event_name(self) -> str:
Expand Down
Loading
Loading