2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-30 05:48:14 +00:00

Merge develop -> asyncio

This commit is contained in:
Dan 2019-09-08 19:27:37 +02:00
commit 928ce5d850
23 changed files with 141 additions and 91 deletions

View File

@ -48,6 +48,8 @@ from .methods import Methods
from .storage import Storage, FileStorage, MemoryStorage from .storage import Storage, FileStorage, MemoryStorage
from .types import User, SentCode, TermsOfService from .types import User, SentCode, TermsOfService
log = logging.getLogger(__name__)
class Client(Methods, BaseClient): class Client(Methods, BaseClient):
"""Pyrogram Client, the main means for interacting with Telegram. """Pyrogram Client, the main means for interacting with Telegram.
@ -340,7 +342,7 @@ class Client(Methods, BaseClient):
if self.takeout_id: if self.takeout_id:
await self.send(functions.account.FinishTakeoutSession()) await self.send(functions.account.FinishTakeoutSession())
logging.warning("Takeout session {} finished".format(self.takeout_id)) log.warning("Takeout session {} finished".format(self.takeout_id))
await Syncer.remove(self) await Syncer.remove(self)
await self.dispatcher.stop() await self.dispatcher.stop()
@ -733,7 +735,7 @@ class Client(Methods, BaseClient):
print(e.MESSAGE.format(x=e.x)) print(e.MESSAGE.format(x=e.x))
time.sleep(e.x) time.sleep(e.x)
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
raise raise
else: else:
self.password = None self.password = None
@ -833,7 +835,7 @@ class Client(Methods, BaseClient):
if not self.storage.is_bot and self.takeout: if not self.storage.is_bot and self.takeout:
self.takeout_id = (await self.send(functions.account.InitTakeoutSession())).id self.takeout_id = (await self.send(functions.account.InitTakeoutSession())).id
logging.warning("Takeout session {} initiated".format(self.takeout_id)) log.warning("Takeout session {} initiated".format(self.takeout_id))
await self.send(functions.updates.GetState()) await self.send(functions.updates.GetState())
except (Exception, KeyboardInterrupt): except (Exception, KeyboardInterrupt):
@ -1273,7 +1275,7 @@ class Client(Methods, BaseClient):
os.makedirs(directory, exist_ok=True) os.makedirs(directory, exist_ok=True)
shutil.move(temp_file_path, final_file_path) shutil.move(temp_file_path, final_file_path)
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
try: try:
os.remove(temp_file_path) os.remove(temp_file_path)
@ -1314,7 +1316,7 @@ class Client(Methods, BaseClient):
pts_count = getattr(update, "pts_count", None) pts_count = getattr(update, "pts_count", None)
if isinstance(update, types.UpdateChannelTooLong): if isinstance(update, types.UpdateChannelTooLong):
logging.warning(update) log.warning(update)
if isinstance(update, types.UpdateNewChannelMessage) and is_min: if isinstance(update, types.UpdateNewChannelMessage) and is_min:
message = update.message message = update.message
@ -1366,9 +1368,9 @@ class Client(Methods, BaseClient):
elif isinstance(updates, types.UpdateShort): elif isinstance(updates, types.UpdateShort):
self.dispatcher.updates_queue.put_nowait((updates.update, {}, {})) self.dispatcher.updates_queue.put_nowait((updates.update, {}, {}))
elif isinstance(updates, types.UpdatesTooLong): elif isinstance(updates, types.UpdatesTooLong):
logging.info(updates) log.info(updates)
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
async def send(self, data: TLObject, retries: int = Session.MAX_RETRIES, timeout: float = Session.WAIT_TIMEOUT): async def send(self, data: TLObject, retries: int = Session.MAX_RETRIES, timeout: float = Session.WAIT_TIMEOUT):
"""Send raw Telegram queries. """Send raw Telegram queries.
@ -1543,7 +1545,7 @@ class Client(Methods, BaseClient):
if isinstance(handler, Handler) and isinstance(group, int): if isinstance(handler, Handler) and isinstance(group, int):
self.add_handler(handler, group) self.add_handler(handler, group)
logging.info('[{}] [LOAD] {}("{}") in group {} from "{}"'.format( log.info('[{}] [LOAD] {}("{}") in group {} from "{}"'.format(
self.session_name, type(handler).__name__, name, group, module_path)) self.session_name, type(handler).__name__, name, group, module_path))
count += 1 count += 1
@ -1557,12 +1559,12 @@ class Client(Methods, BaseClient):
try: try:
module = import_module(module_path) module = import_module(module_path)
except ImportError: except ImportError:
logging.warning('[{}] [LOAD] Ignoring non-existent module "{}"'.format( log.warning('[{}] [LOAD] Ignoring non-existent module "{}"'.format(
self.session_name, module_path)) self.session_name, module_path))
continue continue
if "__path__" in dir(module): if "__path__" in dir(module):
logging.warning('[{}] [LOAD] Ignoring namespace "{}"'.format( log.warning('[{}] [LOAD] Ignoring namespace "{}"'.format(
self.session_name, module_path)) self.session_name, module_path))
continue continue
@ -1578,13 +1580,13 @@ class Client(Methods, BaseClient):
if isinstance(handler, Handler) and isinstance(group, int): if isinstance(handler, Handler) and isinstance(group, int):
self.add_handler(handler, group) self.add_handler(handler, group)
logging.info('[{}] [LOAD] {}("{}") in group {} from "{}"'.format( log.info('[{}] [LOAD] {}("{}") in group {} from "{}"'.format(
self.session_name, type(handler).__name__, name, group, module_path)) self.session_name, type(handler).__name__, name, group, module_path))
count += 1 count += 1
except Exception: except Exception:
if warn_non_existent_functions: if warn_non_existent_functions:
logging.warning('[{}] [LOAD] Ignoring non-existent function "{}" from "{}"'.format( log.warning('[{}] [LOAD] Ignoring non-existent function "{}" from "{}"'.format(
self.session_name, name, module_path)) self.session_name, name, module_path))
if exclude: if exclude:
@ -1595,12 +1597,12 @@ class Client(Methods, BaseClient):
try: try:
module = import_module(module_path) module = import_module(module_path)
except ImportError: except ImportError:
logging.warning('[{}] [UNLOAD] Ignoring non-existent module "{}"'.format( log.warning('[{}] [UNLOAD] Ignoring non-existent module "{}"'.format(
self.session_name, module_path)) self.session_name, module_path))
continue continue
if "__path__" in dir(module): if "__path__" in dir(module):
logging.warning('[{}] [UNLOAD] Ignoring namespace "{}"'.format( log.warning('[{}] [UNLOAD] Ignoring namespace "{}"'.format(
self.session_name, module_path)) self.session_name, module_path))
continue continue
@ -1616,20 +1618,20 @@ class Client(Methods, BaseClient):
if isinstance(handler, Handler) and isinstance(group, int): if isinstance(handler, Handler) and isinstance(group, int):
self.remove_handler(handler, group) self.remove_handler(handler, group)
logging.info('[{}] [UNLOAD] {}("{}") from group {} in "{}"'.format( log.info('[{}] [UNLOAD] {}("{}") from group {} in "{}"'.format(
self.session_name, type(handler).__name__, name, group, module_path)) self.session_name, type(handler).__name__, name, group, module_path))
count -= 1 count -= 1
except Exception: except Exception:
if warn_non_existent_functions: if warn_non_existent_functions:
logging.warning('[{}] [UNLOAD] Ignoring non-existent function "{}" from "{}"'.format( log.warning('[{}] [UNLOAD] Ignoring non-existent function "{}" from "{}"'.format(
self.session_name, name, module_path)) self.session_name, name, module_path))
if count > 0: if count > 0:
logging.warning('[{}] Successfully loaded {} plugin{} from "{}"'.format( log.warning('[{}] Successfully loaded {} plugin{} from "{}"'.format(
self.session_name, count, "s" if count > 1 else "", root)) self.session_name, count, "s" if count > 1 else "", root))
else: else:
logging.warning('[{}] No plugin loaded from "{}"'.format( log.warning('[{}] No plugin loaded from "{}"'.format(
self.session_name, root)) self.session_name, root))
# def get_initial_dialogs_chunk(self, offset_date: int = 0): # def get_initial_dialogs_chunk(self, offset_date: int = 0):
@ -1646,10 +1648,10 @@ class Client(Methods, BaseClient):
# ) # )
# ) # )
# except FloodWait as e: # except FloodWait as e:
# logging.warning("get_dialogs flood: waiting {} seconds".format(e.x)) # log.warning("get_dialogs flood: waiting {} seconds".format(e.x))
# time.sleep(e.x) # time.sleep(e.x)
# else: # else:
# logging.info("Total peers: {}".format(self.storage.peers_count)) # log.info("Total peers: {}".format(self.storage.peers_count))
# return r # return r
# #
# def get_initial_dialogs(self): # def get_initial_dialogs(self):
@ -1885,7 +1887,7 @@ class Client(Methods, BaseClient):
except Client.StopTransmission: except Client.StopTransmission:
raise raise
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
else: else:
if is_big: if is_big:
return types.InputFileBig( return types.InputFileBig(
@ -2117,7 +2119,7 @@ class Client(Methods, BaseClient):
raise e raise e
except Exception as e: except Exception as e:
if not isinstance(e, Client.StopTransmission): if not isinstance(e, Client.StopTransmission):
logging.error(e, exc_info=True) log.error(e, exc_info=True)
try: try:
os.remove(file_name) os.remove(file_name)

View File

@ -34,6 +34,8 @@ from ..handlers import (
UserStatusHandler, RawUpdateHandler, InlineQueryHandler, PollHandler UserStatusHandler, RawUpdateHandler, InlineQueryHandler, PollHandler
) )
log = logging.getLogger(__name__)
class Dispatcher: class Dispatcher:
NEW_MESSAGE_UPDATES = ( NEW_MESSAGE_UPDATES = (
@ -183,7 +185,7 @@ class Dispatcher:
if handler.check(parsed_update): if handler.check(parsed_update):
args = (parsed_update,) args = (parsed_update,)
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
continue continue
elif isinstance(handler, RawUpdateHandler): elif isinstance(handler, RawUpdateHandler):
@ -199,10 +201,10 @@ class Dispatcher:
except pyrogram.ContinuePropagation: except pyrogram.ContinuePropagation:
continue continue
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
break break
except pyrogram.StopPropagation: except pyrogram.StopPropagation:
pass pass
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)

View File

@ -20,6 +20,8 @@ import asyncio
import logging import logging
import time import time
log = logging.getLogger(__name__)
class Syncer: class Syncer:
INTERVAL = 20 INTERVAL = 20
@ -81,9 +83,9 @@ class Syncer:
start = time.time() start = time.time()
client.storage.save() client.storage.save()
except Exception as e: except Exception as e:
logging.critical(e, exc_info=True) log.critical(e, exc_info=True)
else: else:
logging.info('Synced "{}" in {:.6} ms'.format( log.info('Synced "{}" in {:.6} ms'.format(
client.storage.name, client.storage.name,
(time.time() - start) * 1000 (time.time() - start) * 1000
)) ))

View File

@ -23,9 +23,10 @@ from typing import Union, List
import pyrogram import pyrogram
from pyrogram.api import functions, types from pyrogram.api import functions, types
from pyrogram.errors import FloodWait from pyrogram.errors import FloodWait
from ...ext import BaseClient from ...ext import BaseClient
log = logging.getLogger(__name__)
class Filters: class Filters:
ALL = "all" ALL = "all"
@ -152,7 +153,7 @@ class GetChatMembers(BaseClient):
return pyrogram.List(pyrogram.ChatMember._parse(self, member, users) for member in members) return pyrogram.List(pyrogram.ChatMember._parse(self, member, users) for member in members)
except FloodWait as e: except FloodWait as e:
logging.warning("Sleeping for {}s".format(e.x)) log.warning("Sleeping for {}s".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
raise ValueError("The chat_id \"{}\" belongs to a user".format(chat_id)) raise ValueError("The chat_id \"{}\" belongs to a user".format(chat_id))

View File

@ -23,9 +23,10 @@ from typing import List
import pyrogram import pyrogram
from pyrogram.api import functions, types from pyrogram.api import functions, types
from pyrogram.errors import FloodWait from pyrogram.errors import FloodWait
from ...ext import BaseClient, utils from ...ext import BaseClient, utils
log = logging.getLogger(__name__)
class GetDialogs(BaseClient): class GetDialogs(BaseClient):
async def get_dialogs( async def get_dialogs(
@ -81,7 +82,7 @@ class GetDialogs(BaseClient):
) )
) )
except FloodWait as e: except FloodWait as e:
logging.warning("Sleeping for {}s".format(e.x)) log.warning("Sleeping for {}s".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
break break
@ -110,6 +111,6 @@ class GetDialogs(BaseClient):
if not isinstance(dialog, types.Dialog): if not isinstance(dialog, types.Dialog):
continue continue
parsed_dialogs.append(pyrogram.Dialogging._parse(self, dialog, messages, users, chats)) parsed_dialogs.append(pyrogram.Dialog._parse(self, dialog, messages, users, chats))
return pyrogram.List(parsed_dialogs) return pyrogram.List(parsed_dialogs)

View File

@ -25,6 +25,8 @@ from pyrogram.api import functions
from pyrogram.errors import FloodWait from pyrogram.errors import FloodWait
from ...ext import BaseClient from ...ext import BaseClient
log = logging.getLogger(__name__)
class GetContacts(BaseClient): class GetContacts(BaseClient):
async def get_contacts(self) -> List["pyrogram.User"]: async def get_contacts(self) -> List["pyrogram.User"]:
@ -43,7 +45,7 @@ class GetContacts(BaseClient):
try: try:
contacts = await self.send(functions.contacts.GetContacts(hash=0)) contacts = await self.send(functions.contacts.GetContacts(hash=0))
except FloodWait as e: except FloodWait as e:
logging.warning("get_contacts flood: waiting {} seconds".format(e.x)) log.warning("get_contacts flood: waiting {} seconds".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
return pyrogram.List(pyrogram.User._parse(self, user) for user in contacts.users) return pyrogram.List(pyrogram.User._parse(self, user) for user in contacts.users)

View File

@ -24,9 +24,10 @@ import pyrogram
from pyrogram.api import functions from pyrogram.api import functions
from pyrogram.client.ext import utils from pyrogram.client.ext import utils
from pyrogram.errors import FloodWait from pyrogram.errors import FloodWait
from ...ext import BaseClient from ...ext import BaseClient
log = logging.getLogger(__name__)
class GetHistory(BaseClient): class GetHistory(BaseClient):
async def get_history( async def get_history(
@ -102,7 +103,7 @@ class GetHistory(BaseClient):
) )
) )
except FloodWait as e: except FloodWait as e:
logging.warning("Sleeping for {}s".format(e.x)) log.warning("Sleeping for {}s".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
break break

View File

@ -16,11 +16,14 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
from typing import Union from typing import Union
from pyrogram.api import types, functions from pyrogram.api import types, functions
from pyrogram.client.ext import BaseClient from pyrogram.client.ext import BaseClient
log = logging.getLogger(__name__)
class GetHistoryCount(BaseClient): class GetHistoryCount(BaseClient):
async def get_history_count( async def get_history_count(

View File

@ -23,9 +23,10 @@ from typing import Union, Iterable, List
import pyrogram import pyrogram
from pyrogram.api import functions, types from pyrogram.api import functions, types
from pyrogram.errors import FloodWait from pyrogram.errors import FloodWait
from ...ext import BaseClient, utils from ...ext import BaseClient, utils
log = logging.getLogger(__name__)
# TODO: Rewrite using a flag for replied messages and have message_ids non-optional # TODO: Rewrite using a flag for replied messages and have message_ids non-optional
@ -115,7 +116,7 @@ class GetMessages(BaseClient):
try: try:
r = await self.send(rpc) r = await self.send(rpc)
except FloodWait as e: except FloodWait as e:
logging.warning("Sleeping for {}s".format(e.x)) log.warning("Sleeping for {}s".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
break break

View File

@ -26,6 +26,8 @@ from pyrogram.api import functions, types
from pyrogram.client.ext import BaseClient, utils from pyrogram.client.ext import BaseClient, utils
from pyrogram.errors import FloodWait from pyrogram.errors import FloodWait
log = logging.getLogger(__name__)
class SendMediaGroup(BaseClient): class SendMediaGroup(BaseClient):
# TODO: Add progress parameter # TODO: Add progress parameter
@ -87,7 +89,7 @@ class SendMediaGroup(BaseClient):
) )
) )
except FloodWait as e: except FloodWait as e:
logging.warning("Sleeping for {}s".format(e.x)) log.warning("Sleeping for {}s".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
break break
@ -142,7 +144,7 @@ class SendMediaGroup(BaseClient):
) )
) )
except FloodWait as e: except FloodWait as e:
logging.warning("Sleeping for {}s".format(e.x)) log.warning("Sleeping for {}s".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
break break
@ -193,7 +195,7 @@ class SendMediaGroup(BaseClient):
) )
) )
except FloodWait as e: except FloodWait as e:
logging.warning("Sleeping for {}s".format(e.x)) log.warning("Sleeping for {}s".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
break break

View File

@ -28,6 +28,8 @@ from pyrogram.api import types
from pyrogram.errors import PeerIdInvalid from pyrogram.errors import PeerIdInvalid
from . import utils from . import utils
log = logging.getLogger(__name__)
class Parser(HTMLParser): class Parser(HTMLParser):
MENTION_RE = re.compile(r"tg://user\?id=(\d+)") MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
@ -95,7 +97,7 @@ class Parser(HTMLParser):
line, offset = self.getpos() line, offset = self.getpos()
offset += 1 offset += 1
logging.warning("Unmatched closing tag </{}> at line {}:{}".format(tag, line, offset)) log.warning("Unmatched closing tag </{}> at line {}:{}".format(tag, line, offset))
else: else:
if not self.tag_entities[tag]: if not self.tag_entities[tag]:
self.tag_entities.pop(tag) self.tag_entities.pop(tag)
@ -121,7 +123,7 @@ class HTML:
for tag, entities in parser.tag_entities.items(): for tag, entities in parser.tag_entities.items():
unclosed_tags.append("<{}> (x{})".format(tag, len(entities))) unclosed_tags.append("<{}> (x{})".format(tag, len(entities)))
logging.warning("Unclosed tags: {}".format(", ".join(unclosed_tags))) log.warning("Unclosed tags: {}".format(", ".join(unclosed_tags)))
entities = [] entities = []

View File

@ -26,6 +26,8 @@ from threading import Lock
from .memory_storage import MemoryStorage from .memory_storage import MemoryStorage
log = logging.getLogger(__name__)
class FileStorage(MemoryStorage): class FileStorage(MemoryStorage):
FILE_EXTENSION = ".session" FILE_EXTENSION = ".session"
@ -81,20 +83,20 @@ class FileStorage(MemoryStorage):
except ValueError: except ValueError:
pass pass
else: else:
logging.warning("JSON session storage detected! Converting it into an SQLite session storage...") log.warning("JSON session storage detected! Converting it into an SQLite session storage...")
path.rename(path.name + ".OLD") path.rename(path.name + ".OLD")
logging.warning('The old session file has been renamed to "{}.OLD"'.format(path.name)) log.warning('The old session file has been renamed to "{}.OLD"'.format(path.name))
self.migrate_from_json(session_json) self.migrate_from_json(session_json)
logging.warning("Done! The session has been successfully converted from JSON to SQLite storage") log.warning("Done! The session has been successfully converted from JSON to SQLite storage")
return return
if Path(path.name + ".OLD").is_file(): if Path(path.name + ".OLD").is_file():
logging.warning('Old session file detected: "{}.OLD". You can remove this file now'.format(path.name)) log.warning('Old session file detected: "{}.OLD". You can remove this file now'.format(path.name))
self.conn = sqlite3.connect( self.conn = sqlite3.connect(
str(path), str(path),

View File

@ -18,6 +18,7 @@
import base64 import base64
import inspect import inspect
import logging
import sqlite3 import sqlite3
import struct import struct
import time import time
@ -28,6 +29,8 @@ from typing import List, Tuple
from pyrogram.api import types from pyrogram.api import types
from pyrogram.client.storage.storage import Storage from pyrogram.client.storage.storage import Storage
log = logging.getLogger(__name__)
class MemoryStorage(Storage): class MemoryStorage(Storage):
SCHEMA_VERSION = 1 SCHEMA_VERSION = 1

View File

@ -22,6 +22,8 @@ import logging
from .transport import * from .transport import *
from ..session.internals import DataCenter from ..session.internals import DataCenter
log = logging.getLogger(__name__)
class Connection: class Connection:
MAX_RETRIES = 3 MAX_RETRIES = 3
@ -49,14 +51,14 @@ class Connection:
self.protocol = self.mode(self.ipv6, self.proxy) self.protocol = self.mode(self.ipv6, self.proxy)
try: try:
logging.info("Connecting...") log.info("Connecting...")
await self.protocol.connect(self.address) await self.protocol.connect(self.address)
except OSError as e: except OSError as e:
logging.warning(e) # TODO: Remove log.warning(e) # TODO: Remove
self.protocol.close() self.protocol.close()
await asyncio.sleep(1) await asyncio.sleep(1)
else: else:
logging.info("Connected! {} DC{} - IPv{} - {}".format( log.info("Connected! {} DC{} - IPv{} - {}".format(
"Test" if self.test_mode else "Production", "Test" if self.test_mode else "Production",
self.dc_id, self.dc_id,
"6" if self.ipv6 else "4", "6" if self.ipv6 else "4",
@ -64,12 +66,12 @@ class Connection:
)) ))
break break
else: else:
logging.warning("Connection failed! Trying again...") log.warning("Connection failed! Trying again...")
raise TimeoutError raise TimeoutError
def close(self): def close(self):
self.protocol.close() self.protocol.close()
logging.info("Disconnected") log.info("Disconnected")
async def send(self, data: bytes): async def send(self, data: bytes):
try: try:

View File

@ -31,6 +31,8 @@ except ImportError as e:
raise e raise e
log = logging.getLogger(__name__)
class TCP: class TCP:
TIMEOUT = 10 TIMEOUT = 10
@ -65,7 +67,7 @@ class TCP:
password=proxy.get("password", None) password=proxy.get("password", None)
) )
logging.info("Using proxy {}:{}".format(hostname, port)) log.info("Using proxy {}:{}".format(hostname, port))
else: else:
self.socket = socks.socksocket( self.socket = socks.socksocket(
socket.AF_INET6 if ipv6 socket.AF_INET6 if ipv6

View File

@ -16,8 +16,12 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
from .tcp import TCP from .tcp import TCP
log = logging.getLogger(__name__)
class TCPAbridged(TCP): class TCPAbridged(TCP):
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):

View File

@ -16,11 +16,14 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
import os import os
from .tcp import TCP from .tcp import TCP
from ....crypto.aes import AES from ....crypto.aes import AES
log = logging.getLogger(__name__)
class TCPAbridgedO(TCP): class TCPAbridgedO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)

View File

@ -16,11 +16,14 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
from binascii import crc32 from binascii import crc32
from struct import pack, unpack from struct import pack, unpack
from .tcp import TCP from .tcp import TCP
log = logging.getLogger(__name__)
class TCPFull(TCP): class TCPFull(TCP):
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):

