2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-23 18:37:26 +00:00

Force keyword arguments for all TL types

This commit is contained in:
Dan 2019-03-16 16:50:40 +01:00
parent e0f1f6aaeb
commit 34b51b6481
28 changed files with 133 additions and 108 deletions

View File

@ -287,9 +287,11 @@ def start():
sorted_args = sort_args(c.args) sorted_args = sort_args(c.args)
arguments = ", " + ", ".join( arguments = (
[get_argument_type(i) for i in sorted_args if i != ("flags", "#")] ", "
) if c.args else "" + ("*, " if c.args else "")
+ (", ".join([get_argument_type(i) for i in sorted_args if i != ("flags", "#")]) if c.args else "")
)
fields = "\n ".join( fields = "\n ".join(
["self.{0} = {0} # {1}".format(i[0], i[1]) for i in c.args if i != ("flags", "#")] ["self.{0} = {0} # {1}".format(i[0], i[1]) for i in c.args if i != ("flags", "#")]
@ -456,7 +458,9 @@ def start():
fields=fields, fields=fields,
read_types=read_types, read_types=read_types,
write_types=write_types, write_types=write_types,
return_arguments=", ".join([i[0] for i in sorted_args if i != ("flags", "#")]), return_arguments=", ".join(
["{0}={0}".format(i[0]) for i in sorted_args if i != ("flags", "#")]
),
slots=", ".join(['"{}"'.format(i[0]) for i in sorted_args if i != ("flags", "#")]), slots=", ".join(['"{}"'.format(i[0]) for i in sorted_args if i != ("flags", "#")]),
qualname="{}{}".format("{}.".format(c.namespace) if c.namespace else "", c.name) qualname="{}{}".format("{}.".format(c.namespace) if c.namespace else "", c.name)
) )

View File

@ -627,9 +627,9 @@ class Client(Methods, BaseClient):
try: try:
r = self.send( r = self.send(
functions.auth.SignIn( functions.auth.SignIn(
self.phone_number, phone_number=self.phone_number,
phone_code_hash, phone_code_hash=phone_code_hash,
self.phone_code phone_code=self.phone_code
) )
) )
except PhoneNumberUnoccupied: except PhoneNumberUnoccupied:
@ -640,11 +640,11 @@ class Client(Methods, BaseClient):
try: try:
r = self.send( r = self.send(
functions.auth.SignUp( functions.auth.SignUp(
self.phone_number, phone_number=self.phone_number,
phone_code_hash, phone_code_hash=phone_code_hash,
self.phone_code, phone_code=self.phone_code,
self.first_name, first_name=self.first_name,
self.last_name last_name=self.last_name
) )
) )
except PhoneNumberOccupied: except PhoneNumberOccupied:
@ -738,7 +738,11 @@ class Client(Methods, BaseClient):
break break
if terms_of_service: if terms_of_service:
assert self.send(functions.help.AcceptTermsOfService(terms_of_service.id)) assert self.send(
functions.help.AcceptTermsOfService(
id=terms_of_service.id
)
)
self.password = None self.password = None
self.user_id = r.user.id self.user_id = r.user.id
@ -1036,10 +1040,10 @@ class Client(Methods, BaseClient):
raise ConnectionError("Client has not been started") raise ConnectionError("Client has not been started")
if self.no_updates: if self.no_updates:
data = functions.InvokeWithoutUpdates(data) data = functions.InvokeWithoutUpdates(query=data)
if self.takeout_id: if self.takeout_id:
data = functions.InvokeWithTakeout(self.takeout_id, data) data = functions.InvokeWithTakeout(takeout_id=self.takeout_id, query=data)
r = self.session.send(data, retries, timeout) r = self.session.send(data, retries, timeout)
@ -1353,7 +1357,7 @@ class Client(Methods, BaseClient):
self.fetch_peers( self.fetch_peers(
self.send( self.send(
functions.users.GetUsers( functions.users.GetUsers(
id=[types.InputUser(peer_id, 0)] id=[types.InputUser(user_id=peer_id, access_hash=0)]
) )
) )
) )
@ -1361,7 +1365,7 @@ class Client(Methods, BaseClient):
if str(peer_id).startswith("-100"): if str(peer_id).startswith("-100"):
self.send( self.send(
functions.channels.GetChannels( functions.channels.GetChannels(
id=[types.InputChannel(int(str(peer_id)[4:]), 0)] id=[types.InputChannel(channel_id=int(str(peer_id)[4:]), access_hash=0)]
) )
) )
else: else:
@ -1668,8 +1672,8 @@ class Client(Methods, BaseClient):
hashes = session.send( hashes = session.send(
functions.upload.GetCdnFileHashes( functions.upload.GetCdnFileHashes(
r.file_token, file_token=r.file_token,
offset offset=offset
) )
) )

