2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 21:07:59 +00:00

Revert "Stop instantiating logger objects and directly use the logging module"

This reverts commit 792068d7
This commit is contained in:
Dan 2019-09-08 19:24:06 +02:00
parent 792068d7c8
commit a015f998fa
23 changed files with 148 additions and 94 deletions

View File

@ -48,6 +48,8 @@ from .methods import Methods
from .storage import Storage, FileStorage, MemoryStorage from .storage import Storage, FileStorage, MemoryStorage
from .types import User, SentCode, TermsOfService from .types import User, SentCode, TermsOfService
log = logging.getLogger(__name__)
class Client(Methods, BaseClient): class Client(Methods, BaseClient):
"""Pyrogram Client, the main means for interacting with Telegram. """Pyrogram Client, the main means for interacting with Telegram.
@ -340,7 +342,7 @@ class Client(Methods, BaseClient):
if self.takeout_id: if self.takeout_id:
self.send(functions.account.FinishTakeoutSession()) self.send(functions.account.FinishTakeoutSession())
logging.warning("Takeout session {} finished".format(self.takeout_id)) log.warning("Takeout session {} finished".format(self.takeout_id))
Syncer.remove(self) Syncer.remove(self)
self.dispatcher.stop() self.dispatcher.stop()
@ -728,7 +730,7 @@ class Client(Methods, BaseClient):
print(e.MESSAGE.format(x=e.x)) print(e.MESSAGE.format(x=e.x))
time.sleep(e.x) time.sleep(e.x)
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
raise raise
else: else:
self.password = None self.password = None
@ -828,7 +830,7 @@ class Client(Methods, BaseClient):
if not self.storage.is_bot and self.takeout: if not self.storage.is_bot and self.takeout:
self.takeout_id = self.send(functions.account.InitTakeoutSession()).id self.takeout_id = self.send(functions.account.InitTakeoutSession()).id
logging.warning("Takeout session {} initiated".format(self.takeout_id)) log.warning("Takeout session {} initiated".format(self.takeout_id))
self.send(functions.updates.GetState()) self.send(functions.updates.GetState())
except (Exception, KeyboardInterrupt): except (Exception, KeyboardInterrupt):
@ -1227,7 +1229,7 @@ class Client(Methods, BaseClient):
def download_worker(self): def download_worker(self):
name = threading.current_thread().name name = threading.current_thread().name
logging.debug("{} started".format(name)) log.debug("{} started".format(name))
while True: while True:
packet = self.download_queue.get() packet = self.download_queue.get()
@ -1262,7 +1264,7 @@ class Client(Methods, BaseClient):
os.makedirs(directory, exist_ok=True) os.makedirs(directory, exist_ok=True)
shutil.move(temp_file_path, final_file_path) shutil.move(temp_file_path, final_file_path)
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
try: try:
os.remove(temp_file_path) os.remove(temp_file_path)
@ -1276,11 +1278,11 @@ class Client(Methods, BaseClient):
finally: finally:
done.set() done.set()
logging.debug("{} stopped".format(name)) log.debug("{} stopped".format(name))
def updates_worker(self): def updates_worker(self):
name = threading.current_thread().name name = threading.current_thread().name
logging.debug("{} started".format(name)) log.debug("{} started".format(name))
while True: while True:
updates = self.updates_queue.get() updates = self.updates_queue.get()
@ -1308,7 +1310,7 @@ class Client(Methods, BaseClient):
pts_count = getattr(update, "pts_count", None) pts_count = getattr(update, "pts_count", None)
if isinstance(update, types.UpdateChannelTooLong): if isinstance(update, types.UpdateChannelTooLong):
logging.warning(update) log.warning(update)
if isinstance(update, types.UpdateNewChannelMessage) and is_min: if isinstance(update, types.UpdateNewChannelMessage) and is_min:
message = update.message message = update.message
@ -1360,11 +1362,11 @@ class Client(Methods, BaseClient):
elif isinstance(updates, types.UpdateShort): elif isinstance(updates, types.UpdateShort):
self.dispatcher.updates_queue.put((updates.update, {}, {})) self.dispatcher.updates_queue.put((updates.update, {}, {}))
elif isinstance(updates, types.UpdatesTooLong): elif isinstance(updates, types.UpdatesTooLong):
logging.info(updates) log.info(updates)
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
logging.debug("{} stopped".format(name)) log.debug("{} stopped".format(name))
def send(self, data: TLObject, retries: int = Session.MAX_RETRIES, timeout: float = Session.WAIT_TIMEOUT): def send(self, data: TLObject, retries: int = Session.MAX_RETRIES, timeout: float = Session.WAIT_TIMEOUT):
"""Send raw Telegram queries. """Send raw Telegram queries.
@ -1539,7 +1541,7 @@ class Client(Methods, BaseClient):
if isinstance(handler, Handler) and isinstance(group, int): if isinstance(handler, Handler) and isinstance(group, int):
self.add_handler(handler, group) self.add_handler(handler, group)
logging.info('[{}] [LOAD] {}("{}") in group {} from "{}"'.format( log.info('[{}] [LOAD] {}("{}") in group {} from "{}"'.format(
self.session_name, type(handler).__name__, name, group, module_path)) self.session_name, type(handler).__name__, name, group, module_path))
count += 1 count += 1
@ -1553,12 +1555,12 @@ class Client(Methods, BaseClient):
try: try:
module = import_module(module_path) module = import_module(module_path)
except ImportError: except ImportError:
logging.warning('[{}] [LOAD] Ignoring non-existent module "{}"'.format( log.warning('[{}] [LOAD] Ignoring non-existent module "{}"'.format(
self.session_name, module_path)) self.session_name, module_path))
continue continue
if "__path__" in dir(module): if "__path__" in dir(module):
logging.warning('[{}] [LOAD] Ignoring namespace "{}"'.format( log.warning('[{}] [LOAD] Ignoring namespace "{}"'.format(
self.session_name, module_path)) self.session_name, module_path))
continue continue
@ -1574,13 +1576,13 @@ class Client(Methods, BaseClient):
if isinstance(handler, Handler) and isinstance(group, int): if isinstance(handler, Handler) and isinstance(group, int):
self.add_handler(handler, group) self.add_handler(handler, group)
logging.info('[{}] [LOAD] {}("{}") in group {} from "{}"'.format( log.info('[{}] [LOAD] {}("{}") in group {} from "{}"'.format(
self.session_name, type(handler).__name__, name, group, module_path)) self.session_name, type(handler).__name__, name, group, module_path))
count += 1 count += 1
except Exception: except Exception:
if warn_non_existent_functions: if warn_non_existent_functions:
logging.warning('[{}] [LOAD] Ignoring non-existent function "{}" from "{}"'.format( log.warning('[{}] [LOAD] Ignoring non-existent function "{}" from "{}"'.format(
self.session_name, name, module_path)) self.session_name, name, module_path))
if exclude: if exclude:
@ -1591,12 +1593,12 @@ class Client(Methods, BaseClient):
try: try:
module = import_module(module_path) module = import_module(module_path)
except ImportError: except ImportError:
logging.warning('[{}] [UNLOAD] Ignoring non-existent module "{}"'.format( log.warning('[{}] [UNLOAD] Ignoring non-existent module "{}"'.format(
self.session_name, module_path)) self.session_name, module_path))
continue continue
if "__path__" in dir(module): if "__path__" in dir(module):
logging.warning('[{}] [UNLOAD] Ignoring namespace "{}"'.format( log.warning('[{}] [UNLOAD] Ignoring namespace "{}"'.format(
self.session_name, module_path)) self.session_name, module_path))
continue continue
@ -1612,20 +1614,20 @@ class Client(Methods, BaseClient):
if isinstance(handler, Handler) and isinstance(group, int): if isinstance(handler, Handler) and isinstance(group, int):
self.remove_handler(handler, group) self.remove_handler(handler, group)
logging.info('[{}] [UNLOAD] {}("{}") from group {} in "{}"'.format( log.info('[{}] [UNLOAD] {}("{}") from group {} in "{}"'.format(
self.session_name, type(handler).__name__, name, group, module_path)) self.session_name, type(handler).__name__, name, group, module_path))
count -= 1 count -= 1
except Exception: except Exception:
if warn_non_existent_functions: if warn_non_existent_functions:
logging.warning('[{}] [UNLOAD] Ignoring non-existent function "{}" from "{}"'.format( log.warning('[{}] [UNLOAD] Ignoring non-existent function "{}" from "{}"'.format(
self.session_name, name, module_path)) self.session_name, name, module_path))
if count > 0: if count > 0:
logging.warning('[{}] Successfully loaded {} plugin{} from "{}"'.format( log.warning('[{}] Successfully loaded {} plugin{} from "{}"'.format(
self.session_name, count, "s" if count > 1 else "", root)) self.session_name, count, "s" if count > 1 else "", root))
else: else:
logging.warning('[{}] No plugin loaded from "{}"'.format( log.warning('[{}] No plugin loaded from "{}"'.format(
self.session_name, root)) self.session_name, root))
# def get_initial_dialogs_chunk(self, offset_date: int = 0): # def get_initial_dialogs_chunk(self, offset_date: int = 0):
@ -1642,10 +1644,10 @@ class Client(Methods, BaseClient):
# ) # )
# ) # )
# except FloodWait as e: # except FloodWait as e:
# logging.warning("get_dialogs flood: waiting {} seconds".format(e.x)) # log.warning("get_dialogs flood: waiting {} seconds".format(e.x))
# time.sleep(e.x) # time.sleep(e.x)
# else: # else:
# logging.info("Total peers: {}".format(self.storage.peers_count)) # log.info("Total peers: {}".format(self.storage.peers_count))
# return r # return r
# #
# def get_initial_dialogs(self): # def get_initial_dialogs(self):
@ -1868,7 +1870,7 @@ class Client(Methods, BaseClient):
except Client.StopTransmission: except Client.StopTransmission:
raise raise
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
else: else:
if is_big: if is_big:
return types.InputFileBig( return types.InputFileBig(
@ -2094,7 +2096,7 @@ class Client(Methods, BaseClient):
raise e raise e
except Exception as e: except Exception as e:
if not isinstance(e, Client.StopTransmission): if not isinstance(e, Client.StopTransmission):
logging.error(e, exc_info=True) log.error(e, exc_info=True)
try: try:
os.remove(file_name) os.remove(file_name)

View File

@ -36,6 +36,8 @@ from ..handlers import (
UserStatusHandler, RawUpdateHandler, InlineQueryHandler, PollHandler UserStatusHandler, RawUpdateHandler, InlineQueryHandler, PollHandler
) )
log = logging.getLogger(__name__)
class Dispatcher: class Dispatcher:
NEW_MESSAGE_UPDATES = ( NEW_MESSAGE_UPDATES = (
@ -156,7 +158,7 @@ class Dispatcher:
def update_worker(self, lock): def update_worker(self, lock):
name = threading.current_thread().name name = threading.current_thread().name
logging.debug("{} started".format(name)) log.debug("{} started".format(name))
while True: while True:
packet = self.updates_queue.get() packet = self.updates_queue.get()
@ -184,7 +186,7 @@ class Dispatcher:
if handler.check(parsed_update): if handler.check(parsed_update):
args = (parsed_update,) args = (parsed_update,)
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
continue continue
elif isinstance(handler, RawUpdateHandler): elif isinstance(handler, RawUpdateHandler):
@ -200,12 +202,12 @@ class Dispatcher:
except pyrogram.ContinuePropagation: except pyrogram.ContinuePropagation:
continue continue
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
break break
except pyrogram.StopPropagation: except pyrogram.StopPropagation:
pass pass
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
logging.debug("{} stopped".format(name)) log.debug("{} stopped".format(name))

View File

@ -20,6 +20,8 @@ import logging
import time import time
from threading import Thread, Event, Lock from threading import Thread, Event, Lock
log = logging.getLogger(__name__)
class Syncer: class Syncer:
INTERVAL = 20 INTERVAL = 20
@ -77,9 +79,9 @@ class Syncer:
start = time.time() start = time.time()
client.storage.save() client.storage.save()
except Exception as e: except Exception as e:
logging.critical(e, exc_info=True) log.critical(e, exc_info=True)
else: else:
logging.info('Synced "{}" in {:.6} ms'.format( log.info('Synced "{}" in {:.6} ms'.format(
client.storage.name, client.storage.name,
(time.time() - start) * 1000 (time.time() - start) * 1000
)) ))

View File

@ -25,6 +25,8 @@ from pyrogram.api import functions, types
from pyrogram.errors import FloodWait from pyrogram.errors import FloodWait
from ...ext import BaseClient from ...ext import BaseClient
log = logging.getLogger(__name__)
class Filters: class Filters:
ALL = "all" ALL = "all"
@ -151,7 +153,7 @@ class GetChatMembers(BaseClient):
return pyrogram.List(pyrogram.ChatMember._parse(self, member, users) for member in members) return pyrogram.List(pyrogram.ChatMember._parse(self, member, users) for member in members)
except FloodWait as e: except FloodWait as e:
logging.warning("Sleeping for {}s".format(e.x)) log.warning("Sleeping for {}s".format(e.x))
time.sleep(e.x) time.sleep(e.x)
else: else:
raise ValueError("The chat_id \"{}\" belongs to a user".format(chat_id)) raise ValueError("The chat_id \"{}\" belongs to a user".format(chat_id))

View File

@ -25,6 +25,8 @@ from pyrogram.api import functions, types
from pyrogram.errors import FloodWait from pyrogram.errors import FloodWait
from ...ext import BaseClient, utils from ...ext import BaseClient, utils
log = logging.getLogger(__name__)
class GetDialogs(BaseClient): class GetDialogs(BaseClient):
def get_dialogs( def get_dialogs(
@ -80,7 +82,7 @@ class GetDialogs(BaseClient):
) )
) )
except FloodWait as e: except FloodWait as e:
logging.warning("Sleeping {}s".format(e.x)) log.warning("Sleeping {}s".format(e.x))
time.sleep(e.x) time.sleep(e.x)
else: else:
break break
@ -109,6 +111,6 @@ class GetDialogs(BaseClient):
if not isinstance(dialog, types.Dialog): if not isinstance(dialog, types.Dialog):
continue continue
parsed_dialogs.append(pyrogram.Dialogging._parse(self, dialog, messages, users, chats)) parsed_dialogs.append(pyrogram.Dialog._parse(self, dialog, messages, users, chats))
return pyrogram.List(parsed_dialogs) return pyrogram.List(parsed_dialogs)

View File

@ -25,6 +25,8 @@ from pyrogram.api import functions
from pyrogram.errors import FloodWait from pyrogram.errors import FloodWait
from ...ext import BaseClient from ...ext import BaseClient
log = logging.getLogger(__name__)
class GetContacts(BaseClient): class GetContacts(BaseClient):
def get_contacts(self) -> List["pyrogram.User"]: def get_contacts(self) -> List["pyrogram.User"]:
@ -43,7 +45,7 @@ class GetContacts(BaseClient):
try: try:
contacts = self.send(functions.contacts.GetContacts(hash=0)) contacts = self.send(functions.contacts.GetContacts(hash=0))
except FloodWait as e: except FloodWait as e:
logging.warning("get_contacts flood: waiting {} seconds".format(e.x)) log.warning("get_contacts flood: waiting {} seconds".format(e.x))
time.sleep(e.x) time.sleep(e.x)
else: else:
return pyrogram.List(pyrogram.User._parse(self, user) for user in contacts.users) return pyrogram.List(pyrogram.User._parse(self, user) for user in contacts.users)

View File

@ -26,6 +26,8 @@ from pyrogram.client.ext import utils
from pyrogram.errors import FloodWait from pyrogram.errors import FloodWait
from ...ext import BaseClient from ...ext import BaseClient
log = logging.getLogger(__name__)
class GetHistory(BaseClient): class GetHistory(BaseClient):
def get_history( def get_history(
@ -101,7 +103,7 @@ class GetHistory(BaseClient):
) )
) )
except FloodWait as e: except FloodWait as e:
logging.warning("Sleeping for {}s".format(e.x)) log.warning("Sleeping for {}s".format(e.x))
time.sleep(e.x) time.sleep(e.x)
else: else:
break break

