|
@@ -11,15 +11,23 @@
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
+import logging
|
|
|
+from typing import Callable
|
|
|
|
|
|
from OpenSSL import SSL
|
|
|
from twisted.internet import ssl
|
|
|
from twisted.internet.abstract import isIPAddress, isIPv6Address
|
|
|
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
|
|
|
+from twisted.protocols.tls import TLSMemoryBIOProtocol
|
|
|
+from twisted.python.failure import Failure
|
|
|
from zope.interface import implementer
|
|
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
|
|
-def _tolerateErrors(wrapped):
|
|
|
+F = Callable[[SSL.Connection, int, int], None]
|
|
|
+
|
|
|
+
|
|
|
+def _tolerateErrors(wrapped: F) -> F:
|
|
|
"""
|
|
|
Wrap up an info_callback for pyOpenSSL so that if something goes wrong
|
|
|
the error is immediately logged and the connection is dropped if possible.
|
|
@@ -27,10 +35,10 @@ def _tolerateErrors(wrapped):
|
|
|
documentation, see the twisted documentation.
|
|
|
"""
|
|
|
|
|
|
- def infoCallback(connection, where, ret):
|
|
|
+ def infoCallback(connection: SSL.Connection, where: int, ret: int) -> None:
|
|
|
try:
|
|
|
return wrapped(connection, where, ret)
|
|
|
- except: # noqa: E722, taken from the twisted implementation
|
|
|
+ except BaseException:
|
|
|
f = Failure()
|
|
|
logger.exception("Error during info_callback")
|
|
|
connection.get_app_data().failVerification(f)
|
|
@@ -38,7 +46,7 @@ def _tolerateErrors(wrapped):
|
|
|
return infoCallback
|
|
|
|
|
|
|
|
|
-def _idnaBytes(text):
|
|
|
+def _idnaBytes(text: str) -> bytes:
|
|
|
"""
|
|
|
Convert some text typed by a human into some ASCII bytes. This is a
|
|
|
copy of twisted.internet._idna._idnaBytes. For documentation, see the
|
|
@@ -60,7 +68,7 @@ class ClientTLSOptions:
|
|
|
verification left out. For documentation, see the twisted documentation.
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, hostname, ctx):
|
|
|
+ def __init__(self, hostname: str, ctx: SSL.Context):
|
|
|
self._ctx = ctx
|
|
|
|
|
|
if isIPAddress(hostname) or isIPv6Address(hostname):
|
|
@@ -72,13 +80,17 @@ class ClientTLSOptions:
|
|
|
|
|
|
ctx.set_info_callback(_tolerateErrors(self._identityVerifyingInfoCallback))
|
|
|
|
|
|
- def clientConnectionForTLS(self, tlsProtocol):
|
|
|
+ def clientConnectionForTLS(
|
|
|
+ self, tlsProtocol: TLSMemoryBIOProtocol
|
|
|
+ ) -> SSL.Connection:
|
|
|
context = self._ctx
|
|
|
connection = SSL.Connection(context, None)
|
|
|
connection.set_app_data(tlsProtocol)
|
|
|
return connection
|
|
|
|
|
|
- def _identityVerifyingInfoCallback(self, connection, where, ret):
|
|
|
+ def _identityVerifyingInfoCallback(
|
|
|
+ self, connection: SSL.Connection, where: int, ret: int
|
|
|
+ ) -> None:
|
|
|
# Literal IPv4 and IPv6 addresses are not permitted
|
|
|
# as host names according to the RFCs
|
|
|
if where & SSL.SSL_CB_HANDSHAKE_START and self._sendSNI:
|
|
@@ -95,6 +107,6 @@ class ClientTLSOptionsFactory:
|
|
|
else:
|
|
|
self._options = ssl.CertificateOptions()
|
|
|
|
|
|
- def get_options(self, host):
|
|
|
+ def get_options(self, host: str) -> ClientTLSOptions:
|
|
|
# Use _makeContext so that we get a fresh OpenSSL CTX each time.
|
|
|
return ClientTLSOptions(host, self._options._makeContext())
|