diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 50009ba4..ea77ff75 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -25,6 +25,7 @@ import mimetypes import os import re import struct +import tempfile import threading import time from collections import namedtuple @@ -34,9 +35,6 @@ from hashlib import sha256, md5 from queue import Queue from signal import signal, SIGINT, SIGTERM, SIGABRT from threading import Event, Thread -import tempfile -import shutil -import errno from pyrogram.api import functions, types from pyrogram.api.core import Object @@ -506,23 +504,17 @@ class Client: while True: media = self.download_queue.get() + temp_file_path = "" + final_file_path = "" if media is None: break try: media, file_name, done, progress, path = media - tmp_file_name = None - download_directory = "downloads" - - if file_name.endswith('/') or file_name.endswith('\\'): - # treat the file name as a directory - download_directory = file_name - file_name = None - elif '/' in file_name or '\\' in file_name: - # use file_name as a full path instead - download_directory = '' + directory, file_name = os.path.split(file_name) + directory = directory or "downloads" if isinstance(media, types.MessageMediaDocument): document = media.document @@ -548,7 +540,7 @@ class Client: elif isinstance(i, types.DocumentAttributeAnimated): file_name = file_name.replace("doc", "gif") - tmp_file_name = self.get_file( + temp_file_path = self.get_file( dc_id=document.dc_id, id=document.id, access_hash=document.access_hash, @@ -571,7 +563,7 @@ class Client: photo_loc = photo.sizes[-1].location - tmp_file_name = self.get_file( + temp_file_path = self.get_file( dc_id=photo_loc.dc_id, volume_id=photo_loc.volume_id, local_id=photo_loc.local_id, @@ -580,37 +572,29 @@ class Client: progress=progress ) - if tmp_file_name is None: - return None + if temp_file_path: + final_file_path = os.path.join(directory, file_name) - if file_name is not None: - path[0] = os.path.join(download_directory, file_name) - - try: - os.remove(os.path.join(download_directory, file_name)) - except OSError: - pass - finally: try: - if download_directory: - os.makedirs(download_directory, exist_ok=True) - else: - os.makedirs(os.path.dirname(file_name), exist_ok=True) + os.remove(final_file_path) + except OSError: + pass - # avoid errors moving between drives/partitions etc. - shutil.move(tmp_file_name, os.path.join(download_directory, file_name)) - except OSError as e: - log.error(e, exc_info=True) + os.renames(temp_file_path, final_file_path) except Exception as e: log.error(e, exc_info=True) - finally: - done.set() try: - os.remove(tmp_file_name) - except OSError as e: - if e.errno != errno.ENOENT: - log.error(e, exc_info=True) + os.remove(temp_file_path) + except OSError: + pass + else: + # TODO: "" or None for faulty download, which is better? + # os.path methods return "" in case something does not exist, I prefer this. + # For now let's keep None + path[0] = final_file_path or None + finally: + done.set() log.debug("{} stopped".format(name)) @@ -2250,7 +2234,7 @@ class Client: limit = 1024 * 1024 offset = 0 - file_name = None + file_name = "" try: r = session.send( @@ -2303,6 +2287,7 @@ class Client: try: with tempfile.NamedTemporaryFile('wb', delete=False) as f: file_name = f.name + while True: r2 = cdn_session.send( functions.upload.GetCdnFile( @@ -2370,6 +2355,8 @@ class Client: os.remove(file_name) except OSError: pass + + return "" else: return file_name finally: @@ -2628,7 +2615,7 @@ class Client: def download_media(self, message: types.Message, - file_name: str = None, + file_name: str = "", block: bool = True, progress: callable = None): """Use this method to download the media from a Message. @@ -2684,7 +2671,7 @@ class Client: def download_photo(self, photo: types.Photo or types.UserProfilePhoto or types.ChatPhoto, - file_name: str = None, + file_name: str = "", block: bool = True): """Use this method to download a photo not contained inside a Message. For example, a photo of a User or a Chat/Channel.