Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions src/internal.c
Original file line number Diff line number Diff line change
Expand Up @@ -13349,10 +13349,34 @@ int SendUserAuthKeyboardRequest(WOLFSSH* ssh, WS_UserAuthData* authData)
}

if (ret == WS_SUCCESS) {
ret = ssh->ctx->keyboardAuthCb(&authData->sf.keyboard,
ssh->keyboardAuthCtx);
/* Set responseCount to 0 to indicate this is a prompt setup call */
authData->sf.keyboard.responseCount = 0;

/* First try using userAuthCb if it's set */
if (ssh->ctx->userAuthCb != NULL) {
WLOG(WS_LOG_DEBUG, "SUAKR: Calling userAuthCb for prompt setup");
ret = ssh->ctx->userAuthCb(WOLFSSH_USERAUTH_KEYBOARD,
authData, ssh->userAuthCtx);

/* If userAuthCb doesn't return SUCCESS_ANOTHER, fall back to keyboardAuthCb */
if (ret != WOLFSSH_USERAUTH_SUCCESS_ANOTHER) {
WLOG(WS_LOG_DEBUG, "SUAKR: userAuthCb didn't return SUCCESS_ANOTHER, falling back");
ret = ssh->ctx->keyboardAuthCb(&authData->sf.keyboard,
ssh->keyboardAuthCtx);
}
else {
WLOG(WS_LOG_DEBUG, "SUAKR: userAuthCb returned SUCCESS_ANOTHER, proceeding");
ret = WS_SUCCESS;
}
}
else {
/* Fall back to keyboardAuthCb if userAuthCb is not set */
ret = ssh->ctx->keyboardAuthCb(&authData->sf.keyboard,
ssh->keyboardAuthCtx);
}
}

