44# This module is part of asyncpg and is released under
55# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66
7+ from __future__ import annotations
78
89import asyncio
910import collections
11+ from collections .abc import Callable
1012import enum
1113import functools
1214import getpass
@@ -764,14 +766,21 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
764766
765767
766768class TLSUpgradeProto (asyncio .Protocol ):
767- def __init__ (self , loop , host , port , ssl_context , ssl_is_advisory ):
769+ def __init__ (
770+ self ,
771+ loop : asyncio .AbstractEventLoop ,
772+ host : str ,
773+ port : int ,
774+ ssl_context : ssl_module .SSLContext ,
775+ ssl_is_advisory : bool ,
776+ ) -> None :
768777 self .on_data = _create_future (loop )
769778 self .host = host
770779 self .port = port
771780 self .ssl_context = ssl_context
772781 self .ssl_is_advisory = ssl_is_advisory
773782
774- def data_received (self , data ) :
783+ def data_received (self , data : bytes ) -> None :
775784 if data == b'S' :
776785 self .on_data .set_result (True )
777786 elif (self .ssl_is_advisory and
@@ -789,15 +798,30 @@ def data_received(self, data):
789798 'rejected SSL upgrade' .format (
790799 host = self .host , port = self .port )))
791800
792- def connection_lost (self , exc ) :
801+ def connection_lost (self , exc : typing . Optional [ Exception ]) -> None :
793802 if not self .on_data .done ():
794803 if exc is None :
795804 exc = ConnectionError ('unexpected connection_lost() call' )
796805 self .on_data .set_exception (exc )
797806
798807
799- async def _create_ssl_connection (protocol_factory , host , port , * ,
800- loop , ssl_context , ssl_is_advisory = False ):
808+ _ProctolFactoryR = typing .TypeVar (
809+ "_ProctolFactoryR" , bound = asyncio .protocols .Protocol
810+ )
811+
812+
813+ async def _create_ssl_connection (
814+ # TODO: The return type is a specific combination of subclasses of
815+ # asyncio.protocols.Protocol that we can't express. For now, having the
816+ # return type be dependent on signature of the factory is an improvement
817+ protocol_factory : Callable [[], _ProctolFactoryR ],
818+ host : str ,
819+ port : int ,
820+ * ,
821+ loop : asyncio .AbstractEventLoop ,
822+ ssl_context : ssl_module .SSLContext ,
823+ ssl_is_advisory : bool = False ,
824+ ) -> typing .Tuple [asyncio .Transport , _ProctolFactoryR ]:
801825
802826 tr , pr = await loop .create_connection (
803827 lambda : TLSUpgradeProto (loop , host , port ,
@@ -817,6 +841,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
817841 try :
818842 new_tr = await loop .start_tls (
819843 tr , pr , ssl_context , server_hostname = host )
844+ assert new_tr is not None
820845 except (Exception , asyncio .CancelledError ):
821846 tr .close ()
822847 raise
0 commit comments