Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 222 additions & 0 deletions internal/passkey/sessiontest/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
package sessiontest

import (
"context"
"testing"
"time"

"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/meigma/authkit/proof/passkey"
"github.com/meigma/authkit/proof/passkey/session"
)

const extensionCounter int64 = 42

// Run exercises a passkey ceremony session store's shared behavior.
//
//nolint:funlen // The suite keeps the full session-store contract readable in one file.
func Run(t *testing.T, newStore func(t *testing.T) session.Store) {
t.Helper()

t.Run("registration round trip", func(t *testing.T) {
store := newStore(t)
registration := testRegistration(futureExpiration())

id, err := store.PutRegistration(context.Background(), registration)
require.NoError(t, err)
assert.NotEmpty(t, id)

got, err := store.TakeRegistration(context.Background(), id)
require.NoError(t, err)
assertRegistration(t, registration, got)
})

t.Run("login round trip", func(t *testing.T) {
store := newStore(t)
login := testLogin(futureExpiration())

id, err := store.PutLogin(context.Background(), login)
require.NoError(t, err)
assert.NotEmpty(t, id)

got, err := store.TakeLogin(context.Background(), id)
require.NoError(t, err)
assertLogin(t, login, got)
})

t.Run("sessions are deleted on take", func(t *testing.T) {
store := newStore(t)
id, err := store.PutLogin(context.Background(), testLogin(futureExpiration()))
require.NoError(t, err)

_, err = store.TakeLogin(context.Background(), id)
require.NoError(t, err)
_, err = store.TakeLogin(context.Background(), id)
require.ErrorIs(t, err, session.ErrNotFound)
})

t.Run("expired sessions are deleted on take", func(t *testing.T) {
store := newStore(t)
id, err := store.PutLogin(context.Background(), testLogin(pastExpiration()))
require.NoError(t, err)

_, err = store.TakeLogin(context.Background(), id)
require.ErrorIs(t, err, session.ErrExpired)
_, err = store.TakeLogin(context.Background(), id)
require.ErrorIs(t, err, session.ErrNotFound)
})

t.Run("sessions without expiration are rejected", func(t *testing.T) {
store := newStore(t)

_, err := store.PutRegistration(context.Background(), testRegistration(time.Time{}))
require.ErrorIs(t, err, session.ErrExpired)

_, err = store.PutLogin(context.Background(), testLogin(time.Time{}))
require.ErrorIs(t, err, session.ErrExpired)
})

t.Run("expired sessions are pruned on put", func(t *testing.T) {
store := newStore(t)
expiredID, err := store.PutLogin(context.Background(), testLogin(pastExpiration()))
require.NoError(t, err)

activeID, err := store.PutLogin(context.Background(), testLogin(futureExpiration()))
require.NoError(t, err)

_, err = store.TakeLogin(context.Background(), expiredID)
require.ErrorIs(t, err, session.ErrNotFound)
_, err = store.TakeLogin(context.Background(), activeID)
require.NoError(t, err)
})

t.Run("wrong flow sessions are consumed on take", func(t *testing.T) {
store := newStore(t)
id, err := store.PutRegistration(context.Background(), testRegistration(futureExpiration()))
require.NoError(t, err)

_, err = store.TakeLogin(context.Background(), id)
require.ErrorIs(t, err, session.ErrNotFound)
_, err = store.TakeRegistration(context.Background(), id)
require.ErrorIs(t, err, session.ErrNotFound)
})

t.Run("input values are independent from stored values", func(t *testing.T) {
store := newStore(t)
registration := testRegistration(futureExpiration())

id, err := store.PutRegistration(context.Background(), registration)
require.NoError(t, err)
registration.User.Handle[0] = 'X'
registration.SessionData.UserID[0] = 'X'
registration.SessionData.AllowedCredentialIDs[0][0] = 'X'
nested, ok := registration.SessionData.Extensions["nested"].(map[string]any)
require.True(t, ok)
nested["value"] = "mutated"

got, err := store.TakeRegistration(context.Background(), id)
require.NoError(t, err)

assert.Equal(t, []byte("handle"), got.User.Handle)
assert.Equal(t, []byte("user-id"), got.SessionData.UserID)
assert.Equal(t, [][]byte{[]byte("credential-id")}, got.SessionData.AllowedCredentialIDs)
assert.Equal(t, []byte("extension"), got.SessionData.Extensions["bytes"])
assert.Equal(t, extensionCounter, got.SessionData.Extensions["counter"])
assert.Equal(t, true, got.SessionData.Extensions["enabled"])
assert.Equal(t, "extension", got.SessionData.Extensions["text"])
gotNested, ok := got.SessionData.Extensions["nested"].(map[string]any)
require.True(t, ok)
assert.Equal(t, []byte("nested"), gotNested["bytes"])
assert.Equal(t, "kept", gotNested["value"])
})

t.Run("canceled context is honored", func(t *testing.T) {
store := newStore(t)
ctx, cancel := context.WithCancel(context.Background())
cancel()

_, err := store.PutLogin(ctx, testLogin(futureExpiration()))
require.ErrorIs(t, err, context.Canceled)
_, err = store.TakeLogin(ctx, "session-id")
require.ErrorIs(t, err, context.Canceled)
})
}

