diff --git a/internal/passkey/sessiontest/session.go b/internal/passkey/sessiontest/session.go new file mode 100644 index 0000000..29c3d6e --- /dev/null +++ b/internal/passkey/sessiontest/session.go @@ -0,0 +1,222 @@ +package sessiontest + +import ( + "context" + "testing" + "time" + + "github.com/go-webauthn/webauthn/protocol" + "github.com/go-webauthn/webauthn/webauthn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/meigma/authkit/proof/passkey" + "github.com/meigma/authkit/proof/passkey/session" +) + +const extensionCounter int64 = 42 + +// Run exercises a passkey ceremony session store's shared behavior. +// +//nolint:funlen // The suite keeps the full session-store contract readable in one file. +func Run(t *testing.T, newStore func(t *testing.T) session.Store) { + t.Helper() + + t.Run("registration round trip", func(t *testing.T) { + store := newStore(t) + registration := testRegistration(futureExpiration()) + + id, err := store.PutRegistration(context.Background(), registration) + require.NoError(t, err) + assert.NotEmpty(t, id) + + got, err := store.TakeRegistration(context.Background(), id) + require.NoError(t, err) + assertRegistration(t, registration, got) + }) + + t.Run("login round trip", func(t *testing.T) { + store := newStore(t) + login := testLogin(futureExpiration()) + + id, err := store.PutLogin(context.Background(), login) + require.NoError(t, err) + assert.NotEmpty(t, id) + + got, err := store.TakeLogin(context.Background(), id) + require.NoError(t, err) + assertLogin(t, login, got) + }) + + t.Run("sessions are deleted on take", func(t *testing.T) { + store := newStore(t) + id, err := store.PutLogin(context.Background(), testLogin(futureExpiration())) + require.NoError(t, err) + + _, err = store.TakeLogin(context.Background(), id) + require.NoError(t, err) + _, err = store.TakeLogin(context.Background(), id) + require.ErrorIs(t, err, session.ErrNotFound) + }) + + t.Run("expired sessions are deleted on take", func(t *testing.T) { + store := newStore(t) + id, err := store.PutLogin(context.Background(), testLogin(pastExpiration())) + require.NoError(t, err) + + _, err = store.TakeLogin(context.Background(), id) + require.ErrorIs(t, err, session.ErrExpired) + _, err = store.TakeLogin(context.Background(), id) + require.ErrorIs(t, err, session.ErrNotFound) + }) + + t.Run("sessions without expiration are rejected", func(t *testing.T) { + store := newStore(t) + + _, err := store.PutRegistration(context.Background(), testRegistration(time.Time{})) + require.ErrorIs(t, err, session.ErrExpired) + + _, err = store.PutLogin(context.Background(), testLogin(time.Time{})) + require.ErrorIs(t, err, session.ErrExpired) + }) + + t.Run("expired sessions are pruned on put", func(t *testing.T) { + store := newStore(t) + expiredID, err := store.PutLogin(context.Background(), testLogin(pastExpiration())) + require.NoError(t, err) + + activeID, err := store.PutLogin(context.Background(), testLogin(futureExpiration())) + require.NoError(t, err) + + _, err = store.TakeLogin(context.Background(), expiredID) + require.ErrorIs(t, err, session.ErrNotFound) + _, err = store.TakeLogin(context.Background(), activeID) + require.NoError(t, err) + }) + + t.Run("wrong flow sessions are consumed on take", func(t *testing.T) { + store := newStore(t) + id, err := store.PutRegistration(context.Background(), testRegistration(futureExpiration())) + require.NoError(t, err) + + _, err = store.TakeLogin(context.Background(), id) + require.ErrorIs(t, err, session.ErrNotFound) + _, err = store.TakeRegistration(context.Background(), id) + require.ErrorIs(t, err, session.ErrNotFound) + }) + + t.Run("input values are independent from stored values", func(t *testing.T) { + store := newStore(t) + registration := testRegistration(futureExpiration()) + + id, err := store.PutRegistration(context.Background(), registration) + require.NoError(t, err) + registration.User.Handle[0] = 'X' + registration.SessionData.UserID[0] = 'X' + registration.SessionData.AllowedCredentialIDs[0][0] = 'X' + nested, ok := registration.SessionData.Extensions["nested"].(map[string]any) + require.True(t, ok) + nested["value"] = "mutated" + + got, err := store.TakeRegistration(context.Background(), id) + require.NoError(t, err) + + assert.Equal(t, []byte("handle"), got.User.Handle) + assert.Equal(t, []byte("user-id"), got.SessionData.UserID) + assert.Equal(t, [][]byte{[]byte("credential-id")}, got.SessionData.AllowedCredentialIDs) + assert.Equal(t, []byte("extension"), got.SessionData.Extensions["bytes"]) + assert.Equal(t, extensionCounter, got.SessionData.Extensions["counter"]) + assert.Equal(t, true, got.SessionData.Extensions["enabled"]) + assert.Equal(t, "extension", got.SessionData.Extensions["text"]) + gotNested, ok := got.SessionData.Extensions["nested"].(map[string]any) + require.True(t, ok) + assert.Equal(t, []byte("nested"), gotNested["bytes"]) + assert.Equal(t, "kept", gotNested["value"]) + }) + + t.Run("canceled context is honored", func(t *testing.T) { + store := newStore(t) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := store.PutLogin(ctx, testLogin(futureExpiration())) + require.ErrorIs(t, err, context.Canceled) + _, err = store.TakeLogin(ctx, "session-id") + require.ErrorIs(t, err, context.Canceled) + }) +} + +func testRegistration(expires time.Time) session.Registration { + return session.Registration{ + User: passkey.User{ + RPID: "localhost", + PrincipalID: "principal-1", + Handle: []byte("handle"), + Name: "ada@example.test", + DisplayName: "Ada Lovelace", + }, + SessionData: testSessionData(expires), + } +} + +func assertRegistration(t *testing.T, want session.Registration, got session.Registration) { + t.Helper() + + assert.Equal(t, want.User, got.User) + assertSessionData(t, want.SessionData, got.SessionData) +} + +func assertLogin(t *testing.T, want session.Login, got session.Login) { + t.Helper() + + assertSessionData(t, want.SessionData, got.SessionData) +} + +func assertSessionData(t *testing.T, want webauthn.SessionData, got webauthn.SessionData) { + t.Helper() + + assert.Equal(t, want.Challenge, got.Challenge) + assert.Equal(t, want.RelyingPartyID, got.RelyingPartyID) + assert.Equal(t, want.UserID, got.UserID) + assert.Equal(t, want.AllowedCredentialIDs, got.AllowedCredentialIDs) + assert.True(t, got.Expires.Equal(want.Expires), "expected expiration %s, got %s", want.Expires, got.Expires) + assert.Equal(t, want.UserVerification, got.UserVerification) + assert.Equal(t, want.Extensions, got.Extensions) + assert.Equal(t, want.CredParams, got.CredParams) + assert.Equal(t, want.Mediation, got.Mediation) +} + +func testLogin(expires time.Time) session.Login { + return session.Login{ + SessionData: testSessionData(expires), + } +} + +func testSessionData(expires time.Time) webauthn.SessionData { + return webauthn.SessionData{ + Challenge: "challenge", + RelyingPartyID: "localhost", + UserID: []byte("user-id"), + AllowedCredentialIDs: [][]byte{[]byte("credential-id")}, + Expires: expires, + UserVerification: protocol.VerificationRequired, + Extensions: protocol.AuthenticationExtensions{ + "bytes": []byte("extension"), + "counter": extensionCounter, + "enabled": true, + "text": "extension", + "nested": map[string]any{ + "bytes": []byte("nested"), + "value": "kept", + }, + }, + } +} + +func futureExpiration() time.Time { + return time.Now().UTC().Truncate(time.Second).Add(time.Hour) +} + +func pastExpiration() time.Time { + return time.Now().UTC().Truncate(time.Second).Add(-time.Hour) +} diff --git a/proof/passkey/session/memory_test.go b/proof/passkey/session/memory_test.go index ce0a7e6..5a2f0dd 100644 --- a/proof/passkey/session/memory_test.go +++ b/proof/passkey/session/memory_test.go @@ -1,131 +1,24 @@ package session_test import ( - "context" "testing" "time" "github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/webauthn" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/meigma/authkit/internal/passkey/sessiontest" "github.com/meigma/authkit/proof/passkey" "github.com/meigma/authkit/proof/passkey/session" ) -func TestMemoryStoreRegistrationRoundTrip(t *testing.T) { - store := session.NewMemoryStore(session.WithClock(fixedTime)) - registration := testRegistration(fixedTime().Add(time.Minute)) +func TestMemoryStoreBehavior(t *testing.T) { + sessiontest.Run(t, func(t *testing.T) session.Store { + t.Helper() - id, err := store.PutRegistration(context.Background(), registration) - require.NoError(t, err) - assert.NotEmpty(t, id) - - got, err := store.TakeRegistration(context.Background(), id) - require.NoError(t, err) - - assert.Equal(t, registration, got) -} - -func TestMemoryStoreLoginRoundTrip(t *testing.T) { - store := session.NewMemoryStore(session.WithClock(fixedTime)) - login := testLogin(fixedTime().Add(time.Minute)) - - id, err := store.PutLogin(context.Background(), login) - require.NoError(t, err) - assert.NotEmpty(t, id) - - got, err := store.TakeLogin(context.Background(), id) - require.NoError(t, err) - - assert.Equal(t, login, got) -} - -func TestMemoryStoreDeletesSessionsOnTake(t *testing.T) { - store := session.NewMemoryStore(session.WithClock(fixedTime)) - id, err := store.PutLogin(context.Background(), testLogin(fixedTime().Add(time.Minute))) - require.NoError(t, err) - - _, err = store.TakeLogin(context.Background(), id) - require.NoError(t, err) - _, err = store.TakeLogin(context.Background(), id) - require.ErrorIs(t, err, session.ErrNotFound) -} - -func TestMemoryStoreDeletesExpiredSessionsOnTake(t *testing.T) { - store := session.NewMemoryStore(session.WithClock(fixedTime)) - id, err := store.PutLogin(context.Background(), testLogin(fixedTime().Add(-time.Minute))) - require.NoError(t, err) - - _, err = store.TakeLogin(context.Background(), id) - require.ErrorIs(t, err, session.ErrExpired) - _, err = store.TakeLogin(context.Background(), id) - require.ErrorIs(t, err, session.ErrNotFound) -} - -func TestMemoryStoreRejectsSessionsWithoutExpiration(t *testing.T) { - store := session.NewMemoryStore(session.WithClock(fixedTime)) - - _, err := store.PutRegistration(context.Background(), testRegistration(time.Time{})) - require.ErrorIs(t, err, session.ErrExpired) - - _, err = store.PutLogin(context.Background(), testLogin(time.Time{})) - require.ErrorIs(t, err, session.ErrExpired) -} - -func TestMemoryStorePrunesExpiredSessionsOnPut(t *testing.T) { - now := fixedTime() - store := session.NewMemoryStore(session.WithClock(func() time.Time { - return now - })) - expiredID, err := store.PutLogin(context.Background(), testLogin(now.Add(time.Minute))) - require.NoError(t, err) - - now = now.Add(2 * time.Minute) - activeID, err := store.PutLogin(context.Background(), testLogin(now.Add(time.Minute))) - require.NoError(t, err) - - _, err = store.TakeLogin(context.Background(), expiredID) - require.ErrorIs(t, err, session.ErrNotFound) - _, err = store.TakeLogin(context.Background(), activeID) - require.NoError(t, err) -} - -func TestMemoryStoreDeletesWrongFlowSessionsOnTake(t *testing.T) { - store := session.NewMemoryStore(session.WithClock(fixedTime)) - id, err := store.PutRegistration( - context.Background(), - testRegistration(fixedTime().Add(time.Minute)), - ) - require.NoError(t, err) - - _, err = store.TakeLogin(context.Background(), id) - require.ErrorIs(t, err, session.ErrNotFound) - _, err = store.TakeRegistration(context.Background(), id) - require.ErrorIs(t, err, session.ErrNotFound) -} - -func TestMemoryStoreClonesInputBeforeStoring(t *testing.T) { - store := session.NewMemoryStore(session.WithClock(fixedTime)) - registration := testRegistration(fixedTime().Add(time.Minute)) - - id, err := store.PutRegistration(context.Background(), registration) - require.NoError(t, err) - registration.User.Handle[0] = 'X' - registration.SessionData.UserID[0] = 'X' - registration.SessionData.AllowedCredentialIDs[0][0] = 'X' - registration.SessionData.Extensions["bytes"].([]byte)[0] = 'X' - registration.SessionData.Extensions["nested"].(map[string]any)["bytes"].([]byte)[0] = 'X' - - got, err := store.TakeRegistration(context.Background(), id) - require.NoError(t, err) - - assert.Equal(t, []byte("handle"), got.User.Handle) - assert.Equal(t, []byte("user-id"), got.SessionData.UserID) - assert.Equal(t, [][]byte{[]byte("credential-id")}, got.SessionData.AllowedCredentialIDs) - assert.Equal(t, []byte("extension"), got.SessionData.Extensions["bytes"]) - assert.Equal(t, []byte("nested"), got.SessionData.Extensions["nested"].(map[string]any)["bytes"]) + return session.NewMemoryStore() + }) } func TestSessionHelpersCloneBeginResults(t *testing.T) { @@ -148,17 +41,6 @@ func TestSessionHelpersCloneBeginResults(t *testing.T) { assert.Equal(t, [][]byte{[]byte("credential-id")}, login.SessionData.AllowedCredentialIDs) } -func TestMemoryStoreHonorsCanceledContext(t *testing.T) { - store := session.NewMemoryStore(session.WithClock(fixedTime)) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - _, err := store.PutLogin(ctx, testLogin(fixedTime().Add(time.Minute))) - require.ErrorIs(t, err, context.Canceled) - _, err = store.TakeLogin(ctx, "session-id") - require.ErrorIs(t, err, context.Canceled) -} - func testRegistration(expires time.Time) session.Registration { return session.Registration{ User: passkey.User{ @@ -172,12 +54,6 @@ func testRegistration(expires time.Time) session.Registration { } } -func testLogin(expires time.Time) session.Login { - return session.Login{ - SessionData: testSessionData(expires), - } -} - func testSessionData(expires time.Time) webauthn.SessionData { return webauthn.SessionData{ Challenge: "challenge", diff --git a/store/postgres/doc.go b/store/postgres/doc.go index 682c8fb..8a44d2a 100644 --- a/store/postgres/doc.go +++ b/store/postgres/doc.go @@ -1,7 +1,8 @@ // Package postgres provides PostgreSQL storage for authkit. // // Store persists principals, local roles, provisioning rules, identity links, -// API tokens, and OIDC provider trust. Applications own connection pool -// configuration and must call Migrate explicitly before constructing a Store. -// NewStore only validates and wraps the supplied pool. +// API tokens, OIDC provider trust, passkeys, and temporary passkey ceremony +// sessions. Applications own connection pool configuration and must call +// Migrate explicitly before constructing a Store. NewStore only validates and +// wraps the supplied pool. package postgres diff --git a/store/postgres/migrations/000007_passkey_sessions.sql b/store/postgres/migrations/000007_passkey_sessions.sql new file mode 100644 index 0000000..bc32af0 --- /dev/null +++ b/store/postgres/migrations/000007_passkey_sessions.sql @@ -0,0 +1,14 @@ +CREATE TABLE IF NOT EXISTS authkit_passkey_sessions ( + id text PRIMARY KEY, + kind text NOT NULL CHECK (kind IN ('registration', 'login')), + payload jsonb NOT NULL CHECK (jsonb_typeof(payload) = 'object'), + expires_at timestamptz NOT NULL, + created_at timestamptz NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS authkit_passkey_sessions_expires_at_idx + ON authkit_passkey_sessions (expires_at); + +INSERT INTO authkit_schema_migrations (version, name) +VALUES (7, '000007_passkey_sessions') +ON CONFLICT (version) DO NOTHING; diff --git a/store/postgres/passkey_session.go b/store/postgres/passkey_session.go new file mode 100644 index 0000000..d31bd61 --- /dev/null +++ b/store/postgres/passkey_session.go @@ -0,0 +1,377 @@ +package postgres + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/go-webauthn/webauthn/webauthn" + "github.com/jackc/pgx/v5" + + "github.com/meigma/authkit/proof/passkey" + passkeysession "github.com/meigma/authkit/proof/passkey/session" +) + +const ( + passkeySessionIDBytes = 32 + passkeySessionIDAttempts = 4 +) + +type passkeySessionKind string + +const ( + passkeySessionKindRegistration passkeySessionKind = "registration" + passkeySessionKindLogin passkeySessionKind = "login" +) + +var _ passkeysession.Store = (*Store)(nil) + +// passkeySessionPayload is the JSONB record stored for one temporary WebAuthn +// ceremony. The WebAuthn session itself is stored as upstream MessagePack +// bytes so extension values are not reshaped by JSON decoding. +type passkeySessionPayload struct { + Registration *passkeyRegistrationSessionPayload `json:"registration,omitempty"` + Login *passkeyLoginSessionPayload `json:"login,omitempty"` +} + +// passkeyRegistrationSessionPayload is the tagged JSON shape for registration +// ceremony state. +type passkeyRegistrationSessionPayload struct { + User passkeyUserSessionPayload `json:"user"` + SessionData []byte `json:"session_data"` +} + +// passkeyLoginSessionPayload is the tagged JSON shape for login ceremony state. +type passkeyLoginSessionPayload struct { + SessionData []byte `json:"session_data"` +} + +// passkeyUserSessionPayload is the tagged JSON shape for a passkey user. +type passkeyUserSessionPayload struct { + RPID string `json:"rp_id"` + PrincipalID string `json:"principal_id"` + Handle []byte `json:"handle"` + Name string `json:"name"` + DisplayName string `json:"display_name"` +} + +// storedPasskeySession is the row shape returned by a one-time delete. +type storedPasskeySession struct { + kind passkeySessionKind + payload string + expired bool +} + +// PutRegistration stores registration and returns an opaque session ID. +func (s *Store) PutRegistration(ctx context.Context, registration passkeysession.Registration) (string, error) { + if err := ctx.Err(); err != nil { + return "", err + } + if registration.SessionData.Expires.IsZero() { + return "", passkeysession.ErrExpired + } + + payload, err := passkeySessionPayloadFromRegistration(registration) + if err != nil { + return "", err + } + + return s.putPasskeySession(ctx, passkeySessionKindRegistration, payload, registration.SessionData.Expires) +} + +// TakeRegistration returns and removes the registration session identified by id. +func (s *Store) TakeRegistration(ctx context.Context, id string) (passkeysession.Registration, error) { + if err := ctx.Err(); err != nil { + return passkeysession.Registration{}, err + } + + stored, err := s.takePasskeySession(ctx, id) + if err != nil { + return passkeysession.Registration{}, err + } + if stored.kind != passkeySessionKindRegistration { + return passkeysession.Registration{}, passkeysession.ErrNotFound + } + if stored.expired { + return passkeysession.Registration{}, passkeysession.ErrExpired + } + + payload, err := decodePasskeySessionPayload(stored.payload) + if err != nil { + return passkeysession.Registration{}, err + } + if payload.Registration == nil { + return passkeysession.Registration{}, errors.New("postgres: passkey registration session payload is empty") + } + + registration, err := payload.Registration.registration() + if err != nil { + return passkeysession.Registration{}, err + } + + return registration, nil +} + +// PutLogin stores login and returns an opaque session ID. +func (s *Store) PutLogin(ctx context.Context, login passkeysession.Login) (string, error) { + if err := ctx.Err(); err != nil { + return "", err + } + if login.SessionData.Expires.IsZero() { + return "", passkeysession.ErrExpired + } + + payload, err := passkeySessionPayloadFromLogin(login) + if err != nil { + return "", err + } + + return s.putPasskeySession(ctx, passkeySessionKindLogin, payload, login.SessionData.Expires) +} + +// TakeLogin returns and removes the login session identified by id. +func (s *Store) TakeLogin(ctx context.Context, id string) (passkeysession.Login, error) { + if err := ctx.Err(); err != nil { + return passkeysession.Login{}, err + } + + stored, err := s.takePasskeySession(ctx, id) + if err != nil { + return passkeysession.Login{}, err + } + if stored.kind != passkeySessionKindLogin { + return passkeysession.Login{}, passkeysession.ErrNotFound + } + if stored.expired { + return passkeysession.Login{}, passkeysession.ErrExpired + } + + payload, err := decodePasskeySessionPayload(stored.payload) + if err != nil { + return passkeysession.Login{}, err + } + if payload.Login == nil { + return passkeysession.Login{}, errors.New("postgres: passkey login session payload is empty") + } + + login, err := payload.Login.login() + if err != nil { + return passkeysession.Login{}, err + } + + return login, nil +} + +// passkeySessionPayloadFromRegistration returns the JSON DTO for registration. +func passkeySessionPayloadFromRegistration( + registration passkeysession.Registration, +) (passkeySessionPayload, error) { + sessionData, err := encodeWebAuthnSessionData(registration.SessionData) + if err != nil { + return passkeySessionPayload{}, err + } + payload := passkeyRegistrationSessionPayload{ + User: passkeyUserSessionPayloadFromUser(registration.User), + SessionData: sessionData, + } + + return passkeySessionPayload{Registration: &payload}, nil +} + +// passkeySessionPayloadFromLogin returns the JSON DTO for login. +func passkeySessionPayloadFromLogin(login passkeysession.Login) (passkeySessionPayload, error) { + sessionData, err := encodeWebAuthnSessionData(login.SessionData) + if err != nil { + return passkeySessionPayload{}, err + } + payload := passkeyLoginSessionPayload{SessionData: sessionData} + + return passkeySessionPayload{Login: &payload}, nil +} + +// passkeyUserSessionPayloadFromUser returns the JSON DTO for user. +func passkeyUserSessionPayloadFromUser(user passkey.User) passkeyUserSessionPayload { + return passkeyUserSessionPayload{ + RPID: user.RPID, + PrincipalID: user.PrincipalID, + Handle: user.Handle, + Name: user.Name, + DisplayName: user.DisplayName, + } +} + +// registration returns the public session value from payload. +func (p passkeyRegistrationSessionPayload) registration() (passkeysession.Registration, error) { + sessionData, err := decodeWebAuthnSessionData(p.SessionData) + if err != nil { + return passkeysession.Registration{}, err + } + + return passkeysession.Registration{ + User: p.User.user(), + SessionData: sessionData, + }, nil +} + +// login returns the public session value from payload. +func (p passkeyLoginSessionPayload) login() (passkeysession.Login, error) { + sessionData, err := decodeWebAuthnSessionData(p.SessionData) + if err != nil { + return passkeysession.Login{}, err + } + + return passkeysession.Login{ + SessionData: sessionData, + }, nil +} + +// user returns the public passkey user value from payload. +func (p passkeyUserSessionPayload) user() passkey.User { + return passkey.User{ + RPID: p.RPID, + PrincipalID: p.PrincipalID, + Handle: p.Handle, + Name: p.Name, + DisplayName: p.DisplayName, + } +} + +// putPasskeySession stores payload under a fresh opaque ID and prunes expired +// sessions first so abandoned ceremonies do not accumulate indefinitely. +func (s *Store) putPasskeySession( + ctx context.Context, + kind passkeySessionKind, + payload passkeySessionPayload, + expiresAt time.Time, +) (string, error) { + if err := pruneExpiredPasskeySessions(ctx, s.pool); err != nil { + return "", err + } + + encodedPayload, err := encodePasskeySessionPayload(payload) + if err != nil { + return "", err + } + + for range passkeySessionIDAttempts { + id, idErr := newPasskeySessionID() + if idErr != nil { + return "", idErr + } + + _, err = s.pool.Exec( + ctx, + `insert into authkit_passkey_sessions (id, kind, payload, expires_at) + values ($1, $2, $3::jsonb, $4)`, + id, + string(kind), + encodedPayload, + expiresAt, + ) + if err == nil { + return id, nil + } + if isPostgresCode(err, uniqueViolation) { + continue + } + + return "", fmt.Errorf("postgres: store passkey session: %w", err) + } + + return "", errors.New("postgres: generate unique passkey session ID") +} + +// takePasskeySession deletes and returns a passkey session row in one +// statement. The delete happens before callers check kind or expiration so +// wrong-flow and expired redemptions consume the session ID. +func (s *Store) takePasskeySession(ctx context.Context, id string) (storedPasskeySession, error) { + var stored storedPasskeySession + var kind string + err := s.pool.QueryRow( + ctx, + `delete from authkit_passkey_sessions + where id = $1 + returning kind, payload::text, expires_at <= now()`, + id, + ).Scan(&kind, &stored.payload, &stored.expired) + if errors.Is(err, pgx.ErrNoRows) { + return storedPasskeySession{}, passkeysession.ErrNotFound + } + if err != nil { + return storedPasskeySession{}, fmt.Errorf("postgres: take passkey session: %w", err) + } + + stored.kind = passkeySessionKind(kind) + + return stored, nil +} + +// pruneExpiredPasskeySessions removes sessions whose ceremony window has +// closed. It is called opportunistically on Put; no background worker is +// required for correctness. +func pruneExpiredPasskeySessions(ctx context.Context, exec sqlExecutor) error { + if _, err := exec.Exec(ctx, `delete from authkit_passkey_sessions where expires_at <= now()`); err != nil { + return fmt.Errorf("postgres: prune expired passkey sessions: %w", err) + } + + return nil +} + +// encodePasskeySessionPayload returns the JSON encoding of payload for +// storage as a JSONB column. +func encodePasskeySessionPayload(payload passkeySessionPayload) (string, error) { + encoded, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("postgres: encode passkey session: %w", err) + } + + return string(encoded), nil +} + +// decodePasskeySessionPayload parses a JSONB-encoded passkey session payload. +func decodePasskeySessionPayload(encoded string) (passkeySessionPayload, error) { + var payload passkeySessionPayload + if err := json.Unmarshal([]byte(encoded), &payload); err != nil { + return passkeySessionPayload{}, fmt.Errorf("postgres: decode passkey session: %w", err) + } + + return payload, nil +} + +// encodeWebAuthnSessionData returns the upstream MessagePack encoding of data. +func encodeWebAuthnSessionData(data webauthn.SessionData) ([]byte, error) { + encoded, err := data.MarshalMsg(nil) + if err != nil { + return nil, fmt.Errorf("postgres: encode WebAuthn session data: %w", err) + } + + return encoded, nil +} + +// decodeWebAuthnSessionData parses upstream MessagePack-encoded session data. +func decodeWebAuthnSessionData(encoded []byte) (webauthn.SessionData, error) { + var data webauthn.SessionData + left, err := data.UnmarshalMsg(encoded) + if err != nil { + return webauthn.SessionData{}, fmt.Errorf("postgres: decode WebAuthn session data: %w", err) + } + if len(left) != 0 { + return webauthn.SessionData{}, errors.New("postgres: decode WebAuthn session data: trailing bytes") + } + + return data, nil +} + +// newPasskeySessionID returns a cryptographically-random base64url session ID. +func newPasskeySessionID() (string, error) { + raw := make([]byte, passkeySessionIDBytes) + if _, err := rand.Read(raw); err != nil { + return "", fmt.Errorf("postgres: generate passkey session ID: %w", err) + } + + return base64.RawURLEncoding.EncodeToString(raw), nil +} diff --git a/store/postgres/store.go b/store/postgres/store.go index 4ebeac4..9f55992 100644 --- a/store/postgres/store.go +++ b/store/postgres/store.go @@ -17,10 +17,11 @@ const ( ) // Store persists authkit principals, roles, provisioning rules, identity -// links, API tokens, and OIDC provider trust in PostgreSQL. A single Store -// implements every authkit storage port (see `internal/storetest.Store`); -// domain methods are split across sibling files in this package by domain -// (principal, role, provisioning, identity, oidc, token, passkey). +// links, API tokens, OIDC provider trust, passkeys, and temporary passkey +// ceremony sessions in PostgreSQL. A single Store implements every authkit +// storage port (see `internal/storetest.Store`); domain methods are split +// across sibling files in this package by domain (principal, role, +// provisioning, identity, oidc, token, passkey). // // All exported methods are safe for concurrent use. Multi-write operations // (CreateProvisioningRule, UpdateProvisioningRule, ProvisionIdentity, and diff --git a/store/postgres/store_integration_test.go b/store/postgres/store_integration_test.go index c60edfa..a64b1d2 100644 --- a/store/postgres/store_integration_test.go +++ b/store/postgres/store_integration_test.go @@ -17,7 +17,9 @@ import ( "github.com/testcontainers/testcontainers-go/wait" "github.com/meigma/authkit" + "github.com/meigma/authkit/internal/passkey/sessiontest" "github.com/meigma/authkit/internal/storetest" + passkeysession "github.com/meigma/authkit/proof/passkey/session" ) func TestSharedStoreBehavior(t *testing.T) { @@ -84,6 +86,22 @@ func TestProvisionIdentityConcurrentCallsLeaveOnePrincipal(t *testing.T) { assert.Equal(t, 1, links) } +func TestPasskeySessionStoreBehavior(t *testing.T) { + ctx := context.Background() + pool := newPostgresPool(t) + require.NoError(t, Migrate(ctx, pool)) + + sessiontest.Run(t, func(t *testing.T) passkeysession.Store { + t.Helper() + resetStore(t, pool) + + store, err := NewStore(pool) + require.NoError(t, err) + + return store + }) +} + func TestMigrateCreatesSchema(t *testing.T) { ctx := context.Background() pool := newPostgresPool(t) @@ -104,6 +122,7 @@ func TestMigrateCreatesSchema(t *testing.T) { "authkit_provisioning_rule_roles", "authkit_passkey_users", "authkit_passkey_credentials", + "authkit_passkey_sessions", } { t.Run(table, func(t *testing.T) { var exists bool @@ -122,10 +141,10 @@ func TestMigrateCreatesSchema(t *testing.T) { } var migrationRows int - err := pool.QueryRow(ctx, `select count(*) from authkit_schema_migrations where version in (1, 2, 3, 4, 5, 6)`). + err := pool.QueryRow(ctx, `select count(*) from authkit_schema_migrations where version in (1, 2, 3, 4, 5, 6, 7)`). Scan(&migrationRows) require.NoError(t, err) - assert.Equal(t, 6, migrationRows) + assert.Equal(t, 7, migrationRows) } func TestMigrateConcurrentCalls(t *testing.T) { @@ -150,10 +169,10 @@ func TestMigrateConcurrentCalls(t *testing.T) { } var migrationRows int - err := pool.QueryRow(ctx, `select count(*) from authkit_schema_migrations where version in (1, 2, 3, 4, 5, 6)`). + err := pool.QueryRow(ctx, `select count(*) from authkit_schema_migrations where version in (1, 2, 3, 4, 5, 6, 7)`). Scan(&migrationRows) require.NoError(t, err) - assert.Equal(t, 6, migrationRows) + assert.Equal(t, 7, migrationRows) } func newPostgresPool(t *testing.T) *pgxpool.Pool { @@ -191,6 +210,7 @@ func resetStore(t *testing.T, pool *pgxpool.Pool) { _, err := pool.Exec( context.Background(), `truncate table + authkit_passkey_sessions, authkit_passkey_credentials, authkit_passkey_users, authkit_provisioning_rule_roles, diff --git a/testkit/README.md b/testkit/README.md index 553db19..624b22f 100644 --- a/testkit/README.md +++ b/testkit/README.md @@ -85,7 +85,7 @@ routes. ## Persistence -By default, testkit stores pastes in process memory. Restarting the server clears them. +By default, testkit stores pastes and passkey ceremony sessions in process memory. Restarting the server clears them. Set `TESTKIT_DATABASE_URL` to use PostgreSQL paste persistence instead: @@ -94,7 +94,7 @@ TESTKIT_DATABASE_URL='postgres://testkit:testkit@localhost:5432/testkit?sslmode= go run ./testkit/cmd/testkit ``` -When `TESTKIT_DATABASE_URL` is set, startup opens a Postgres pool, runs testkit's `testkit_*` paste migrations, runs authkit's Postgres migrations, stores paste data in `testkit_*` tables, and stores authkit principals/API tokens in `authkit_*` tables. +When `TESTKIT_DATABASE_URL` is set, startup opens a Postgres pool, runs testkit's `testkit_*` paste migrations, runs authkit's Postgres migrations, stores paste data in `testkit_*` tables, and stores authkit principals/API tokens/passkey ceremony sessions in `authkit_*` tables. Without `TESTKIT_DATABASE_URL`, both paste data and authkit state are in memory. diff --git a/testkit/cmd/testkit/main.go b/testkit/cmd/testkit/main.go index 6626698..d95b3e6 100644 --- a/testkit/cmd/testkit/main.go +++ b/testkit/cmd/testkit/main.go @@ -13,6 +13,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" + passkeysession "github.com/meigma/authkit/proof/passkey/session" authmemory "github.com/meigma/authkit/store/memory" authpostgres "github.com/meigma/authkit/store/postgres" "github.com/meigma/authkit/testkit/internal/authflow" @@ -72,7 +73,11 @@ func run(ctx context.Context, out io.Writer) error { if err != nil { return err } - handler, err := httpui.NewServer(pasteService, authRuntime) + handler, err := httpui.NewServer( + pasteService, + authRuntime, + httpui.WithPasskeySessionStore(stores.passkeySessions), + ) if err != nil { return err } @@ -112,8 +117,9 @@ func run(ctx context.Context, out io.Writer) error { } type stores struct { - pastes paste.Repository - auth authflow.Store + pastes paste.Repository + auth authflow.Store + passkeySessions passkeysession.Store } func newStores(ctx context.Context) (stores, func(), error) { @@ -125,8 +131,9 @@ func newStores(ctx context.Context) (stores, func(), error) { } return stores{ - pastes: testkitmemory.NewStore(), - auth: authStore, + pastes: testkitmemory.NewStore(), + auth: authStore, + passkeySessions: passkeysession.NewMemoryStore(), }, func() {}, nil } @@ -164,7 +171,8 @@ func newStores(ctx context.Context) (stores, func(), error) { } return stores{ - pastes: pasteStore, - auth: authStore, + pastes: pasteStore, + auth: authStore, + passkeySessions: authStore, }, pool.Close, nil } diff --git a/testkit/cmd/testkit/main_test.go b/testkit/cmd/testkit/main_test.go index 8d47350..962e03e 100644 --- a/testkit/cmd/testkit/main_test.go +++ b/testkit/cmd/testkit/main_test.go @@ -18,6 +18,7 @@ import ( "github.com/meigma/authkit" "github.com/meigma/authkit/proof/oidc" + passkeysession "github.com/meigma/authkit/proof/passkey/session" "github.com/meigma/authkit/testkit/internal/authflow" ) @@ -256,6 +257,17 @@ func TestNewStoresTrustsConfiguredOIDCProvider(t *testing.T) { assert.Equal(t, result.Principal.ID, result.AccessToken.PrincipalID) } +func TestNewStoresUsesMemoryPasskeySessionStoreByDefault(t *testing.T) { + ctx := context.Background() + + stores, cleanup, err := newStores(ctx) + require.NoError(t, err) + t.Cleanup(cleanup) + + _, ok := stores.passkeySessions.(*passkeysession.MemoryStore) + assert.True(t, ok) +} + func mapGetenv(values map[string]string) func(string) string { return func(key string) string { return values[key] diff --git a/testkit/internal/httpui/server.go b/testkit/internal/httpui/server.go index 6cf71c3..2d23411 100644 --- a/testkit/internal/httpui/server.go +++ b/testkit/internal/httpui/server.go @@ -32,6 +32,24 @@ type Server struct { templates *templateSet } +type options struct { + passkeySessions passkeysession.Store +} + +// Option configures a Server. +type Option func(*options) + +// WithPasskeySessionStore configures the store used for temporary passkey +// registration and login ceremony sessions. Nil leaves the default in-memory +// store in place. +func WithPasskeySessionStore(store passkeysession.Store) Option { + return func(opts *options) { + if store != nil { + opts.passkeySessions = store + } + } +} + type authRuntime interface { ExchangeAPIToken(ctx context.Context, token string) (exchange.APITokenResult, error) ExchangeOIDCToken(ctx context.Context, plaintext string) (exchange.IdentityResult, error) @@ -58,15 +76,15 @@ type authRuntime interface { } // NewServer constructs a testkit HTTP UI server. -func NewServer(pastes *paste.Service, auth *authflow.Runtime) (*Server, error) { +func NewServer(pastes *paste.Service, auth *authflow.Runtime, opts ...Option) (*Server, error) { if auth == nil { return nil, errors.New("httpui: auth runtime is required") } - return newServer(pastes, auth) + return newServer(pastes, auth, opts...) } -func newServer(pastes *paste.Service, auth authRuntime) (*Server, error) { +func newServer(pastes *paste.Service, auth authRuntime, opts ...Option) (*Server, error) { if pastes == nil { return nil, errors.New("httpui: paste service is required") } @@ -84,10 +102,17 @@ func newServer(pastes *paste.Service, auth authRuntime) (*Server, error) { return nil, fmt.Errorf("httpui: prepare static assets: %w", err) } + cfg := defaultOptions() + for _, opt := range opts { + if opt != nil { + opt(&cfg) + } + } + server := &Server{ auth: auth, csrf: newCSRFProtector(), - passkeySessions: passkeysession.NewMemoryStore(), + passkeySessions: cfg.passkeySessions, pastes: pastes, templates: templates, } @@ -121,6 +146,12 @@ func newServer(pastes *paste.Service, auth authRuntime) (*Server, error) { return server, nil } +func defaultOptions() options { + return options{ + passkeySessions: passkeysession.NewMemoryStore(), + } +} + // ServeHTTP serves an HTTP request. func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { s.handler.ServeHTTP(w, req) diff --git a/testkit/internal/httpui/server_test.go b/testkit/internal/httpui/server_test.go index f29882a..373966a 100644 --- a/testkit/internal/httpui/server_test.go +++ b/testkit/internal/httpui/server_test.go @@ -78,6 +78,20 @@ func TestServerRendersPublicPages(t *testing.T) { assert.NotEmpty(t, findCookie(t, loginRecorder, csrfCookieName).Value) } +func TestServerUsesDefaultPasskeySessionStore(t *testing.T) { + server := newFakeAuthServer(t, &fakeAuthRuntime{}) + + _, ok := server.passkeySessions.(*passkeysession.MemoryStore) + assert.True(t, ok) +} + +func TestServerUsesConfiguredPasskeySessionStore(t *testing.T) { + store := passkeysession.NewMemoryStore() + server := newFakeAuthServer(t, &fakeAuthRuntime{}, WithPasskeySessionStore(store)) + + assert.Same(t, store, server.passkeySessions) +} + func TestServerRequiresAuthenticationForPasteCreation(t *testing.T) { server := newTestServer(t, testPasteID) @@ -1005,7 +1019,7 @@ func newTestServerWithAuth( } } -func newFakeAuthServer(t *testing.T, auth *fakeAuthRuntime) *Server { +func newFakeAuthServer(t *testing.T, auth *fakeAuthRuntime, opts ...Option) *Server { t.Helper() sequence := sequentialIDs(testPasteID) @@ -1015,7 +1029,7 @@ func newFakeAuthServer(t *testing.T, auth *fakeAuthRuntime) *Server { paste.WithClock(fixedTime), ) require.NoError(t, err) - server, err := newServer(service, auth) + server, err := newServer(service, auth, opts...) require.NoError(t, err) return server