View File

@ -16,11 +16,14 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
from typing import Union from typing import Union
from pyrogram.api import types, functions from pyrogram.api import types, functions
from pyrogram.client.ext import BaseClient from pyrogram.client.ext import BaseClient
log = logging.getLogger(__name__)
class GetHistoryCount(BaseClient): class GetHistoryCount(BaseClient):
def get_history_count( def get_history_count(

View File

@ -25,6 +25,8 @@ from pyrogram.api import functions, types
from pyrogram.errors import FloodWait from pyrogram.errors import FloodWait
from ...ext import BaseClient, utils from ...ext import BaseClient, utils
log = logging.getLogger(__name__)
# TODO: Rewrite using a flag for replied messages and have message_ids non-optional # TODO: Rewrite using a flag for replied messages and have message_ids non-optional
@ -114,7 +116,7 @@ class GetMessages(BaseClient):
try: try:
r = self.send(rpc) r = self.send(rpc)
except FloodWait as e: except FloodWait as e:
logging.warning("Sleeping for {}s".format(e.x)) log.warning("Sleeping for {}s".format(e.x))
time.sleep(e.x) time.sleep(e.x)
else: else:
break break

View File

@ -26,6 +26,8 @@ from pyrogram.api import functions, types
from pyrogram.client.ext import BaseClient, utils from pyrogram.client.ext import BaseClient, utils
from pyrogram.errors import FloodWait from pyrogram.errors import FloodWait
log = logging.getLogger(__name__)
class SendMediaGroup(BaseClient): class SendMediaGroup(BaseClient):
# TODO: Add progress parameter # TODO: Add progress parameter
@ -87,7 +89,7 @@ class SendMediaGroup(BaseClient):
) )
) )
except FloodWait as e: except FloodWait as e:
logging.warning("Sleeping for {}s".format(e.x)) log.warning("Sleeping for {}s".format(e.x))
time.sleep(e.x) time.sleep(e.x)
else: else:
break break
@ -142,7 +144,7 @@ class SendMediaGroup(BaseClient):
) )
) )
except FloodWait as e: except FloodWait as e:
logging.warning("Sleeping for {}s".format(e.x)) log.warning("Sleeping for {}s".format(e.x))
time.sleep(e.x) time.sleep(e.x)
else: else:
break break
@ -193,7 +195,7 @@ class SendMediaGroup(BaseClient):
) )
) )
except FloodWait as e: except FloodWait as e:
logging.warning("Sleeping for {}s".format(e.x)) log.warning("Sleeping for {}s".format(e.x))
time.sleep(e.x) time.sleep(e.x)
else: else:
break break

