From b6d4878c51a25b566c74e217731f3793af7ea01a Mon Sep 17 00:00:00 2001 From: rking32 Date: Fri, 15 Jan 2021 04:42:30 +0530 Subject: [PATCH] add flood control and wildcard support to upload and closes 237 --- userge/core/client.py | 26 +++--- userge/core/database.py | 4 + userge/core/ext/pool.py | 54 ++++-------- userge/core/ext/raw_client.py | 88 ++++++++++++++++++- .../core/methods/decorators/raw_decorator.py | 18 ++-- userge/core/types/bound/message.py | 6 +- userge/core/types/new/channel_logger.py | 4 +- userge/core/types/new/conversation.py | 2 +- userge/plugins/misc/upload.py | 29 +++--- 9 files changed, 154 insertions(+), 77 deletions(-) diff --git a/userge/core/client.py b/userge/core/client.py index 675d818ba..430dd5f17 100644 --- a/userge/core/client.py +++ b/userge/core/client.py @@ -25,6 +25,7 @@ from userge.plugins import get_all_plugins from .methods import Methods from .ext import RawClient, pool +from .database import _close_db _LOG = logging.getLogger(__name__) _LOG_STR = "<<>>" @@ -103,7 +104,7 @@ async def reload_plugins(self) -> int: return len(reloaded) -class _UsergeBot(_AbstractUserge): +class UsergeBot(_AbstractUserge): """ UsergeBot, the bot """ def __init__(self, **kwargs) -> None: _LOG.info(_LOG_STR, "Setting UsergeBot Configs") @@ -131,12 +132,13 @@ def __init__(self, **kwargs) -> None: kwargs['bot_token'] = Config.BOT_TOKEN if Config.HU_STRING_SESSION and Config.BOT_TOKEN: RawClient.DUAL_MODE = True - kwargs['bot'] = _UsergeBot(bot=self, **kwargs) + kwargs['bot'] = UsergeBot(bot=self, **kwargs) kwargs['session_name'] = Config.HU_STRING_SESSION or ":memory:" super().__init__(**kwargs) + self.executor.shutdown() @property - def bot(self) -> Union['_UsergeBot', 'Userge']: + def bot(self) -> Union['UsergeBot', 'Userge']: """ returns usergebot """ if self._bot is None: if Config.BOT_TOKEN: @@ -146,7 +148,7 @@ def bot(self) -> Union['_UsergeBot', 'Userge']: async def start(self) -> None: """ start client and bot """ - pool._start() # pylint: disable=protected-access + self.executor = pool._get() # pylint: disable=protected-access _LOG.info(_LOG_STR, "Starting Userge") await super().start() if self._bot is not None: @@ -161,7 +163,8 @@ async def stop(self) -> None: # pylint: disable=arguments-differ await self._bot.stop() _LOG.info(_LOG_STR, "Stopping Userge") await super().stop() - await pool._stop() # pylint: disable=protected-access + _close_db() + pool._stop() # pylint: disable=protected-access def begin(self, coro: Optional[Awaitable[Any]] = None) -> None: """ start userge """ @@ -174,11 +177,14 @@ async def _finalize() -> None: task.cancel() if self.is_initialized: await self.stop() - # pylint: disable=expression-not-assigned - [t.cancel() for t in asyncio.all_tasks() if t is not asyncio.current_task()] - await self.loop.shutdown_asyncgens() - self.loop.stop() - _LOG.info(_LOG_STR, "Loop Stopped !") + else: + _close_db() + pool._stop() # pylint: disable=protected-access + # pylint: disable=expression-not-assigned + [t.cancel() for t in asyncio.all_tasks() if t is not asyncio.current_task()] + await self.loop.shutdown_asyncgens() + self.loop.stop() + _LOG.info(_LOG_STR, "Loop Stopped !") async def _shutdown(sig: signal.Signals) -> None: _LOG.info(_LOG_STR, f"Received Stop Signal [{sig.name}], Exiting Userge ...") diff --git a/userge/core/database.py b/userge/core/database.py index b8d05cc3a..a16f66600 100644 --- a/userge/core/database.py +++ b/userge/core/database.py @@ -44,4 +44,8 @@ def get_collection(name: str) -> AgnosticCollection: return _DATABASE[name] +def _close_db() -> None: + _MGCLIENT.close() + + logbot.del_last_msg() diff --git a/userge/core/ext/pool.py b/userge/core/ext/pool.py index 1abfb2757..4c3a786d8 100644 --- a/userge/core/ext/pool.py +++ b/userge/core/ext/pool.py @@ -8,31 +8,24 @@ # # All rights reserved. -__all__ = ['submit_task', 'submit_thread', 'run_in_thread'] +__all__ = ['submit_thread', 'run_in_thread'] import asyncio -from typing import Any, Callable, List +from typing import Any, Callable from concurrent.futures import ThreadPoolExecutor, Future from functools import wraps, partial -from userge import logging, Config +from motor.frameworks.asyncio import _EXECUTOR # pylint: disable=protected-access + +from userge import logging -_WORKERS = Config.WORKERS -_THREAD_POOL: ThreadPoolExecutor -_ASYNC_QUEUE = asyncio.Queue() -_TASKS: List[asyncio.Task] = [] _LOG = logging.getLogger(__name__) _LOG_STR = "<<>>" -def submit_task(task: asyncio.coroutines.CoroWrapper) -> None: - """ submit task to task pool """ - _ASYNC_QUEUE.put_nowait(task) - - def submit_thread(func: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Future: """ submit thread to thread pool """ - return _THREAD_POOL.submit(func, *args, **kwargs) + return _EXECUTOR.submit(func, *args, **kwargs) def run_in_thread(func: Callable[[Any], Any]) -> Callable[[Any], Any]: @@ -40,34 +33,19 @@ def run_in_thread(func: Callable[[Any], Any]) -> Callable[[Any], Any]: @wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: loop = asyncio.get_running_loop() - return await loop.run_in_executor(_THREAD_POOL, partial(func, *args, **kwargs)) + return await loop.run_in_executor(_EXECUTOR, partial(func, *args, **kwargs)) return wrapper -def _start(): - global _THREAD_POOL # pylint: disable=global-statement - _THREAD_POOL = ThreadPoolExecutor(_WORKERS) +def _get() -> ThreadPoolExecutor: + return _EXECUTOR + - async def _task_worker(): - while True: - coro = await _ASYNC_QUEUE.get() - if coro is None: - break - await coro - loop = asyncio.get_event_loop() - for _ in range(_WORKERS): - _TASKS.append(loop.create_task(_task_worker())) - _LOG.info(_LOG_STR, f"Started Pool : {_WORKERS} Workers") +def _stop(): + _EXECUTOR.shutdown() + # pylint: disable=protected-access + _LOG.info(_LOG_STR, f"Stopped Pool : {_EXECUTOR._max_workers} Workers") -async def _stop(): - _THREAD_POOL.shutdown() - for _ in range(_WORKERS): - _ASYNC_QUEUE.put_nowait(None) - for task in _TASKS: - try: - await asyncio.wait_for(task, timeout=0.3) - except asyncio.TimeoutError: - task.cancel() - _TASKS.clear() - _LOG.info(_LOG_STR, f"Stopped Pool : {_WORKERS} Workers") +# pylint: disable=protected-access +_LOG.info(_LOG_STR, f"Started Pool : {_EXECUTOR._max_workers} Workers") diff --git a/userge/core/ext/raw_client.py b/userge/core/ext/raw_client.py index ced0287cf..494c9d86e 100644 --- a/userge/core/ext/raw_client.py +++ b/userge/core/ext/raw_client.py @@ -10,21 +10,105 @@ __all__ = ['RawClient'] +import asyncio import time -from typing import Optional +from typing import Optional, Dict +import pyrogram.raw.functions as funcs +import pyrogram.raw.types as types from pyrogram import Client +from pyrogram.session import Session +from pyrogram.raw.core import TLObject import userge # pylint: disable=unused-import +_LOG = userge.logging.getLogger(__name__) +_LOG_STR = "<<>>" + class RawClient(Client): """ userge raw client """ DUAL_MODE = False LAST_OUTGOING_TIME = time.time() - def __init__(self, bot: Optional['userge.core.client._UsergeBot'] = None, **kwargs) -> None: + REQ_LOGS: Dict[int, 'ChatReq'] = {} + DELAY_BET_MSG_REQ = 1 + MSG_REQ_PER_MIN = 20 + REQ_LOCK = asyncio.Lock() + + def __init__(self, bot: Optional['userge.core.client.UsergeBot'] = None, **kwargs) -> None: self._bot = bot super().__init__(**kwargs) self._channel = userge.core.types.new.ChannelLogger(self, "CORE") userge.core.types.new.Conversation.init(self) + + async def send(self, data: TLObject, retries: int = Session.MAX_RETRIES, + timeout: float = Session.WAIT_TIMEOUT, sleep_threshold: float = None): + key = 0 + if isinstance(data, (funcs.messages.SendMessage, + funcs.messages.EditMessage, + funcs.messages.ForwardMessages)): + if isinstance(data, funcs.messages.ForwardMessages): + tmp = data.to_peer + else: + tmp = data.peer + if isinstance(tmp, (types.InputPeerChannel, types.InputPeerChannelFromMessage)): + key = int(tmp.channel_id) + elif isinstance(tmp, types.InputPeerChat): + key = int(tmp.chat_id) + elif isinstance(tmp, (types.InputPeerUser, types.InputPeerUserFromMessage)): + key = int(tmp.user_id) + elif isinstance(data, funcs.channels.DeleteMessages): + if isinstance(data.channel, (types.InputChannel, types.InputChannelFromMessage)): + key = int(data.channel.channel_id) + if key: + async def slp(to_sl: float) -> None: + if to_sl > 0.1: + if to_sl > 1: + _LOG.info(_LOG_STR, to_sl, key) + else: + _LOG.debug(_LOG_STR, to_sl, key) + await asyncio.sleep(to_sl) + async with self.REQ_LOCK: + if key in self.REQ_LOGS: + chat_req = self.REQ_LOGS[key] + else: + chat_req = self.REQ_LOGS[key] = ChatReq() + diff = chat_req.small_diff + if 0 < diff < self.DELAY_BET_MSG_REQ: + await slp(1 - diff) + diff = chat_req.big_diff + if diff >= 60: + chat_req.reset() + elif chat_req.count > self.MSG_REQ_PER_MIN: + await slp(60 - diff) + chat_req.reset() + else: + chat_req.update() + return await super().send(data, retries, timeout, sleep_threshold) + + +class ChatReq: + def __init__(self) -> None: + self._first = self._last = time.time() + self._count = 0 + + @property + def small_diff(self) -> float: + return time.time() - self._last + + @property + def big_diff(self) -> float: + return time.time() - self._first + + @property + def count(self) -> float: + return self._count + + def reset(self) -> None: + self._first = self._last = time.time() + self._count = 1 + + def update(self) -> None: + self._last = time.time() + self._count += 1 diff --git a/userge/core/methods/decorators/raw_decorator.py b/userge/core/methods/decorators/raw_decorator.py index 5b062b11a..6d48484ae 100644 --- a/userge/core/methods/decorators/raw_decorator.py +++ b/userge/core/methods/decorators/raw_decorator.py @@ -87,7 +87,7 @@ def _clear_cht() -> None: _TASK_1_START_TO = time.time() -async def _init(r_c: Union['_client.Userge', '_client._UsergeBot'], +async def _init(r_c: Union['_client.Userge', '_client.UsergeBot'], r_m: RawMessage) -> None: global _U_ID, _B_ID # pylint: disable=global-statement if r_m.from_user and ( @@ -110,7 +110,7 @@ async def _init(r_c: Union['_client.Userge', '_client._UsergeBot'], _U_ID = (await r_c.ubot.get_me()).id -async def _raise_func(r_c: Union['_client.Userge', '_client._UsergeBot'], +async def _raise_func(r_c: Union['_client.Userge', '_client.UsergeBot'], chat_id: int, message_id: int, text: str) -> None: try: _sent = await r_c.send_message( @@ -123,7 +123,7 @@ async def _raise_func(r_c: Union['_client.Userge', '_client._UsergeBot'], pass -async def _is_admin(r_c: Union['_client.Userge', '_client._UsergeBot'], +async def _is_admin(r_c: Union['_client.Userge', '_client.UsergeBot'], r_m: RawMessage) -> bool: if r_m.chat.type in ("private", "bot"): return False @@ -136,7 +136,7 @@ async def _is_admin(r_c: Union['_client.Userge', '_client._UsergeBot'], return r_m.chat.id in _B_AD_CHT -def _get_chat_member(r_c: Union['_client.Userge', '_client._UsergeBot'], +def _get_chat_member(r_c: Union['_client.Userge', '_client.UsergeBot'], r_m: RawMessage) -> Optional[ChatMember]: if r_m.chat.type in ("private", "bot"): return None @@ -156,7 +156,7 @@ async def _get_lock(key: str) -> asyncio.Lock: return _CH_LKS[key] -async def _bot_is_present(r_c: Union['_client.Userge', '_client._UsergeBot'], +async def _bot_is_present(r_c: Union['_client.Userge', '_client.UsergeBot'], r_m: RawMessage) -> bool: global _TASK_2_START_TO # pylint: disable=global-statement if isinstance(r_c, _client.Userge): @@ -175,7 +175,7 @@ async def _bot_is_present(r_c: Union['_client.Userge', '_client._UsergeBot'], return r_m.chat.id in _B_CMN_CHT -async def _both_are_admins(r_c: Union['_client.Userge', '_client._UsergeBot'], +async def _both_are_admins(r_c: Union['_client.Userge', '_client.UsergeBot'], r_m: RawMessage) -> bool: if not await _bot_is_present(r_c, r_m): return False @@ -183,7 +183,7 @@ async def _both_are_admins(r_c: Union['_client.Userge', '_client._UsergeBot'], async def _both_have_perm(flt: Union['types.raw.Command', 'types.raw.Filter'], - r_c: Union['_client.Userge', '_client._UsergeBot'], + r_c: Union['_client.Userge', '_client.UsergeBot'], r_m: RawMessage) -> bool: if not await _bot_is_present(r_c, r_m): return False @@ -233,7 +233,7 @@ def _build_decorator(self, flt: Union['types.raw.Command', 'types.raw.Filter'], **kwargs: Union[str, bool]) -> 'RawDecorator._PYRORETTYPE': def decorator(func: _PYROFUNC) -> _PYROFUNC: - async def template(r_c: Union['_client.Userge', '_client._UsergeBot'], + async def template(r_c: Union['_client.Userge', '_client.UsergeBot'], r_m: RawMessage) -> None: if Config.DISABLED_ALL and r_m.chat.id != Config.LOG_CHANNEL_ID: return @@ -303,7 +303,7 @@ async def template(r_c: Union['_client.Userge', '_client._UsergeBot'], if cond: if Config.USE_USER_FOR_CLIENT_CHECKS: # pylint: disable=protected-access - if isinstance(r_c, _client._UsergeBot): + if isinstance(r_c, _client.UsergeBot): return elif await _bot_is_present(r_c, r_m): if isinstance(r_c, _client.Userge): diff --git a/userge/core/types/bound/message.py b/userge/core/types/bound/message.py index cac5e71e9..3798e1a5c 100644 --- a/userge/core/types/bound/message.py +++ b/userge/core/types/bound/message.py @@ -33,7 +33,7 @@ class Message(RawMessage): """ Modded Message Class For Userge """ def __init__(self, - client: Union['_client.Userge', '_client._UsergeBot'], + client: Union['_client.Userge', '_client.UsergeBot'], mvars: Dict[str, object], module: str, **kwargs: Union[str, bool]) -> None: self._filtered = False self._filtered_input_str = '' @@ -44,7 +44,7 @@ def __init__(self, super().__init__(client=client, **mvars) @classmethod - def parse(cls, client: Union['_client.Userge', '_client._UsergeBot'], + def parse(cls, client: Union['_client.Userge', '_client.UsergeBot'], message: RawMessage, **kwargs: Union[str, bool]) -> 'Message': """ parse message """ mvars = vars(message) @@ -57,7 +57,7 @@ def parse(cls, client: Union['_client.Userge', '_client._UsergeBot'], return cls(client, mvars, **kwargs) @property - def client(self) -> Union['_client.Userge', '_client._UsergeBot']: + def client(self) -> Union['_client.Userge', '_client.UsergeBot']: """ returns client """ return self._client diff --git a/userge/core/types/new/channel_logger.py b/userge/core/types/new/channel_logger.py index 377e2af69..c3e5f5747 100644 --- a/userge/core/types/new/channel_logger.py +++ b/userge/core/types/new/channel_logger.py @@ -31,7 +31,7 @@ def _gen_string(name: str) -> str: class ChannelLogger: """ Channel logger for Userge """ - def __init__(self, client: Union['_client.Userge', '_client._UsergeBot'], name: str) -> None: + def __init__(self, client: Union['_client.Userge', '_client.UsergeBot'], name: str) -> None: self._id = Config.LOG_CHANNEL_ID self._client = client self._string = _gen_string(name) @@ -148,7 +148,7 @@ async def store(self, return message_id async def forward_stored(self, - client: Union['_client.Userge', '_client._UsergeBot'], + client: Union['_client.Userge', '_client.UsergeBot'], message_id: int, chat_id: int, user_id: int, diff --git a/userge/core/types/new/conversation.py b/userge/core/types/new/conversation.py index beadefc0b..283393bd5 100644 --- a/userge/core/types/new/conversation.py +++ b/userge/core/types/new/conversation.py @@ -27,7 +27,7 @@ _LOG = logging.getLogger(__name__) _LOG_STR = "<<>>" -_CL_TYPE = Union['_client.Userge', '_client._UsergeBot'] +_CL_TYPE = Union['_client.Userge', '_client.UsergeBot'] _CONV_DICT: Dict[Tuple[int, _CL_TYPE], Union[asyncio.Queue, Tuple[int, asyncio.Queue]]] = {} diff --git a/userge/plugins/misc/upload.py b/userge/plugins/misc/upload.py index 31448e3ab..d20fff370 100644 --- a/userge/plugins/misc/upload.py +++ b/userge/plugins/misc/upload.py @@ -121,14 +121,19 @@ async def _handle_message(message: Message) -> None: async def upload_path(message: Message, path: Path, del_path: bool): file_paths = [] - - def explorer(_path: Path) -> None: - if _path.is_file() and _path.stat().st_size: - file_paths.append(_path) - elif _path.is_dir(): - for i in sorted(_path.iterdir()): - explorer(i) - explorer(path) + if path.exists(): + def explorer(_path: Path) -> None: + if _path.is_file() and _path.stat().st_size: + file_paths.append(_path) + elif _path.is_dir(): + for i in sorted(_path.iterdir()): + explorer(i) + explorer(path) + else: + path = path.expanduser() + str_path = os.path.join(*(path.parts[1:] if path.is_absolute() else path.parts)) + for p in Path(path.root).glob(str_path): + file_paths.append(p) current = 0 for p_t in file_paths: current += 1 @@ -172,7 +177,7 @@ async def doc_upload(message: Message, path, del_path: bool = False, chat_id=message.chat.id, document=str_path, thumb=thumb, - caption=str_path, + caption=path.name, parse_mode="html", disable_notification=True, progress=progress, @@ -220,7 +225,7 @@ async def vid_upload(message: Message, path, del_path: bool = False, thumb=thumb, width=width, height=height, - caption=str_path, + caption=path.name, parse_mode="html", disable_notification=True, progress=progress, @@ -276,7 +281,7 @@ async def audio_upload(message: Message, path, del_path: bool = False, chat_id=message.chat.id, audio=str_path, thumb=thumb, - caption=f"{str_path} [ {file_size} ]", + caption=f"{path.name} [ {file_size} ]", title=title, performer=artist, duration=duration, @@ -317,7 +322,7 @@ async def photo_upload(message: Message, path, del_path: bool = False, extra: st progress_args=(message, f"uploading {extra}", str_path) ) except ValueError as e_e: - await sent.edit(f"Skipping `{path}` due to {e_e}") + await sent.edit(f"Skipping `{str_path}` due to {e_e}") except Exception as u_e: await sent.edit(str(u_e)) raise u_e