diff --git a/ubot/fixes/fast_telethon.py b/ubot/fixes/fast_telethon.py index 9116838..e35f570 100644 --- a/ubot/fixes/fast_telethon.py +++ b/ubot/fixes/fast_telethon.py @@ -327,7 +327,7 @@ async def upload_file(client: TelegramClient, file: Union[BinaryIO, BytesIO, str if isinstance(file, str): with open(file, mode="rb") as fh: res = (await _internal_transfer_to_telegram(client, fh, progress_callback, file))[0] - if isinstance(file, BytesIO): + elif isinstance(file, BytesIO): res = (await _internal_transfer_to_telegram(client, file, progress_callback, file.name))[0] else: res = (await _internal_transfer_to_telegram(client, file, progress_callback))[0] diff --git a/ubot/fixes/parallel_download.py b/ubot/fixes/parallel_download.py index df984bf..2e9e733 100644 --- a/ubot/fixes/parallel_download.py +++ b/ubot/fixes/parallel_download.py @@ -2,52 +2,71 @@ import mimetypes from asyncio import gather -from io import BytesIO +from os import remove +import aiofiles from aiohttp import ClientSession class ParallelDownload: - def __init__(self, url: str, aioclient: ClientSession): + def __init__(self, url: str, aioclient: ClientSession, file_name: str): self.url = url self.aioclient = aioclient + self.file_name = file_name - async def download_chunk(self, chunk_start: int, chunk_end: int, total_size: int) -> bytes: + async def download_chunk(self, chunk_start: int, chunk_end: int, total_size: int, chunk_number: int) -> str: chunk_headers = { "Content-Range": f"bytes {chunk_start}-{chunk_end}/{total_size}" } async with self.aioclient.get(self.url, headers=chunk_headers) as response: - chunk_data = await response.read() + async with aiofiles.open(f"ubot/cache/{self.file_name}.part{chunk_number}", mode="wb") as cache_file: + while True: + chunk = await response.content.read(4096) - return chunk_data + if not chunk: + break - async def generate_chunk_coros(self, chunk_size: int = 1000000) -> list: + await cache_file.write(chunk) + + await cache_file.flush() + + return f"ubot/cache/{self.file_name}.part{chunk_number}" + + async def generate_chunk_coros(self, chunk_size: int) -> (list, str): async with self.aioclient.get(self.url) as response: content_length = int(response.headers["content-length"]) file_extension = mimetypes.guess_extension(response.headers["content-type"]) place = 0 - remaining_length = content_length + chunk_number = 0 chunk_coros = [] - while remaining_length > 0: - if remaining_length < chunk_size: - chunk_coros.append(self.download_chunk(place, content_length, content_length)) + while place < content_length: + if place + chunk_size > content_length: + chunk_coros.append(self.download_chunk(place, content_length, content_length, chunk_number)) break - chunk_coros.append(self.download_chunk(place, place + chunk_size, content_length)) + chunk_coros.append(self.download_chunk(place, place + chunk_size, content_length, chunk_number)) place += chunk_size - remaining_length -= chunk_size + + chunk_number += 1 return chunk_coros, file_extension -async def download(url: str, aioclient: ClientSession = ClientSession()) -> BytesIO: - downloader = ParallelDownload(url, aioclient) - chunk_coros, file_extension = await downloader.generate_chunk_coros() - downloaded_byte_chunks = await gather(*chunk_coros) - downloaded_bytes = BytesIO(b''.join(downloaded_byte_chunks)) - downloaded_bytes.name = f"downloaded_file{file_extension}" +async def download(url: str, file_name: str, aioclient: ClientSession = ClientSession(), chunk_size: int = 5000000) -> str: + downloader = ParallelDownload(url, aioclient, file_name) + chunk_coros, file_extension = await downloader.generate_chunk_coros(chunk_size) + downloaded_part_files = await gather(*chunk_coros) - return downloaded_bytes + async with aiofiles.open(f"ubot/cache/{file_name}{file_extension}", "wb") as final_fh: + for part_file in downloaded_part_files: + async with aiofiles.open(part_file, "rb") as part_fh: + await final_fh.write(await part_fh.read()) + + remove(part_file) + + await final_fh.flush() + + return f"ubot/cache/{file_name}{file_extension}" diff --git a/ubot/modules/scrapers.py b/ubot/modules/scrapers.py index 00e799e..5953edc 100644 --- a/ubot/modules/scrapers.py +++ b/ubot/modules/scrapers.py @@ -198,7 +198,7 @@ async def youtube_cmd(event): try: if await ldr.cache.is_cache_required(video_stream.url): - file_path = await download(video_stream.url, ldr.aioclient) + file_path = await download(video_stream.url, f"{event.chat_id}_{event.id}", ldr.aioclient) file_handle = await upload_file(event.client, file_path) await event.client.send_file(event.chat, file=file_handle, reply_to=event, attributes=[ @@ -208,9 +208,16 @@ async def youtube_cmd(event): h=video_stream.dimensions[1], supports_streaming=True )]) + + ldr.cache.remove_cache(file_path) else: await event.reply(file=video_stream.url) except: + try: + ldr.cache.remove_cache(file_path) + except: + pass + await event.reply(f"Download failed: [URL]({video_stream.url})")