2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-27 20:37:54 +00:00

Move update handler into Client

This commit is contained in:
Dan 2018-02-08 20:46:47 +01:00
parent 7cee6b079f
commit d8edfb38bf
2 changed files with 37 additions and 14 deletions

View File

@ -23,12 +23,14 @@ import math
import mimetypes import mimetypes
import os import os
import re import re
import threading
import time import time
from collections import namedtuple from collections import namedtuple
from configparser import ConfigParser from configparser import ConfigParser
from hashlib import sha256, md5 from hashlib import sha256, md5
from queue import Queue
from signal import signal, SIGINT, SIGTERM, SIGABRT from signal import signal, SIGINT, SIGTERM, SIGABRT
from threading import Event from threading import Event, Thread
from pyrogram.api import functions, types from pyrogram.api import functions, types
from pyrogram.api.core import Object from pyrogram.api.core import Object
@ -141,6 +143,8 @@ class Client:
self.update_handler = None self.update_handler = None
self.is_idle = Event() self.is_idle = Event()
self.update_queue = Queue()
def start(self): def start(self):
"""Use this method to start the Client after creating it. """Use this method to start the Client after creating it.
Requires no parameters. Requires no parameters.
@ -156,7 +160,8 @@ class Client:
self.test_mode, self.test_mode,
self.proxy, self.proxy,
self.auth_key, self.auth_key,
self.config.api_id self.config.api_id,
client=self
) )
terms = self.session.start() terms = self.session.start()
@ -170,7 +175,9 @@ class Client:
self.rnd_id = self.session.msg_id self.rnd_id = self.session.msg_id
self.get_dialogs() self.get_dialogs()
self.session.set_update_handler(self, self.update_handler)
for i in range(self.workers):
Thread(target=self.update_worker, name="UpdateWorker#{}".format(i + 1)).start()
mimetypes.init() mimetypes.init()
@ -180,6 +187,26 @@ class Client:
""" """
self.session.stop() self.session.stop()
for i in range(self.workers):
self.update_queue.put(None)
def update_worker(self):
name = threading.current_thread().name
log.debug("{} started".format(name))
while True:
update = self.update_queue.get()
if update is None:
break
try:
self.update_handler(self, update)
except Exception as e:
log.error(e, exc_info=True)
log.debug("{} stopped".format(name))
def signal_handler(self, *args): def signal_handler(self, *args):
self.stop() self.stop()
self.is_idle.set() self.is_idle.set()
@ -261,7 +288,8 @@ class Client:
self.test_mode, self.test_mode,
self.proxy, self.proxy,
self.auth_key, self.auth_key,
self.config.api_id self.config.api_id,
client=self
) )
self.session.start() self.session.start()

View File

@ -74,7 +74,8 @@ class Session:
proxy: type, proxy: type,
auth_key: bytes, auth_key: bytes,
api_id: str, api_id: str,
is_cdn: bool = False): is_cdn: bool = False,
client: pyrogram = None):
if not Session.notice_displayed: if not Session.notice_displayed:
print("Pyrogram v{}, {}".format(__version__, __copyright__)) print("Pyrogram v{}, {}".format(__version__, __copyright__))
print("Licensed under the terms of the " + __license__, end="\n\n") print("Licensed under the terms of the " + __license__, end="\n\n")
@ -83,6 +84,7 @@ class Session:
self.connection = Connection(DataCenter(dc_id, test_mode), proxy) self.connection = Connection(DataCenter(dc_id, test_mode), proxy)
self.api_id = api_id self.api_id = api_id
self.is_cdn = is_cdn self.is_cdn = is_cdn
self.client = client
self.auth_key = auth_key self.auth_key = auth_key
self.auth_key_id = sha1(auth_key).digest()[-8:] self.auth_key_id = sha1(auth_key).digest()[-8:]
@ -106,9 +108,6 @@ class Session:
self.is_connected = Event() self.is_connected = Event()
self.client = None
self.update_handler = None
def start(self): def start(self):
terms = None terms = None
@ -236,10 +235,6 @@ class Session:
log.debug("{} stopped".format(name)) log.debug("{} stopped".format(name))
def set_update_handler(self, client: pyrogram, update_handler: callable):
self.client = client
self.update_handler = update_handler
def unpack_dispatch_and_ack(self, packet: bytes): def unpack_dispatch_and_ack(self, packet: bytes):
data = self.unpack(BytesIO(packet)) data = self.unpack(BytesIO(packet))
@ -274,8 +269,8 @@ class Session:
elif isinstance(msg.body, types.Pong): elif isinstance(msg.body, types.Pong):
msg_id = msg.body.msg_id msg_id = msg.body.msg_id
else: else:
if self.update_handler: if self.client is not None:
self.update_handler(self.client, msg.body) self.client.update_queue.put(msg.body)
if msg_id in self.results: if msg_id in self.results:
self.results[msg_id].value = getattr(msg.body, "result", msg.body) self.results[msg_id].value = getattr(msg.body, "result", msg.body)