2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 12:57:52 +00:00

Add support for progress callback when downloading media

This commit is contained in:
Dan 2018-02-24 17:16:25 +01:00
parent 2e4802fbda
commit ed4ff07742

View File

@ -300,10 +300,10 @@ class Client:
if media is None: if media is None:
break break
media, file_name, done = media
tmp_file_name = ""
try: try:
media, file_name, done, progress, path = media
tmp_file_name = None
if isinstance(media, types.MessageMediaDocument): if isinstance(media, types.MessageMediaDocument):
document = media.document document = media.document
@ -331,15 +331,10 @@ class Client:
dc_id=document.dc_id, dc_id=document.dc_id,
id=document.id, id=document.id,
access_hash=document.access_hash, access_hash=document.access_hash,
version=document.version version=document.version,
size=document.size,
progress=progress
) )
try:
os.remove("./downloads/{}".format(file_name))
except FileNotFoundError:
pass
os.renames("./{}".format(tmp_file_name), "./downloads/{}".format(file_name))
elif isinstance(media, types.MessageMediaPhoto): elif isinstance(media, types.MessageMediaPhoto):
photo = media.photo photo = media.photo
@ -355,23 +350,32 @@ class Client:
dc_id=photo_loc.dc_id, dc_id=photo_loc.dc_id,
volume_id=photo_loc.volume_id, volume_id=photo_loc.volume_id,
local_id=photo_loc.local_id, local_id=photo_loc.local_id,
secret=photo_loc.secret secret=photo_loc.secret,
size=photo.sizes[-1].size,
progress=progress
) )
try: if file_name is not None:
os.remove("downloads/{}".format(file_name)) path[0] = "downloads/{}".format(file_name)
except FileNotFoundError:
pass
try:
os.remove("downloads/{}".format(file_name))
except OSError:
pass
finally:
try:
os.renames("{}".format(tmp_file_name), "downloads/{}".format(file_name)) os.renames("{}".format(tmp_file_name), "downloads/{}".format(file_name))
except OSError:
pass
except Exception as e: except Exception as e:
log.error(e, exc_info=True) log.error(e, exc_info=True)
finally: finally:
print(done)
done.set() done.set()
try: try:
os.remove("{}".format(tmp_file_name)) os.remove("{}".format(tmp_file_name))
except FileNotFoundError: except OSError:
pass pass
log.debug("{} stopped".format(name)) log.debug("{} stopped".format(name))
@ -1861,7 +1865,9 @@ class Client:
volume_id: int = None, volume_id: int = None,
local_id: int = None, local_id: int = None,
secret: int = None, secret: int = None,
version: int = 0) -> str: version: int = 0,
size: int = None,
progress: callable = None) -> str:
if dc_id != self.dc_id: if dc_id != self.dc_id:
exported_auth = self.send( exported_auth = self.send(
functions.auth.ExportAuthorization( functions.auth.ExportAuthorization(
@ -1936,6 +1942,9 @@ class Client:
offset += limit offset += limit
if progress:
progress(offset, size)
r = session.send( r = session.send(
functions.upload.GetFile( functions.upload.GetFile(
location=location, location=location,
@ -2007,10 +2016,13 @@ class Client:
f.flush() f.flush()
os.fsync(f.fileno()) os.fsync(f.fileno())
offset += limit
if progress:
progress(min(offset, size), size)
if len(chunk) < limit: if len(chunk) < limit:
break break
offset += limit
except Exception as e: except Exception as e:
log.error(e) log.error(e)
finally: finally:
@ -2371,14 +2383,36 @@ class Client:
) )
) )
def download_media(self, message: types.Message, file_name: str = None, block: bool = True): def download_media(self,
done = Event() message: types.Message,
media = message.media if isinstance(message, types.Message) else message file_name: str = None,
block: bool = True,
progress: callable = None):
"""Use this method to download the media from a Message.
self.download_queue.put((media, file_name, done)) Files are saved in the *downloads* folder.
if block: Args:
done.wait() message (:obj:`Message <pyrogram.api.types.Message>`):
The Message containing the media.
file_name (:obj:`str`):
Specify a file_name to be used
"""
if isinstance(message, types.Message):
done = Event()
media = message.media
path = [None]
if media is not None:
self.download_queue.put((media, file_name, done, progress, path))
else:
return
if block:
done.wait()
return path[0]
def add_contacts(self, contacts: list): def add_contacts(self, contacts: list):
"""Use this method to add contacts to your Telegram address book. """Use this method to add contacts to your Telegram address book.
@ -2408,8 +2442,8 @@ class Client:
Args: Args:
ids (:obj:`list`): ids (:obj:`list`):
A list of unique identifiers for the target users. Can be an ID (int), a username (string) A list of unique identifiers for the target users.
or phone number (string). Can be an ID (int), a username (string) or phone number (string).
Returns: Returns:
True on success. True on success.