View File

@ -28,6 +28,8 @@ from pyrogram.api import types
from pyrogram.errors import PeerIdInvalid from pyrogram.errors import PeerIdInvalid
from . import utils from . import utils
log = logging.getLogger(__name__)
class Parser(HTMLParser): class Parser(HTMLParser):
MENTION_RE = re.compile(r"tg://user\?id=(\d+)") MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
@ -95,7 +97,7 @@ class Parser(HTMLParser):
line, offset = self.getpos() line, offset = self.getpos()
offset += 1 offset += 1
logging.warning("Unmatched closing tag </{}> at line {}:{}".format(tag, line, offset)) log.warning("Unmatched closing tag </{}> at line {}:{}".format(tag, line, offset))
else: else:
if not self.tag_entities[tag]: if not self.tag_entities[tag]:
self.tag_entities.pop(tag) self.tag_entities.pop(tag)
@ -121,7 +123,7 @@ class HTML:
for tag, entities in parser.tag_entities.items(): for tag, entities in parser.tag_entities.items():
unclosed_tags.append("<{}> (x{})".format(tag, len(entities))) unclosed_tags.append("<{}> (x{})".format(tag, len(entities)))
logging.warning("Unclosed tags: {}".format(", ".join(unclosed_tags))) log.warning("Unclosed tags: {}".format(", ".join(unclosed_tags)))
entities = [] entities = []

