2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-22 18:07:21 +00:00

Switch to non-blocking sockets & use a send queue

This commit is contained in:
Dan 2022-12-24 16:15:07 +01:00
parent f350691c69
commit 84d60b56b3
4 changed files with 26 additions and 14 deletions

View File

@ -57,7 +57,7 @@ class Connection:
await self.protocol.connect(self.address) await self.protocol.connect(self.address)
except OSError as e: except OSError as e:
log.warning(f"Unable to connect due to network issues: {e}") log.warning(f"Unable to connect due to network issues: {e}")
self.protocol.close() await self.protocol.close()
await asyncio.sleep(1) await asyncio.sleep(1)
else: else:
log.info("Connected! {} DC{}{} - IPv{} - {}".format( log.info("Connected! {} DC{}{} - IPv{} - {}".format(
@ -72,8 +72,8 @@ class Connection:
log.warning("Connection failed! Trying again...") log.warning("Connection failed! Trying again...")
raise TimeoutError raise TimeoutError
def close(self): async def close(self):
self.protocol.close() await self.protocol.close()
log.info("Disconnected") log.info("Disconnected")
async def send(self, data: bytes): async def send(self, data: bytes):

View File

@ -21,7 +21,7 @@ import ipaddress
import logging import logging
import socket import socket
import time import time
from concurrent.futures import ThreadPoolExecutor from typing import Optional
try: try:
import socks import socks
@ -76,17 +76,21 @@ class TCP:
else socket.AF_INET else socket.AF_INET
) )
self.socket.setblocking(False)
self.socket.settimeout(TCP.TIMEOUT) self.socket.settimeout(TCP.TIMEOUT)
self.send_queue = asyncio.Queue()
self.send_task = None
async def connect(self, address: tuple): async def connect(self, address: tuple):
# The socket used by the whole logic is blocking and thus it blocks when connecting. await asyncio.get_event_loop().sock_connect(self.socket, address)
# Offload the task to a thread executor to avoid blocking the main event loop.
with ThreadPoolExecutor(1) as executor:
await self.loop.run_in_executor(executor, self.socket.connect, address)
self.reader, self.writer = await asyncio.open_connection(sock=self.socket) self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
self.send_task = asyncio.create_task(self.send_worker())
async def close(self):
await self.send_queue.put(None)
await self.send_task
def close(self):
try: try:
self.writer.close() self.writer.close()
except AttributeError: except AttributeError:
@ -100,8 +104,16 @@ class TCP:
time.sleep(0.001) time.sleep(0.001)
self.socket.close() self.socket.close()
async def send(self, data: bytes): async def send(self, data: Optional[bytes]):
async with self.lock: await self.send_queue.put(data)
async def send_worker(self):
while True:
data = await self.send_queue.get()
if data is None:
break
self.writer.write(data) self.writer.write(data)
await self.writer.drain() await self.writer.drain()

View File

@ -258,4 +258,4 @@ class Auth:
else: else:
return auth_key return auth_key
finally: finally:
self.connection.close() await self.connection.close()

View File

@ -157,7 +157,7 @@ class Session:
self.ping_task_event.clear() self.ping_task_event.clear()
self.connection.close() await self.connection.close()
if self.network_task: if self.network_task:
await self.network_task await self.network_task