2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-29 05:18:10 +00:00

Merge develop -> asyncio

This commit is contained in:
Dan 2019-09-08 13:26:10 +02:00
commit 8f0b8babc2
28 changed files with 214 additions and 528 deletions

View File

@ -264,6 +264,7 @@ def pyrogram_api():
send_recovery_code send_recovery_code
recover_password recover_password
accept_terms_of_service accept_terms_of_service
log_out
""", """,
advanced=""" advanced="""
Advanced Advanced

View File

@ -38,12 +38,8 @@ from pyrogram.client.handlers.handler import Handler
from pyrogram.client.methods.password.utils import compute_check from pyrogram.client.methods.password.utils import compute_check
from pyrogram.crypto import AES from pyrogram.crypto import AES
from pyrogram.errors import ( from pyrogram.errors import (
PhoneMigrate, NetworkMigrate, PhoneNumberInvalid, PhoneMigrate, NetworkMigrate, SessionPasswordNeeded,
PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty, FloodWait, PeerIdInvalid, VolumeLocNotFound, UserMigrate, ChannelPrivate, AuthBytesInvalid,
PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded,
PasswordHashInvalid, FloodWait, PeerIdInvalid, FirstnameInvalid, PhoneNumberBanned,
VolumeLocNotFound, UserMigrate, ChannelPrivate, PhoneNumberOccupied,
PasswordRecoveryNa, PasswordEmpty, AuthBytesInvalid,
BadRequest) BadRequest)
from pyrogram.session import Auth, Session from pyrogram.session import Auth, Session
from .ext import utils, Syncer, BaseClient, Dispatcher from .ext import utils, Syncer, BaseClient, Dispatcher
@ -52,8 +48,6 @@ from .methods import Methods
from .storage import Storage, FileStorage, MemoryStorage from .storage import Storage, FileStorage, MemoryStorage
from .types import User, SentCode, TermsOfService from .types import User, SentCode, TermsOfService
log = logging.getLogger(__name__)
class Client(Methods, BaseClient): class Client(Methods, BaseClient):
"""Pyrogram Client, the main means for interacting with Telegram. """Pyrogram Client, the main means for interacting with Telegram.
@ -68,24 +62,24 @@ class Client(Methods, BaseClient):
:meth:`~pyrogram.Client.export_session_string` before stopping the client to get a session string you can :meth:`~pyrogram.Client.export_session_string` before stopping the client to get a session string you can
pass here as argument. pass here as argument.
api_id (``int``, *optional*): api_id (``int`` | ``str``, *optional*):
The *api_id* part of your Telegram API Key, as integer. E.g.: 12345 The *api_id* part of your Telegram API Key, as integer. E.g.: "12345".
This is an alternative way to pass it if you don't want to use the *config.ini* file. This is an alternative way to pass it if you don't want to use the *config.ini* file.
api_hash (``str``, *optional*): api_hash (``str``, *optional*):
The *api_hash* part of your Telegram API Key, as string. E.g.: "0123456789abcdef0123456789abcdef". The *api_hash* part of your Telegram API Key, as string. E.g.: "0123456789abcdef0123456789abcdef".
This is an alternative way to pass it if you don't want to use the *config.ini* file. This is an alternative way to set it if you don't want to use the *config.ini* file.
app_version (``str``, *optional*): app_version (``str``, *optional*):
Application version. Defaults to "Pyrogram X.Y.Z" Application version. Defaults to "Pyrogram |version|".
This is an alternative way to set it if you don't want to use the *config.ini* file. This is an alternative way to set it if you don't want to use the *config.ini* file.
device_model (``str``, *optional*): device_model (``str``, *optional*):
Device model. Defaults to *platform.python_implementation() + " " + platform.python_version()* Device model. Defaults to *platform.python_implementation() + " " + platform.python_version()*.
This is an alternative way to set it if you don't want to use the *config.ini* file. This is an alternative way to set it if you don't want to use the *config.ini* file.
system_version (``str``, *optional*): system_version (``str``, *optional*):
Operating System version. Defaults to *platform.system() + " " + platform.release()* Operating System version. Defaults to *platform.system() + " " + platform.release()*.
This is an alternative way to set it if you don't want to use the *config.ini* file. This is an alternative way to set it if you don't want to use the *config.ini* file.
lang_code (``str``, *optional*): lang_code (``str``, *optional*):
@ -99,69 +93,52 @@ class Client(Methods, BaseClient):
proxy (``dict``, *optional*): proxy (``dict``, *optional*):
Your SOCKS5 Proxy settings as dict, Your SOCKS5 Proxy settings as dict,
e.g.: *dict(hostname="11.22.33.44", port=1080, username="user", password="pass")*. e.g.: *dict(hostname="11.22.33.44", port=1080, username="user", password="pass")*.
*username* and *password* can be omitted if your proxy doesn't require authorization. The *username* and *password* can be omitted if your proxy doesn't require authorization.
This is an alternative way to setup a proxy if you don't want to use the *config.ini* file. This is an alternative way to setup a proxy if you don't want to use the *config.ini* file.
test_mode (``bool``, *optional*): test_mode (``bool``, *optional*):
Enable or disable login to the test servers. Defaults to False. Enable or disable login to the test servers.
Only applicable for new sessions and will be ignored in case previously Only applicable for new sessions and will be ignored in case previously created sessions are loaded.
created sessions are loaded. Defaults to False.
bot_token (``str``, *optional*): bot_token (``str``, *optional*):
Pass your Bot API token to create a bot session, e.g.: "123456:ABC-DEF1234ghIkl-zyx57W2v1u123ew11" Pass your Bot API token to create a bot session, e.g.: "123456:ABC-DEF1234ghIkl-zyx57W2v1u123ew11"
Only applicable for new sessions. Only applicable for new sessions.
This is an alternative way to set it if you don't want to use the *config.ini* file. This is an alternative way to set it if you don't want to use the *config.ini* file.
phone_number (``str`` | ``callable``, *optional*): phone_number (``str``, *optional*):
Pass your phone number as string (with your Country Code prefix included) to avoid entering it manually. Pass your phone number as string (with your Country Code prefix included) to avoid entering it manually.
Or pass a callback function which accepts no arguments and must return the correct phone number as string
(e.g., "391234567890").
Only applicable for new sessions. Only applicable for new sessions.
phone_code (``str`` | ``callable``, *optional*): phone_code (``str``, *optional*):
Pass the phone code as string (for test numbers only) to avoid entering it manually. Or pass a callback Pass the phone code as string (for test numbers only) to avoid entering it manually.
function which accepts a single positional argument *(phone_number)* and must return the correct phone code
as string (e.g., "12345").
Only applicable for new sessions. Only applicable for new sessions.
password (``str``, *optional*): password (``str``, *optional*):
Pass your Two-Step Verification password as string (if you have one) to avoid entering it manually. Pass your Two-Step Verification password as string (if you have one) to avoid entering it manually.
Or pass a callback function which accepts a single positional argument *(password_hint)* and must return
the correct password as string (e.g., "password").
Only applicable for new sessions. Only applicable for new sessions.
recovery_code (``callable``, *optional*): force_sms (``bool``, *optional*):
Pass a callback function which accepts a single positional argument *(email_pattern)* and must return the
correct password recovery code as string (e.g., "987654").
Only applicable for new sessions.
force_sms (``str``, *optional*):
Pass True to force Telegram sending the authorization code via SMS. Pass True to force Telegram sending the authorization code via SMS.
Only applicable for new sessions. Only applicable for new sessions.
Defaults to False.
first_name (``str``, *optional*):
Pass a First Name as string to avoid entering it manually. Or pass a callback function which accepts no
arguments and must return the correct name as string (e.g., "Dan"). It will be used to automatically create
a new Telegram account in case the phone number you passed is not registered yet.
Only applicable for new sessions.
last_name (``str``, *optional*):
Same purpose as *first_name*; pass a Last Name to avoid entering it manually. It can
be an empty string: "". Only applicable for new sessions.
workers (``int``, *optional*): workers (``int``, *optional*):
Number of maximum concurrent workers for handling incoming updates. Defaults to 4. Number of maximum concurrent workers for handling incoming updates.
Defaults to 4.
workdir (``str``, *optional*): workdir (``str``, *optional*):
Define a custom working directory. The working directory is the location in your filesystem Define a custom working directory. The working directory is the location in your filesystem where Pyrogram
where Pyrogram will store your session files. Defaults to the parent directory of the main script. will store your session files.
Defaults to the parent directory of the main script.
config_file (``str``, *optional*): config_file (``str``, *optional*):
Path of the configuration file. Defaults to ./config.ini Path of the configuration file.
Defaults to ./config.ini
plugins (``dict``, *optional*): plugins (``dict``, *optional*):
Your Smart Plugins settings as dict, e.g.: *dict(root="plugins")*. Your Smart Plugins settings as dict, e.g.: *dict(root="plugins")*.
This is an alternative way to setup plugins if you don't want to use the *config.ini* file. This is an alternative way setup plugins if you don't want to use the *config.ini* file.
no_updates (``bool``, *optional*): no_updates (``bool``, *optional*):
Pass True to completely disable incoming updates for the current session. Pass True to completely disable incoming updates for the current session.
@ -175,17 +152,6 @@ class Client(Methods, BaseClient):
download_media, ...) are less prone to throw FloodWait exceptions. download_media, ...) are less prone to throw FloodWait exceptions.
Only available for users, bots will ignore this parameter. Only available for users, bots will ignore this parameter.
Defaults to False (normal session). Defaults to False (normal session).
Example:
.. code-block:: python
from pyrogram import Client
app = Client("my_account")
with app:
app.send_message("me", "Hi!")
""" """
def __init__( def __init__(
@ -202,12 +168,9 @@ class Client(Methods, BaseClient):
test_mode: bool = False, test_mode: bool = False,
bot_token: str = None, bot_token: str = None,
phone_number: str = None, phone_number: str = None,
phone_code: Union[str, callable] = None, phone_code: str = None,
password: str = None, password: str = None,
recovery_code: callable = None,
force_sms: bool = False, force_sms: bool = False,
first_name: str = None,
last_name: str = None,
workers: int = BaseClient.WORKERS, workers: int = BaseClient.WORKERS,
workdir: str = BaseClient.WORKDIR, workdir: str = BaseClient.WORKDIR,
config_file: str = BaseClient.CONFIG_FILE, config_file: str = BaseClient.CONFIG_FILE,
@ -232,10 +195,7 @@ class Client(Methods, BaseClient):
self.phone_number = phone_number self.phone_number = phone_number
self.phone_code = phone_code self.phone_code = phone_code
self.password = password self.password = password
self.recovery_code = recovery_code
self.force_sms = force_sms self.force_sms = force_sms
self.first_name = first_name
self.last_name = last_name
self.workers = workers self.workers = workers
self.workdir = Path(workdir) self.workdir = Path(workdir)
self.config_file = Path(config_file) self.config_file = Path(config_file)
@ -260,7 +220,10 @@ class Client(Methods, BaseClient):
return self.start() return self.start()
def __exit__(self, *args): def __exit__(self, *args):
self.stop() try:
self.stop()
except ConnectionError:
pass
async def __aenter__(self): async def __aenter__(self):
return await self.start() return await self.start()
@ -349,20 +312,18 @@ class Client(Methods, BaseClient):
asyncio.ensure_future(self.updates_worker()) asyncio.ensure_future(self.updates_worker())
) )
log.info("Started {} UpdatesWorkerTasks".format(Client.UPDATES_WORKERS)) logging.info("Started {} UpdatesWorkerTasks".format(Client.UPDATES_WORKERS))
for _ in range(Client.DOWNLOAD_WORKERS): for _ in range(Client.DOWNLOAD_WORKERS):
self.download_worker_tasks.append( self.download_worker_tasks.append(
asyncio.ensure_future(self.download_worker()) asyncio.ensure_future(self.download_worker())
) )
log.info("Started {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS)) logging.info("Started {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS))
await self.dispatcher.start() await self.dispatcher.start()
await Syncer.add(self) await Syncer.add(self)
Syncer.add(self)
self.is_initialized = True self.is_initialized = True
async def terminate(self): async def terminate(self):
@ -379,7 +340,7 @@ class Client(Methods, BaseClient):
if self.takeout_id: if self.takeout_id:
await self.send(functions.account.FinishTakeoutSession()) await self.send(functions.account.FinishTakeoutSession())
log.warning("Takeout session {} finished".format(self.takeout_id)) logging.warning("Takeout session {} finished".format(self.takeout_id))
await Syncer.remove(self) await Syncer.remove(self)
await self.dispatcher.stop() await self.dispatcher.stop()
@ -392,7 +353,7 @@ class Client(Methods, BaseClient):
self.download_worker_tasks.clear() self.download_worker_tasks.clear()
log.info("Stopped {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS)) logging.info("Stopped {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS))
for _ in range(Client.UPDATES_WORKERS): for _ in range(Client.UPDATES_WORKERS):
self.updates_queue.put_nowait(None) self.updates_queue.put_nowait(None)
@ -402,7 +363,7 @@ class Client(Methods, BaseClient):
self.updates_worker_tasks.clear() self.updates_worker_tasks.clear()
log.info("Stopped {} UpdatesWorkerTasks".format(Client.UPDATES_WORKERS)) logging.info("Stopped {} UpdatesWorkerTasks".format(Client.UPDATES_WORKERS))
for media_session in self.media_sessions.values(): for media_session in self.media_sessions.values():
await media_session.stop() await media_session.stop()
@ -689,36 +650,37 @@ class Client(Methods, BaseClient):
return True return True
async def authorize(self) -> User: async def authorize(self) -> User:
if self.bot_token is not None: if self.bot_token:
return await self.sign_in_bot(self.bot_token) return await self.sign_in_bot(self.bot_token)
while True: while True:
if self.phone_number is None:
while True:
value = await ainput("Enter phone number or bot token: ")
confirm = await ainput("Is \"{}\" correct? (y/n): ".format(value))
if confirm in ("y", "1"):
break
elif confirm in ("n", "2"):
continue
if ":" in value:
self.bot_token = value
return await self.sign_in_bot(value)
else:
self.phone_number = value
try: try:
if not self.phone_number:
while True:
value = await ainput("Enter phone number or bot token: ")
if not value:
continue
confirm = input("Is \"{}\" correct? (y/N): ".format(value)).lower()
if confirm == "y":
break
if ":" in value:
self.bot_token = value
return await self.sign_in_bot(value)
else:
self.phone_number = value
sent_code = await self.send_code(self.phone_number) sent_code = await self.send_code(self.phone_number)
except BadRequest as e: except BadRequest as e:
print(e.MESSAGE) print(e.MESSAGE)
self.phone_number = None self.phone_number = None
self.bot_token = None
except FloodWait as e: except FloodWait as e:
print(e.MESSAGE.format(x=e.x)) print(e.MESSAGE.format(x=e.x))
time.sleep(e.x) time.sleep(e.x)
except Exception as e:
log.error(e, exc_info=True)
else: else:
break break
@ -735,7 +697,7 @@ class Client(Methods, BaseClient):
)) ))
while True: while True:
if self.phone_code is None: if not self.phone_code:
self.phone_code = await ainput("Enter confirmation code: ") self.phone_code = await ainput("Enter confirmation code: ")
try: try:
@ -749,14 +711,14 @@ class Client(Methods, BaseClient):
while True: while True:
print("Password hint: {}".format(await self.get_password_hint())) print("Password hint: {}".format(await self.get_password_hint()))
if self.password is None: if not self.password:
self.password = await ainput("Enter password (empty to recover): ") self.password = await ainput("Enter password (empty to recover): ")
try: try:
if self.password == "": if not self.password:
confirm = await ainput("Confirm password recovery (y/n): ") confirm = await ainput("Confirm password recovery (y/n): ")
if confirm in ("y", "1"): if confirm == "y":
email_pattern = await self.send_recovery_code() email_pattern = await self.send_recovery_code()
print("The recovery code has been sent to {}".format(email_pattern)) print("The recovery code has been sent to {}".format(email_pattern))
@ -771,10 +733,9 @@ class Client(Methods, BaseClient):
print(e.MESSAGE.format(x=e.x)) print(e.MESSAGE.format(x=e.x))
time.sleep(e.x) time.sleep(e.x)
except Exception as e: except Exception as e:
log.error(e, exc_info=True) logging.error(e, exc_info=True)
raise raise
else:
elif confirm in ("n", "2"):
self.password = None self.password = None
else: else:
return await self.check_password(self.password) return await self.check_password(self.password)
@ -784,14 +745,9 @@ class Client(Methods, BaseClient):
except FloodWait as e: except FloodWait as e:
print(e.MESSAGE.format(x=e.x)) print(e.MESSAGE.format(x=e.x))
time.sleep(e.x) time.sleep(e.x)
except Exception as e:
log.error(e, exc_info=True)
raise
except FloodWait as e: except FloodWait as e:
print(e.MESSAGE.format(x=e.x)) print(e.MESSAGE.format(x=e.x))
time.sleep(e.x) time.sleep(e.x)
except Exception as e:
log.error(e, exc_info=True)
else: else:
break break
@ -799,20 +755,18 @@ class Client(Methods, BaseClient):
return signed_in return signed_in
while True: while True:
self.first_name = await ainput("Enter first name: ") first_name = await ainput("Enter first name: ")
self.last_name = await ainput("Enter last name (empty to skip): ") last_name = await ainput("Enter last name (empty to skip): ")
try: try:
signed_up = await self.sign_up( signed_up = await self.sign_up(
self.phone_number, self.phone_number,
sent_code.phone_code_hash, sent_code.phone_code_hash,
self.first_name, first_name,
self.last_name last_name
) )
except BadRequest as e: except BadRequest as e:
print(e.MESSAGE) print(e.MESSAGE)
self.first_name = None
self.last_name = None
except FloodWait as e: except FloodWait as e:
print(e.MESSAGE.format(x=e.x)) print(e.MESSAGE.format(x=e.x))
time.sleep(e.x) time.sleep(e.x)
@ -825,7 +779,28 @@ class Client(Methods, BaseClient):
return signed_up return signed_up
def start(self): async def log_out(self):
"""Log out from Telegram and delete the *\\*.session* file.
When you log out, the current client is stopped and the storage session destroyed.
No more API calls can be made until you start the client and re-authorize again.
Returns:
``bool``: On success, True is returned.
Example:
.. code-block:: python
# Log out.
app.log_out()
"""
await self.send(functions.auth.LogOut())
await self.stop()
self.storage.destroy()
return True
async def start(self):
"""Start the client. """Start the client.
This method connects the client to Telegram and, in case of new sessions, automatically manages the full This method connects the client to Telegram and, in case of new sessions, automatically manages the full
@ -850,25 +825,25 @@ class Client(Methods, BaseClient):
app.stop() app.stop()
""" """
is_authorized = self.connect() is_authorized = await self.connect()
try: try:
if not is_authorized: if not is_authorized:
self.authorize() await self.authorize()
if not self.storage.is_bot and self.takeout: if not self.storage.is_bot and self.takeout:
self.takeout_id = self.send(functions.account.InitTakeoutSession()).id self.takeout_id = (await self.send(functions.account.InitTakeoutSession())).id
log.warning("Takeout session {} initiated".format(self.takeout_id)) logging.warning("Takeout session {} initiated".format(self.takeout_id))
self.send(functions.updates.GetState()) await self.send(functions.updates.GetState())
except Exception as e: except (Exception, KeyboardInterrupt):
self.disconnect() await self.disconnect()
raise e raise
else: else:
self.initialize() await self.initialize()
return self return self
def stop(self): async def stop(self):
"""Stop the Client. """Stop the Client.
This method disconnects the client from Telegram and stops the underlying tasks. This method disconnects the client from Telegram and stops the underlying tasks.
@ -892,8 +867,8 @@ class Client(Methods, BaseClient):
app.stop() app.stop()
""" """
self.terminate() await self.terminate()
self.disconnect() await self.disconnect()
return self return self
@ -976,8 +951,8 @@ class Client(Methods, BaseClient):
app3.stop() app3.stop()
""" """
def signal_handler(*args): def signal_handler(_, __):
log.info("Stop signal received ({}). Exiting...".format(args[0])) logging.info("Stop signal received ({}). Exiting...".format(_))
Client.is_idling = False Client.is_idling = False
for s in stop_signals: for s in stop_signals:
@ -1199,254 +1174,6 @@ class Client(Methods, BaseClient):
self.parse_mode = parse_mode self.parse_mode = parse_mode
async def authorize_bot(self):
try:
r = await self.send(
functions.auth.ImportBotAuthorization(
flags=0,
api_id=self.api_id,
api_hash=self.api_hash,
bot_auth_token=self.bot_token
)
)
except UserMigrate as e:
await self.session.stop()
self.storage.dc_id = e.x
self.storage.auth_key = await Auth(self, self.storage.dc_id).create()
self.session = Session(self, self.storage.dc_id, self.storage.auth_key)
await self.session.start()
await self.authorize_bot()
else:
self.storage.user_id = r.user.id
print("Logged in successfully as @{}".format(r.user.username))
async def authorize_user(self):
phone_number_invalid_raises = self.phone_number is not None
phone_code_invalid_raises = self.phone_code is not None
password_invalid_raises = self.password is not None
first_name_invalid_raises = self.first_name is not None
async def default_phone_number_callback():
while True:
phone_number = await ainput("Enter phone number: ")
confirm = await ainput("Is \"{}\" correct? (y/n): ".format(phone_number))
if confirm in ("y", "1"):
return phone_number
elif confirm in ("n", "2"):
continue
while True:
self.phone_number = (
await default_phone_number_callback() if self.phone_number is None
else str(await self.phone_number()) if callable(self.phone_number)
else str(self.phone_number)
)
self.phone_number = self.phone_number.strip("+")
try:
r = await self.send(
functions.auth.SendCode(
phone_number=self.phone_number,
api_id=self.api_id,
api_hash=self.api_hash,
settings=types.CodeSettings()
)
)
except (PhoneMigrate, NetworkMigrate) as e:
await self.session.stop()
self.storage.dc_id = e.x
self.storage.auth_key = await Auth(self, self.storage.dc_id).create()
self.session = Session(self, self.storage.dc_id, self.storage.auth_key)
await self.session.start()
except (PhoneNumberInvalid, PhoneNumberBanned) as e:
if phone_number_invalid_raises:
raise
else:
print(e.MESSAGE)
self.phone_number = None
except FloodWait as e:
if phone_number_invalid_raises:
raise
else:
print(e.MESSAGE.format(x=e.x))
await asyncio.sleep(e.x)
except Exception as e:
log.error(e, exc_info=True)
raise
else:
break
phone_registered = r.phone_registered
phone_code_hash = r.phone_code_hash
terms_of_service = r.terms_of_service
if terms_of_service and not Client.terms_of_service_displayed:
print("\n" + terms_of_service.text + "\n")
Client.terms_of_service_displayed = True
if self.force_sms:
await self.send(
functions.auth.ResendCode(
phone_number=self.phone_number,
phone_code_hash=phone_code_hash
)
)
while True:
if not phone_registered:
self.first_name = (
await ainput("First name: ") if self.first_name is None
else str(await self.first_name()) if callable(self.first_name)
else str(self.first_name)
)
self.last_name = (
await ainput("Last name: ") if self.last_name is None
else str(await self.last_name()) if callable(self.last_name)
else str(self.last_name)
)
self.phone_code = (
await ainput("Enter phone code: ") if self.phone_code is None
else str(await self.phone_code(self.phone_number)) if callable(self.phone_code)
else str(self.phone_code)
)
try:
if phone_registered:
try:
r = await self.send(
functions.auth.SignIn(
phone_number=self.phone_number,
phone_code_hash=phone_code_hash,
phone_code=self.phone_code
)
)
except PhoneNumberUnoccupied:
log.warning("Phone number unregistered")
phone_registered = False
continue
else:
try:
r = await self.send(
functions.auth.SignUp(
phone_number=self.phone_number,
phone_code_hash=phone_code_hash,
phone_code=self.phone_code,
first_name=self.first_name,
last_name=self.last_name
)
)
except PhoneNumberOccupied:
log.warning("Phone number already registered")
phone_registered = True
continue
except (PhoneCodeInvalid, PhoneCodeEmpty, PhoneCodeExpired, PhoneCodeHashEmpty) as e:
if phone_code_invalid_raises:
raise
else:
print(e.MESSAGE)
self.phone_code = None
except FirstnameInvalid as e:
if first_name_invalid_raises:
raise
else:
print(e.MESSAGE)
self.first_name = None
except SessionPasswordNeeded as e:
print(e.MESSAGE)
async def default_password_callback(password_hint: str) -> str:
print("Hint: {}".format(password_hint))
return await ainput("Enter password (empty to recover): ")
async def default_recovery_callback(email_pattern: str) -> str:
print("An e-mail containing the recovery code has been sent to {}".format(email_pattern))
return await ainput("Enter password recovery code: ")
while True:
try:
r = await self.send(functions.account.GetPassword())
self.password = (
await default_password_callback(r.hint) if self.password is None
else str((await self.password(r.hint)) or "") if callable(self.password)
else str(self.password)
)
if self.password == "":
r = await self.send(functions.auth.RequestPasswordRecovery())
self.recovery_code = (
await default_recovery_callback(r.email_pattern) if self.recovery_code is None
else str(await self.recovery_code(r.email_pattern)) if callable(self.recovery_code)
else str(self.recovery_code)
)
r = await self.send(
functions.auth.RecoverPassword(
code=self.recovery_code
)
)
else:
r = await self.send(
functions.auth.CheckPassword(
password=compute_check(r, self.password)
)
)
except (PasswordEmpty, PasswordRecoveryNa, PasswordHashInvalid) as e:
if password_invalid_raises:
raise
else:
print(e.MESSAGE)
self.password = None
self.recovery_code = None
except FloodWait as e:
if password_invalid_raises:
raise
else:
print(e.MESSAGE.format(x=e.x))
await asyncio.sleep(e.x)
self.password = None
self.recovery_code = None
except Exception as e:
log.error(e, exc_info=True)
raise
else:
break
break
except FloodWait as e:
if phone_code_invalid_raises or first_name_invalid_raises:
raise
else:
print(e.MESSAGE.format(x=e.x))
await asyncio.sleep(e.x)
except Exception as e:
log.error(e, exc_info=True)
raise
else:
break
if terms_of_service:
assert await self.send(
functions.help.AcceptTermsOfService(
id=terms_of_service.id
)
)
self.password = None
self.storage.user_id = r.user.id
print("Logged in successfully as {}".format(r.user.first_name))
def fetch_peers( def fetch_peers(
self, self,
peers: List[ peers: List[
@ -1546,7 +1273,7 @@ class Client(Methods, BaseClient):
os.makedirs(directory, exist_ok=True) os.makedirs(directory, exist_ok=True)
shutil.move(temp_file_path, final_file_path) shutil.move(temp_file_path, final_file_path)
except Exception as e: except Exception as e:
log.error(e, exc_info=True) logging.error(e, exc_info=True)
try: try:
os.remove(temp_file_path) os.remove(temp_file_path)
@ -1587,7 +1314,7 @@ class Client(Methods, BaseClient):
pts_count = getattr(update, "pts_count", None) pts_count = getattr(update, "pts_count", None)
if isinstance(update, types.UpdateChannelTooLong): if isinstance(update, types.UpdateChannelTooLong):
log.warning(update) logging.warning(update)
if isinstance(update, types.UpdateNewChannelMessage) and is_min: if isinstance(update, types.UpdateNewChannelMessage) and is_min:
message = update.message message = update.message
@ -1639,14 +1366,11 @@ class Client(Methods, BaseClient):
elif isinstance(updates, types.UpdateShort): elif isinstance(updates, types.UpdateShort):
self.dispatcher.updates_queue.put_nowait((updates.update, {}, {})) self.dispatcher.updates_queue.put_nowait((updates.update, {}, {}))
elif isinstance(updates, types.UpdatesTooLong): elif isinstance(updates, types.UpdatesTooLong):
log.warning(updates) logging.info(updates)
except Exception as e: except Exception as e:
log.error(e, exc_info=True) logging.error(e, exc_info=True)
async def send(self, async def send(self, data: TLObject, retries: int = Session.MAX_RETRIES, timeout: float = Session.WAIT_TIMEOUT):
data: TLObject,
retries: int = Session.MAX_RETRIES,
timeout: float = Session.WAIT_TIMEOUT):
"""Send raw Telegram queries. """Send raw Telegram queries.
This method makes it possible to manually call every single Telegram API method in a low-level manner. This method makes it possible to manually call every single Telegram API method in a low-level manner.
@ -1819,7 +1543,7 @@ class Client(Methods, BaseClient):
if isinstance(handler, Handler) and isinstance(group, int): if isinstance(handler, Handler) and isinstance(group, int):
self.add_handler(handler, group) self.add_handler(handler, group)
log.info('[{}] [LOAD] {}("{}") in group {} from "{}"'.format( logging.info('[{}] [LOAD] {}("{}") in group {} from "{}"'.format(
self.session_name, type(handler).__name__, name, group, module_path)) self.session_name, type(handler).__name__, name, group, module_path))
count += 1 count += 1
@ -1833,12 +1557,12 @@ class Client(Methods, BaseClient):
try: try:
module = import_module(module_path) module = import_module(module_path)
except ImportError: except ImportError:
log.warning('[{}] [LOAD] Ignoring non-existent module "{}"'.format( logging.warning('[{}] [LOAD] Ignoring non-existent module "{}"'.format(
self.session_name, module_path)) self.session_name, module_path))
continue continue
if "__path__" in dir(module): if "__path__" in dir(module):
log.warning('[{}] [LOAD] Ignoring namespace "{}"'.format( logging.warning('[{}] [LOAD] Ignoring namespace "{}"'.format(
self.session_name, module_path)) self.session_name, module_path))
continue continue
@ -1854,13 +1578,13 @@ class Client(Methods, BaseClient):
if isinstance(handler, Handler) and isinstance(group, int): if isinstance(handler, Handler) and isinstance(group, int):
self.add_handler(handler, group) self.add_handler(handler, group)
log.info('[{}] [LOAD] {}("{}") in group {} from "{}"'.format( logging.info('[{}] [LOAD] {}("{}") in group {} from "{}"'.format(
self.session_name, type(handler).__name__, name, group, module_path)) self.session_name, type(handler).__name__, name, group, module_path))
count += 1 count += 1
except Exception: except Exception:
if warn_non_existent_functions: if warn_non_existent_functions:
log.warning('[{}] [LOAD] Ignoring non-existent function "{}" from "{}"'.format( logging.warning('[{}] [LOAD] Ignoring non-existent function "{}" from "{}"'.format(
self.session_name, name, module_path)) self.session_name, name, module_path))
if exclude: if exclude:
@ -1871,12 +1595,12 @@ class Client(Methods, BaseClient):
try: try:
module = import_module(module_path) module = import_module(module_path)
except ImportError: except ImportError:
log.warning('[{}] [UNLOAD] Ignoring non-existent module "{}"'.format( logging.warning('[{}] [UNLOAD] Ignoring non-existent module "{}"'.format(
self.session_name, module_path)) self.session_name, module_path))
continue continue
if "__path__" in dir(module): if "__path__" in dir(module):
log.warning('[{}] [UNLOAD] Ignoring namespace "{}"'.format( logging.warning('[{}] [UNLOAD] Ignoring namespace "{}"'.format(
self.session_name, module_path)) self.session_name, module_path))
continue continue
@ -1892,20 +1616,20 @@ class Client(Methods, BaseClient):
if isinstance(handler, Handler) and isinstance(group, int): if isinstance(handler, Handler) and isinstance(group, int):
self.remove_handler(handler, group) self.remove_handler(handler, group)
log.info('[{}] [UNLOAD] {}("{}") from group {} in "{}"'.format( logging.info('[{}] [UNLOAD] {}("{}") from group {} in "{}"'.format(
self.session_name, type(handler).__name__, name, group, module_path)) self.session_name, type(handler).__name__, name, group, module_path))
count -= 1 count -= 1
except Exception: except Exception:
if warn_non_existent_functions: if warn_non_existent_functions:
log.warning('[{}] [UNLOAD] Ignoring non-existent function "{}" from "{}"'.format( logging.warning('[{}] [UNLOAD] Ignoring non-existent function "{}" from "{}"'.format(
self.session_name, name, module_path)) self.session_name, name, module_path))
if count > 0: if count > 0:
log.warning('[{}] Successfully loaded {} plugin{} from "{}"'.format( logging.warning('[{}] Successfully loaded {} plugin{} from "{}"'.format(
self.session_name, count, "s" if count > 1 else "", root)) self.session_name, count, "s" if count > 1 else "", root))
else: else:
log.warning('[{}] No plugin loaded from "{}"'.format( logging.warning('[{}] No plugin loaded from "{}"'.format(
self.session_name, root)) self.session_name, root))
# def get_initial_dialogs_chunk(self, offset_date: int = 0): # def get_initial_dialogs_chunk(self, offset_date: int = 0):
@ -1922,10 +1646,10 @@ class Client(Methods, BaseClient):
# ) # )
# ) # )
# except FloodWait as e: # except FloodWait as e:
# log.warning("get_dialogs flood: waiting {} seconds".format(e.x)) # logging.warning("get_dialogs flood: waiting {} seconds".format(e.x))
# time.sleep(e.x) # time.sleep(e.x)
# else: # else:
# log.info("Total peers: {}".format(self.storage.peers_count)) # logging.info("Total peers: {}".format(self.storage.peers_count))
# return r # return r
# #
# def get_initial_dialogs(self): # def get_initial_dialogs(self):
@ -1940,8 +1664,7 @@ class Client(Methods, BaseClient):
# #
# self.get_initial_dialogs_chunk() # self.get_initial_dialogs_chunk()
async def resolve_peer(self, async def resolve_peer(self, peer_id: Union[int, str]):
peer_id: Union[int, str]):
"""Get the InputPeer of a known peer id. """Get the InputPeer of a known peer id.
Useful whenever an InputPeer type is required. Useful whenever an InputPeer type is required.
@ -1980,9 +1703,11 @@ class Client(Methods, BaseClient):
try: try:
return self.storage.get_peer_by_username(peer_id) return self.storage.get_peer_by_username(peer_id)
except KeyError: except KeyError:
await self.send(functions.contacts.ResolveUsername(username=peer_id await self.send(
) functions.contacts.ResolveUsername(
) username=peer_id
)
)
return self.storage.get_peer_by_username(peer_id) return self.storage.get_peer_by_username(peer_id)
else: else:
@ -2094,7 +1819,7 @@ class Client(Methods, BaseClient):
try: try:
await asyncio.ensure_future(session.send(data)) await asyncio.ensure_future(session.send(data))
except Exception as e: except Exception as e:
log.error(e) logging.error(e)
part_size = 512 * 1024 part_size = 512 * 1024
file_size = os.path.getsize(path) file_size = os.path.getsize(path)
@ -2160,7 +1885,7 @@ class Client(Methods, BaseClient):
except Client.StopTransmission: except Client.StopTransmission:
raise raise
except Exception as e: except Exception as e:
log.error(e, exc_info=True) logging.error(e, exc_info=True)
else: else:
if is_big: if is_big:
return types.InputFileBig( return types.InputFileBig(
@ -2392,7 +2117,7 @@ class Client(Methods, BaseClient):
raise e raise e
except Exception as e: except Exception as e:
if not isinstance(e, Client.StopTransmission): if not isinstance(e, Client.StopTransmission):
log.error(e, exc_info=True) logging.error(e, exc_info=True)
try: try:
os.remove(file_name) os.remove(file_name)

View File

@ -50,7 +50,7 @@ class BaseClient:
INVITE_LINK_RE = re.compile(r"^(?:https?://)?(?:www\.)?(?:t(?:elegram)?\.(?:org|me|dog)/joinchat/)([\w-]+)$") INVITE_LINK_RE = re.compile(r"^(?:https?://)?(?:www\.)?(?:t(?:elegram)?\.(?:org|me|dog)/joinchat/)([\w-]+)$")
DIALOGS_AT_ONCE = 100 DIALOGS_AT_ONCE = 100
UPDATES_WORKERS = 1 UPDATES_WORKERS = 4
DOWNLOAD_WORKERS = 4 DOWNLOAD_WORKERS = 4
OFFLINE_SLEEP = 900 OFFLINE_SLEEP = 900
WORKERS = 4 WORKERS = 4

View File

@ -34,8 +34,6 @@ from ..handlers import (
UserStatusHandler, RawUpdateHandler, InlineQueryHandler, PollHandler UserStatusHandler, RawUpdateHandler, InlineQueryHandler, PollHandler
) )
log = logging.getLogger(__name__)
class Dispatcher: class Dispatcher:
NEW_MESSAGE_UPDATES = ( NEW_MESSAGE_UPDATES = (
@ -111,7 +109,7 @@ class Dispatcher:
asyncio.ensure_future(self.update_worker(self.locks_list[-1])) asyncio.ensure_future(self.update_worker(self.locks_list[-1]))
) )
log.info("Started {} UpdateWorkerTasks".format(self.workers)) logging.info("Started {} UpdateWorkerTasks".format(self.workers))
async def stop(self): async def stop(self):
for i in range(self.workers): for i in range(self.workers):
@ -123,7 +121,7 @@ class Dispatcher:
self.update_worker_tasks.clear() self.update_worker_tasks.clear()
self.groups.clear() self.groups.clear()
log.info("Stopped {} UpdateWorkerTasks".format(self.workers)) logging.info("Stopped {} UpdateWorkerTasks".format(self.workers))
def add_handler(self, handler, group: int): def add_handler(self, handler, group: int):
async def fn(): async def fn():
@ -185,7 +183,7 @@ class Dispatcher:
if handler.check(parsed_update): if handler.check(parsed_update):
args = (parsed_update,) args = (parsed_update,)
except Exception as e: except Exception as e:
log.error(e, exc_info=True) logging.error(e, exc_info=True)
continue continue
elif isinstance(handler, RawUpdateHandler): elif isinstance(handler, RawUpdateHandler):
@ -201,10 +199,10 @@ class Dispatcher:
except pyrogram.ContinuePropagation: except pyrogram.ContinuePropagation:
continue continue
except Exception as e: except Exception as e:
log.error(e, exc_info=True) logging.error(e, exc_info=True)
break break
except pyrogram.StopPropagation: except pyrogram.StopPropagation:
pass pass
except Exception as e: except Exception as e:
log.error(e, exc_info=True) logging.error(e, exc_info=True)

View File

@ -20,8 +20,6 @@ import asyncio
import logging import logging
import time import time
log = logging.getLogger(__name__)
class Syncer: class Syncer:
INTERVAL = 20 INTERVAL = 20
@ -83,9 +81,9 @@ class Syncer:
start = time.time() start = time.time()
client.storage.save() client.storage.save()
except Exception as e: except Exception as e:
log.critical(e, exc_info=True) logging.critical(e, exc_info=True)
else: else:
log.info('Synced "{}" in {:.6} ms'.format( logging.info('Synced "{}" in {:.6} ms'.format(
client.storage.name, client.storage.name,
(time.time() - start) * 1000 (time.time() - start) * 1000
)) ))

View File

@ -26,8 +26,6 @@ from pyrogram.errors import FloodWait
from ...ext import BaseClient from ...ext import BaseClient
log = logging.getLogger(__name__)
class Filters: class Filters:
ALL = "all" ALL = "all"
@ -154,7 +152,7 @@ class GetChatMembers(BaseClient):
return pyrogram.List(pyrogram.ChatMember._parse(self, member, users) for member in members) return pyrogram.List(pyrogram.ChatMember._parse(self, member, users) for member in members)
except FloodWait as e: except FloodWait as e:
log.warning("Sleeping for {}s".format(e.x)) logging.warning("Sleeping for {}s".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
raise ValueError("The chat_id \"{}\" belongs to a user".format(chat_id)) raise ValueError("The chat_id \"{}\" belongs to a user".format(chat_id))

View File

@ -26,8 +26,6 @@ from pyrogram.errors import FloodWait
from ...ext import BaseClient, utils from ...ext import BaseClient, utils
log = logging.getLogger(__name__)
class GetDialogs(BaseClient): class GetDialogs(BaseClient):
async def get_dialogs( async def get_dialogs(
@ -83,7 +81,7 @@ class GetDialogs(BaseClient):
) )
) )
except FloodWait as e: except FloodWait as e:
log.warning("Sleeping for {}s".format(e.x)) logging.warning("Sleeping for {}s".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
break break
@ -112,6 +110,6 @@ class GetDialogs(BaseClient):
if not isinstance(dialog, types.Dialog): if not isinstance(dialog, types.Dialog):
continue continue
parsed_dialogs.append(pyrogram.Dialog._parse(self, dialog, messages, users, chats)) parsed_dialogs.append(pyrogram.Dialogging._parse(self, dialog, messages, users, chats))
return pyrogram.List(parsed_dialogs) return pyrogram.List(parsed_dialogs)

View File

@ -25,8 +25,6 @@ from pyrogram.api import functions
from pyrogram.errors import FloodWait from pyrogram.errors import FloodWait
from ...ext import BaseClient from ...ext import BaseClient
log = logging.getLogger(__name__)
class GetContacts(BaseClient): class GetContacts(BaseClient):
async def get_contacts(self) -> List["pyrogram.User"]: async def get_contacts(self) -> List["pyrogram.User"]:
@ -45,7 +43,7 @@ class GetContacts(BaseClient):
try: try:
contacts = await self.send(functions.contacts.GetContacts(hash=0)) contacts = await self.send(functions.contacts.GetContacts(hash=0))
except FloodWait as e: except FloodWait as e:
log.warning("get_contacts flood: waiting {} seconds".format(e.x)) logging.warning("get_contacts flood: waiting {} seconds".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
return pyrogram.List(pyrogram.User._parse(self, user) for user in contacts.users) return pyrogram.List(pyrogram.User._parse(self, user) for user in contacts.users)

View File

@ -27,8 +27,6 @@ from pyrogram.errors import FloodWait
from ...ext import BaseClient from ...ext import BaseClient
log = logging.getLogger(__name__)
class GetHistory(BaseClient): class GetHistory(BaseClient):
async def get_history( async def get_history(
@ -104,7 +102,7 @@ class GetHistory(BaseClient):
) )
) )
except FloodWait as e: except FloodWait as e:
log.warning("Sleeping for {}s".format(e.x)) logging.warning("Sleeping for {}s".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
break break

View File

@ -16,14 +16,11 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
from typing import Union from typing import Union
from pyrogram.api import types, functions from pyrogram.api import types, functions
from pyrogram.client.ext import BaseClient from pyrogram.client.ext import BaseClient
log = logging.getLogger(__name__)
class GetHistoryCount(BaseClient): class GetHistoryCount(BaseClient):
async def get_history_count( async def get_history_count(

View File

@ -26,8 +26,6 @@ from pyrogram.errors import FloodWait
from ...ext import BaseClient, utils from ...ext import BaseClient, utils
log = logging.getLogger(__name__)
# TODO: Rewrite using a flag for replied messages and have message_ids non-optional # TODO: Rewrite using a flag for replied messages and have message_ids non-optional
@ -117,7 +115,7 @@ class GetMessages(BaseClient):
try: try:
r = await self.send(rpc) r = await self.send(rpc)
except FloodWait as e: except FloodWait as e:
log.warning("Sleeping for {}s".format(e.x)) logging.warning("Sleeping for {}s".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
break break

View File

@ -26,8 +26,6 @@ from pyrogram.api import functions, types
from pyrogram.client.ext import BaseClient, utils from pyrogram.client.ext import BaseClient, utils
from pyrogram.errors import FloodWait from pyrogram.errors import FloodWait
log = logging.getLogger(__name__)
class SendMediaGroup(BaseClient): class SendMediaGroup(BaseClient):
# TODO: Add progress parameter # TODO: Add progress parameter
@ -89,7 +87,7 @@ class SendMediaGroup(BaseClient):
) )
) )
except FloodWait as e: except FloodWait as e:
log.warning("Sleeping for {}s".format(e.x)) logging.warning("Sleeping for {}s".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
break break
@ -144,7 +142,7 @@ class SendMediaGroup(BaseClient):
) )
) )
except FloodWait as e: except FloodWait as e:
log.warning("Sleeping for {}s".format(e.x)) logging.warning("Sleeping for {}s".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
break break
@ -195,7 +193,7 @@ class SendMediaGroup(BaseClient):
) )
) )
except FloodWait as e: except FloodWait as e:
log.warning("Sleeping for {}s".format(e.x)) logging.warning("Sleeping for {}s".format(e.x))
await asyncio.sleep(e.x) await asyncio.sleep(e.x)
else: else:
break break

