mirror of
https://github.com/pyrogram/pyrogram
synced 2025-09-08 02:05:51 +00:00
Add support for downloading files to file pointer, fix for https://github.com/pyrogram/pyrogram/issues/284
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -1231,9 +1231,9 @@ class Client(Methods, BaseClient):
|
||||
|
||||
temp_file_path = ""
|
||||
final_file_path = ""
|
||||
|
||||
path = [None]
|
||||
try:
|
||||
data, directory, file_name, done, progress, progress_args, path = packet
|
||||
data, done, progress, progress_args, out, path, to_file = packet
|
||||
|
||||
temp_file_path = self.get_file(
|
||||
media_type=data.media_type,
|
||||
@@ -1250,13 +1250,15 @@ class Client(Methods, BaseClient):
|
||||
file_size=data.file_size,
|
||||
is_big=data.is_big,
|
||||
progress=progress,
|
||||
progress_args=progress_args
|
||||
progress_args=progress_args,
|
||||
out=out
|
||||
)
|
||||
|
||||
if temp_file_path:
|
||||
final_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
shutil.move(temp_file_path, final_file_path)
|
||||
if to_file:
|
||||
final_file_path = out.name
|
||||
else:
|
||||
final_file_path = ''
|
||||
if to_file:
|
||||
out.close()
|
||||
except Exception as e:
|
||||
log.error(e, exc_info=True)
|
||||
|
||||
@@ -1864,7 +1866,8 @@ class Client(Methods, BaseClient):
|
||||
file_size: int,
|
||||
is_big: bool,
|
||||
progress: callable,
|
||||
progress_args: tuple = ()
|
||||
progress_args: tuple = (),
|
||||
out: io.IOBase = None
|
||||
) -> str:
|
||||
with self.media_sessions_lock:
|
||||
session = self.media_sessions.get(dc_id, None)
|
||||
@@ -1950,7 +1953,10 @@ class Client(Methods, BaseClient):
|
||||
limit = 1024 * 1024
|
||||
offset = 0
|
||||
file_name = ""
|
||||
|
||||
if not out:
|
||||
f = tempfile.NamedTemporaryFile("wb", delete=False)
|
||||
else:
|
||||
f = out
|
||||
try:
|
||||
r = session.send(
|
||||
functions.upload.GetFile(
|
||||
@@ -1961,36 +1967,37 @@ class Client(Methods, BaseClient):
|
||||
)
|
||||
|
||||
if isinstance(r, types.upload.File):
|
||||
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
|
||||
if hasattr(f, "name"):
|
||||
file_name = f.name
|
||||
|
||||
while True:
|
||||
chunk = r.bytes
|
||||
while True:
|
||||
chunk = r.bytes
|
||||
|
||||
if not chunk:
|
||||
break
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
f.write(chunk)
|
||||
f.write(chunk)
|
||||
|
||||
offset += limit
|
||||
offset += limit
|
||||
|
||||
if progress:
|
||||
progress(
|
||||
min(offset, file_size)
|
||||
if file_size != 0
|
||||
else offset,
|
||||
file_size,
|
||||
*progress_args
|
||||
)
|
||||
if progress:
|
||||
progress(
|
||||
|
||||
r = session.send(
|
||||
functions.upload.GetFile(
|
||||
location=location,
|
||||
offset=offset,
|
||||
limit=limit
|
||||
)
|
||||
min(offset, file_size)
|
||||
if file_size != 0
|
||||
else offset,
|
||||
file_size,
|
||||
*progress_args
|
||||
)
|
||||
|
||||
r = session.send(
|
||||
functions.upload.GetFile(
|
||||
location=location,
|
||||
offset=offset,
|
||||
limit=limit
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(r, types.upload.FileCdnRedirect):
|
||||
with self.media_sessions_lock:
|
||||
cdn_session = self.media_sessions.get(r.dc_id, None)
|
||||
@@ -2003,70 +2010,71 @@ class Client(Methods, BaseClient):
|
||||
self.media_sessions[r.dc_id] = cdn_session
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
|
||||
if hasattr(f, "name"):
|
||||
file_name = f.name
|
||||
|
||||
while True:
|
||||
r2 = cdn_session.send(
|
||||
functions.upload.GetCdnFile(
|
||||
file_token=r.file_token,
|
||||
offset=offset,
|
||||
limit=limit
|
||||
)
|
||||
while True:
|
||||
r2 = cdn_session.send(
|
||||
functions.upload.GetCdnFile(
|
||||
file_token=r.file_token,
|
||||
offset=offset,
|
||||
limit=limit
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
|
||||
try:
|
||||
session.send(
|
||||
functions.upload.ReuploadCdnFile(
|
||||
file_token=r.file_token,
|
||||
request_token=r2.request_token
|
||||
)
|
||||
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
|
||||
try:
|
||||
session.send(
|
||||
functions.upload.ReuploadCdnFile(
|
||||
file_token=r.file_token,
|
||||
request_token=r2.request_token
|
||||
)
|
||||
except VolumeLocNotFound:
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
chunk = r2.bytes
|
||||
|
||||
# https://core.telegram.org/cdn#decrypting-files
|
||||
decrypted_chunk = AES.ctr256_decrypt(
|
||||
chunk,
|
||||
r.encryption_key,
|
||||
bytearray(
|
||||
r.encryption_iv[:-4]
|
||||
+ (offset // 16).to_bytes(4, "big")
|
||||
)
|
||||
)
|
||||
|
||||
hashes = session.send(
|
||||
functions.upload.GetCdnFileHashes(
|
||||
file_token=r.file_token,
|
||||
offset=offset
|
||||
)
|
||||
)
|
||||
|
||||
# https://core.telegram.org/cdn#verifying-files
|
||||
for i, h in enumerate(hashes):
|
||||
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
||||
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
|
||||
|
||||
f.write(decrypted_chunk)
|
||||
|
||||
offset += limit
|
||||
|
||||
if progress:
|
||||
progress(
|
||||
min(offset, file_size)
|
||||
if file_size != 0
|
||||
else offset,
|
||||
file_size,
|
||||
*progress_args
|
||||
)
|
||||
|
||||
if len(chunk) < limit:
|
||||
except VolumeLocNotFound:
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
chunk = r2.bytes
|
||||
|
||||
# https://core.telegram.org/cdn#decrypting-files
|
||||
decrypted_chunk = AES.ctr256_decrypt(
|
||||
chunk,
|
||||
r.encryption_key,
|
||||
bytearray(
|
||||
r.encryption_iv[:-4]
|
||||
+ (offset // 16).to_bytes(4, "big")
|
||||
)
|
||||
)
|
||||
|
||||
hashes = session.send(
|
||||
functions.upload.GetCdnFileHashes(
|
||||
file_token=r.file_token,
|
||||
offset=offset
|
||||
)
|
||||
)
|
||||
|
||||
# https://core.telegram.org/cdn#verifying-files
|
||||
for i, h in enumerate(hashes):
|
||||
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
||||
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
|
||||
|
||||
f.write(decrypted_chunk)
|
||||
|
||||
offset += limit
|
||||
|
||||
if progress:
|
||||
progress(
|
||||
|
||||
min(offset, file_size)
|
||||
if file_size != 0
|
||||
else offset,
|
||||
file_size,
|
||||
*progress_args
|
||||
)
|
||||
|
||||
if len(chunk) < limit:
|
||||
break
|
||||
except Exception as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
@@ -2074,7 +2082,8 @@ class Client(Methods, BaseClient):
|
||||
log.error(e, exc_info=True)
|
||||
|
||||
try:
|
||||
os.remove(file_name)
|
||||
if out:
|
||||
os.remove(file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
Reference in New Issue
Block a user