Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions packages/wabe/src/authentication/Session.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ describe('Session', () => {
const mockCreateObject = mock(() => Promise.resolve({ id: 'userId' })) as any
const mockDeleteObject = mock(() => Promise.resolve()) as any
const mockUpdateObject = mock(() => Promise.resolve()) as any
const mockUpdateObjects = mock(() => Promise.resolve([{ id: 'sessionId' }])) as any

const controllers = {
database: {
Expand All @@ -27,6 +28,7 @@ describe('Session', () => {
createObject: mockCreateObject,
deleteObject: mockDeleteObject,
updateObject: mockUpdateObject,
updateObjects: mockUpdateObjects,
},
}

Expand All @@ -36,6 +38,7 @@ describe('Session', () => {
mockCreateObject.mockClear()
mockDeleteObject.mockClear()
mockUpdateObject.mockClear()
mockUpdateObjects.mockClear()
})

const context = {
Expand Down Expand Up @@ -488,23 +491,38 @@ describe('Session', () => {
context: expect.any(Object),
})

expect(mockUpdateObject).toHaveBeenCalledTimes(1)
expect(mockUpdateObject).toHaveBeenCalledWith({
expect(mockUpdateObjects).toHaveBeenCalledTimes(1)
expect(mockUpdateObjects).toHaveBeenCalledWith({
className: '_Session',
context: expect.any(Object),
id: 'sessionId',
where: {
id: {
equalTo: 'sessionId',
},
accessTokenEncrypted: {
equalTo: encryptToken(oldAccessToken, 'dev'),
},
refreshTokenEncrypted: {
equalTo: encryptToken(oldRefreshToken, 'dev'),
},
refreshTokenExpiresAt: {
greaterThanOrEqualTo: expect.any(Date),
},
},
data: {
accessTokenEncrypted: expect.any(String),
accessTokenExpiresAt: expect.any(Date),
refreshTokenEncrypted: expect.any(String),
refreshTokenExpiresAt: expect.any(Date),
},
select: {},
select: { id: true },
first: 1,
})

const accessTokenExpiresAt = mockUpdateObject.mock.calls[0][0].data.accessTokenExpiresAt as Date
const accessTokenExpiresAt = mockUpdateObjects.mock.calls[0][0].data
.accessTokenExpiresAt as Date

const refreshTokenExpiresAt = mockUpdateObject.mock.calls[0][0].data
const refreshTokenExpiresAt = mockUpdateObjects.mock.calls[0][0].data
.refreshTokenExpiresAt as Date

// -1000 to avoid flaky
Expand Down Expand Up @@ -669,4 +687,26 @@ describe('Session', () => {
'Invalid refresh token',
)
})

it('should throw an error when refresh rotation update is stale', async () => {
const session = new Session<any>()
const { refreshToken, accessToken } = await session.create('userId', context)

mockGetObjects.mockResolvedValue([
{
id: 'sessionId',
refreshTokenEncrypted: encryptToken(refreshToken, 'dev'),
refreshTokenExpiresAt: new Date(Date.now() + 1000 * 60 * 60 * 24),
user: {
id: 'userId',
email: 'userEmail',
},
},
])
mockUpdateObjects.mockResolvedValue([])

await expect(session.refresh(accessToken, refreshToken, context)).rejects.toThrow(
'Invalid refresh token',
)
})
})
264 changes: 155 additions & 109 deletions packages/wabe/src/authentication/Session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ const getJwtVerifyOptions = <T extends WabeTypes>(context: WabeContext<T>) => {
export class Session<T extends WabeTypes> {
private accessToken: string | undefined = undefined
private refreshToken: string | undefined = undefined
private static refreshLocks = new Map<string, Promise<void>>()

private async acquireRefreshLock(lockKey: string) {
while (Session.refreshLocks.has(lockKey)) {
const existingLock = Session.refreshLocks.get(lockKey)
if (existingLock) await existingLock
}

let release = () => {}
const lockPromise = new Promise<void>((resolve) => {
release = resolve
})
Session.refreshLocks.set(lockKey, lockPromise)

return () => {
release()
if (Session.refreshLocks.get(lockKey) === lockPromise) Session.refreshLocks.delete(lockKey)
}
}

getAccessTokenExpireAt(config: WabeConfig<T>) {
const customExpiresInMs = config?.authentication?.session?.accessTokenExpiresInMs
Expand Down Expand Up @@ -371,138 +390,165 @@ export class Session<T extends WabeTypes> {
refreshToken,
getTokenEncryptionKey(context),
)
const releaseRefreshLock = await this.acquireRefreshLock(
`${accessTokenEncrypted}:${incomingRefreshTokenEncrypted}`,
)

const session = await context.wabe.controllers.database.getObjects({
className: '_Session',
// @ts-expect-error
where: {
accessTokenEncrypted: { equalTo: accessTokenEncrypted },
refreshTokenEncrypted: {
equalTo: incomingRefreshTokenEncrypted,
},
},
select: {
try {
const session = await context.wabe.controllers.database.getObjects({
className: '_Session',
// @ts-expect-error
id: true,
// @ts-expect-error
user: {
where: {
accessTokenEncrypted: { equalTo: accessTokenEncrypted },
refreshTokenEncrypted: {
equalTo: incomingRefreshTokenEncrypted,
},
},
select: {
// @ts-expect-error
id: true,
role: {
// @ts-expect-error
user: {
id: true,
name: true,
role: {
id: true,
name: true,
},
},
// @ts-expect-error
refreshTokenEncrypted: true,
// @ts-expect-error
refreshTokenExpiresAt: true,
},
// @ts-expect-error
refreshTokenEncrypted: true,
// @ts-expect-error
refreshTokenExpiresAt: true,
},
context: contextWithRoot(context),
})
context: contextWithRoot(context),
})

if (!session.length)
return {
accessToken: null,
refreshToken: null,
}
if (!session.length)
return {
accessToken: null,
refreshToken: null,
}

if (!session[0]) throw new Error('Session not found')
if (!session[0]) throw new Error('Session not found')

const {
refreshTokenExpiresAt,
user,
refreshTokenEncrypted: storedRefreshTokenEncrypted,
id,
} = session[0]
const {
refreshTokenExpiresAt,
user,
refreshTokenEncrypted: storedRefreshTokenEncrypted,
id,
} = session[0]

if (new Date(refreshTokenExpiresAt) < new Date(Date.now()))
throw new Error('Refresh token expired')
if (new Date(refreshTokenExpiresAt) < new Date(Date.now()))
throw new Error('Refresh token expired')

const decryptedRefreshToken =
decryptDeterministicToken(storedRefreshTokenEncrypted, getTokenEncryptionKey(context)) ||
refreshToken
const decryptedRefreshToken =
decryptDeterministicToken(storedRefreshTokenEncrypted, getTokenEncryptionKey(context)) ||
refreshToken

if (!decryptedRefreshToken || decryptedRefreshToken !== refreshToken)
throw new Error('Invalid refresh token')
if (!decryptedRefreshToken || decryptedRefreshToken !== refreshToken)
throw new Error('Invalid refresh token')

// Always rotate tokens on refresh
const userId = user?.id
// Always rotate tokens on refresh
const userId = user?.id

if (!userId)
return {
accessToken: null,
refreshToken: null,
}
if (!userId)
return {
accessToken: null,
refreshToken: null,
}

const jwtTokenFields = context.wabe.config.authentication?.session?.jwtTokenFields
const jwtTokenFields = context.wabe.config.authentication?.session?.jwtTokenFields

const result = jwtTokenFields
? await context.wabe.controllers.database.getObject({
className: 'User',
select: jwtTokenFields,
context,
id: userId,
})
: undefined
const result = jwtTokenFields
? await context.wabe.controllers.database.getObject({
className: 'User',
select: jwtTokenFields,
context,
id: userId,
})
: undefined

const nowSeconds = Math.floor(Date.now() / 1000)
const nowSeconds = Math.floor(Date.now() / 1000)

const signOptions: SignOptions = {
jwtid: crypto.randomUUID(),
algorithm: JWT_ALGORITHM,
}
const audience = context.wabe.config.authentication?.session?.jwtAudience
const issuer = context.wabe.config.authentication?.session?.jwtIssuer
if (audience) signOptions.audience = audience
if (issuer) signOptions.issuer = issuer
const signOptions: SignOptions = {
jwtid: crypto.randomUUID(),
algorithm: JWT_ALGORITHM,
}
const audience = context.wabe.config.authentication?.session?.jwtAudience
const issuer = context.wabe.config.authentication?.session?.jwtIssuer
if (audience) signOptions.audience = audience
if (issuer) signOptions.issuer = issuer

const newAccessToken = jwt.sign(
{
userId,
user: result,
iat: nowSeconds,
exp: Math.floor(this.getAccessTokenExpireAt(context.wabe.config).getTime() / 1000),
},
secretKey,
{ ...signOptions, algorithm: JWT_ALGORITHM },
)

const newAccessToken = jwt.sign(
{
userId,
user: result,
iat: nowSeconds,
exp: Math.floor(this.getAccessTokenExpireAt(context.wabe.config).getTime() / 1000),
},
secretKey,
{ ...signOptions, algorithm: JWT_ALGORITHM },
)
const newRefreshToken = jwt.sign(
{
userId,
user: result,
iat: nowSeconds,
exp: Math.floor(this.getRefreshTokenExpireAt(context.wabe.config).getTime() / 1000),
},
secretKey,
{ ...signOptions, algorithm: JWT_ALGORITHM },
)

const newRefreshToken = jwt.sign(
{
userId,
user: result,
iat: nowSeconds,
exp: Math.floor(this.getRefreshTokenExpireAt(context.wabe.config).getTime() / 1000),
},
secretKey,
{ ...signOptions, algorithm: JWT_ALGORITHM },
)
const newAccessTokenEncrypted = encryptDeterministicToken(
newAccessToken,
getTokenEncryptionKey(context),
)
const newRefreshTokenEncrypted = encryptDeterministicToken(
newRefreshToken,
getTokenEncryptionKey(context),
)

const newAccessTokenEncrypted = encryptDeterministicToken(
newAccessToken,
getTokenEncryptionKey(context),
)
const newRefreshTokenEncrypted = encryptDeterministicToken(
newRefreshToken,
getTokenEncryptionKey(context),
)
const updatedSessions = await context.wabe.controllers.database.updateObjects({
className: '_Session',
context: contextWithRoot(context),
// @ts-expect-error _Session where input is valid at runtime; WhereType is narrower than GraphQL filters.
where: {
id: {
equalTo: id,
},
accessTokenEncrypted: {
equalTo: accessTokenEncrypted,
},
refreshTokenEncrypted: {
equalTo: incomingRefreshTokenEncrypted,
},
refreshTokenExpiresAt: {
greaterThanOrEqualTo: new Date(),
},
},
data: {
accessTokenEncrypted: newAccessTokenEncrypted,
accessTokenExpiresAt: this.getAccessTokenExpireAt(context.wabe.config),
refreshTokenEncrypted: newRefreshTokenEncrypted,
refreshTokenExpiresAt: this.getRefreshTokenExpireAt(context.wabe.config),
} as any,
select: {
// @ts-expect-error
id: true,
},
first: 1,
})

await context.wabe.controllers.database.updateObject({
className: '_Session',
context: contextWithRoot(context),
id,
data: {
accessTokenEncrypted: newAccessTokenEncrypted,
accessTokenExpiresAt: this.getAccessTokenExpireAt(context.wabe.config),
refreshTokenEncrypted: newRefreshTokenEncrypted,
refreshTokenExpiresAt: this.getRefreshTokenExpireAt(context.wabe.config),
} as any,
select: {},
})
if (!updatedSessions.length) throw new Error('Invalid refresh token')

return {
accessToken: newAccessToken,
refreshToken: newRefreshToken,
return {
accessToken: newAccessToken,
refreshToken: newRefreshToken,
}
} finally {
releaseRefreshLock()
}
}

Expand Down
Loading
Loading