View File

@ -28,8 +28,6 @@ from pyrogram.api import types
from pyrogram.errors import PeerIdInvalid from pyrogram.errors import PeerIdInvalid
from . import utils from . import utils
log = logging.getLogger(__name__)
class Parser(HTMLParser): class Parser(HTMLParser):
MENTION_RE = re.compile(r"tg://user\?id=(\d+)") MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
@ -97,7 +95,7 @@ class Parser(HTMLParser):
line, offset = self.getpos() line, offset = self.getpos()
offset += 1 offset += 1
log.warning("Unmatched closing tag </{}> at line {}:{}".format(tag, line, offset)) logging.warning("Unmatched closing tag </{}> at line {}:{}".format(tag, line, offset))
else: else:
if not self.tag_entities[tag]: if not self.tag_entities[tag]:
self.tag_entities.pop(tag) self.tag_entities.pop(tag)
@ -123,7 +121,7 @@ class HTML:
for tag, entities in parser.tag_entities.items(): for tag, entities in parser.tag_entities.items():
unclosed_tags.append("<{}> (x{})".format(tag, len(entities))) unclosed_tags.append("<{}> (x{})".format(tag, len(entities)))
log.warning("Unclosed tags: {}".format(", ".join(unclosed_tags))) logging.warning("Unclosed tags: {}".format(", ".join(unclosed_tags)))
entities = [] entities = []