View File

@ -67,10 +67,10 @@ def get_peer_id(input_peer) -> int:
def get_input_peer(peer_id: int, access_hash: int): def get_input_peer(peer_id: int, access_hash: int):
return ( return (
types.InputPeerUser(peer_id, access_hash) if peer_id > 0 types.InputPeerUser(user_id=peer_id, access_hash=access_hash) if peer_id > 0
else types.InputPeerChannel(int(str(peer_id)[4:]), access_hash) else types.InputPeerChannel(channel_id=int(str(peer_id)[4:]), access_hash=access_hash)
if (str(peer_id).startswith("-100") and access_hash) if (str(peer_id).startswith("-100") and access_hash)
else types.InputPeerChat(-peer_id) else types.InputPeerChat(chat_id=-peer_id)
) )

View File

@ -45,7 +45,7 @@ class ExportChatInviteLink(BaseClient):
if isinstance(peer, types.InputPeerChat): if isinstance(peer, types.InputPeerChat):
return self.send( return self.send(
functions.messages.ExportChatInvite( functions.messages.ExportChatInvite(
chat_id=peer.chat_id peer=peer.chat_id
) )
).link ).link
elif isinstance(peer, types.InputPeerChannel): elif isinstance(peer, types.InputPeerChannel):

View File

@ -67,10 +67,10 @@ class GetChat(BaseClient):
peer = self.resolve_peer(chat_id) peer = self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChannel): if isinstance(peer, types.InputPeerChannel):
r = self.send(functions.channels.GetFullChannel(peer)) r = self.send(functions.channels.GetFullChannel(channel=peer))
elif isinstance(peer, (types.InputPeerUser, types.InputPeerSelf)): elif isinstance(peer, (types.InputPeerUser, types.InputPeerSelf)):
r = self.send(functions.users.GetFullUser(peer)) r = self.send(functions.users.GetFullUser(id=peer))
else: else:
r = self.send(functions.messages.GetFullChat(peer.chat_id)) r = self.send(functions.messages.GetFullChat(chat_id=peer.chat_id))
return pyrogram.Chat._parse_full(self, r) return pyrogram.Chat._parse_full(self, r)

View File

@ -92,7 +92,7 @@ class GetChatMembers(BaseClient):
self, self,
self.send( self.send(
functions.messages.GetFullChat( functions.messages.GetFullChat(
peer.chat_id chat_id=peer.chat_id
) )
) )
) )

View File

@ -39,7 +39,7 @@ class GetContacts(BaseClient):
""" """
while True: while True:
try: try:
contacts = self.send(functions.contacts.GetContacts(0)) contacts = self.send(functions.contacts.GetContacts(hash=0))
except FloodWait as e: except FloodWait as e:
log.warning("get_contacts flood: waiting {} seconds".format(e.x)) log.warning("get_contacts flood: waiting {} seconds".format(e.x))
time.sleep(e.x) time.sleep(e.x)

View File

