Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kafka/net/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def data_received(self, data):
try:
(req_correlation_id, future, sent_time, _timeout_at) = self.in_flight_requests.popleft()
except IndexError:
return self.close(Errors.KafakConnectionError('Received response with no in-flight-requests!'))
return self.close(Errors.KafkaConnectionError('Received response with no in-flight-requests!'))

if req_correlation_id != resp_correlation_id:
return self.close(Errors.KafkaConnectionError('Received unrecognized correlation id'))
Expand Down
51 changes: 49 additions & 2 deletions kafka/net/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import logging
import random
import socket
import ssl
import time

from .inet import create_connection
from .connection import KafkaConnection
from .transport import KafkaTCPTransport
from .transport import KafkaSSLTransport, KafkaTCPTransport
import kafka.errors as Errors
from kafka.protocol.broker_version_data import BrokerVersionData
from kafka.future import Future
Expand All @@ -23,6 +24,14 @@ class KafkaConnectionManager:
'socket_options': [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)],
'max_in_flight_requests_per_connection': 5,
'connections_max_idle_ms': 9 * 60 * 1000,
'security_protocol': 'PLAINTEXT',
'ssl_context': None,
'ssl_check_hostname': True,
'ssl_cafile': None,
'ssl_certfile': None,
'ssl_keyfile': None,
'ssl_password': None,
'ssl_crlfile': None,
}
def __init__(self, net, cluster, **configs):
self.config = copy.copy(self.DEFAULT_CONFIG)
Expand Down Expand Up @@ -112,6 +121,31 @@ def close_idle_connections(self):
log.debug('next idle close check in %d secs', next_idle_at - time.monotonic())
self._net.call_at(next_idle_at, self.close_idle_connections)

@property
def ssl_enabled(self):
return self.config['security_protocol'] in ('SSL', 'SASL_SSL')

def _build_ssl_context(self):
if self.config['ssl_context'] is not None:
return self.config['ssl_context']
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
ctx.check_hostname = self.config['ssl_check_hostname']
if self.config['ssl_cafile']:
ctx.load_verify_locations(self.config['ssl_cafile'])
else:
ctx.load_default_certs()
if self.config['ssl_certfile']:
ctx.load_cert_chain(
certfile=self.config['ssl_certfile'],
keyfile=self.config['ssl_keyfile'],
password=self.config['ssl_password'],
)
if self.config['ssl_crlfile']:
ctx.load_verify_locations(crl=self.config['ssl_crlfile'])
ctx.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
return ctx

async def _connect(self, node, conn):
conn.close_future.add_both(lambda _: self._conns.pop(node.node_id, None))
conn.close_future.add_errback(lambda _: self.cluster.request_update())
Expand All @@ -124,7 +158,20 @@ async def _connect(self, node, conn):
self.update_backoff(node.node_id)
return

conn.connection_made(KafkaTCPTransport(self._net, sock))
if self.ssl_enabled:
hostname = node.host if self.config['ssl_check_hostname'] else None
transport = KafkaSSLTransport(self._net, sock, self._build_ssl_context(), hostname)
else:
transport = KafkaTCPTransport(self._net, sock)

try:
await transport.handshake()
except Exception as e:
conn.connection_lost(Errors.KafkaConnectionError('Handshake failed: %s' % e))
self.update_backoff(node.node_id)
return

conn.connection_made(transport)

try:
await conn.init_future
Expand Down
72 changes: 70 additions & 2 deletions kafka/net/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import selectors
import socket
import ssl
import time

import kafka.errors as Errors
Expand Down Expand Up @@ -115,7 +116,7 @@ def _sock_recv(self):
else:
recvd.append(data)

except (BlockingIOError, InterruptedError): #, SSLWantReadError, SSLWantWriteError):
except (BlockingIOError, InterruptedError):
break
except BaseException as e:
log.exception('%s: Error receiving network data'
Expand Down Expand Up @@ -202,7 +203,7 @@ def _sock_send(self):
sent_bytes = self._sock.send(next_chunk)
total_bytes += sent_bytes
next_chunk = next_chunk[sent_bytes:]
except (BlockingIOError, InterruptedError): # SSLWantReadError, SSLWantWriteError):
except (BlockingIOError, InterruptedError):
self._write_buffer.appendleft(next_chunk)
return total_bytes, err
except BaseException as e:
Expand Down Expand Up @@ -316,5 +317,72 @@ def writeSequence(self, data):
"""
return self.writelines(data)

async def handshake(self):
pass

def __str__(self):
return ("<KafkaTCPTransport [%s:%d]" % self.getPeer()[0:2]) + (" (closed)>" if self._closed else ">")


class KafkaSSLTransport(KafkaTCPTransport):
def __init__(self, net, sock, ssl_context, server_hostname=None):
self._ssl_context = ssl_context
sock = ssl_context.wrap_socket(
sock, server_hostname=server_hostname, do_handshake_on_connect=False)
super().__init__(net, sock)

async def handshake(self):
while True:
try:
self._sock.do_handshake()
return
except ssl.SSLWantReadError:
await self._net.wait_read(self._sock)
except ssl.SSLWantWriteError:
await self._net.wait_write(self._sock)

def _sock_recv(self):
recvd = []
err = None
while True:
try:
data = self._sock.recv(4096)
if not data:
log.error('%s: socket disconnected', self)
err = Errors.KafkaConnectionError('socket disconnected')
break
else:
recvd.append(data)
except (BlockingIOError, InterruptedError,
ssl.SSLWantReadError, ssl.SSLWantWriteError):
break
except BaseException as e:
log.exception('%s: Error receiving network data'
' closing socket', self)
err = Errors.KafkaConnectionError(e)
break
recvd_data = b''.join(recvd)
return recvd_data, err

def _sock_send(self):
total_bytes = 0
err = None
while self._write_buffer:
next_chunk = self._write_buffer.popleft()
while next_chunk:
try:
sent_bytes = self._sock.send(next_chunk)
total_bytes += sent_bytes
next_chunk = next_chunk[sent_bytes:]
except (BlockingIOError, InterruptedError,
ssl.SSLWantReadError, ssl.SSLWantWriteError):
self._write_buffer.appendleft(next_chunk)
return total_bytes, err
except BaseException as e:
log.exception("%s: Error sending request data: %s", self, e)
err = Errors.KafkaConnectionError(e)
return total_bytes, err
return total_bytes, err

def __str__(self):
return ("<KafkaSSLTransport [%s:%d]" % self.getPeer()[0:2]) + (" (closed)>" if self._closed else ">")
Loading