2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-09-05 00:35:10 +00:00

Client becomes async

This commit is contained in:
Dan
2018-06-20 11:41:22 +02:00
parent 399a7b6403
commit 6fcf41d857

View File

@@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import base64
import binascii
import getpass
@@ -28,7 +29,6 @@ import re
import shutil
import struct
import tempfile
import threading
import time
from configparser import ConfigParser
from datetime import datetime
@@ -43,11 +43,11 @@ from pyrogram.api.errors import (
PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded,
PasswordHashInvalid, FloodWait, PeerIdInvalid, FirstnameInvalid, PhoneNumberBanned,
VolumeLocNotFound, UserMigrate, FileIdInvalid)
from pyrogram.client.handlers import DisconnectHandler
from pyrogram.crypto import AES
from pyrogram.session import Auth, Session
from .dispatcher import Dispatcher
from .ext import utils, Syncer, BaseClient
from .ext import BaseClient, Syncer, utils
from .handlers import DisconnectHandler
from .methods import Methods
# Custom format for nice looking log lines
@@ -114,7 +114,7 @@ class Client(Methods, BaseClient):
be an empty string: "". Only applicable for new sessions.
workers (``int``, *optional*):
Thread pool size for handling incoming updates. Defaults to 4.
Number of maximum concurrent workers for handling incoming updates. Defaults to 4.
workdir (``str``, *optional*):
Define a custom working directory. The working directory is the location in your filesystem
@@ -168,15 +168,10 @@ class Client(Methods, BaseClient):
self._proxy["enabled"] = True
self._proxy.update(value)
async def start(self, debug: bool = False):
async def start(self):
"""Use this method to start the Client after creating it.
Requires no parameters.
Args:
debug (``bool``, *optional*):
Enable or disable debug mode. When enabled, extra logging
lines will be printed out on your console.
Raises:
:class:`Error <pyrogram.Error>`
"""
@@ -188,7 +183,7 @@ class Client(Methods, BaseClient):
self.session_name = self.session_name.split(":")[0]
self.load_config()
self.load_session()
await self.load_session()
self.session = Session(
self.dc_id,
@@ -204,9 +199,9 @@ class Client(Methods, BaseClient):
if self.user_id is None:
if self.token is None:
self.authorize_user()
await self.authorize_user()
else:
self.authorize_bot()
await self.authorize_bot()
self.save_session()
@@ -217,38 +212,27 @@ class Client(Methods, BaseClient):
self.peers_by_username = {}
self.peers_by_phone = {}
self.get_dialogs()
self.get_contacts()
await self.get_dialogs()
await self.get_contacts()
else:
self.send(functions.messages.GetPinnedDialogs())
self.get_dialogs_chunk(0)
await self.send(functions.messages.GetPinnedDialogs())
await self.get_dialogs_chunk(0)
else:
await self.send(functions.updates.GetState())
# for i in range(self.UPDATES_WORKERS):
# self.updates_workers_list.append(
# Thread(
# target=self.updates_worker,
# name="UpdatesWorker#{}".format(i + 1)
# )
# )
#
# self.updates_workers_list[-1].start()
#
# for i in range(self.DOWNLOAD_WORKERS):
# self.download_workers_list.append(
# Thread(
# target=self.download_worker,
# name="DownloadWorker#{}".format(i + 1)
# )
# )
#
# self.download_workers_list[-1].start()
#
# self.dispatcher.start()
self.updates_worker_task = asyncio.ensure_future(self.updates_worker())
for _ in range(Client.DOWNLOAD_WORKERS):
self.download_worker_tasks.append(
asyncio.ensure_future(self.download_worker())
)
log.info("Started {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS))
await self.dispatcher.start()
await Syncer.add(self)
mimetypes.init()
# Syncer.add(self)
async def stop(self):
"""Use this method to manually stop the Client.
@@ -257,29 +241,26 @@ class Client(Methods, BaseClient):
if not self.is_started:
raise ConnectionError("Client is already stopped")
# Syncer.remove(self)
# self.dispatcher.stop()
#
# for _ in range(self.DOWNLOAD_WORKERS):
# self.download_queue.put(None)
#
# for i in self.download_workers_list:
# i.join()
#
# self.download_workers_list.clear()
#
# for _ in range(self.UPDATES_WORKERS):
# self.updates_queue.put(None)
#
# for i in self.updates_workers_list:
# i.join()
#
# self.updates_workers_list.clear()
#
# for i in self.media_sessions.values():
# i.stop()
#
# self.media_sessions.clear()
await Syncer.remove(self)
await self.dispatcher.stop()
for _ in range(Client.DOWNLOAD_WORKERS):
self.download_queue.put_nowait(None)
for task in self.download_worker_tasks:
await task
self.download_worker_tasks.clear()
log.info("Stopped {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS))
self.updates_queue.put_nowait(None)
await self.updates_worker_task
for media_session in self.media_sessions.values():
await media_session.stop()
self.media_sessions.clear()
self.is_started = False
await self.session.stop()
@@ -327,9 +308,9 @@ class Client(Methods, BaseClient):
else:
self.dispatcher.remove_handler(handler, group)
def authorize_bot(self):
async def authorize_bot(self):
try:
r = self.send(
r = await self.send(
functions.auth.ImportBotAuthorization(
flags=0,
api_id=self.api_id,
@@ -338,10 +319,10 @@ class Client(Methods, BaseClient):
)
)
except UserMigrate as e:
self.session.stop()
await self.session.stop()
self.dc_id = e.x
self.auth_key = Auth(self.dc_id, self.test_mode, self._proxy).create()
self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create()
self.session = Session(
self.dc_id,
@@ -352,12 +333,12 @@ class Client(Methods, BaseClient):
client=self
)
self.session.start()
self.authorize_bot()
await self.session.start()
await self.authorize_bot()
else:
self.user_id = r.user.id
def authorize_user(self):
async def authorize_user(self):
phone_number_invalid_raises = self.phone_number is not None
phone_code_invalid_raises = self.phone_code is not None
password_hash_invalid_raises = self.password is not None
@@ -378,7 +359,7 @@ class Client(Methods, BaseClient):
self.phone_number = self.phone_number.strip("+")
try:
r = self.send(
r = await self.send(
functions.auth.SendCode(
self.phone_number,
self.api_id,
@@ -386,10 +367,10 @@ class Client(Methods, BaseClient):
)
)
except (PhoneMigrate, NetworkMigrate) as e:
self.session.stop()
await self.session.stop()
self.dc_id = e.x
self.auth_key = Auth(self.dc_id, self.test_mode, self._proxy).create()
self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create()
self.session = Session(
self.dc_id,
@@ -399,9 +380,9 @@ class Client(Methods, BaseClient):
self.api_id,
client=self
)
self.session.start()
await self.session.start()
r = self.send(
r = await self.send(
functions.auth.SendCode(
self.phone_number,
self.api_id,
@@ -430,7 +411,7 @@ class Client(Methods, BaseClient):
phone_code_hash = r.phone_code_hash
if self.force_sms:
self.send(
await self.send(
functions.auth.ResendCode(
phone_number=self.phone_number,
phone_code_hash=phone_code_hash
@@ -446,7 +427,7 @@ class Client(Methods, BaseClient):
try:
if phone_registered:
r = self.send(
r = await self.send(
functions.auth.SignIn(
self.phone_number,
phone_code_hash,
@@ -455,7 +436,7 @@ class Client(Methods, BaseClient):
)
else:
try:
self.send(
await self.send(
functions.auth.SignIn(
self.phone_number,
phone_code_hash,
@@ -468,7 +449,7 @@ class Client(Methods, BaseClient):
self.first_name = self.first_name if self.first_name is not None else input("First name: ")
self.last_name = self.last_name if self.last_name is not None else input("Last name: ")
r = self.send(
r = await self.send(
functions.auth.SignUp(
self.phone_number,
phone_code_hash,
@@ -491,7 +472,7 @@ class Client(Methods, BaseClient):
self.first_name = None
except SessionPasswordNeeded as e:
print(e.MESSAGE)
r = self.send(functions.account.GetPassword())
r = await self.send(functions.account.GetPassword())
while True:
try:
@@ -505,7 +486,7 @@ class Client(Methods, BaseClient):
password_hash = sha256(self.password).digest()
r = self.send(functions.auth.CheckPassword(password_hash))
r = await self.send(functions.auth.CheckPassword(password_hash))
except PasswordHashInvalid as e:
if password_hash_invalid_raises:
raise
@@ -594,12 +575,9 @@ class Client(Methods, BaseClient):
if username is not None:
self.peers_by_username[username.lower()] = input_peer
def download_worker(self):
name = threading.current_thread().name
log.debug("{} started".format(name))
async def download_worker(self):
while True:
media = self.download_queue.get()
media = await self.download_queue.get()
if media is None:
break
@@ -666,7 +644,7 @@ class Client(Methods, BaseClient):
extension
)
temp_file_path = self.get_file(
temp_file_path = await self.get_file(
dc_id=dc_id,
id=id,
access_hash=access_hash,
@@ -697,14 +675,11 @@ class Client(Methods, BaseClient):
finally:
done.set()
log.debug("{} stopped".format(name))
def updates_worker(self):
name = threading.current_thread().name
log.debug("{} started".format(name))
async def updates_worker(self):
log.info("UpdatesWorkerTask started")
while True:
updates = self.updates_queue.get()
updates = await self.updates_queue.get()
if updates is None:
break
@@ -730,9 +705,9 @@ class Client(Methods, BaseClient):
message = update.message
if not isinstance(message, types.MessageEmpty):
diff = self.send(
diff = await self.send(
functions.updates.GetChannelDifference(
channel=self.resolve_peer(int("-100" + str(channel_id))),
channel=await self.resolve_peer(int("-100" + str(channel_id))),
filter=types.ChannelMessagesFilter(
ranges=[types.MessageRange(
min_id=update.message.id,
@@ -760,9 +735,9 @@ class Client(Methods, BaseClient):
if len(self.channels_pts[channel_id]) > 50:
self.channels_pts[channel_id] = self.channels_pts[channel_id][25:]
self.dispatcher.updates.put((update, updates.users, updates.chats))
self.dispatcher.updates.put_nowait((update, updates.users, updates.chats))
elif isinstance(updates, (types.UpdateShortMessage, types.UpdateShortChatMessage)):
diff = self.send(
diff = await self.send(
functions.updates.GetDifference(
pts=updates.pts - updates.pts_count,
date=updates.date,
@@ -771,7 +746,7 @@ class Client(Methods, BaseClient):
)
if diff.new_messages:
self.dispatcher.updates.put((
self.dispatcher.updates.put_nowait((
types.UpdateNewMessage(
message=diff.new_messages[0],
pts=updates.pts,
@@ -781,18 +756,19 @@ class Client(Methods, BaseClient):
diff.chats
))
else:
self.dispatcher.updates.put((diff.other_updates[0], [], []))
self.dispatcher.updates.put_nowait((diff.other_updates[0], [], []))
elif isinstance(updates, types.UpdateShort):
self.dispatcher.updates.put((updates.update, [], []))
self.dispatcher.updates.put_nowait((updates.update, [], []))
except Exception as e:
log.error(e, exc_info=True)
log.debug("{} stopped".format(name))
log.info("UpdatesWorkerTask stopped")
def signal_handler(self, *args):
log.info("Stop signal received ({}). Exiting...".format(args[0]))
self.is_idle = False
def idle(self, stop_signals: tuple = (SIGINT, SIGTERM, SIGABRT)):
async def idle(self, stop_signals: tuple = (SIGINT, SIGTERM, SIGABRT)):
"""Blocks the program execution until one of the signals are received,
then gently stop the Client by closing the underlying connection.
@@ -807,9 +783,9 @@ class Client(Methods, BaseClient):
self.is_idle = True
while self.is_idle:
time.sleep(1)
await asyncio.sleep(1)
self.stop()
await self.stop()
async def send(self, data: Object):
"""Use this method to send Raw Function queries.
@@ -863,14 +839,14 @@ class Client(Methods, BaseClient):
self._proxy["username"] = parser.get("proxy", "username", fallback=None) or None
self._proxy["password"] = parser.get("proxy", "password", fallback=None) or None
def load_session(self):
async def load_session(self):
try:
with open(os.path.join(self.workdir, "{}.session".format(self.session_name)), encoding="utf-8") as f:
s = json.load(f)
except FileNotFoundError:
self.dc_id = 1
self.date = 0
self.auth_key = Auth(self.dc_id, self.test_mode, self._proxy).create()
self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create()
else:
self.dc_id = s["dc_id"]
self.test_mode = s["test_mode"]
@@ -912,10 +888,10 @@ class Client(Methods, BaseClient):
indent=4
)
def get_dialogs_chunk(self, offset_date):
async def get_dialogs_chunk(self, offset_date):
while True:
try:
r = self.send(
r = await self.send(
functions.messages.GetDialogs(
offset_date, 0, types.InputPeerEmpty(),
self.DIALOGS_AT_ONCE, True
@@ -923,24 +899,24 @@ class Client(Methods, BaseClient):
)
except FloodWait as e:
log.warning("get_dialogs flood: waiting {} seconds".format(e.x))
time.sleep(e.x)
await asyncio.sleep(e.x)
else:
log.info("Total peers: {}".format(len(self.peers_by_id)))
return r
def get_dialogs(self):
self.send(functions.messages.GetPinnedDialogs())
async def get_dialogs(self):
await self.send(functions.messages.GetPinnedDialogs())
dialogs = self.get_dialogs_chunk(0)
dialogs = await self.get_dialogs_chunk(0)
offset_date = utils.get_offset_date(dialogs)
while len(dialogs.dialogs) == self.DIALOGS_AT_ONCE:
dialogs = self.get_dialogs_chunk(offset_date)
dialogs = await self.get_dialogs_chunk(offset_date)
offset_date = utils.get_offset_date(dialogs)
self.get_dialogs_chunk(0)
await self.get_dialogs_chunk(0)
def resolve_peer(self, peer_id: int or str):
async def resolve_peer(self, peer_id: int or str):
"""Use this method to get the *InputPeer* of a known *peer_id*.
It is intended to be used when working with Raw Functions (i.e: a Telegram API method you wish to use which is
@@ -968,7 +944,7 @@ class Client(Methods, BaseClient):
try:
decoded = base64.b64decode(match.group(1) + "=" * (-len(match.group(1)) % 4), "-_")
return self.resolve_peer(struct.unpack(">2iq", decoded)[1])
return await self.resolve_peer(struct.unpack(">2iq", decoded)[1])
except (AttributeError, binascii.Error, struct.error):
pass
@@ -980,7 +956,7 @@ class Client(Methods, BaseClient):
try:
return self.peers_by_username[peer_id]
except KeyError:
self.send(functions.contacts.ResolveUsername(peer_id))
await self.send(functions.contacts.ResolveUsername(peer_id))
return self.peers_by_username[peer_id]
else:
try:
@@ -1007,12 +983,12 @@ class Client(Methods, BaseClient):
except (KeyError, ValueError):
raise PeerIdInvalid
def save_file(self,
path: str,
file_id: int = None,
file_part: int = 0,
progress: callable = None,
progress_args: tuple = ()):
async def save_file(self,
path: str,
file_id: int = None,
file_part: int = 0,
progress: callable = None,
progress_args: tuple = ()):
part_size = 512 * 1024
file_size = os.path.getsize(path)
file_total_parts = int(math.ceil(file_size / part_size))
@@ -1022,7 +998,7 @@ class Client(Methods, BaseClient):
md5_sum = md5() if not is_big and not is_missing_part else None
session = Session(self.dc_id, self.test_mode, self._proxy, self.auth_key, self.api_id)
session.start()
await session.start()
try:
with open(path, "rb") as f:
@@ -1050,7 +1026,7 @@ class Client(Methods, BaseClient):
bytes=chunk
)
assert self.send(rpc), "Couldn't upload file"
assert await session.send(rpc), "Couldn't upload file"
if is_missing_part:
return
@@ -1080,25 +1056,25 @@ class Client(Methods, BaseClient):
md5_checksum=md5_sum
)
finally:
session.stop()
await session.stop()
def get_file(self,
dc_id: int,
id: int = None,
access_hash: int = None,
volume_id: int = None,
local_id: int = None,
secret: int = None,
version: int = 0,
size: int = None,
progress: callable = None,
progress_args: tuple = None) -> str:
with self.media_sessions_lock:
async def get_file(self,
dc_id: int,
id: int = None,
access_hash: int = None,
volume_id: int = None,
local_id: int = None,
secret: int = None,
version: int = 0,
size: int = None,
progress: callable = None,
progress_args: tuple = None) -> str:
with await self.media_sessions_lock:
session = self.media_sessions.get(dc_id, None)
if session is None:
if dc_id != self.dc_id:
exported_auth = self.send(
exported_auth = await self.send(
functions.auth.ExportAuthorization(
dc_id=dc_id
)
@@ -1108,15 +1084,15 @@ class Client(Methods, BaseClient):
dc_id,
self.test_mode,
self._proxy,
Auth(dc_id, self.test_mode, self._proxy).create(),
await Auth(dc_id, self.test_mode, self._proxy).create(),
self.api_id
)
session.start()
await session.start()
self.media_sessions[dc_id] = session
session.send(
await session.send(
functions.auth.ImportAuthorization(
id=exported_auth.id,
bytes=exported_auth.bytes
@@ -1131,7 +1107,7 @@ class Client(Methods, BaseClient):
self.api_id
)
session.start()
await session.start()
self.media_sessions[dc_id] = session
@@ -1153,7 +1129,7 @@ class Client(Methods, BaseClient):
file_name = ""
try:
r = session.send(
r = await session.send(
functions.upload.GetFile(
location=location,
offset=offset,
@@ -1180,7 +1156,7 @@ class Client(Methods, BaseClient):
if progress:
progress(self, min(offset, size), size, *progress_args)
r = session.send(
r = await session.send(
functions.upload.GetFile(
location=location,
offset=offset,
@@ -1189,7 +1165,7 @@ class Client(Methods, BaseClient):
)
elif isinstance(r, types.upload.FileCdnRedirect):
with self.media_sessions_lock:
with await self.media_sessions_lock:
cdn_session = self.media_sessions.get(r.dc_id, None)
if cdn_session is None:
@@ -1197,12 +1173,12 @@ class Client(Methods, BaseClient):
r.dc_id,
self.test_mode,
self._proxy,
Auth(r.dc_id, self.test_mode, self._proxy).create(),
await Auth(r.dc_id, self.test_mode, self._proxy).create(),
self.api_id,
is_cdn=True
)
cdn_session.start()
await cdn_session.start()
self.media_sessions[r.dc_id] = cdn_session
@@ -1211,7 +1187,7 @@ class Client(Methods, BaseClient):
file_name = f.name
while True:
r2 = cdn_session.send(
r2 = await cdn_session.send(
functions.upload.GetCdnFile(
file_token=r.file_token,
offset=offset,
@@ -1221,7 +1197,7 @@ class Client(Methods, BaseClient):
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
try:
session.send(
await session.send(
functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
@@ -1244,7 +1220,7 @@ class Client(Methods, BaseClient):
)
)
hashes = session.send(
hashes = await session.send(
functions.upload.GetCdnFileHashes(
r.file_token,
offset