func testRegistration(expires time.Time) session.Registration {
return session.Registration{
User: passkey.User{
RPID: "localhost",
PrincipalID: "principal-1",
Handle: []byte("handle"),
Name: "ada@example.test",
DisplayName: "Ada Lovelace",
},
SessionData: testSessionData(expires),
}
}

func assertRegistration(t *testing.T, want session.Registration, got session.Registration) {
t.Helper()

assert.Equal(t, want.User, got.User)
assertSessionData(t, want.SessionData, got.SessionData)
}

func assertLogin(t *testing.T, want session.Login, got session.Login) {
t.Helper()

assertSessionData(t, want.SessionData, got.SessionData)
}

func assertSessionData(t *testing.T, want webauthn.SessionData, got webauthn.SessionData) {
t.Helper()

assert.Equal(t, want.Challenge, got.Challenge)
assert.Equal(t, want.RelyingPartyID, got.RelyingPartyID)
assert.Equal(t, want.UserID, got.UserID)
assert.Equal(t, want.AllowedCredentialIDs, got.AllowedCredentialIDs)
assert.True(t, got.Expires.Equal(want.Expires), "expected expiration %s, got %s", want.Expires, got.Expires)
assert.Equal(t, want.UserVerification, got.UserVerification)
assert.Equal(t, want.Extensions, got.Extensions)
assert.Equal(t, want.CredParams, got.CredParams)
assert.Equal(t, want.Mediation, got.Mediation)
}

func testLogin(expires time.Time) session.Login {
return session.Login{
SessionData: testSessionData(expires),
}
}

func testSessionData(expires time.Time) webauthn.SessionData {
return webauthn.SessionData{
Challenge: "challenge",
RelyingPartyID: "localhost",
UserID: []byte("user-id"),
AllowedCredentialIDs: [][]byte{[]byte("credential-id")},
Expires: expires,
UserVerification: protocol.VerificationRequired,
Extensions: protocol.AuthenticationExtensions{
"bytes": []byte("extension"),
"counter": extensionCounter,
"enabled": true,
"text": "extension",
"nested": map[string]any{
"bytes": []byte("nested"),
"value": "kept",
},
},
}
}

func futureExpiration() time.Time {
return time.Now().UTC().Truncate(time.Second).Add(time.Hour)
}

func pastExpiration() time.Time {
return time.Now().UTC().Truncate(time.Second).Add(-time.Hour)
}
136 changes: 6 additions & 130 deletions proof/passkey/session/memory_test.go
Original file line number Diff line number Diff line change
@@ -1,131 +1,24 @@
package session_test

import (
"context"
"testing"
"time"

"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/meigma/authkit/internal/passkey/sessiontest"
"github.com/meigma/authkit/proof/passkey"
"github.com/meigma/authkit/proof/passkey/session"
)