View File

@ -26,6 +26,8 @@ from threading import Lock
from .memory_storage import MemoryStorage from .memory_storage import MemoryStorage
log = logging.getLogger(__name__)
class FileStorage(MemoryStorage): class FileStorage(MemoryStorage):
FILE_EXTENSION = ".session" FILE_EXTENSION = ".session"
@ -81,20 +83,20 @@ class FileStorage(MemoryStorage):
except ValueError: except ValueError:
pass pass
else: else:
logging.warning("JSON session storage detected! Converting it into an SQLite session storage...") log.warning("JSON session storage detected! Converting it into an SQLite session storage...")
path.rename(path.name + ".OLD") path.rename(path.name + ".OLD")
logging.warning('The old session file has been renamed to "{}.OLD"'.format(path.name)) log.warning('The old session file has been renamed to "{}.OLD"'.format(path.name))
self.migrate_from_json(session_json) self.migrate_from_json(session_json)
logging.warning("Done! The session has been successfully converted from JSON to SQLite storage") log.warning("Done! The session has been successfully converted from JSON to SQLite storage")
return return
if Path(path.name + ".OLD").is_file(): if Path(path.name + ".OLD").is_file():
logging.warning('Old session file detected: "{}.OLD". You can remove this file now'.format(path.name)) log.warning('Old session file detected: "{}.OLD". You can remove this file now'.format(path.name))
self.conn = sqlite3.connect( self.conn = sqlite3.connect(
str(path), str(path),

View File

@ -18,6 +18,7 @@
import base64 import base64
import inspect import inspect
import logging
import sqlite3 import sqlite3
import struct import struct
import time import time
@ -28,6 +29,8 @@ from typing import List, Tuple
from pyrogram.api import types from pyrogram.api import types
from pyrogram.client.storage.storage import Storage from pyrogram.client.storage.storage import Storage
log = logging.getLogger(__name__)
class MemoryStorage(Storage): class MemoryStorage(Storage):
SCHEMA_VERSION = 1 SCHEMA_VERSION = 1

View File

@ -23,6 +23,8 @@ import time
from .transport import * from .transport import *
from ..session.internals import DataCenter from ..session.internals import DataCenter
log = logging.getLogger(__name__)
class Connection: class Connection:
MAX_RETRIES = 3 MAX_RETRIES = 3
@ -51,14 +53,14 @@ class Connection:
self.connection = self.mode(self.ipv6, self.proxy) self.connection = self.mode(self.ipv6, self.proxy)
try: try:
logging.info("Connecting...") log.info("Connecting...")
self.connection.connect(self.address) self.connection.connect(self.address)
except OSError as e: except OSError as e:
logging.warning(e) # TODO: Remove log.warning(e) # TODO: Remove
self.connection.close() self.connection.close()
time.sleep(1) time.sleep(1)
else: else:
logging.info("Connected! {} DC{} - IPv{} - {}".format( log.info("Connected! {} DC{} - IPv{} - {}".format(
"Test" if self.test_mode else "Production", "Test" if self.test_mode else "Production",
self.dc_id, self.dc_id,
"6" if self.ipv6 else "4", "6" if self.ipv6 else "4",
@ -66,12 +68,12 @@ class Connection:
)) ))
break break
else: else:
logging.warning("Connection failed! Trying again...") log.warning("Connection failed! Trying again...")
raise TimeoutError raise TimeoutError
def close(self): def close(self):
self.connection.close() self.connection.close()
logging.info("Disconnected") log.info("Disconnected")
def send(self, data: bytes): def send(self, data: bytes):
with self.lock: with self.lock:

View File

@ -30,6 +30,8 @@ except ImportError as e:
raise e raise e
log = logging.getLogger(__name__)
class TCP(socks.socksocket): class TCP(socks.socksocket):
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):
@ -55,7 +57,7 @@ class TCP(socks.socksocket):
password=proxy.get("password", None) password=proxy.get("password", None)
) )
logging.info("Using proxy {}:{}".format(hostname, port)) log.info("Using proxy {}:{}".format(hostname, port))
else: else:
super().__init__( super().__init__(
socket.AF_INET6 if ipv6 socket.AF_INET6 if ipv6

View File

@ -16,8 +16,12 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
from .tcp import TCP from .tcp import TCP
log = logging.getLogger(__name__)
class TCPAbridged(TCP): class TCPAbridged(TCP):
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):

View File

@ -16,11 +16,14 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
import os import os
from .tcp import TCP from .tcp import TCP
from ....crypto.aes import AES from ....crypto.aes import AES
log = logging.getLogger(__name__)
class TCPAbridgedO(TCP): class TCPAbridgedO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)