@ -131,7 +131,9 @@ class EditMessageMedia(BaseClient):
w=media.width, w=media.width,
h=media.height h=media.height
), ),
types.DocumentAttributeFilename(os.path.basename(media.media)) types.DocumentAttributeFilename(
file_name=os.path.basename(media.media)
)
] ]
) )
) )
@ -187,7 +189,9 @@ class EditMessageMedia(BaseClient):
performer=media.performer, performer=media.performer,
title=media.title title=media.title
), ),
types.DocumentAttributeFilename(os.path.basename(media.media)) types.DocumentAttributeFilename(
file_name=os.path.basename(media.media)
)
] ]
) )
) )
@ -244,7 +248,9 @@ class EditMessageMedia(BaseClient):
w=media.width, w=media.width,
h=media.height h=media.height
), ),
types.DocumentAttributeFilename(os.path.basename(media.media)), types.DocumentAttributeFilename(
file_name=os.path.basename(media.media)
),
types.DocumentAttributeAnimated() types.DocumentAttributeAnimated()
] ]
) )
@ -296,7 +302,9 @@ class EditMessageMedia(BaseClient):
thumb=None if media.thumb is None else self.save_file(media.thumb), thumb=None if media.thumb is None else self.save_file(media.thumb),
file=self.save_file(media.media), file=self.save_file(media.media),
attributes=[ attributes=[
types.DocumentAttributeFilename(os.path.basename(media.media)) types.DocumentAttributeFilename(
file_name=os.path.basename(media.media)
)
] ]
) )
) )

View File

@ -76,7 +76,7 @@ class GetMessages(BaseClient):
is_iterable = not isinstance(ids, int) is_iterable = not isinstance(ids, int)
ids = list(ids) if is_iterable else [ids] ids = list(ids) if is_iterable else [ids]
ids = [ids_type(i) for i in ids] ids = [ids_type(id=i) for i in ids]
if isinstance(peer, types.InputPeerChannel): if isinstance(peer, types.InputPeerChannel):
rpc = functions.channels.GetMessages(channel=peer, id=ids) rpc = functions.channels.GetMessages(channel=peer, id=ids)

View File

@ -141,7 +141,7 @@ class SendAnimation(BaseClient):
w=width, w=width,
h=height h=height
), ),
types.DocumentAttributeFilename(os.path.basename(animation)), types.DocumentAttributeFilename(file_name=os.path.basename(animation)),
types.DocumentAttributeAnimated() types.DocumentAttributeAnimated()
] ]
) )

View File

@ -142,7 +142,7 @@ class SendAudio(BaseClient):
performer=performer, performer=performer,
title=title title=title
), ),
types.DocumentAttributeFilename(os.path.basename(audio)) types.DocumentAttributeFilename(file_name=os.path.basename(audio))
] ]
) )
elif audio.startswith("http"): elif audio.startswith("http"):

View File

@ -123,7 +123,7 @@ class SendDocument(BaseClient):
file=file, file=file,
thumb=thumb, thumb=thumb,
attributes=[ attributes=[
types.DocumentAttributeFilename(os.path.basename(document)) types.DocumentAttributeFilename(file_name=os.path.basename(document))
] ]
) )
elif document.startswith("http"): elif document.startswith("http"):

View File

@ -69,9 +69,9 @@ class SendLocation(BaseClient):
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=self.resolve_peer(chat_id),
media=types.InputMediaGeoPoint( media=types.InputMediaGeoPoint(
types.InputGeoPoint( geo_point=types.InputGeoPoint(
latitude, lat=latitude,
longitude long=longitude
) )
), ),
message="", message="",

View File

@ -137,7 +137,7 @@ class SendMediaGroup(BaseClient):
w=i.width, w=i.width,
h=i.height h=i.height
), ),
types.DocumentAttributeFilename(os.path.basename(i.media)) types.DocumentAttributeFilename(file_name=os.path.basename(i.media))
] ]
) )
) )

View File

@ -103,7 +103,7 @@ class SendSticker(BaseClient):
mime_type="image/webp", mime_type="image/webp",
file=file, file=file,
attributes=[ attributes=[
types.DocumentAttributeFilename(os.path.basename(sticker)) types.DocumentAttributeFilename(file_name=os.path.basename(sticker))
] ]
) )
elif sticker.startswith("http"): elif sticker.startswith("http"):

View File

@ -145,7 +145,7 @@ class SendVideo(BaseClient):
w=width, w=width,
h=height h=height
), ),
types.DocumentAttributeFilename(os.path.basename(video)) types.DocumentAttributeFilename(file_name=os.path.basename(video))
] ]
) )
elif video.startswith("http"): elif video.startswith("http"):