func TestMemoryStoreRegistrationRoundTrip(t *testing.T) {
store := session.NewMemoryStore(session.WithClock(fixedTime))
registration := testRegistration(fixedTime().Add(time.Minute))
func TestMemoryStoreBehavior(t *testing.T) {
sessiontest.Run(t, func(t *testing.T) session.Store {
t.Helper()

id, err := store.PutRegistration(context.Background(), registration)
require.NoError(t, err)
assert.NotEmpty(t, id)

got, err := store.TakeRegistration(context.Background(), id)
require.NoError(t, err)

assert.Equal(t, registration, got)
}

func TestMemoryStoreLoginRoundTrip(t *testing.T) {
store := session.NewMemoryStore(session.WithClock(fixedTime))
login := testLogin(fixedTime().Add(time.Minute))

id, err := store.PutLogin(context.Background(), login)
require.NoError(t, err)
assert.NotEmpty(t, id)

got, err := store.TakeLogin(context.Background(), id)
require.NoError(t, err)

assert.Equal(t, login, got)
}

func TestMemoryStoreDeletesSessionsOnTake(t *testing.T) {
store := session.NewMemoryStore(session.WithClock(fixedTime))
id, err := store.PutLogin(context.Background(), testLogin(fixedTime().Add(time.Minute)))
require.NoError(t, err)

_, err = store.TakeLogin(context.Background(), id)
require.NoError(t, err)
_, err = store.TakeLogin(context.Background(), id)
require.ErrorIs(t, err, session.ErrNotFound)
}

func TestMemoryStoreDeletesExpiredSessionsOnTake(t *testing.T) {
store := session.NewMemoryStore(session.WithClock(fixedTime))
id, err := store.PutLogin(context.Background(), testLogin(fixedTime().Add(-time.Minute)))
require.NoError(t, err)

_, err = store.TakeLogin(context.Background(), id)
require.ErrorIs(t, err, session.ErrExpired)
_, err = store.TakeLogin(context.Background(), id)
require.ErrorIs(t, err, session.ErrNotFound)
}

func TestMemoryStoreRejectsSessionsWithoutExpiration(t *testing.T) {
store := session.NewMemoryStore(session.WithClock(fixedTime))

_, err := store.PutRegistration(context.Background(), testRegistration(time.Time{}))
require.ErrorIs(t, err, session.ErrExpired)

_, err = store.PutLogin(context.Background(), testLogin(time.Time{}))
require.ErrorIs(t, err, session.ErrExpired)
}

func TestMemoryStorePrunesExpiredSessionsOnPut(t *testing.T) {
now := fixedTime()
store := session.NewMemoryStore(session.WithClock(func() time.Time {
return now
}))
expiredID, err := store.PutLogin(context.Background(), testLogin(now.Add(time.Minute)))
require.NoError(t, err)

now = now.Add(2 * time.Minute)
activeID, err := store.PutLogin(context.Background(), testLogin(now.Add(time.Minute)))
require.NoError(t, err)

_, err = store.TakeLogin(context.Background(), expiredID)
require.ErrorIs(t, err, session.ErrNotFound)
_, err = store.TakeLogin(context.Background(), activeID)
require.NoError(t, err)
}

func TestMemoryStoreDeletesWrongFlowSessionsOnTake(t *testing.T) {
store := session.NewMemoryStore(session.WithClock(fixedTime))
id, err := store.PutRegistration(
context.Background(),
testRegistration(fixedTime().Add(time.Minute)),
)
require.NoError(t, err)

_, err = store.TakeLogin(context.Background(), id)
require.ErrorIs(t, err, session.ErrNotFound)
_, err = store.TakeRegistration(context.Background(), id)
require.ErrorIs(t, err, session.ErrNotFound)
}

func TestMemoryStoreClonesInputBeforeStoring(t *testing.T) {
store := session.NewMemoryStore(session.WithClock(fixedTime))
registration := testRegistration(fixedTime().Add(time.Minute))

id, err := store.PutRegistration(context.Background(), registration)
require.NoError(t, err)
registration.User.Handle[0] = 'X'
registration.SessionData.UserID[0] = 'X'
registration.SessionData.AllowedCredentialIDs[0][0] = 'X'
registration.SessionData.Extensions["bytes"].([]byte)[0] = 'X'
registration.SessionData.Extensions["nested"].(map[string]any)["bytes"].([]byte)[0] = 'X'

got, err := store.TakeRegistration(context.Background(), id)
require.NoError(t, err)

assert.Equal(t, []byte("handle"), got.User.Handle)
assert.Equal(t, []byte("user-id"), got.SessionData.UserID)
assert.Equal(t, [][]byte{[]byte("credential-id")}, got.SessionData.AllowedCredentialIDs)
assert.Equal(t, []byte("extension"), got.SessionData.Extensions["bytes"])
assert.Equal(t, []byte("nested"), got.SessionData.Extensions["nested"].(map[string]any)["bytes"])
return session.NewMemoryStore()
})
}

func TestSessionHelpersCloneBeginResults(t *testing.T) {
Expand All @@ -148,17 +41,6 @@ func TestSessionHelpersCloneBeginResults(t *testing.T) {
assert.Equal(t, [][]byte{[]byte("credential-id")}, login.SessionData.AllowedCredentialIDs)
}

func TestMemoryStoreHonorsCanceledContext(t *testing.T) {
store := session.NewMemoryStore(session.WithClock(fixedTime))
ctx, cancel := context.WithCancel(context.Background())
cancel()

_, err := store.PutLogin(ctx, testLogin(fixedTime().Add(time.Minute)))
require.ErrorIs(t, err, context.Canceled)
_, err = store.TakeLogin(ctx, "session-id")
require.ErrorIs(t, err, context.Canceled)
}

func testRegistration(expires time.Time) session.Registration {
return session.Registration{
User: passkey.User{
Expand All @@ -172,12 +54,6 @@ func testRegistration(expires time.Time) session.Registration {
}
}

func testLogin(expires time.Time) session.Login {
return session.Login{
SessionData: testSessionData(expires),
}
}

func testSessionData(expires time.Time) webauthn.SessionData {
return webauthn.SessionData{
Challenge: "challenge",
Expand Down
Loading