View File

@ -19,14 +19,13 @@
import base64 import base64
import json import json
import logging import logging
import os
import sqlite3 import sqlite3
from pathlib import Path from pathlib import Path
from threading import Lock from threading import Lock
from .memory_storage import MemoryStorage from .memory_storage import MemoryStorage
log = logging.getLogger(__name__)
class FileStorage(MemoryStorage): class FileStorage(MemoryStorage):
FILE_EXTENSION = ".session" FILE_EXTENSION = ".session"
@ -82,20 +81,20 @@ class FileStorage(MemoryStorage):
except ValueError: except ValueError:
pass pass
else: else:
log.warning("JSON session storage detected! Converting it into an SQLite session storage...") logging.warning("JSON session storage detected! Converting it into an SQLite session storage...")
path.rename(path.name + ".OLD") path.rename(path.name + ".OLD")
log.warning('The old session file has been renamed to "{}.OLD"'.format(path.name)) logging.warning('The old session file has been renamed to "{}.OLD"'.format(path.name))
self.migrate_from_json(session_json) self.migrate_from_json(session_json)
log.warning("Done! The session has been successfully converted from JSON to SQLite storage") logging.warning("Done! The session has been successfully converted from JSON to SQLite storage")
return return
if Path(path.name + ".OLD").is_file(): if Path(path.name + ".OLD").is_file():
log.warning('Old session file detected: "{}.OLD". You can remove this file now'.format(path.name)) logging.warning('Old session file detected: "{}.OLD". You can remove this file now'.format(path.name))
self.conn = sqlite3.connect( self.conn = sqlite3.connect(
str(path), str(path),
@ -108,3 +107,6 @@ class FileStorage(MemoryStorage):
with self.conn: with self.conn:
self.conn.execute("VACUUM") self.conn.execute("VACUUM")
def destroy(self):
os.remove(self.database)

View File

@ -18,7 +18,6 @@
import base64 import base64
import inspect import inspect
import logging
import sqlite3 import sqlite3
import struct import struct
import time import time
@ -29,8 +28,6 @@ from typing import List, Tuple
from pyrogram.api import types from pyrogram.api import types
from pyrogram.client.storage.storage import Storage from pyrogram.client.storage.storage import Storage
log = logging.getLogger(__name__)
class MemoryStorage(Storage): class MemoryStorage(Storage):
SCHEMA_VERSION = 1 SCHEMA_VERSION = 1
@ -97,6 +94,9 @@ class MemoryStorage(Storage):
with self.lock: with self.lock:
self.conn.close() self.conn.close()
def destroy(self):
pass
def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
with self.lock: with self.lock:
self.conn.executemany( self.conn.executemany(

View File

@ -30,6 +30,9 @@ class Storage:
def close(self): def close(self):
raise NotImplementedError raise NotImplementedError
def destroy(self):
raise NotImplementedError
def update_peers(self, peers): def update_peers(self, peers):
raise NotImplementedError raise NotImplementedError

View File

@ -22,8 +22,9 @@ from .chat_permissions import ChatPermissions
from .chat_photo import ChatPhoto from .chat_photo import ChatPhoto
from .chat_preview import ChatPreview from .chat_preview import ChatPreview
from .dialog import Dialog from .dialog import Dialog
from .restriction import Restriction
from .user import User from .user import User
__all__ = [ __all__ = [
"Chat", "ChatMember", "ChatPermissions", "ChatPhoto", "ChatPreview", "Dialog", "User" "Chat", "ChatMember", "ChatPermissions", "ChatPhoto", "ChatPreview", "Dialog", "User", "Restriction"
] ]

View File

@ -20,7 +20,6 @@ from struct import pack
import pyrogram import pyrogram
from pyrogram.api import types from pyrogram.api import types
from pyrogram.errors import PeerIdInvalid
from ..object import Object from ..object import Object
from ...ext.utils import encode from ...ext.utils import encode
@ -60,8 +59,10 @@ class ChatPhoto(Object):
loc_big = chat_photo.photo_big loc_big = chat_photo.photo_big
try: try:
peer = client.resolve_peer(peer_id) # We just want a local storage lookup by id, whose method is not async.
except PeerIdInvalid: # Otherwise we have to turn this _parse method async and also all the other methods that use this one.
peer = client.storage.get_peer_by_id(peer_id)
except KeyError:
return None return None
if isinstance(peer, types.InputPeerUser): if isinstance(peer, types.InputPeerUser):

View File

@ -22,8 +22,6 @@ import logging
from .transport import * from .transport import *
from ..session.internals import DataCenter from ..session.internals import DataCenter
log = logging.getLogger(__name__)
class Connection: class Connection:
MAX_RETRIES = 3 MAX_RETRIES = 3
@ -51,14 +49,14 @@ class Connection:
self.protocol = self.mode(self.ipv6, self.proxy) self.protocol = self.mode(self.ipv6, self.proxy)
try: try:
log.info("Connecting...") logging.info("Connecting...")
await self.protocol.connect(self.address) await self.protocol.connect(self.address)
except OSError as e: except OSError as e:
log.warning(e) # TODO: Remove logging.warning(e) # TODO: Remove
self.protocol.close() self.protocol.close()
await asyncio.sleep(1) await asyncio.sleep(1)
else: else:
log.info("Connected! {} DC{} - IPv{} - {}".format( logging.info("Connected! {} DC{} - IPv{} - {}".format(
"Test" if self.test_mode else "Production", "Test" if self.test_mode else "Production",
self.dc_id, self.dc_id,
"6" if self.ipv6 else "4", "6" if self.ipv6 else "4",
@ -66,12 +64,12 @@ class Connection:
)) ))
break break
else: else:
log.warning("Connection failed! Trying again...") logging.warning("Connection failed! Trying again...")
raise TimeoutError raise TimeoutError
def close(self): def close(self):
self.protocol.close() self.protocol.close()
log.info("Disconnected") logging.info("Disconnected")
async def send(self, data: bytes): async def send(self, data: bytes):
try: try:

View File

@ -31,8 +31,6 @@ except ImportError as e:
raise e raise e
log = logging.getLogger(__name__)
class TCP: class TCP:
TIMEOUT = 10 TIMEOUT = 10
@ -67,7 +65,7 @@ class TCP:
password=proxy.get("password", None) password=proxy.get("password", None)
) )
log.info("Using proxy {}:{}".format(hostname, port)) logging.info("Using proxy {}:{}".format(hostname, port))
else: else:
self.socket = socks.socksocket( self.socket = socks.socksocket(
socket.AF_INET6 if ipv6 socket.AF_INET6 if ipv6

View File

@ -16,12 +16,8 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
from .tcp import TCP from .tcp import TCP
log = logging.getLogger(__name__)
class TCPAbridged(TCP): class TCPAbridged(TCP):
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):

View File

@ -16,14 +16,11 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
import os import os
from .tcp import TCP from .tcp import TCP
from ....crypto.aes import AES from ....crypto.aes import AES
log = logging.getLogger(__name__)
class TCPAbridgedO(TCP): class TCPAbridgedO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)

View File

@ -16,14 +16,11 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
from binascii import crc32 from binascii import crc32
from struct import pack, unpack from struct import pack, unpack
from .tcp import TCP from .tcp import TCP
log = logging.getLogger(__name__)
class TCPFull(TCP): class TCPFull(TCP):
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):

View File

@ -16,13 +16,10 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
from struct import pack, unpack from struct import pack, unpack
from .tcp import TCP from .tcp import TCP
log = logging.getLogger(__name__)
class TCPIntermediate(TCP): class TCPIntermediate(TCP):
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):

