From 2653e816bbcbbeca954b1200700d0e2adcceb8e1 Mon Sep 17 00:00:00 2001 From: ekzyis Date: Fri, 12 Jul 2024 13:26:27 +0200 Subject: [PATCH] Implement LNURL-auth --- db/db.go | 4 +- db/init.sql | 80 +------ db/schema.sql | 79 +++++++ docker-compose.yml | 1 + public/css/_tw-input.css | 10 +- server/auth/lnauth.go | 68 +++--- server/router/context/context.go | 2 + server/router/handler/auth.go | 191 ++++++++++++++++ server/router/handler/lnauth_test.go | 264 +++++++++++++++++++++-- server/router/handler/login.go | 13 -- server/router/pages/auth.templ | 66 ++++++ server/router/pages/components/nav.templ | 2 +- server/router/pages/lnAuth.templ | 33 +++ server/router/pages/login.templ | 40 +++- server/router/router.go | 8 +- 15 files changed, 718 insertions(+), 143 deletions(-) create mode 100644 db/schema.sql create mode 100644 server/router/handler/auth.go delete mode 100644 server/router/handler/login.go create mode 100644 server/router/pages/auth.templ create mode 100644 server/router/pages/lnAuth.templ diff --git a/db/db.go b/db/db.go index 31429ab..b9732ae 100644 --- a/db/db.go +++ b/db/db.go @@ -13,7 +13,7 @@ type DB struct { } var ( - initSqlPath = "./db/init.sql" + schemaPath = "./db/schema.sql" ) func New(dbUrl string) (*DB, error) { @@ -42,7 +42,7 @@ func (db *DB) Reset(dbName string) error { if err = db.Clear(dbName); err != nil { return err } - if f, err = ioutil.ReadFile(initSqlPath); err != nil { + if f, err = ioutil.ReadFile(schemaPath); err != nil { return err } if _, err = db.Exec(string(f)); err != nil { diff --git a/db/init.sql b/db/init.sql index fe05bb0..eac30ae 100644 --- a/db/init.sql +++ b/db/init.sql @@ -1,79 +1 @@ -CREATE TABLE users( - id SERIAL PRIMARY KEY, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, - ln_pubkey TEXT, - nostr_pubkey TEXT, - msats BIGINT NOT NULL DEFAULT 0 -); - -CREATE TABLE sessions( - user_id INTEGER NOT NULL REFERENCES users(id), - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, - session_id VARCHAR(48) -); - -CREATE TABLE lnauth( - k1 VARCHAR(64) PRIMARY KEY, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, - lnurl TEXT NOT NULL, - session_id VARCHAR(48) NOT NULL DEFAULT encode(gen_random_uuid()::text::bytea, 'base64') -); - -CREATE TABLE invoices( - id SERIAL PRIMARY KEY, - user_id INTEGER NOT NULL REFERENCES users(id), - msats BIGINT NOT NULL, - msats_received BIGINT, - preimage TEXT NOT NULL UNIQUE, - hash TEXT NOT NULL UNIQUE, - bolt11 TEXT NOT NULL, - created_at TIMESTAMP WITH TIME ZONE NOT NULL, - expires_at TIMESTAMP WITH TIME ZONE NOT NULL, - confirmed_at TIMESTAMP WITH TIME ZONE, - held_since TIMESTAMP WITH TIME ZONE, - description TEXT -); - -CREATE TABLE markets( - id SERIAL PRIMARY KEY, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, - description TEXT NOT NULL, - end_date TIMESTAMP WITH TIME ZONE NOT NULL, - settled_at TIMESTAMP WITH TIME ZONE, - user_id INTEGER NOT NULL REFERENCES users(id), - invoice_id INTEGER NOT NULL UNIQUE REFERENCES invoices(id) -); - -CREATE TABLE shares( - id SERIAL PRIMARY KEY, - market_id INTEGER NOT NULL REFERENCES markets(id), - description TEXT NOT NULL, - win BOOLEAN -); - -CREATE TYPE order_side AS ENUM ('BUY', 'SELL'); - -CREATE TABLE orders( - id SERIAL PRIMARY KEY, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, - deleted_at TIMESTAMP WITH TIME ZONE, - share_id INTEGER NOT NULL REFERENCES shares(id), - user_id INTEGER NOT NULL REFERENCES users(id), - side ORDER_SIDE NOT NULL, - quantity BIGINT NOT NULL, - price BIGINT NOT NULL, - invoice_id INTEGER REFERENCES invoices(id), - order_id INTEGER REFERENCES orders(id) -); - -ALTER TABLE orders ADD CONSTRAINT order_price CHECK(price > 0 AND price < 100); -ALTER TABLE orders ADD CONSTRAINT order_quantity CHECK(quantity > 0); - -CREATE TABLE withdrawals( - id SERIAL PRIMARY KEY, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, - canceled_at TIMESTAMP WITH TIME ZONE, - user_id INTEGER NOT NULL REFERENCES users(id), - bolt11 TEXT NOT NULL UNIQUE, - paid_at TIMESTAMP WITH TIME ZONE -); +CREATE DATABASE "delphi_test"; diff --git a/db/schema.sql b/db/schema.sql new file mode 100644 index 0000000..e2f2e31 --- /dev/null +++ b/db/schema.sql @@ -0,0 +1,79 @@ +CREATE TABLE users( + id SERIAL PRIMARY KEY, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + ln_pubkey TEXT UNIQUE, + nostr_pubkey TEXT UNIQUE, + msats BIGINT NOT NULL DEFAULT 0 +); + +CREATE TABLE sessions( + id VARCHAR(48) PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE lnauth( + k1 VARCHAR(64) PRIMARY KEY, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + lnurl TEXT NOT NULL, + session_id VARCHAR(48) NOT NULL DEFAULT encode(gen_random_uuid()::text::bytea, 'base64') +); + +CREATE TABLE invoices( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id), + msats BIGINT NOT NULL, + msats_received BIGINT, + preimage TEXT NOT NULL UNIQUE, + hash TEXT NOT NULL UNIQUE, + bolt11 TEXT NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL, + confirmed_at TIMESTAMP WITH TIME ZONE, + held_since TIMESTAMP WITH TIME ZONE, + description TEXT +); + +CREATE TABLE markets( + id SERIAL PRIMARY KEY, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + description TEXT NOT NULL, + end_date TIMESTAMP WITH TIME ZONE NOT NULL, + settled_at TIMESTAMP WITH TIME ZONE, + user_id INTEGER NOT NULL REFERENCES users(id), + invoice_id INTEGER NOT NULL UNIQUE REFERENCES invoices(id) +); + +CREATE TABLE shares( + id SERIAL PRIMARY KEY, + market_id INTEGER NOT NULL REFERENCES markets(id), + description TEXT NOT NULL, + win BOOLEAN +); + +CREATE TYPE order_side AS ENUM ('BUY', 'SELL'); + +CREATE TABLE orders( + id SERIAL PRIMARY KEY, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + deleted_at TIMESTAMP WITH TIME ZONE, + share_id INTEGER NOT NULL REFERENCES shares(id), + user_id INTEGER NOT NULL REFERENCES users(id), + side ORDER_SIDE NOT NULL, + quantity BIGINT NOT NULL, + price BIGINT NOT NULL, + invoice_id INTEGER REFERENCES invoices(id), + order_id INTEGER REFERENCES orders(id) +); + +ALTER TABLE orders ADD CONSTRAINT order_price CHECK(price > 0 AND price < 100); +ALTER TABLE orders ADD CONSTRAINT order_quantity CHECK(quantity > 0); + +CREATE TABLE withdrawals( + id SERIAL PRIMARY KEY, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + canceled_at TIMESTAMP WITH TIME ZONE, + user_id INTEGER NOT NULL REFERENCES users(id), + bolt11 TEXT NOT NULL UNIQUE, + paid_at TIMESTAMP WITH TIME ZONE +); diff --git a/docker-compose.yml b/docker-compose.yml index be365d7..44880dc 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -13,6 +13,7 @@ services: volumes: - delphi:/var/lib/postgresql/data - ./db/init.sql:/docker-entrypoint-initdb.d/init.sql + - ./db/schema.sql:/docker-entrypoint-initdb.d/schema.sql - ./postgresql.conf:/var/lib/postgresql/data/postgresql.conf # for some reason this can't be mounted on first run volumes: diff --git a/public/css/_tw-input.css b/public/css/_tw-input.css index 2e0291c..a9ba781 100644 --- a/public/css/_tw-input.css +++ b/public/css/_tw-input.css @@ -29,14 +29,14 @@ @apply pb-1; } - a, + a:not(.no-link), button[hx-get], button[hx-post] { text-decoration: underline; transition: background-color 150ms ease-in, color 150ms ease-in; } - a:hover, + a:not(.no-link):hover, button[hx-get]:hover, button[hx-post]:hover { background-color: var(--color); @@ -70,10 +70,12 @@ @apply my-3 } - .login { + .login, .signup { + text-decoration: none !important; + transition: none !important; + padding: 0.25em 1em !important; width: fit-content; margin: 0 auto; - padding: 0.25em 1em; border-radius: 5px; font-weight: bold; } diff --git a/server/auth/lnauth.go b/server/auth/lnauth.go index 43f190c..87ddf6c 100644 --- a/server/auth/lnauth.go +++ b/server/auth/lnauth.go @@ -8,58 +8,74 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcutil/bech32" + "github.com/decred/dcrd/dcrec/secp256k1/v4" "git.ekzyis.com/ekzyis/delphi.market/env" ) -type LNAuth struct { +type LnAuth struct { K1 string LNURL string } -type LNAuthResponse struct { - K1 string `query:"k1"` - Sig string `query:"sig"` - Key string `query:"key"` +type LnAuthCallback struct { + K1 string `query:"k1"` + Sig string `query:"sig"` + Key string `query:"key"` + Action string `query:"action"` } -func NewLNAuth() (*LNAuth, error) { - k1 := make([]byte, 32) - _, err := rand.Read(k1) - if err != nil { +func NewLnAuth(action string) (*LnAuth, error) { + var ( + k1 = make([]byte, 32) + k1hex string + url []byte + bech32Url []byte + lnurl string + err error + ) + + if _, err := rand.Read(k1); err != nil { return nil, fmt.Errorf("rand.Read error: %w", err) } - k1hex := hex.EncodeToString(k1) - url := []byte(fmt.Sprintf("https://%s/api/login/callback?tag=login&k1=%s&action=login", env.PublicURL, k1hex)) - conv, err := bech32.ConvertBits(url, 8, 5, true) - if err != nil { + + k1hex = hex.EncodeToString(k1) + url = []byte(fmt.Sprintf("https://%s/api/lnauth/callback?tag=login&k1=%s&action=%s", env.PublicURL, k1hex, action)) + + if bech32Url, err = bech32.ConvertBits(url, 8, 5, true); err != nil { return nil, fmt.Errorf("bech32.ConvertBits error: %w", err) } - lnurl, err := bech32.Encode("lnurl", conv) - if err != nil { + + if lnurl, err = bech32.Encode("lnurl", bech32Url); err != nil { return nil, fmt.Errorf("bech32.Encode error: %w", err) } - return &LNAuth{k1hex, lnurl}, nil + + return &LnAuth{k1hex, lnurl}, nil } -func VerifyLNAuth(r *LNAuthResponse) (bool, error) { - var k1Bytes, sigBytes, keyBytes []byte - k1Bytes, err := hex.DecodeString(r.K1) - if err != nil { +func VerifyLNAuth(r *LnAuthCallback) (bool, error) { + var ( + k1Bytes, sigBytes, keyBytes []byte + key *secp256k1.PublicKey + err error + ) + + if k1Bytes, err = hex.DecodeString(r.K1); err != nil { return false, fmt.Errorf("k1 decode error: %w", err) } - sigBytes, err = hex.DecodeString(r.Sig) - if err != nil { + + if sigBytes, err = hex.DecodeString(r.Sig); err != nil { return false, fmt.Errorf("sig decode error: %w", err) } - keyBytes, err = hex.DecodeString(r.Key) - if err != nil { + + if keyBytes, err = hex.DecodeString(r.Key); err != nil { return false, fmt.Errorf("key decode error: %w", err) } - key, err := btcec.ParsePubKey(keyBytes) - if err != nil { + + if key, err = btcec.ParsePubKey(keyBytes); err != nil { return false, fmt.Errorf("key parse error: %w", err) } + ecdsaKey := ecdsa.PublicKey{Curve: btcec.S256(), X: key.X(), Y: key.Y()} return ecdsa.VerifyASN1(&ecdsaKey, k1Bytes, sigBytes), nil } diff --git a/server/router/context/context.go b/server/router/context/context.go index 4bf1d6d..445beda 100644 --- a/server/router/context/context.go +++ b/server/router/context/context.go @@ -25,6 +25,7 @@ var ( EnvContextKey RenderContextKey = "env" SessionContextKey RenderContextKey = "session" CommitContextKey RenderContextKey = "commit" + ReqPathContextKey RenderContextKey = "reqPath" ) func RenderContext(sc Context, c echo.Context) context.Context { @@ -32,5 +33,6 @@ func RenderContext(sc Context, c echo.Context) context.Context { ctx = context.WithValue(ctx, EnvContextKey, sc.Environment) ctx = context.WithValue(ctx, SessionContextKey, c.Get("session")) ctx = context.WithValue(ctx, CommitContextKey, sc.CommitShortSha) + ctx = context.WithValue(ctx, ReqPathContextKey, c.Request().URL.Path) return ctx } diff --git a/server/router/handler/auth.go b/server/router/handler/auth.go new file mode 100644 index 0000000..ad7bad2 --- /dev/null +++ b/server/router/handler/auth.go @@ -0,0 +1,191 @@ +package handler + +import ( + "database/sql" + "net/http" + "time" + + "git.ekzyis.com/ekzyis/delphi.market/lib" + "git.ekzyis.com/ekzyis/delphi.market/server/auth" + "git.ekzyis.com/ekzyis/delphi.market/server/router/context" + "git.ekzyis.com/ekzyis/delphi.market/server/router/pages" + "github.com/labstack/echo/v4" + "github.com/lib/pq" +) + +func HandleAuth(sc context.Context, action string) echo.HandlerFunc { + return func(c echo.Context) error { + + if c.Param("method") == "lightning" { + return LnAuth(sc, c, action) + } + + return pages.Auth(mapAction(action)).Render(context.RenderContext(sc, c), c.Response().Writer) + } +} + +func LnAuth(sc context.Context, c echo.Context, action string) error { + var ( + db = sc.Db + ctx = c.Request().Context() + lnAuth *auth.LnAuth + sessionId string + // sessions expire in 30 days. TODO: refresh sessions + expires = time.Now().Add(60 * 60 * 24 * 30 * time.Second) + qr string + err error + ) + + if lnAuth, err = auth.NewLnAuth(action); err != nil { + return err + } + + if err = db.QueryRowContext( + ctx, + "INSERT INTO lnauth(k1, lnurl) VALUES($1, $2) RETURNING session_id", + lnAuth.K1, lnAuth.LNURL).Scan(&sessionId); err != nil { + return err + } + + if qr, err = lib.ToQR(lnAuth.LNURL); err != nil { + return err + } + + c.SetCookie(&http.Cookie{ + Name: "session", + HttpOnly: true, + Path: "/", + Value: sessionId, + Secure: true, + Expires: expires, + }) + + return pages.LnAuth(qr, lnAuth.LNURL, mapAction(action)).Render(context.RenderContext(sc, c), c.Response().Writer) +} + +func HandleLnAuthCallback(sc context.Context) echo.HandlerFunc { + return func(c echo.Context) error { + var ( + db = sc.Db + tx *sql.Tx + ctx = c.Request().Context() + query auth.LnAuthCallback + sessionId string + userId int + ok bool + err error + pqErr *pq.Error + ) + + bail := func(code int, reason string) error { + c.JSON(code, map[string]string{"status": "ERROR", "reason": reason}) + return nil + } + + if err = c.Bind(&query); err != nil { + return bail(http.StatusInternalServerError, err.Error()) + } else if query.K1 == "" || query.Sig == "" || query.Key == "" || query.Action == "" { + return bail(http.StatusBadRequest, "bad query") + } + + if tx, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted}); err != nil { + return bail(http.StatusInternalServerError, err.Error()) + } + + err = tx.QueryRow("SELECT session_id FROM lnauth WHERE k1 = $1 LIMIT 1", query.K1).Scan(&sessionId) + if err == sql.ErrNoRows { + tx.Rollback() + return bail(http.StatusNotFound, "session not found") + } else if err != nil { + tx.Rollback() + return bail(http.StatusInternalServerError, err.Error()) + } + + ok, err = auth.VerifyLNAuth(&query) + if err != nil { + tx.Rollback() + return bail(http.StatusInternalServerError, err.Error()) + } else if !ok { + tx.Rollback() + return bail(http.StatusBadRequest, "bad signature") + } + + if query.Action == "register" { + err = tx.QueryRow("INSERT INTO users(ln_pubkey) VALUES ($1) RETURNING id", query.Key).Scan(&userId) + if err != nil { + tx.Rollback() + pqErr, ok = err.(*pq.Error) + if ok && pqErr.Code == "23505" { + return bail(http.StatusBadRequest, "user already exists") + } + return bail(http.StatusInternalServerError, err.Error()) + } + } else if query.Action == "login" { + err = tx.QueryRow("SELECT id FROM users WHERE ln_pubkey = $1", query.Key).Scan(&userId) + if err == sql.ErrNoRows { + tx.Rollback() + return bail(http.StatusNotFound, "user not found") + } else if err != nil { + tx.Rollback() + return bail(http.StatusInternalServerError, err.Error()) + } + } else { + return bail(http.StatusBadRequest, "bad action") + } + + if _, err = tx.Exec("INSERT INTO sessions(id, user_id) VALUES($1, $2)", sessionId, userId); err != nil { + tx.Rollback() + return bail(http.StatusInternalServerError, err.Error()) + } + + if _, err = tx.Exec("DELETE FROM lnauth WHERE k1 = $1", query.K1); err != nil { + tx.Rollback() + return bail(http.StatusInternalServerError, err.Error()) + } + + if err = tx.Commit(); err != nil { + tx.Rollback() + return bail(http.StatusInternalServerError, err.Error()) + } + + return c.JSON(http.StatusOK, map[string]string{"status": "OK"}) + } +} + +func HandleSessionCheck(sc context.Context) echo.HandlerFunc { + return func(c echo.Context) error { + var ( + db = sc.Db + ctx = c.Request().Context() + cookie *http.Cookie + userId int + err error + ) + + if cookie, err = c.Cookie("session"); err != nil { + return c.JSON(http.StatusUnauthorized, "no session cookie") + } + + if err = db.QueryRowContext(ctx, + "SELECT user_id FROM sessions WHERE id = $1", cookie.Value). + Scan(&userId); err != nil { + return c.JSON(http.StatusNotFound, "session not found") + } + + c.Response().Header().Set("HX-Location", "/") + // htmx requires a 200 response to follow redirects + // see https://github.com/bigskysoftware/htmx/issues/2052 + return c.HTML(http.StatusOK, "/") + } +} + +func mapAction(action string) string { + // LNURL spec uses "register" but we want to show "signup" to the user + // see https://github.com/lnurl/luds/blob/luds/04.md + switch action { + case "register": + return "signup" + default: + return action + } +} diff --git a/server/router/handler/lnauth_test.go b/server/router/handler/lnauth_test.go index 2832bfb..0cb923b 100644 --- a/server/router/handler/lnauth_test.go +++ b/server/router/handler/lnauth_test.go @@ -1,13 +1,17 @@ package handler_test import ( + "encoding/hex" + "fmt" "net/http" "net/http/httptest" "testing" + "git.ekzyis.com/ekzyis/delphi.market/server/auth" "git.ekzyis.com/ekzyis/delphi.market/server/router/context" "git.ekzyis.com/ekzyis/delphi.market/server/router/handler" "git.ekzyis.com/ekzyis/delphi.market/test" + "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) @@ -16,12 +20,12 @@ func init() { test.Init(&db) } -func TestLnAuth(t *testing.T) { +func TestLnAuthSignup(t *testing.T) { var ( assert = assert.New(t) + sc = context.Context{Db: db} e *echo.Echo c echo.Context - sc context.Context req *http.Request rec *httptest.ResponseRecorder cookies []*http.Cookie @@ -29,29 +33,263 @@ func TestLnAuth(t *testing.T) { dbSessionId string err error ) + + e, req, rec = test.HTTPMocks("GET", "/signup/lightning", nil) + c = e.NewContext(req, rec) + c.SetParamNames("method") + c.SetParamValues("lightning") + + err = handler.HandleAuth(sc, "register")(c) + assert.NoErrorf(err, "handler returned error") + + // Set-Cookie header present + cookies = rec.Result().Cookies() + assert.Equalf(1, len(cookies), "wrong number of Set-Cookie headers") + assert.Equalf("session", cookies[0].Name, "wrong cookie name") + + // new challenge inserted which matches cookie value + sessionId = cookies[0].Value + err = db.QueryRow("SELECT session_id FROM lnauth WHERE session_id = $1", sessionId).Scan(&dbSessionId) + assert.NoError(err) + assert.Equalf(sessionId, dbSessionId, "wrong session id") +} + +func TestLnAuthSignupCallbackUserNotExists(t *testing.T) { + var ( + assert = assert.New(t) + e *echo.Echo + c echo.Context + sc context.Context + req *http.Request + rec *httptest.ResponseRecorder + lnAuth *auth.LnAuth + sk *secp256k1.PrivateKey + pk *secp256k1.PublicKey + sig string + key string + sessionId string + userId int + count int + err error + ) + + lnAuth, err = auth.NewLnAuth("register") + assert.NoErrorf(err, "error creating challenge") + + err = db.QueryRow( + "INSERT INTO lnauth(k1, lnurl) VALUES($1, $2) RETURNING session_id", + lnAuth.K1, lnAuth.LNURL).Scan(&sessionId) + assert.NoErrorf(err, "error creating challenge") + + sk, pk, err = test.GenerateKeyPair() + assert.NoErrorf(err, "error generating keypair") + + sig, err = test.Sign(sk, lnAuth.K1) + assert.NoErrorf(err, "error signing k1") + + key = hex.EncodeToString(pk.SerializeCompressed()) + sc = context.Context{Db: db} - e, req, rec = test.HTTPMocks("GET", "/login", nil) + e, req, rec = test.HTTPMocks("GET", + fmt.Sprintf("/api/login?tag=login&k1=%s&key=%s&sig=%s&action=%s", lnAuth.K1, key, sig, "register"), nil) c = e.NewContext(req, rec) - err = handler.HandleLogin(sc)(c) + err = handler.HandleLnAuthCallback(sc)(c) + assert.NoErrorf(err, "handler returned error") + + // user created + err = db.QueryRow("SELECT id FROM users WHERE ln_pubkey = $1", key).Scan(&userId) + assert.NoErrorf(err, "error fetching user") + + // session created + err = db.QueryRow("SELECT COUNT(1) FROM sessions WHERE id = $1 AND user_id = $2", sessionId, userId).Scan(&count) + assert.NoErrorf(err, "error fetching session") + assert.Equalf(1, count, "invalid session count") + + // challenge deleted + err = db.QueryRow("SELECT COUNT(1) FROM lnauth WHERE k1 = $1", lnAuth.K1).Scan(&count) + assert.NoErrorf(err, "error fetching challenge") + assert.Equalf(count, 0, "challenge not deleted") +} + +func TestLnAuthSignupCallbackUserExists(t *testing.T) { + var ( + assert = assert.New(t) + e *echo.Echo + c echo.Context + sc context.Context + req *http.Request + rec *httptest.ResponseRecorder + lnAuth *auth.LnAuth + sk *secp256k1.PrivateKey + pk *secp256k1.PublicKey + sig string + key string + sessionId string + err error + ) + + lnAuth, err = auth.NewLnAuth("register") + assert.NoErrorf(err, "error creating challenge") + + err = db.QueryRow( + "INSERT INTO lnauth(k1, lnurl) VALUES($1, $2) RETURNING session_id", + lnAuth.K1, lnAuth.LNURL).Scan(&sessionId) + assert.NoErrorf(err, "error creating challenge") + + sk, pk, err = test.GenerateKeyPair() + assert.NoErrorf(err, "error generating keypair") + + sig, err = test.Sign(sk, lnAuth.K1) + assert.NoErrorf(err, "error signing k1") + + key = hex.EncodeToString(pk.SerializeCompressed()) + + // create user such that signup must fail + _, err = db.Exec("INSERT INTO users(ln_pubkey) VALUES($1) RETURNING id", key) + assert.NoErrorf(err, "error creating user") + + sc = context.Context{Db: db} + e, req, rec = test.HTTPMocks("GET", + fmt.Sprintf("/api/login?tag=login&k1=%s&key=%s&sig=%s&action=%s", lnAuth.K1, key, sig, "register"), nil) + c = e.NewContext(req, rec) + + // must throw error because user already exists + err = handler.HandleLnAuthCallback(sc)(c) + assert.ErrorContains(err, "user already exists", "user check failed") +} + +func TestLnAuthLogin(t *testing.T) { + var ( + assert = assert.New(t) + sc = context.Context{Db: db} + e *echo.Echo + c echo.Context + req *http.Request + rec *httptest.ResponseRecorder + cookies []*http.Cookie + sessionId string + dbSessionId string + err error + ) + + e, req, rec = test.HTTPMocks("GET", "/login/lightning", nil) + c = e.NewContext(req, rec) + c.SetParamNames("method") + c.SetParamValues("lightning") + + err = handler.HandleAuth(sc, "login")(c) assert.NoErrorf(err, "handler returned error") // Set-Cookie header present cookies = rec.Result().Cookies() assert.Equalf(len(cookies), 1, "wrong number of Set-Cookie headers") - assert.Equalf(cookies[0].Name, "session", "wrong cookie name") + assert.Equalf("session", cookies[0].Name, "wrong cookie name") - // new challenge inserted + // new challenge inserted which matches cookie value sessionId = cookies[0].Value err = db.QueryRow("SELECT session_id FROM lnauth WHERE session_id = $1", sessionId).Scan(&dbSessionId) - if !assert.NoError(err) { - return - } - - // inserted challenge matches cookie value + assert.NoError(err) assert.Equalf(sessionId, dbSessionId, "wrong session id") } -func TestLnAuthCallback(t *testing.T) { - t.Skip() +func TestLnAuthLoginCallbackUserNotExists(t *testing.T) { + var ( + assert = assert.New(t) + e *echo.Echo + c echo.Context + sc context.Context + req *http.Request + rec *httptest.ResponseRecorder + lnAuth *auth.LnAuth + sk *secp256k1.PrivateKey + pk *secp256k1.PublicKey + sig string + key string + sessionId string + err error + ) + + lnAuth, err = auth.NewLnAuth("login") + assert.NoErrorf(err, "error creating challenge") + + err = db.QueryRow( + "INSERT INTO lnauth(k1, lnurl) VALUES($1, $2) RETURNING session_id", + lnAuth.K1, lnAuth.LNURL).Scan(&sessionId) + assert.NoErrorf(err, "error creating challenge") + + sk, pk, err = test.GenerateKeyPair() + assert.NoErrorf(err, "error generating keypair") + + sig, err = test.Sign(sk, lnAuth.K1) + assert.NoErrorf(err, "error signing k1") + + key = hex.EncodeToString(pk.SerializeCompressed()) + + sc = context.Context{Db: db} + e, req, rec = test.HTTPMocks("GET", + fmt.Sprintf("/api/login?tag=login&k1=%s&key=%s&sig=%s&action=%s", lnAuth.K1, key, sig, "login"), nil) + c = e.NewContext(req, rec) + + // must throw error because user does not exist + err = handler.HandleLnAuthCallback(sc)(c) + assert.ErrorContains(err, "user not found", "user check failed") +} + +func TestLnAuthLoginCallbackUserExists(t *testing.T) { + var ( + assert = assert.New(t) + e *echo.Echo + c echo.Context + sc context.Context + req *http.Request + rec *httptest.ResponseRecorder + lnAuth *auth.LnAuth + sk *secp256k1.PrivateKey + pk *secp256k1.PublicKey + sig string + key string + sessionId string + userId int + count int + err error + ) + + lnAuth, err = auth.NewLnAuth("login") + assert.NoErrorf(err, "error creating challenge") + + err = db.QueryRow( + "INSERT INTO lnauth(k1, lnurl) VALUES($1, $2) RETURNING session_id", + lnAuth.K1, lnAuth.LNURL).Scan(&sessionId) + assert.NoErrorf(err, "error creating challenge") + + sk, pk, err = test.GenerateKeyPair() + assert.NoErrorf(err, "error generating keypair") + + sig, err = test.Sign(sk, lnAuth.K1) + assert.NoErrorf(err, "error signing k1") + + key = hex.EncodeToString(pk.SerializeCompressed()) + + // create user such that login does not fail + err = db.QueryRow("INSERT INTO users(ln_pubkey) VALUES($1) RETURNING id", key).Scan(&userId) + assert.NoErrorf(err, "error creating user") + + sc = context.Context{Db: db} + e, req, rec = test.HTTPMocks("GET", + fmt.Sprintf("/api/login?tag=login&k1=%s&key=%s&sig=%s&action=%s", lnAuth.K1, key, sig, "login"), nil) + c = e.NewContext(req, rec) + + err = handler.HandleLnAuthCallback(sc)(c) + assert.NoErrorf(err, "handler returned error") + + // session created + err = db.QueryRow("SELECT COUNT(1) FROM sessions WHERE id = $1 AND user_id = $2", sessionId, userId).Scan(&count) + assert.NoErrorf(err, "error fetching session") + assert.Equalf(1, count, "invalid session count") + + // challenge deleted + err = db.QueryRow("SELECT COUNT(1) FROM lnauth WHERE k1 = $1", lnAuth.K1).Scan(&count) + assert.NoErrorf(err, "error fetching challenge") + assert.Equalf(count, 0, "challenge not deleted") } diff --git a/server/router/handler/login.go b/server/router/handler/login.go deleted file mode 100644 index d715af0..0000000 --- a/server/router/handler/login.go +++ /dev/null @@ -1,13 +0,0 @@ -package handler - -import ( - "git.ekzyis.com/ekzyis/delphi.market/server/router/context" - "git.ekzyis.com/ekzyis/delphi.market/server/router/pages" - "github.com/labstack/echo/v4" -) - -func HandleLogin(sc context.Context) echo.HandlerFunc { - return func(c echo.Context) error { - return pages.Login().Render(context.RenderContext(sc, c), c.Response().Writer) - } -} diff --git a/server/router/pages/auth.templ b/server/router/pages/auth.templ new file mode 100644 index 0000000..f9ac357 --- /dev/null +++ b/server/router/pages/auth.templ @@ -0,0 +1,66 @@ +package pages + +import ( + "fmt" + "git.ekzyis.com/ekzyis/delphi.market/server/router/pages/components" +) + +templ Auth(action string) { + + @components.Head() + + @components.Nav() +
+ @components.Figlet("random", action) +
+ + +
+
+ if action == "signup" { + + not your first time? + + } else { + first time? + } +
+
+ @components.Footer() + + +} diff --git a/server/router/pages/components/nav.templ b/server/router/pages/components/nav.templ index 207f83d..78248ff 100644 --- a/server/router/pages/components/nav.templ +++ b/server/router/pages/components/nav.templ @@ -13,7 +13,7 @@ templ Nav() { if ctx.Value(c.SessionContextKey) != nil { } else { - + } diff --git a/server/router/pages/lnAuth.templ b/server/router/pages/lnAuth.templ new file mode 100644 index 0000000..08df7a8 --- /dev/null +++ b/server/router/pages/lnAuth.templ @@ -0,0 +1,33 @@ +package pages + +import "git.ekzyis.com/ekzyis/delphi.market/server/router/pages/components" + +templ LnAuth(qr string, lnurl string, action string) { + + @components.Head() + + @components.Nav() +
+ @components.Figlet("random", action) + with lightning +
+ + + + { lnurl } +
+
+
+ @components.Footer() + + +} diff --git a/server/router/pages/login.templ b/server/router/pages/login.templ index a7c641e..ed0de6d 100644 --- a/server/router/pages/login.templ +++ b/server/router/pages/login.templ @@ -9,12 +9,44 @@ templ Login() { @components.Nav()
@components.Figlet("random", "login") -
- - +
+ +
@components.Footer() diff --git a/server/router/router.go b/server/router/router.go index f80e061..e524325 100644 --- a/server/router/router.go +++ b/server/router/router.go @@ -14,5 +14,11 @@ func Init(e *echo.Echo, sc Context) { e.GET("/", handler.HandleIndex(sc)) e.GET("/about", handler.HandleAbout(sc)) - e.GET("/login", handler.HandleLogin(sc)) + + e.GET("/login", handler.HandleAuth(sc, "login")) + e.GET("/login/:method", handler.HandleAuth(sc, "login")) + e.GET("/signup", handler.HandleAuth(sc, "register")) + e.GET("/signup/:method", handler.HandleAuth(sc, "register")) + e.GET("/api/lnauth/callback", handler.HandleLnAuthCallback(sc)) + e.GET("/session", handler.HandleSessionCheck(sc)) }