View File

@ -101,4 +101,4 @@ def compute_check(r: types.account.Password, password: str) -> types.InputCheckP
+ K_bytes + K_bytes
) )
return types.InputCheckPasswordSRP(srp_id, A_bytes, M1_bytes) return types.InputCheckPasswordSRP(srp_id=srp_id, A=A_bytes, M1=M1_bytes)

View File

@ -35,7 +35,7 @@ class GetMe(BaseClient):
self, self,
self.send( self.send(
functions.users.GetFullUser( functions.users.GetFullUser(
types.InputPeerSelf() id=types.InputPeerSelf()
) )
).user ).user
) )

View File

@ -43,7 +43,7 @@ class SetUserProfilePhoto(BaseClient):
return bool( return bool(
self.send( self.send(
functions.photos.UploadProfilePhoto( functions.photos.UploadProfilePhoto(
self.save_file(photo) file=self.save_file(photo)
) )
) )
) )

View File

@ -55,20 +55,20 @@ class HTML:
input_user = self.peers_by_id.get(user_id, None) input_user = self.peers_by_id.get(user_id, None)
entity = ( entity = (
Mention(start, len(body), input_user) Mention(offset=start, length=len(body), user_id=input_user)
if input_user else MentionInvalid(start, len(body), user_id) if input_user else MentionInvalid(offset=start, length=len(body), user_id=user_id)
) )
else: else:
entity = Url(start, len(body), url) entity = Url(offset=start, length=len(body), url=url)
else: else:
if style == "b" or style == "strong": if style == "b" or style == "strong":
entity = Bold(start, len(body)) entity = Bold(offset=start, length=len(body))
elif style == "i" or style == "em": elif style == "i" or style == "em":
entity = Italic(start, len(body)) entity = Italic(offset=start, length=len(body))
elif style == "code": elif style == "code":
entity = Code(start, len(body)) entity = Code(offset=start, length=len(body))
elif style == "pre": elif style == "pre":
entity = Pre(start, len(body), "") entity = Pre(offset=start, length=len(body), language="")
else: else:
continue continue

View File

@ -72,24 +72,24 @@ class Markdown:
input_user = self.peers_by_id.get(user_id, None) input_user = self.peers_by_id.get(user_id, None)
entity = ( entity = (
Mention(start, len(text), input_user) Mention(offset=start, length=len(text), user_id=input_user)
if input_user if input_user
else MentionInvalid(start, len(text), user_id) else MentionInvalid(offset=start, length=len(text), user_id=user_id)
) )
else: else:
entity = Url(start, len(text), url) entity = Url(offset=start, length=len(text), url=url)
body = text body = text
offset += len(url) + 4 offset += len(url) + 4
else: else:
if style == self.BOLD_DELIMITER: if style == self.BOLD_DELIMITER:
entity = Bold(start, len(body)) entity = Bold(offset=start, length=len(body))
elif style == self.ITALIC_DELIMITER: elif style == self.ITALIC_DELIMITER:
entity = Italic(start, len(body)) entity = Italic(offset=start, length=len(body))
elif style == self.CODE_DELIMITER: elif style == self.CODE_DELIMITER:
entity = Code(start, len(body)) entity = Code(offset=start, length=len(body))
elif style == self.PRE_DELIMITER: elif style == self.PRE_DELIMITER:
entity = Pre(start, len(body), "") entity = Pre(offset=start, length=len(body), language="")
else: else:
continue continue

View File

