Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion kafka/net/connection.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import collections
import copy
import logging
import struct
import time

import kafka.errors as Errors
from kafka.future import Future
from kafka.protocol.metadata import ApiVersionsRequest
from kafka.protocol.sasl import SaslAuthenticateRequest, SaslHandshakeRequest, SaslBytesRequest
from kafka.protocol.broker_version_data import BrokerVersionData
from kafka.protocol.parser import KafkaProtocol
from kafka.sasl import get_sasl_mechanism
from kafka.version import __version__


Expand All @@ -21,6 +24,14 @@ class KafkaConnection:
'client_software_version': __version__,
'request_timeout_ms': 30000,
'max_in_flight_requests_per_connection': 5,
'security_protocol': 'PLAINTEXT',
'sasl_mechanism': None,
'sasl_plain_username': None,
'sasl_plain_password': None,
'sasl_kerberos_name': None,
'sasl_kerberos_service_name': 'kafka',
'sasl_kerberos_domain_name': None,
'sasl_oauth_token_provider': None,
}

def __init__(self, net, node_id=None, **configs):
Expand Down Expand Up @@ -336,11 +347,75 @@ async def _check_version(self):
for api_version in response.api_keys}
self.broker_version_data = BrokerVersionData(api_versions=api_versions)
log.info('%s: Broker version identified as %s', self, '.'.join(map(str, self.broker_version)))
self._init_complete()
if self.sasl_enabled:
await self._sasl_authenticate()
if self.initializing:
self._init_complete()
return

self.close(Errors.KafkaConnectionError('Unable to determine broker version.'))

@property
def sasl_enabled(self):
return self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL')

async def _sasl_authenticate(self):
# Step 1: SaslHandshake to negotiate mechanism
version = self.broker_version_data.api_version(SaslHandshakeRequest, max_version=1)
request = SaslHandshakeRequest[version](self.config['sasl_mechanism'])
try:
response = await self._send_request(request)
except Exception as exc:
self.close(Errors.KafkaConnectionError('SaslHandshake failed: %s' % exc))
return

error_type = Errors.for_code(response.error_code)
if error_type is not Errors.NoError:
self.close(error_type())
return

if self.config['sasl_mechanism'] not in response.mechanisms:
self.close(Errors.UnsupportedSaslMechanismError(
'Kafka broker does not support %s sasl mechanism. Enabled mechanisms: %s'
% (self.config['sasl_mechanism'], response.mechanisms)))
return

# Step 2: SASL authentication exchange
mechanism = get_sasl_mechanism(self.config['sasl_mechanism'])(
host=self.transport.getPeer()[0], **self.config)

while not mechanism.is_done():
token = mechanism.auth_bytes()
if version == 1:
auth_request = SaslAuthenticateRequest[0](token)
else:
auth_request = SaslBytesRequest(token)

try:
auth_response = await self._send_request(auth_request)
except Exception as exc:
self.close(Errors.KafkaConnectionError('SaslAuthenticate failed: %s' % exc))
return

error_type = Errors.for_code(auth_response.error_code)
if error_type is not Errors.NoError:
self.close(Errors.SaslAuthenticationFailedError(
'%s: %s' % (error_type.__name__, auth_response.error_message)))
return

# GSSAPI does not get a final recv in v0 unframed mode
if version == 0 and mechanism.is_done():
break

mechanism.receive(auth_response.auth_bytes)

if not mechanism.is_authenticated():
self.close(Errors.SaslAuthenticationFailedError(
'Failed to authenticate via SASL %s' % self.config['sasl_mechanism']))
return

log.info('%s: %s', self, mechanism.auth_details())

def _init_complete(self):
self.initializing = False
self.connected = True
Expand Down
7 changes: 7 additions & 0 deletions kafka/net/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ class KafkaConnectionManager:
'ssl_keyfile': None,
'ssl_password': None,
'ssl_crlfile': None,
'sasl_mechanism': None,
'sasl_plain_username': None,
'sasl_plain_password': None,
'sasl_kerberos_name': None,
'sasl_kerberos_service_name': 'kafka',
'sasl_kerberos_domain_name': None,
'sasl_oauth_token_provider': None,
}
def __init__(self, net, cluster, **configs):
self.config = copy.copy(self.DEFAULT_CONFIG)
Expand Down
38 changes: 38 additions & 0 deletions kafka/protocol/sasl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import struct

from .api_message import ApiMessage


Expand All @@ -8,7 +10,43 @@ class SaslAuthenticateRequest(ApiMessage): pass
class SaslAuthenticateResponse(ApiMessage): pass


class SaslBytesRequest:
"""Request for raw SASL v0 exchange -- length-prefixed raw bytes."""
API_VERSION = 0

def __init__(self, data):
self._data = data
self.header = None

def with_header(self, correlation_id=None, **kwargs):
self.header = SaslBytesResponse(correlation_id)

def encode(self, framed=True, header=True):
return struct.pack('>I', len(self._data)) + self._data

def expect_response(self):
return True


class SaslBytesResponse:
"""Response for raw SASL v0 exchange -- returns bytes as-is."""
def __init__(self, correlation_id):
self.correlation_id = correlation_id
self.error_code = 0

def parse_header(self, read_buffer):
return self

def decode(self, read_buffer):
self.auth_bytes = read_buffer.read()
return self

def get_response_class(self):
return self


__all__ = [
'SaslHandshakeRequest', 'SaslHandshakeResponse',
'SaslAuthenticateRequest', 'SaslAuthenticateResponse',
'SaslBytesRequest', 'SaslBytesResponse',
]
8 changes: 7 additions & 1 deletion kafka/protocol/sasl.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ from typing import Any, Self
from kafka.protocol.api_message import ApiMessage
from kafka.protocol.api_data import ApiData

__all__ = ['SaslHandshakeRequest', 'SaslHandshakeResponse', 'SaslAuthenticateRequest', 'SaslAuthenticateResponse']
__all__ = ['SaslHandshakeRequest', 'SaslHandshakeResponse', 'SaslAuthenticateRequest', 'SaslAuthenticateResponse', 'SaslBytesRequest', 'SaslBytesResponse']

class SaslHandshakeRequest(ApiMessage):
mechanism: str
Expand Down Expand Up @@ -118,3 +118,9 @@ class SaslAuthenticateResponse(ApiMessage):
def is_request(cls) -> bool: ...
def expect_response(self) -> bool: ...
def with_header(self, correlation_id: int = 0, client_id: str = "kafka-python") -> None: ...

class SaslBytesRequest:
API_VERSION: int

class SaslBytesResponse:
...
Loading