2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 21:07:59 +00:00

Do not consume async gens, turn them to gens instead

This commit is contained in:
Dan 2022-02-10 01:08:11 +01:00
parent 462e5d11a5
commit 921d87304f

View File

@ -30,8 +30,23 @@ def async_to_sync(obj, name):
function = getattr(obj, name) function = getattr(obj, name)
main_loop = asyncio.get_event_loop() main_loop = asyncio.get_event_loop()
async def consume_generator(coroutine): def async_to_sync_gen(agen, loop, is_main_thread):
return types.List([i async for i in coroutine]) async def anext(agen):
try:
return await agen.__anext__(), False
except StopAsyncIteration:
return None, True
while True:
if is_main_thread:
item, done = loop.run_until_complete(anext(agen))
else:
item, done = asyncio.run_coroutine_threadsafe(anext(agen), loop).result()
if done:
break
yield item
@functools.wraps(function) @functools.wraps(function)
def async_to_sync_wrap(*args, **kwargs): def async_to_sync_wrap(*args, **kwargs):
@ -51,7 +66,7 @@ def async_to_sync(obj, name):
return loop.run_until_complete(coroutine) return loop.run_until_complete(coroutine)
if inspect.isasyncgen(coroutine): if inspect.isasyncgen(coroutine):
return loop.run_until_complete(consume_generator(coroutine)) return async_to_sync_gen(coroutine, loop, True)
else: else:
if inspect.iscoroutine(coroutine): if inspect.iscoroutine(coroutine):
if loop.is_running(): if loop.is_running():
@ -66,7 +81,7 @@ def async_to_sync(obj, name):
if loop.is_running(): if loop.is_running():
return coroutine return coroutine
else: else:
return asyncio.run_coroutine_threadsafe(consume_generator(coroutine), main_loop).result() return async_to_sync_gen(coroutine, main_loop, False)
setattr(obj, name, async_to_sync_wrap) setattr(obj, name, async_to_sync_wrap)