View File

@ -16,10 +16,13 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
from struct import pack, unpack from struct import pack, unpack
from .tcp import TCP from .tcp import TCP
log = logging.getLogger(__name__)
class TCPIntermediate(TCP): class TCPIntermediate(TCP):
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):

View File

@ -16,12 +16,15 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
import os import os
from struct import pack, unpack from struct import pack, unpack
from .tcp import TCP from .tcp import TCP
from ....crypto.aes import AES from ....crypto.aes import AES
log = logging.getLogger(__name__)
class TCPIntermediateO(TCP): class TCPIntermediateO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)

View File

@ -18,10 +18,12 @@
import logging import logging
log = logging.getLogger(__name__)
try: try:
import tgcrypto import tgcrypto
logging.info("Using TgCrypto") log.info("Using TgCrypto")
class AES: class AES:
@ -51,7 +53,7 @@ try:
except ImportError: except ImportError:
import pyaes import pyaes
logging.warning( log.warning(
"TgCrypto is missing! " "TgCrypto is missing! "
"Pyrogram will work the same, but at a much slower speed. " "Pyrogram will work the same, but at a much slower speed. "
"More info: https://docs.pyrogram.org/topics/tgcrypto" "More info: https://docs.pyrogram.org/topics/tgcrypto"

View File

@ -30,6 +30,8 @@ from pyrogram.connection import Connection
from pyrogram.crypto import AES, RSA, Prime from pyrogram.crypto import AES, RSA, Prime
from .internals import MsgId from .internals import MsgId
log = logging.getLogger(__name__)
class Auth: class Auth:
MAX_RETRIES = 5 MAX_RETRIES = 5
@ -76,34 +78,34 @@ class Auth:
self.connection = Connection(self.dc_id, self.test_mode, self.ipv6, self.proxy) self.connection = Connection(self.dc_id, self.test_mode, self.ipv6, self.proxy)
try: try:
logging.info("Start creating a new auth key on DC{}".format(self.dc_id)) log.info("Start creating a new auth key on DC{}".format(self.dc_id))
await self.connection.connect() await self.connection.connect()
# Step 1; Step 2 # Step 1; Step 2
nonce = int.from_bytes(urandom(16), "little", signed=True) nonce = int.from_bytes(urandom(16), "little", signed=True)
logging.debug("Send req_pq: {}".format(nonce)) log.debug("Send req_pq: {}".format(nonce))
res_pq = await self.send(functions.ReqPqMulti(nonce=nonce)) res_pq = await self.send(functions.ReqPqMulti(nonce=nonce))
logging.debug("Got ResPq: {}".format(res_pq.server_nonce)) log.debug("Got ResPq: {}".format(res_pq.server_nonce))
logging.debug("Server public key fingerprints: {}".format(res_pq.server_public_key_fingerprints)) log.debug("Server public key fingerprints: {}".format(res_pq.server_public_key_fingerprints))
for i in res_pq.server_public_key_fingerprints: for i in res_pq.server_public_key_fingerprints:
if i in RSA.server_public_keys: if i in RSA.server_public_keys:
logging.debug("Using fingerprint: {}".format(i)) log.debug("Using fingerprint: {}".format(i))
public_key_fingerprint = i public_key_fingerprint = i
break break
else: else:
logging.debug("Fingerprint unknown: {}".format(i)) log.debug("Fingerprint unknown: {}".format(i))
else: else:
raise Exception("Public key not found") raise Exception("Public key not found")
# Step 3 # Step 3
pq = int.from_bytes(res_pq.pq, "big") pq = int.from_bytes(res_pq.pq, "big")
logging.debug("Start PQ factorization: {}".format(pq)) log.debug("Start PQ factorization: {}".format(pq))
start = time.time() start = time.time()
g = Prime.decompose(pq) g = Prime.decompose(pq)
p, q = sorted((g, pq // g)) # p < q p, q = sorted((g, pq // g)) # p < q
logging.debug("Done PQ factorization ({}s): {} {}".format(round(time.time() - start, 3), p, q)) log.debug("Done PQ factorization ({}s): {} {}".format(round(time.time() - start, 3), p, q))
# Step 4 # Step 4
server_nonce = res_pq.server_nonce server_nonce = res_pq.server_nonce
@ -123,10 +125,10 @@ class Auth:
data_with_hash = sha + data + padding data_with_hash = sha + data + padding
encrypted_data = RSA.encrypt(data_with_hash, public_key_fingerprint) encrypted_data = RSA.encrypt(data_with_hash, public_key_fingerprint)
logging.debug("Done encrypt data with RSA") log.debug("Done encrypt data with RSA")
# Step 5. TODO: Handle "server_DH_params_fail". Code assumes response is ok # Step 5. TODO: Handle "server_DH_params_fail". Code assumes response is ok
logging.debug("Send req_DH_params") log.debug("Send req_DH_params")
server_dh_params = await self.send( server_dh_params = await self.send(
functions.ReqDHParams( functions.ReqDHParams(
nonce=nonce, nonce=nonce,
@ -160,12 +162,12 @@ class Auth:
server_dh_inner_data = TLObject.read(BytesIO(answer)) server_dh_inner_data = TLObject.read(BytesIO(answer))
logging.debug("Done decrypting answer") log.debug("Done decrypting answer")
dh_prime = int.from_bytes(server_dh_inner_data.dh_prime, "big") dh_prime = int.from_bytes(server_dh_inner_data.dh_prime, "big")
delta_time = server_dh_inner_data.server_time - time.time() delta_time = server_dh_inner_data.server_time - time.time()
logging.debug("Delta time: {}".format(round(delta_time, 3))) log.debug("Delta time: {}".format(round(delta_time, 3)))
# Step 6 # Step 6
g = server_dh_inner_data.g g = server_dh_inner_data.g
@ -186,7 +188,7 @@ class Auth:
data_with_hash = sha + data + padding data_with_hash = sha + data + padding
encrypted_data = AES.ige256_encrypt(data_with_hash, tmp_aes_key, tmp_aes_iv) encrypted_data = AES.ige256_encrypt(data_with_hash, tmp_aes_key, tmp_aes_iv)
logging.debug("Send set_client_DH_params") log.debug("Send set_client_DH_params")
set_client_dh_params_answer = await self.send( set_client_dh_params_answer = await self.send(
functions.SetClientDHParams( functions.SetClientDHParams(
nonce=nonce, nonce=nonce,
@ -209,7 +211,7 @@ class Auth:
####################### #######################
assert dh_prime == Prime.CURRENT_DH_PRIME assert dh_prime == Prime.CURRENT_DH_PRIME
logging.debug("DH parameters check: OK") log.debug("DH parameters check: OK")
# https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation # https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation
g_b = int.from_bytes(g_b, "big") g_b = int.from_bytes(g_b, "big")
@ -218,12 +220,12 @@ class Auth:
assert 1 < g_b < dh_prime - 1 assert 1 < g_b < dh_prime - 1
assert 2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64) assert 2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64)
assert 2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64) assert 2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64)
logging.debug("g_a and g_b validation: OK") log.debug("g_a and g_b validation: OK")
# https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values # https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values
answer = server_dh_inner_data.write() # Call .write() to remove padding answer = server_dh_inner_data.write() # Call .write() to remove padding
assert answer_with_hash[:20] == sha1(answer).digest() assert answer_with_hash[:20] == sha1(answer).digest()
logging.debug("SHA1 hash values check: OK") log.debug("SHA1 hash values check: OK")
# https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields # https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields
# 1st message # 1st message
@ -236,14 +238,14 @@ class Auth:
assert nonce == set_client_dh_params_answer.nonce assert nonce == set_client_dh_params_answer.nonce
assert server_nonce == set_client_dh_params_answer.server_nonce assert server_nonce == set_client_dh_params_answer.server_nonce
server_nonce = server_nonce.to_bytes(16, "little", signed=True) server_nonce = server_nonce.to_bytes(16, "little", signed=True)
logging.debug("Nonce fields check: OK") log.debug("Nonce fields check: OK")
# Step 9 # Step 9
server_salt = AES.xor(new_nonce[:8], server_nonce[:8]) server_salt = AES.xor(new_nonce[:8], server_nonce[:8])
logging.debug("Server salt: {}".format(int.from_bytes(server_salt, "little"))) log.debug("Server salt: {}".format(int.from_bytes(server_salt, "little")))
logging.info( log.info(
"Done auth key exchange: {}".format( "Done auth key exchange: {}".format(
set_client_dh_params_answer.__class__.__name__ set_client_dh_params_answer.__class__.__name__
) )

View File

@ -32,6 +32,8 @@ from pyrogram.crypto import MTProto
from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated
from .internals import MsgId, MsgFactory from .internals import MsgId, MsgFactory
log = logging.getLogger(__name__)
class Result: class Result:
def __init__(self): def __init__(self):
@ -156,9 +158,9 @@ class Session:
self.ping_task = asyncio.ensure_future(self.ping()) self.ping_task = asyncio.ensure_future(self.ping())
logging.info("Session initialized: Layer {}".format(layer)) log.info("Session initialized: Layer {}".format(layer))
logging.info("Device: {} - {}".format(self.client.device_model, self.client.app_version)) log.info("Device: {} - {}".format(self.client.device_model, self.client.app_version))
logging.info("System: {} ({})".format(self.client.system_version, self.client.lang_code.upper())) log.info("System: {} ({})".format(self.client.system_version, self.client.lang_code.upper()))
except AuthKeyDuplicated as e: except AuthKeyDuplicated as e:
await self.stop() await self.stop()
@ -173,7 +175,7 @@ class Session:
self.is_connected.set() self.is_connected.set()
logging.info("Session started") log.info("Session started")
async def stop(self): async def stop(self):
self.is_connected.clear() self.is_connected.clear()
@ -205,9 +207,9 @@ class Session:
try: try:
await self.client.disconnect_handler(self.client) await self.client.disconnect_handler(self.client)
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
logging.info("Session stopped") log.info("Session stopped")
async def restart(self): async def restart(self):
await self.stop() await self.stop()
@ -236,7 +238,7 @@ class Session:
else [data] else [data]
) )
logging.debug(data) log.debug(data)
for msg in messages: for msg in messages:
if msg.seq_no % 2 != 0: if msg.seq_no % 2 != 0:
@ -269,7 +271,7 @@ class Session:
self.results[msg_id].event.set() self.results[msg_id].event.set()
if len(self.pending_acks) >= self.ACKS_THRESHOLD: if len(self.pending_acks) >= self.ACKS_THRESHOLD:
logging.info("Send {} acks".format(len(self.pending_acks))) log.info("Send {} acks".format(len(self.pending_acks)))
try: try:
await self._send(types.MsgsAck(msg_ids=list(self.pending_acks)), False) await self._send(types.MsgsAck(msg_ids=list(self.pending_acks)), False)
@ -278,12 +280,12 @@ class Session:
else: else:
self.pending_acks.clear() self.pending_acks.clear()
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
logging.info("NetWorkerTask stopped") log.info("NetWorkerTask stopped")
async def ping(self): async def ping(self):
logging.info("PingTask started") log.info("PingTask started")
while True: while True:
try: try:
@ -302,10 +304,10 @@ class Session:
except (OSError, TimeoutError, RPCError): except (OSError, TimeoutError, RPCError):
pass pass
logging.info("PingTask stopped") log.info("PingTask stopped")
async def next_salt(self): async def next_salt(self):
logging.info("NextSaltTask started") log.info("NextSaltTask started")
while True: while True:
now = datetime.now() now = datetime.now()
@ -315,7 +317,7 @@ class Session:
valid_until = datetime.fromtimestamp(self.current_salt.valid_until) valid_until = datetime.fromtimestamp(self.current_salt.valid_until)
dt = (valid_until - now).total_seconds() - 900 dt = (valid_until - now).total_seconds() - 900
logging.info("Next salt in {:.0f}m {:.0f}s ({})".format( log.info("Next salt in {:.0f}m {:.0f}s ({})".format(
dt // 60, dt % 60, dt // 60, dt % 60,
now + timedelta(seconds=dt) now + timedelta(seconds=dt)
)) ))
@ -333,10 +335,10 @@ class Session:
self.connection.close() self.connection.close()
break break
logging.info("NextSaltTask stopped") log.info("NextSaltTask stopped")
async def recv(self): async def recv(self):
logging.info("RecvTask started") log.info("RecvTask started")
while True: while True:
packet = await self.connection.recv() packet = await self.connection.recv()
@ -345,7 +347,7 @@ class Session:
self.recv_queue.put_nowait(None) self.recv_queue.put_nowait(None)
if packet: if packet:
logging.warning("Server sent \"{}\"".format(Int.read(BytesIO(packet)))) log.warning("Server sent \"{}\"".format(Int.read(BytesIO(packet))))
if self.is_connected.is_set(): if self.is_connected.is_set():
asyncio.ensure_future(self.restart()) asyncio.ensure_future(self.restart())
@ -354,7 +356,7 @@ class Session:
self.recv_queue.put_nowait(packet) self.recv_queue.put_nowait(packet)
logging.info("RecvTask stopped") log.info("RecvTask stopped")
async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT): async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT):
message = self.msg_factory(data) message = self.msg_factory(data)