@ -111,16 +111,20 @@ class InlineKeyboardButton(PyrogramType):
def write(self): def write(self):
if self.callback_data: if self.callback_data:
return KeyboardButtonCallback(self.text, self.callback_data) return KeyboardButtonCallback(text=self.text, data=self.callback_data)
if self.url: if self.url:
return KeyboardButtonUrl(self.text, self.url) return KeyboardButtonUrl(text=self.text, url=self.url)
if self.switch_inline_query: if self.switch_inline_query:
return KeyboardButtonSwitchInline(self.text, self.switch_inline_query) return KeyboardButtonSwitchInline(text=self.text, query=self.switch_inline_query)
if self.switch_inline_query_current_chat: if self.switch_inline_query_current_chat:
return KeyboardButtonSwitchInline(self.text, self.switch_inline_query_current_chat, same_peer=True) return KeyboardButtonSwitchInline(
text=self.text,
query=self.switch_inline_query_current_chat,
same_peer=True
)
if self.callback_game: if self.callback_game:
return KeyboardButtonGame(self.text) return KeyboardButtonGame(text=self.text)

View File

@ -59,7 +59,7 @@ class InlineKeyboardMarkup(PyrogramType):
def write(self): def write(self):
return ReplyInlineMarkup( return ReplyInlineMarkup(
[KeyboardButtonRow( rows=[KeyboardButtonRow(
[j.write() for j in i] buttons=[j.write() for j in i]
) for i in self.inline_keyboard] ) for i in self.inline_keyboard]
) )

View File

@ -75,8 +75,8 @@ class KeyboardButton(PyrogramType):
# TODO: Enforce optional args mutual exclusiveness # TODO: Enforce optional args mutual exclusiveness
if self.request_contact: if self.request_contact:
return KeyboardButtonRequestPhone(self.text) return KeyboardButtonRequestPhone(text=self.text)
elif self.request_location: elif self.request_location:
return KeyboardButtonRequestGeoLocation(self.text) return KeyboardButtonRequestGeoLocation(text=self.text)
else: else:
return RawKeyboardButton(self.text) return RawKeyboardButton(text=self.text)

View File