/* Only check for NULL pointers if we actually have prompts */
if (authData->sf.keyboard.promptCount > 0 &&
(authData->sf.keyboard.prompts == NULL ||
authData->sf.keyboard.promptLengths == NULL ||
Expand Down
133 changes: 106 additions & 27 deletions tests/auth.c
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ word32 kbResponseCount;
byte kbMultiRound = 0;
byte currentRound = 0;
byte unbalanced = 0;
byte useUserAuthCb = 0; /* Flag to test userAuthCb for keyboard-interactive */

WS_UserAuthData_Keyboard promptData;

Expand Down Expand Up @@ -223,38 +224,73 @@ static int load_key(byte isEcc, byte* buf, word32 bufSz)
static int serverUserAuth(byte authType, WS_UserAuthData* authData, void* ctx)
{
(void) ctx;
if (authType != WOLFSSH_USERAUTH_KEYBOARD) {
return WOLFSSH_USERAUTH_FAILURE;
}

if (authData->sf.keyboard.responseCount != kbResponseCount) {
return WOLFSSH_USERAUTH_FAILURE;
}

for (word32 resp = 0; resp < kbResponseCount; resp++) {
if (authData->sf.keyboard.responseLengths[resp] !=
kbResponseLengths[resp]) {

/* Handle keyboard-interactive auth */
if (authType == WOLFSSH_USERAUTH_KEYBOARD) {
/* If responseCount is 0, this is a prompt setup call */
if (authData->sf.keyboard.responseCount == 0) {
/* Set up prompts - only copy the necessary fields, not the entire structure */
authData->sf.keyboard.promptCount = promptData.promptCount;
authData->sf.keyboard.promptName = promptData.promptName;
authData->sf.keyboard.promptNameSz = promptData.promptNameSz;
authData->sf.keyboard.promptInstruction = promptData.promptInstruction;
authData->sf.keyboard.promptInstructionSz = promptData.promptInstructionSz;
authData->sf.keyboard.promptLanguage = promptData.promptLanguage;
authData->sf.keyboard.promptLanguageSz = promptData.promptLanguageSz;
authData->sf.keyboard.prompts = promptData.prompts;
authData->sf.keyboard.promptLengths = promptData.promptLengths;
authData->sf.keyboard.promptEcho = promptData.promptEcho;

/* Return SUCCESS_ANOTHER to proceed with sending prompts */
if (useUserAuthCb) {
return WOLFSSH_USERAUTH_SUCCESS_ANOTHER;
}
/* When not testing userAuthCb, return FAILURE to fall back to keyboardAuthCb */
return WOLFSSH_USERAUTH_FAILURE;

}
if (WSTRCMP((const char*)authData->sf.keyboard.responses[resp],
(const char*)kbResponses[resp]) != 0) {

/* Validate responses */
if (authData->sf.keyboard.responseCount != kbResponseCount) {
return WOLFSSH_USERAUTH_FAILURE;
}

for (word32 resp = 0; resp < kbResponseCount; resp++) {
if (authData->sf.keyboard.responseLengths[resp] !=
kbResponseLengths[resp]) {
return WOLFSSH_USERAUTH_FAILURE;

}
if (WSTRCMP((const char*)authData->sf.keyboard.responses[resp],
(const char*)kbResponses[resp]) != 0) {
return WOLFSSH_USERAUTH_FAILURE;
}
}
if (kbMultiRound && currentRound == 0) {
currentRound++;
kbResponses[0] = (byte*)testText2;
kbResponseLengths[0] = 8;
return WOLFSSH_USERAUTH_SUCCESS_ANOTHER;
}
return WOLFSSH_USERAUTH_SUCCESS;
}
if (kbMultiRound && currentRound == 0) {
currentRound++;
kbResponses[0] = (byte*)testText2;
kbResponseLengths[0] = 8;
return WOLFSSH_USERAUTH_SUCCESS_ANOTHER;
}
return WOLFSSH_USERAUTH_SUCCESS;

return WOLFSSH_USERAUTH_FAILURE;
}

static int serverKeyboardCallback(WS_UserAuthData_Keyboard *kbAuth, void *ctx)
{
(void) ctx;
WMEMCPY(kbAuth, &promptData, sizeof(WS_UserAuthData_Keyboard));
/* Copy individual fields instead of the entire structure to avoid memory issues */
kbAuth->promptCount = promptData.promptCount;
kbAuth->promptName = promptData.promptName;
kbAuth->promptNameSz = promptData.promptNameSz;
kbAuth->promptInstruction = promptData.promptInstruction;
kbAuth->promptInstructionSz = promptData.promptInstructionSz;
kbAuth->promptLanguage = promptData.promptLanguage;
kbAuth->promptLanguageSz = promptData.promptLanguageSz;
kbAuth->prompts = promptData.prompts;
kbAuth->promptLengths = promptData.promptLengths;
kbAuth->promptEcho = promptData.promptEcho;

return WS_SUCCESS;
}
Expand Down Expand Up @@ -332,7 +368,12 @@ static THREAD_RETURN WOLFSSH_THREAD server_thread(void* args)
}

wolfSSH_SetUserAuth(ctx, serverUserAuth);
wolfSSH_SetKeyboardAuthPrompts(ctx, serverKeyboardCallback);

/* Only set keyboard auth callback when not testing userAuthCb */
if (!useUserAuthCb) {
wolfSSH_SetKeyboardAuthPrompts(ctx, serverKeyboardCallback);
}

ssh = wolfSSH_new(ctx);
if (ssh == NULL) {
ES_ERROR("Couldn't allocate SSH data.\n");
Expand Down Expand Up @@ -394,16 +435,24 @@ static int keyboardUserAuth(byte authType, WS_UserAuthData* authData, void* ctx)

if (authType == WOLFSSH_USERAUTH_KEYBOARD) {
AssertIntEQ(kbResponseCount, authData->sf.keyboard.promptCount);
for (word32 prompt = 0; prompt < kbResponseCount; prompt++) {
AssertStrEQ("Password: ", authData->sf.keyboard.prompts[prompt]);

/* Only check prompts if there are any */
if (kbResponseCount > 0) {
for (word32 prompt = 0; prompt < kbResponseCount; prompt++) {
AssertStrEQ("Password: ", authData->sf.keyboard.prompts[prompt]);
}
}

authData->sf.keyboard.responseCount = kbResponseCount;
if (unbalanced) {
authData->sf.keyboard.responseCount++;
}
authData->sf.keyboard.responseLengths = kbResponseLengths;
authData->sf.keyboard.responses = (byte**)kbResponses;

/* Only set response pointers if there are responses */
if (kbResponseCount > 0) {
authData->sf.keyboard.responseLengths = kbResponseLengths;
authData->sf.keyboard.responses = (byte**)kbResponses;
}
ret = WS_SUCCESS;
}
return ret;
Expand Down Expand Up @@ -574,6 +623,34 @@ static void test_unbalanced_client_KeyboardInteractive(void)
test_client();
unbalanced = 0;
}

static void test_userAuthCb_KeyboardInteractive(void)
{
printf("Testing keyboard-interactive auth via userAuthCb\n");
kbResponses[0] = (byte*)testText1;
kbResponseLengths[0] = 4;
kbResponseCount = 1;
useUserAuthCb = 1;

test_client();
useUserAuthCb = 0;
}

static void test_userAuthCb_multi_round_KeyboardInteractive(void)
{
printf("Testing multiple prompt rounds via userAuthCb\n");
kbResponses[0] = (byte*)testText1;
kbResponseLengths[0] = 4;
kbResponseCount = 1;
kbMultiRound = 1;
useUserAuthCb = 1;

test_client();
AssertIntEQ(currentRound, 1);
currentRound = 0;
kbMultiRound = 0;
useUserAuthCb = 0;
}
#endif /* WOLFSSH_TEST_BLOCK */

int wolfSSH_AuthTest(int argc, char** argv)
Expand Down Expand Up @@ -603,6 +680,8 @@ int wolfSSH_AuthTest(int argc, char** argv)
test_multi_prompt_KeyboardInteractive();
test_multi_round_KeyboardInteractive();
test_unbalanced_client_KeyboardInteractive();
test_userAuthCb_KeyboardInteractive();
test_userAuthCb_multi_round_KeyboardInteractive();

AssertIntEQ(wolfSSH_Cleanup(), WS_SUCCESS);

Expand Down
7 changes: 7 additions & 0 deletions wolfssh/ssh.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,13 @@ typedef struct WS_UserAuthData {
} sf;
} WS_UserAuthData;

/* User Authentication callback
* For keyboard-interactive authentication:
* - When responseCount is 0, the callback is being called to set up prompts
* Return WOLFSSH_USERAUTH_SUCCESS_ANOTHER to proceed with sending prompts
* - When responseCount > 0, the callback is being called to validate responses
* Return WOLFSSH_USERAUTH_SUCCESS_ANOTHER to request more prompts
*/
typedef int (*WS_CallbackUserAuth)(byte, WS_UserAuthData*, void*);
WOLFSSH_API void wolfSSH_SetUserAuth(WOLFSSH_CTX*, WS_CallbackUserAuth);
typedef int (*WS_CallbackUserAuthTypes)(WOLFSSH* ssh, void* ctx);
Expand Down
Loading