2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-09-07 17:55:24 +00:00
This commit is contained in:
Dan
2020-07-09 00:20:46 +02:00
parent 6b2d6ffacf
commit 4a8e6fb855
12 changed files with 216 additions and 544 deletions

View File

@@ -15,7 +15,7 @@
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import io
import logging
import math
import os
@@ -1231,9 +1231,9 @@ class Client(Methods, BaseClient):
temp_file_path = ""
final_file_path = ""
path = [None]
try:
data, done, progress, progress_args, out, path, to_file = packet
data, directory, file_name, done, progress, progress_args, path = packet
temp_file_path = self.get_file(
media_type=data.media_type,
@@ -1250,15 +1250,13 @@ class Client(Methods, BaseClient):
file_size=data.file_size,
is_big=data.is_big,
progress=progress,
progress_args=progress_args,
out=out
progress_args=progress_args
)
if to_file:
final_file_path = out.name
else:
final_file_path = ''
if to_file:
out.close()
if temp_file_path:
final_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
os.makedirs(directory, exist_ok=True)
shutil.move(temp_file_path, final_file_path)
except Exception as e:
log.error(e, exc_info=True)
@@ -1715,7 +1713,7 @@ class Client(Methods, BaseClient):
def save_file(
self,
path: Union[str, io.IOBase],
path: str,
file_id: int = None,
file_part: int = 0,
progress: callable = None,
@@ -1767,20 +1765,9 @@ class Client(Methods, BaseClient):
Raises:
RPCError: In case of a Telegram RPC error.
ValueError: if path is not str or file-like readable object
"""
part_size = 512 * 1024
if isinstance(path, str):
fp = open(path, 'rb')
filename = os.path.basename(path)
elif hasattr(path, 'write'):
fp = path
filename = fp.name
else:
raise ValueError("Invalid path passed! Pass file pointer or path to file")
fp.seek(0, os.SEEK_END)
file_size = fp.tell()
fp.seek(0)
file_size = os.path.getsize(path)
if file_size == 0:
raise ValueError("File size equals to 0 B")
@@ -1798,74 +1785,67 @@ class Client(Methods, BaseClient):
session.start()
try:
fp.seek(part_size * file_part)
with open(path, "rb") as f:
f.seek(part_size * file_part)
while True:
chunk = fp.read(part_size)
while True:
chunk = f.read(part_size)
if not chunk:
if not is_big:
md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()])
break
for _ in range(3):
if is_big:
rpc = functions.upload.SaveBigFilePart(
file_id=file_id,
file_part=file_part,
file_total_parts=file_total_parts,
bytes=chunk
)
else:
rpc = functions.upload.SaveFilePart(
file_id=file_id,
file_part=file_part,
bytes=chunk
)
if session.send(rpc):
if not chunk:
if not is_big:
md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()])
break
else:
raise AssertionError("Telegram didn't accept chunk #{} of {}".format(file_part, path))
if is_missing_part:
return
for _ in range(3):
if is_big:
rpc = functions.upload.SaveBigFilePart(
file_id=file_id,
file_part=file_part,
file_total_parts=file_total_parts,
bytes=chunk
)
else:
rpc = functions.upload.SaveFilePart(
file_id=file_id,
file_part=file_part,
bytes=chunk
)
if not is_big:
md5_sum.update(chunk)
if session.send(rpc):
break
else:
raise AssertionError("Telegram didn't accept chunk #{} of {}".format(file_part, path))
file_part += 1
if is_missing_part:
return
if progress:
progress(min(file_part * part_size, file_size), file_size, *progress_args)
if not is_big:
md5_sum.update(chunk)
file_part += 1
if progress:
progress(min(file_part * part_size, file_size), file_size, *progress_args)
except Client.StopTransmission:
if isinstance(path, str):
fp.close()
raise
except Exception as e:
if isinstance(path, str):
fp.close()
log.error(e, exc_info=True)
else:
if isinstance(path, str):
fp.close()
if is_big:
return types.InputFileBig(
id=file_id,
parts=file_total_parts,
name=filename,
name=os.path.basename(path),
)
else:
return types.InputFile(
id=file_id,
parts=file_total_parts,
name=filename,
name=os.path.basename(path),
md5_checksum=md5_sum
)
finally:
if isinstance(path, str):
fp.close()
session.stop()
def get_file(
@@ -1884,8 +1864,7 @@ class Client(Methods, BaseClient):
file_size: int,
is_big: bool,
progress: callable,
progress_args: tuple = (),
out: io.IOBase = None
progress_args: tuple = ()
) -> str:
with self.media_sessions_lock:
session = self.media_sessions.get(dc_id, None)
@@ -1971,10 +1950,7 @@ class Client(Methods, BaseClient):
limit = 1024 * 1024
offset = 0
file_name = ""
if not out:
f = tempfile.NamedTemporaryFile("wb", delete=False)
else:
f = out
try:
r = session.send(
functions.upload.GetFile(
@@ -1985,37 +1961,36 @@ class Client(Methods, BaseClient):
)
if isinstance(r, types.upload.File):
if hasattr(f, "name"):
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
file_name = f.name
while True:
chunk = r.bytes
while True:
chunk = r.bytes
if not chunk:
break
if not chunk:
break
f.write(chunk)
f.write(chunk)
offset += limit
offset += limit
if progress:
progress(
if progress:
progress(
min(offset, file_size)
if file_size != 0
else offset,
file_size,
*progress_args
)
min(offset, file_size)
if file_size != 0
else offset,
file_size,
*progress_args
r = session.send(
functions.upload.GetFile(
location=location,
offset=offset,
limit=limit
)
)
r = session.send(
functions.upload.GetFile(
location=location,
offset=offset,
limit=limit
)
)
elif isinstance(r, types.upload.FileCdnRedirect):
with self.media_sessions_lock:
cdn_session = self.media_sessions.get(r.dc_id, None)
@@ -2028,71 +2003,70 @@ class Client(Methods, BaseClient):
self.media_sessions[r.dc_id] = cdn_session
try:
if hasattr(f, "name"):
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
file_name = f.name
while True:
r2 = cdn_session.send(
functions.upload.GetCdnFile(
file_token=r.file_token,
offset=offset,
limit=limit
)
)
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
try:
session.send(
functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
)
while True:
r2 = cdn_session.send(
functions.upload.GetCdnFile(
file_token=r.file_token,
offset=offset,
limit=limit
)
except VolumeLocNotFound:
)
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
try:
session.send(
functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
)
)
except VolumeLocNotFound:
break
else:
continue
chunk = r2.bytes
# https://core.telegram.org/cdn#decrypting-files
decrypted_chunk = AES.ctr256_decrypt(
chunk,
r.encryption_key,
bytearray(
r.encryption_iv[:-4]
+ (offset // 16).to_bytes(4, "big")
)
)
hashes = session.send(
functions.upload.GetCdnFileHashes(
file_token=r.file_token,
offset=offset
)
)
# https://core.telegram.org/cdn#verifying-files
for i, h in enumerate(hashes):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
f.write(decrypted_chunk)
offset += limit
if progress:
progress(
min(offset, file_size)
if file_size != 0
else offset,
file_size,
*progress_args
)
if len(chunk) < limit:
break
else:
continue
chunk = r2.bytes
# https://core.telegram.org/cdn#decrypting-files
decrypted_chunk = AES.ctr256_decrypt(
chunk,
r.encryption_key,
bytearray(
r.encryption_iv[:-4]
+ (offset // 16).to_bytes(4, "big")
)
)
hashes = session.send(
functions.upload.GetCdnFileHashes(
file_token=r.file_token,
offset=offset
)
)
# https://core.telegram.org/cdn#verifying-files
for i, h in enumerate(hashes):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
f.write(decrypted_chunk)
offset += limit
if progress:
progress(
min(offset, file_size)
if file_size != 0
else offset,
file_size,
*progress_args
)
if len(chunk) < limit:
break
except Exception as e:
raise e
except Exception as e:
@@ -2100,8 +2074,7 @@ class Client(Methods, BaseClient):
log.error(e, exc_info=True)
try:
if out:
os.remove(file_name)
os.remove(file_name)
except OSError:
pass