View File

@ -16,15 +16,12 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
import os import os
from struct import pack, unpack from struct import pack, unpack
from .tcp import TCP from .tcp import TCP
from ....crypto.aes import AES from ....crypto.aes import AES
log = logging.getLogger(__name__)
class TCPIntermediateO(TCP): class TCPIntermediateO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)

View File

@ -18,12 +18,10 @@
import logging import logging
log = logging.getLogger(__name__)
try: try:
import tgcrypto import tgcrypto
log.info("Using TgCrypto") logging.info("Using TgCrypto")
class AES: class AES:
@ -53,7 +51,7 @@ try:
except ImportError: except ImportError:
import pyaes import pyaes
log.warning( logging.warning(
"TgCrypto is missing! " "TgCrypto is missing! "
"Pyrogram will work the same, but at a much slower speed. " "Pyrogram will work the same, but at a much slower speed. "
"More info: https://docs.pyrogram.org/topics/tgcrypto" "More info: https://docs.pyrogram.org/topics/tgcrypto"

View File

@ -30,8 +30,6 @@ from pyrogram.connection import Connection
from pyrogram.crypto import AES, RSA, Prime from pyrogram.crypto import AES, RSA, Prime
from .internals import MsgId from .internals import MsgId
log = logging.getLogger(__name__)
class Auth: class Auth:
MAX_RETRIES = 5 MAX_RETRIES = 5
@ -78,34 +76,34 @@ class Auth:
self.connection = Connection(self.dc_id, self.test_mode, self.ipv6, self.proxy) self.connection = Connection(self.dc_id, self.test_mode, self.ipv6, self.proxy)
try: try:
log.info("Start creating a new auth key on DC{}".format(self.dc_id)) logging.info("Start creating a new auth key on DC{}".format(self.dc_id))
await self.connection.connect() await self.connection.connect()
# Step 1; Step 2 # Step 1; Step 2
nonce = int.from_bytes(urandom(16), "little", signed=True) nonce = int.from_bytes(urandom(16), "little", signed=True)
log.debug("Send req_pq: {}".format(nonce)) logging.debug("Send req_pq: {}".format(nonce))
res_pq = await self.send(functions.ReqPqMulti(nonce=nonce)) res_pq = await self.send(functions.ReqPqMulti(nonce=nonce))
log.debug("Got ResPq: {}".format(res_pq.server_nonce)) logging.debug("Got ResPq: {}".format(res_pq.server_nonce))
log.debug("Server public key fingerprints: {}".format(res_pq.server_public_key_fingerprints)) logging.debug("Server public key fingerprints: {}".format(res_pq.server_public_key_fingerprints))
for i in res_pq.server_public_key_fingerprints: for i in res_pq.server_public_key_fingerprints:
if i in RSA.server_public_keys: if i in RSA.server_public_keys:
log.debug("Using fingerprint: {}".format(i)) logging.debug("Using fingerprint: {}".format(i))
public_key_fingerprint = i public_key_fingerprint = i
break break
else: else:
log.debug("Fingerprint unknown: {}".format(i)) logging.debug("Fingerprint unknown: {}".format(i))
else: else:
raise Exception("Public key not found") raise Exception("Public key not found")
# Step 3 # Step 3
pq = int.from_bytes(res_pq.pq, "big") pq = int.from_bytes(res_pq.pq, "big")
log.debug("Start PQ factorization: {}".format(pq)) logging.debug("Start PQ factorization: {}".format(pq))
start = time.time() start = time.time()
g = Prime.decompose(pq) g = Prime.decompose(pq)
p, q = sorted((g, pq // g)) # p < q p, q = sorted((g, pq // g)) # p < q
log.debug("Done PQ factorization ({}s): {} {}".format(round(time.time() - start, 3), p, q)) logging.debug("Done PQ factorization ({}s): {} {}".format(round(time.time() - start, 3), p, q))
# Step 4 # Step 4
server_nonce = res_pq.server_nonce server_nonce = res_pq.server_nonce
@ -125,10 +123,10 @@ class Auth:
data_with_hash = sha + data + padding data_with_hash = sha + data + padding
encrypted_data = RSA.encrypt(data_with_hash, public_key_fingerprint) encrypted_data = RSA.encrypt(data_with_hash, public_key_fingerprint)
log.debug("Done encrypt data with RSA") logging.debug("Done encrypt data with RSA")
# Step 5. TODO: Handle "server_DH_params_fail". Code assumes response is ok # Step 5. TODO: Handle "server_DH_params_fail". Code assumes response is ok
log.debug("Send req_DH_params") logging.debug("Send req_DH_params")
server_dh_params = await self.send( server_dh_params = await self.send(
functions.ReqDHParams( functions.ReqDHParams(
nonce=nonce, nonce=nonce,
@ -162,12 +160,12 @@ class Auth:
server_dh_inner_data = TLObject.read(BytesIO(answer)) server_dh_inner_data = TLObject.read(BytesIO(answer))
log.debug("Done decrypting answer") logging.debug("Done decrypting answer")
dh_prime = int.from_bytes(server_dh_inner_data.dh_prime, "big") dh_prime = int.from_bytes(server_dh_inner_data.dh_prime, "big")
delta_time = server_dh_inner_data.server_time - time.time() delta_time = server_dh_inner_data.server_time - time.time()
log.debug("Delta time: {}".format(round(delta_time, 3))) logging.debug("Delta time: {}".format(round(delta_time, 3)))
# Step 6 # Step 6
g = server_dh_inner_data.g g = server_dh_inner_data.g
@ -188,7 +186,7 @@ class Auth:
data_with_hash = sha + data + padding data_with_hash = sha + data + padding
encrypted_data = AES.ige256_encrypt(data_with_hash, tmp_aes_key, tmp_aes_iv) encrypted_data = AES.ige256_encrypt(data_with_hash, tmp_aes_key, tmp_aes_iv)
log.debug("Send set_client_DH_params") logging.debug("Send set_client_DH_params")
set_client_dh_params_answer = await self.send( set_client_dh_params_answer = await self.send(
functions.SetClientDHParams( functions.SetClientDHParams(
nonce=nonce, nonce=nonce,
@ -211,7 +209,7 @@ class Auth:
####################### #######################
assert dh_prime == Prime.CURRENT_DH_PRIME assert dh_prime == Prime.CURRENT_DH_PRIME
log.debug("DH parameters check: OK") logging.debug("DH parameters check: OK")
# https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation # https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation
g_b = int.from_bytes(g_b, "big") g_b = int.from_bytes(g_b, "big")
@ -220,12 +218,12 @@ class Auth:
assert 1 < g_b < dh_prime - 1 assert 1 < g_b < dh_prime - 1
assert 2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64) assert 2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64)
assert 2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64) assert 2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64)
log.debug("g_a and g_b validation: OK") logging.debug("g_a and g_b validation: OK")
# https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values # https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values
answer = server_dh_inner_data.write() # Call .write() to remove padding answer = server_dh_inner_data.write() # Call .write() to remove padding
assert answer_with_hash[:20] == sha1(answer).digest() assert answer_with_hash[:20] == sha1(answer).digest()
log.debug("SHA1 hash values check: OK") logging.debug("SHA1 hash values check: OK")
# https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields # https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields
# 1st message # 1st message
@ -238,14 +236,14 @@ class Auth:
assert nonce == set_client_dh_params_answer.nonce assert nonce == set_client_dh_params_answer.nonce
assert server_nonce == set_client_dh_params_answer.server_nonce assert server_nonce == set_client_dh_params_answer.server_nonce
server_nonce = server_nonce.to_bytes(16, "little", signed=True) server_nonce = server_nonce.to_bytes(16, "little", signed=True)
log.debug("Nonce fields check: OK") logging.debug("Nonce fields check: OK")
# Step 9 # Step 9
server_salt = AES.xor(new_nonce[:8], server_nonce[:8]) server_salt = AES.xor(new_nonce[:8], server_nonce[:8])
log.debug("Server salt: {}".format(int.from_bytes(server_salt, "little"))) logging.debug("Server salt: {}".format(int.from_bytes(server_salt, "little")))
log.info( logging.info(
"Done auth key exchange: {}".format( "Done auth key exchange: {}".format(
set_client_dh_params_answer.__class__.__name__ set_client_dh_params_answer.__class__.__name__
) )

View File

@ -32,8 +32,6 @@ from pyrogram.crypto import MTProto
from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated
from .internals import MsgId, MsgFactory from .internals import MsgId, MsgFactory
log = logging.getLogger(__name__)
class Result: class Result:
def __init__(self): def __init__(self):
@ -158,9 +156,9 @@ class Session:
self.ping_task = asyncio.ensure_future(self.ping()) self.ping_task = asyncio.ensure_future(self.ping())
log.info("Session initialized: Layer {}".format(layer)) logging.info("Session initialized: Layer {}".format(layer))
log.info("Device: {} - {}".format(self.client.device_model, self.client.app_version)) logging.info("Device: {} - {}".format(self.client.device_model, self.client.app_version))
log.info("System: {} ({})".format(self.client.system_version, self.client.lang_code.upper())) logging.info("System: {} ({})".format(self.client.system_version, self.client.lang_code.upper()))
except AuthKeyDuplicated as e: except AuthKeyDuplicated as e:
await self.stop() await self.stop()
@ -175,7 +173,7 @@ class Session:
self.is_connected.set() self.is_connected.set()
log.info("Session started") logging.info("Session started")
async def stop(self): async def stop(self):
self.is_connected.clear() self.is_connected.clear()
@ -207,16 +205,16 @@ class Session:
try: try:
await self.client.disconnect_handler(self.client) await self.client.disconnect_handler(self.client)
except Exception as e: except Exception as e:
log.error(e, exc_info=True) logging.error(e, exc_info=True)
log.info("Session stopped") logging.info("Session stopped")
async def restart(self): async def restart(self):
await self.stop() await self.stop()
await self.start() await self.start()
async def net_worker(self): async def net_worker(self):
log.info("NetWorkerTask started") logging.info("NetWorkerTask started")
while True: while True:
packet = await self.recv_queue.get() packet = await self.recv_queue.get()
@ -238,7 +236,7 @@ class Session:
else [data] else [data]
) )
log.debug(data) logging.debug(data)
for msg in messages: for msg in messages:
if msg.seq_no % 2 != 0: if msg.seq_no % 2 != 0:
@ -271,7 +269,7 @@ class Session:
self.results[msg_id].event.set() self.results[msg_id].event.set()
if len(self.pending_acks) >= self.ACKS_THRESHOLD: if len(self.pending_acks) >= self.ACKS_THRESHOLD:
log.info("Send {} acks".format(len(self.pending_acks))) logging.info("Send {} acks".format(len(self.pending_acks)))
try: try:
await self._send(types.MsgsAck(msg_ids=list(self.pending_acks)), False) await self._send(types.MsgsAck(msg_ids=list(self.pending_acks)), False)
@ -280,12 +278,12 @@ class Session:
else: else:
self.pending_acks.clear() self.pending_acks.clear()
except Exception as e: except Exception as e:
log.error(e, exc_info=True) logging.error(e, exc_info=True)
log.info("NetWorkerTask stopped") logging.info("NetWorkerTask stopped")
async def ping(self): async def ping(self):
log.info("PingTask started") logging.info("PingTask started")
while True: while True:
try: try:
@ -304,10 +302,10 @@ class Session:
except (OSError, TimeoutError, RPCError): except (OSError, TimeoutError, RPCError):
pass pass
log.info("PingTask stopped") logging.info("PingTask stopped")
async def next_salt(self): async def next_salt(self):
log.info("NextSaltTask started") logging.info("NextSaltTask started")
while True: while True:
now = datetime.now() now = datetime.now()
@ -317,7 +315,7 @@ class Session:
valid_until = datetime.fromtimestamp(self.current_salt.valid_until) valid_until = datetime.fromtimestamp(self.current_salt.valid_until)
dt = (valid_until - now).total_seconds() - 900 dt = (valid_until - now).total_seconds() - 900
log.info("Next salt in {:.0f}m {:.0f}s ({})".format( logging.info("Next salt in {:.0f}m {:.0f}s ({})".format(
dt // 60, dt % 60, dt // 60, dt % 60,
now + timedelta(seconds=dt) now + timedelta(seconds=dt)
)) ))
@ -335,10 +333,10 @@ class Session:
self.connection.close() self.connection.close()
break break
log.info("NextSaltTask stopped") logging.info("NextSaltTask stopped")
async def recv(self): async def recv(self):
log.info("RecvTask started") logging.info("RecvTask started")
while True: while True:
packet = await self.connection.recv() packet = await self.connection.recv()
@ -347,7 +345,7 @@ class Session:
self.recv_queue.put_nowait(None) self.recv_queue.put_nowait(None)
if packet: if packet:
log.warning("Server sent \"{}\"".format(Int.read(BytesIO(packet)))) logging.warning("Server sent \"{}\"".format(Int.read(BytesIO(packet))))
if self.is_connected.is_set(): if self.is_connected.is_set():
asyncio.ensure_future(self.restart()) asyncio.ensure_future(self.restart())
@ -356,7 +354,7 @@ class Session:
self.recv_queue.put_nowait(packet) self.recv_queue.put_nowait(packet)
log.info("RecvTask stopped") logging.info("RecvTask stopped")
async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT): async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT):
message = self.msg_factory(data) message = self.msg_factory(data)