@ -87,9 +87,11 @@ class ReplyKeyboardMarkup(PyrogramType):
def write(self): def write(self):
return RawReplyKeyboardMarkup( return RawReplyKeyboardMarkup(
rows=[KeyboardButtonRow( rows=[KeyboardButtonRow(
[KeyboardButton(j).write() buttons=[
KeyboardButton(j).write()
if isinstance(j, str) else j.write() if isinstance(j, str) else j.write()
for j in i] for j in i
]
) for i in self.keyboard], ) for i in self.keyboard],
resize=self.resize_keyboard or None, resize=self.resize_keyboard or None,
single_use=self.one_time_keyboard or None, single_use=self.one_time_keyboard or None,

View File

@ -103,7 +103,10 @@ class Sticker(PyrogramType):
try: try:
return send( return send(
functions.messages.GetStickerSet( functions.messages.GetStickerSet(
types.InputStickerSetID(*input_sticker_set_id) stickerset=types.InputStickerSetID(
id=input_sticker_set_id[0],
access_hash=input_sticker_set_id[1]
)
) )
).set.short_name ).set.short_name
except StickersetInvalid: except StickersetInvalid:

View File

@ -83,7 +83,7 @@ class Auth:
# Step 1; Step 2 # Step 1; Step 2
nonce = int.from_bytes(urandom(16), "little", signed=True) nonce = int.from_bytes(urandom(16), "little", signed=True)
log.debug("Send req_pq: {}".format(nonce)) log.debug("Send req_pq: {}".format(nonce))
res_pq = self.send(functions.ReqPqMulti(nonce)) res_pq = self.send(functions.ReqPqMulti(nonce=nonce))
log.debug("Got ResPq: {}".format(res_pq.server_nonce)) log.debug("Got ResPq: {}".format(res_pq.server_nonce))
log.debug("Server public key fingerprints: {}".format(res_pq.server_public_key_fingerprints)) log.debug("Server public key fingerprints: {}".format(res_pq.server_public_key_fingerprints))
@ -110,12 +110,12 @@ class Auth:
new_nonce = int.from_bytes(urandom(32), "little", signed=True) new_nonce = int.from_bytes(urandom(32), "little", signed=True)
data = types.PQInnerData( data = types.PQInnerData(
res_pq.pq, pq=res_pq.pq,
p.to_bytes(4, "big"), p=p.to_bytes(4, "big"),
q.to_bytes(4, "big"), q=q.to_bytes(4, "big"),
nonce, nonce=nonce,
server_nonce, server_nonce=server_nonce,
new_nonce, new_nonce=new_nonce,
).write() ).write()
sha = sha1(data).digest() sha = sha1(data).digest()
@ -129,12 +129,12 @@ class Auth:
log.debug("Send req_DH_params") log.debug("Send req_DH_params")
server_dh_params = self.send( server_dh_params = self.send(
functions.ReqDHParams( functions.ReqDHParams(
nonce, nonce=nonce,
server_nonce, server_nonce=server_nonce,
p.to_bytes(4, "big"), p=p.to_bytes(4, "big"),
q.to_bytes(4, "big"), q=q.to_bytes(4, "big"),
public_key_fingerprint, public_key_fingerprint=public_key_fingerprint,
encrypted_data encrypted_data=encrypted_data
) )
) )
@ -175,10 +175,10 @@ class Auth:
retry_id = 0 retry_id = 0
data = types.ClientDHInnerData( data = types.ClientDHInnerData(
nonce, nonce=nonce,
server_nonce, server_nonce=server_nonce,
retry_id, retry_id=retry_id,
g_b g_b=g_b
).write() ).write()
sha = sha1(data).digest() sha = sha1(data).digest()
@ -189,9 +189,9 @@ class Auth:
log.debug("Send set_client_DH_params") log.debug("Send set_client_DH_params")
set_client_dh_params_answer = self.send( set_client_dh_params_answer = self.send(
functions.SetClientDHParams( functions.SetClientDHParams(
nonce, nonce=nonce,
server_nonce, server_nonce=server_nonce,
encrypted_data encrypted_data=encrypted_data
) )
) )

View File

@ -134,11 +134,11 @@ class Session:
self.current_salt = FutureSalt( self.current_salt = FutureSalt(
0, 0, 0, 0,
self._send( self._send(
functions.Ping(0), functions.Ping(ping_id=0),
timeout=self.START_TIMEOUT timeout=self.START_TIMEOUT
).new_server_salt ).new_server_salt
) )
self.current_salt = self._send(functions.GetFutureSalts(1), timeout=self.START_TIMEOUT).salts[0] self.current_salt = self._send(functions.GetFutureSalts(num=1), timeout=self.START_TIMEOUT).salts[0]
self.next_salt_thread = Thread(target=self.next_salt, name="NextSaltThread") self.next_salt_thread = Thread(target=self.next_salt, name="NextSaltThread")
self.next_salt_thread.start() self.next_salt_thread.start()
@ -146,8 +146,8 @@ class Session:
if not self.is_cdn: if not self.is_cdn:
self._send( self._send(
functions.InvokeWithLayer( functions.InvokeWithLayer(
layer, layer=layer,
functions.InitConnection( query=functions.InitConnection(
api_id=self.client.api_id, api_id=self.client.api_id,
app_version=self.client.app_version, app_version=self.client.app_version,
device_model=self.client.device_model, device_model=self.client.device_model,
@ -314,7 +314,7 @@ class Session:
log.info("Send {} acks".format(len(self.pending_acks))) log.info("Send {} acks".format(len(self.pending_acks)))
try: try:
self._send(types.MsgsAck(list(self.pending_acks)), False) self._send(types.MsgsAck(msg_ids=list(self.pending_acks)), False)
except (OSError, TimeoutError): except (OSError, TimeoutError):
pass pass
else: else:
@ -335,7 +335,7 @@ class Session:
try: try:
self._send(functions.PingDelayDisconnect( self._send(functions.PingDelayDisconnect(
0, self.WAIT_TIMEOUT + 10 ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10
), False) ), False)
except (OSError, TimeoutError, Error): except (OSError, TimeoutError, Error):
pass pass
@ -365,7 +365,7 @@ class Session:
break break
try: try:
self.current_salt = self._send(functions.GetFutureSalts(1)).salts[0] self.current_salt = self._send(functions.GetFutureSalts(num=1)).salts[0]
except (OSError, TimeoutError, Error): except (OSError, TimeoutError, Error):
self.connection.close() self.connection.close()
break break