diff --git a/pyrdp/mitm/RDPMITM.py b/pyrdp/mitm/RDPMITM.py index 07d1437dc..c7fa3d42a 100644 --- a/pyrdp/mitm/RDPMITM.py +++ b/pyrdp/mitm/RDPMITM.py @@ -7,7 +7,9 @@ import asyncio import datetime import typing +import socket +from OpenSSL import SSL, crypto from twisted.internet import reactor from twisted.internet.protocol import Protocol @@ -218,7 +220,22 @@ async def connectToServer(self): self.log.error("Failed to connect to recording host: timeout expired") def doClientTls(self): - cert = self.server.tcp.transport.getPeerCertificate() + if self.state.isRedirected(): + self.log.info( + "Fetching certificate of the original host %(host)s:%(port)d because of NLA redirection", + { + "host": self.state.config.targetHost, + "port": self.state.config.targetPort, + }, + ) + # Use context from pyrdp + context = ClientTLSContext().getContext() + connection = SSL.Connection(context, socket.socket(socket.AF_INET, socket.SOCK_STREAM)) + connection.connect((self.state.config.targetHost, self.state.config.targetPort)) + connection.do_handshake() + cert = connection.get_peer_certificate() + else: + cert = self.server.tcp.transport.getPeerCertificate() if not cert: # Wait for server certificate reactor.callLater(1, self.doClientTls)