View File

@ -16,11 +16,14 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
from binascii import crc32 from binascii import crc32
from struct import pack, unpack from struct import pack, unpack
from .tcp import TCP from .tcp import TCP
log = logging.getLogger(__name__)
class TCPFull(TCP): class TCPFull(TCP):
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):

View File

@ -16,10 +16,13 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
from struct import pack, unpack from struct import pack, unpack
from .tcp import TCP from .tcp import TCP
log = logging.getLogger(__name__)
class TCPIntermediate(TCP): class TCPIntermediate(TCP):
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):

View File

@ -16,12 +16,15 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
import os import os
from struct import pack, unpack from struct import pack, unpack
from .tcp import TCP from .tcp import TCP
from ....crypto.aes import AES from ....crypto.aes import AES
log = logging.getLogger(__name__)
class TCPIntermediateO(TCP): class TCPIntermediateO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)

View File

@ -18,10 +18,12 @@
import logging import logging
log = logging.getLogger(__name__)
try: try:
import tgcrypto import tgcrypto
logging.info("Using TgCrypto") log.info("Using TgCrypto")
class AES: class AES:
@ -51,7 +53,7 @@ try:
except ImportError: except ImportError:
import pyaes import pyaes
logging.warning( log.warning(
"TgCrypto is missing! " "TgCrypto is missing! "
"Pyrogram will work the same, but at a much slower speed. " "Pyrogram will work the same, but at a much slower speed. "
"More info: https://docs.pyrogram.org/topics/tgcrypto" "More info: https://docs.pyrogram.org/topics/tgcrypto"

View File

@ -29,6 +29,8 @@ from pyrogram.connection import Connection
from pyrogram.crypto import AES, RSA, Prime from pyrogram.crypto import AES, RSA, Prime
from .internals import MsgId from .internals import MsgId
log = logging.getLogger(__name__)
class Auth: class Auth:
MAX_RETRIES = 5 MAX_RETRIES = 5
@ -75,34 +77,34 @@ class Auth:
self.connection = Connection(self.dc_id, self.test_mode, self.ipv6, self.proxy) self.connection = Connection(self.dc_id, self.test_mode, self.ipv6, self.proxy)
try: try:
logging.info("Start creating a new auth key on DC{}".format(self.dc_id)) log.info("Start creating a new auth key on DC{}".format(self.dc_id))
self.connection.connect() self.connection.connect()
# Step 1; Step 2 # Step 1; Step 2
nonce = int.from_bytes(urandom(16), "little", signed=True) nonce = int.from_bytes(urandom(16), "little", signed=True)
logging.debug("Send req_pq: {}".format(nonce)) log.debug("Send req_pq: {}".format(nonce))
res_pq = self.send(functions.ReqPqMulti(nonce=nonce)) res_pq = self.send(functions.ReqPqMulti(nonce=nonce))
logging.debug("Got ResPq: {}".format(res_pq.server_nonce)) log.debug("Got ResPq: {}".format(res_pq.server_nonce))
logging.debug("Server public key fingerprints: {}".format(res_pq.server_public_key_fingerprints)) log.debug("Server public key fingerprints: {}".format(res_pq.server_public_key_fingerprints))
for i in res_pq.server_public_key_fingerprints: for i in res_pq.server_public_key_fingerprints:
if i in RSA.server_public_keys: if i in RSA.server_public_keys:
logging.debug("Using fingerprint: {}".format(i)) log.debug("Using fingerprint: {}".format(i))
public_key_fingerprint = i public_key_fingerprint = i
break break
else: else:
logging.debug("Fingerprint unknown: {}".format(i)) log.debug("Fingerprint unknown: {}".format(i))
else: else:
raise Exception("Public key not found") raise Exception("Public key not found")
# Step 3 # Step 3
pq = int.from_bytes(res_pq.pq, "big") pq = int.from_bytes(res_pq.pq, "big")
logging.debug("Start PQ factorization: {}".format(pq)) log.debug("Start PQ factorization: {}".format(pq))
start = time.time() start = time.time()
g = Prime.decompose(pq) g = Prime.decompose(pq)
p, q = sorted((g, pq // g)) # p < q p, q = sorted((g, pq // g)) # p < q
logging.debug("Done PQ factorization ({}s): {} {}".format(round(time.time() - start, 3), p, q)) log.debug("Done PQ factorization ({}s): {} {}".format(round(time.time() - start, 3), p, q))
# Step 4 # Step 4
server_nonce = res_pq.server_nonce server_nonce = res_pq.server_nonce
@ -122,10 +124,10 @@ class Auth:
data_with_hash = sha + data + padding data_with_hash = sha + data + padding
encrypted_data = RSA.encrypt(data_with_hash, public_key_fingerprint) encrypted_data = RSA.encrypt(data_with_hash, public_key_fingerprint)
logging.debug("Done encrypt data with RSA") log.debug("Done encrypt data with RSA")
# Step 5. TODO: Handle "server_DH_params_fail". Code assumes response is ok # Step 5. TODO: Handle "server_DH_params_fail". Code assumes response is ok
logging.debug("Send req_DH_params") log.debug("Send req_DH_params")
server_dh_params = self.send( server_dh_params = self.send(
functions.ReqDHParams( functions.ReqDHParams(
nonce=nonce, nonce=nonce,
@ -159,12 +161,12 @@ class Auth:
server_dh_inner_data = TLObject.read(BytesIO(answer)) server_dh_inner_data = TLObject.read(BytesIO(answer))
logging.debug("Done decrypting answer") log.debug("Done decrypting answer")
dh_prime = int.from_bytes(server_dh_inner_data.dh_prime, "big") dh_prime = int.from_bytes(server_dh_inner_data.dh_prime, "big")
delta_time = server_dh_inner_data.server_time - time.time() delta_time = server_dh_inner_data.server_time - time.time()
logging.debug("Delta time: {}".format(round(delta_time, 3))) log.debug("Delta time: {}".format(round(delta_time, 3)))
# Step 6 # Step 6
g = server_dh_inner_data.g g = server_dh_inner_data.g
@ -185,7 +187,7 @@ class Auth:
data_with_hash = sha + data + padding data_with_hash = sha + data + padding
encrypted_data = AES.ige256_encrypt(data_with_hash, tmp_aes_key, tmp_aes_iv) encrypted_data = AES.ige256_encrypt(data_with_hash, tmp_aes_key, tmp_aes_iv)
logging.debug("Send set_client_DH_params") log.debug("Send set_client_DH_params")
set_client_dh_params_answer = self.send( set_client_dh_params_answer = self.send(
functions.SetClientDHParams( functions.SetClientDHParams(
nonce=nonce, nonce=nonce,
@ -208,7 +210,7 @@ class Auth:
####################### #######################
assert dh_prime == Prime.CURRENT_DH_PRIME assert dh_prime == Prime.CURRENT_DH_PRIME
logging.debug("DH parameters check: OK") log.debug("DH parameters check: OK")
# https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation # https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation
g_b = int.from_bytes(g_b, "big") g_b = int.from_bytes(g_b, "big")
@ -217,12 +219,12 @@ class Auth:
assert 1 < g_b < dh_prime - 1 assert 1 < g_b < dh_prime - 1
assert 2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64) assert 2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64)
assert 2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64) assert 2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64)
logging.debug("g_a and g_b validation: OK") log.debug("g_a and g_b validation: OK")
# https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values # https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values
answer = server_dh_inner_data.write() # Call .write() to remove padding answer = server_dh_inner_data.write() # Call .write() to remove padding
assert answer_with_hash[:20] == sha1(answer).digest() assert answer_with_hash[:20] == sha1(answer).digest()
logging.debug("SHA1 hash values check: OK") log.debug("SHA1 hash values check: OK")
# https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields # https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields
# 1st message # 1st message
@ -235,14 +237,14 @@ class Auth:
assert nonce == set_client_dh_params_answer.nonce assert nonce == set_client_dh_params_answer.nonce
assert server_nonce == set_client_dh_params_answer.server_nonce assert server_nonce == set_client_dh_params_answer.server_nonce
server_nonce = server_nonce.to_bytes(16, "little", signed=True) server_nonce = server_nonce.to_bytes(16, "little", signed=True)
logging.debug("Nonce fields check: OK") log.debug("Nonce fields check: OK")
# Step 9 # Step 9
server_salt = AES.xor(new_nonce[:8], server_nonce[:8]) server_salt = AES.xor(new_nonce[:8], server_nonce[:8])
logging.debug("Server salt: {}".format(int.from_bytes(server_salt, "little"))) log.debug("Server salt: {}".format(int.from_bytes(server_salt, "little")))
logging.info( log.info(
"Done auth key exchange: {}".format( "Done auth key exchange: {}".format(
set_client_dh_params_answer.__class__.__name__ set_client_dh_params_answer.__class__.__name__
) )

View File

@ -36,6 +36,8 @@ from pyrogram.crypto import AES, KDF
from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated
from .internals import MsgId, MsgFactory from .internals import MsgId, MsgFactory
log = logging.getLogger(__name__)
class Result: class Result:
def __init__(self): def __init__(self):
@ -169,9 +171,9 @@ class Session:
self.ping_thread = Thread(target=self.ping, name="PingThread") self.ping_thread = Thread(target=self.ping, name="PingThread")
self.ping_thread.start() self.ping_thread.start()
logging.info("Session initialized: Layer {}".format(layer)) log.info("Session initialized: Layer {}".format(layer))
logging.info("Device: {} - {}".format(self.client.device_model, self.client.app_version)) log.info("Device: {} - {}".format(self.client.device_model, self.client.app_version))
logging.info("System: {} ({})".format(self.client.system_version, self.client.lang_code.upper())) log.info("System: {} ({})".format(self.client.system_version, self.client.lang_code.upper()))
except AuthKeyDuplicated as e: except AuthKeyDuplicated as e:
self.stop() self.stop()
@ -186,7 +188,7 @@ class Session:
self.is_connected.set() self.is_connected.set()
logging.debug("Session started") log.debug("Session started")
def stop(self): def stop(self):
self.is_connected.clear() self.is_connected.clear()
@ -221,9 +223,9 @@ class Session:
try: try:
self.client.disconnect_handler(self.client) self.client.disconnect_handler(self.client)
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
logging.debug("Session stopped") log.debug("Session stopped")
def restart(self): def restart(self):
self.stop() self.stop()
@ -266,7 +268,7 @@ class Session:
def net_worker(self): def net_worker(self):
name = threading.current_thread().name name = threading.current_thread().name
logging.debug("{} started".format(name)) log.debug("{} started".format(name))
while True: while True:
packet = self.recv_queue.get() packet = self.recv_queue.get()
@ -283,7 +285,7 @@ class Session:
else [data] else [data]
) )
logging.debug(data) log.debug(data)
for msg in messages: for msg in messages:
if msg.seq_no % 2 != 0: if msg.seq_no % 2 != 0:
@ -316,7 +318,7 @@ class Session:
self.results[msg_id].event.set() self.results[msg_id].event.set()
if len(self.pending_acks) >= self.ACKS_THRESHOLD: if len(self.pending_acks) >= self.ACKS_THRESHOLD:
logging.info("Send {} acks".format(len(self.pending_acks))) log.info("Send {} acks".format(len(self.pending_acks)))
try: try:
self._send(types.MsgsAck(msg_ids=list(self.pending_acks)), False) self._send(types.MsgsAck(msg_ids=list(self.pending_acks)), False)
@ -325,12 +327,12 @@ class Session:
else: else:
self.pending_acks.clear() self.pending_acks.clear()
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) log.error(e, exc_info=True)
logging.debug("{} stopped".format(name)) log.debug("{} stopped".format(name))
def ping(self): def ping(self):
logging.debug("PingThread started") log.debug("PingThread started")
while True: while True:
self.ping_thread_event.wait(self.PING_INTERVAL) self.ping_thread_event.wait(self.PING_INTERVAL)
@ -345,10 +347,10 @@ class Session:
except (OSError, TimeoutError, RPCError): except (OSError, TimeoutError, RPCError):
pass pass
logging.debug("PingThread stopped") log.debug("PingThread stopped")
def next_salt(self): def next_salt(self):
logging.debug("NextSaltThread started") log.debug("NextSaltThread started")
while True: while True:
now = datetime.now() now = datetime.now()
@ -358,7 +360,7 @@ class Session:
valid_until = datetime.fromtimestamp(self.current_salt.valid_until) valid_until = datetime.fromtimestamp(self.current_salt.valid_until)
dt = (valid_until - now).total_seconds() - 900 dt = (valid_until - now).total_seconds() - 900
logging.debug("Current salt: {} | Next salt in {:.0f}m {:.0f}s ({})".format( log.debug("Current salt: {} | Next salt in {:.0f}m {:.0f}s ({})".format(
self.current_salt.salt, self.current_salt.salt,
dt // 60, dt // 60,
dt % 60, dt % 60,
@ -376,17 +378,17 @@ class Session:
self.connection.close() self.connection.close()
break break
logging.debug("NextSaltThread stopped") log.debug("NextSaltThread stopped")
def recv(self): def recv(self):
logging.debug("RecvThread started") log.debug("RecvThread started")
while True: while True:
packet = self.connection.recv() packet = self.connection.recv()
if packet is None or len(packet) == 4: if packet is None or len(packet) == 4:
if packet: if packet:
logging.warning("Server sent \"{}\"".format(Int.read(BytesIO(packet)))) log.warning("Server sent \"{}\"".format(Int.read(BytesIO(packet))))
if self.is_connected.is_set(): if self.is_connected.is_set():
Thread(target=self.restart, name="RestartThread").start() Thread(target=self.restart, name="RestartThread").start()
@ -394,7 +396,7 @@ class Session:
self.recv_queue.put(packet) self.recv_queue.put(packet)
logging.debug("RecvThread stopped") log.debug("RecvThread stopped")
def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT): def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT):
message = self.msg_factory(data) message = self.msg_factory(data)