Compare commits

...

6 Commits

Author SHA1 Message Date
3ac346df4f Fix route 2024-07-12 20:35:47 +02:00
36ac326767 Check status code instead of returned error 2024-07-12 20:32:50 +02:00
4e13271dbd Don't throw if user already exists on signup 2024-07-12 20:30:34 +02:00
0f3866866a Don't use assert format version unnecessarily 2024-07-12 20:21:12 +02:00
df1886632b Fix tests 2024-07-12 20:11:03 +02:00
9a92444eee Add session middleware 2024-07-12 15:30:50 +02:00
5 changed files with 102 additions and 70 deletions

View File

@ -111,7 +111,10 @@ func HandleLnAuthCallback(sc context.Context) echo.HandlerFunc {
} }
if query.Action == "register" { if query.Action == "register" {
err = tx.QueryRow("INSERT INTO users(ln_pubkey) VALUES ($1) RETURNING id", query.Key).Scan(&userId) err = tx.QueryRow(""+
"INSERT INTO users(ln_pubkey) VALUES ($1) "+
"ON CONFLICT(ln_pubkey) DO UPDATE SET ln_pubkey = $1 "+
"RETURNING id", query.Key).Scan(&userId)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
pqErr, ok = err.(*pq.Error) pqErr, ok = err.(*pq.Error)

View File

@ -39,19 +39,19 @@ func TestLnAuthSignup(t *testing.T) {
c.SetParamNames("method") c.SetParamNames("method")
c.SetParamValues("lightning") c.SetParamValues("lightning")
err = handler.HandleAuth(sc, "register")(c) handler.HandleAuth(sc, "register")(c)
assert.NoErrorf(err, "handler returned error") assert.Equal(http.StatusOK, rec.Code, "wrong status code")
// Set-Cookie header present // Set-Cookie header present
cookies = rec.Result().Cookies() cookies = rec.Result().Cookies()
assert.Equalf(1, len(cookies), "wrong number of Set-Cookie headers") assert.Equal(1, len(cookies), "wrong number of Set-Cookie headers")
assert.Equalf("session", cookies[0].Name, "wrong cookie name") assert.Equal("session", cookies[0].Name, "wrong cookie name")
// new challenge inserted which matches cookie value // new challenge inserted which matches cookie value
sessionId = cookies[0].Value sessionId = cookies[0].Value
err = db.QueryRow("SELECT session_id FROM lnauth WHERE session_id = $1", sessionId).Scan(&dbSessionId) err = db.QueryRow("SELECT session_id FROM lnauth WHERE session_id = $1", sessionId).Scan(&dbSessionId)
assert.NoError(err) assert.NoError(err)
assert.Equalf(sessionId, dbSessionId, "wrong session id") assert.Equal(sessionId, dbSessionId, "wrong session id")
} }
func TestLnAuthSignupCallbackUserNotExists(t *testing.T) { func TestLnAuthSignupCallbackUserNotExists(t *testing.T) {
@ -79,37 +79,39 @@ func TestLnAuthSignupCallbackUserNotExists(t *testing.T) {
err = db.QueryRow( err = db.QueryRow(
"INSERT INTO lnauth(k1, lnurl) VALUES($1, $2) RETURNING session_id", "INSERT INTO lnauth(k1, lnurl) VALUES($1, $2) RETURNING session_id",
lnAuth.K1, lnAuth.LNURL).Scan(&sessionId) lnAuth.K1, lnAuth.LNURL).Scan(&sessionId)
assert.NoErrorf(err, "error creating challenge") assert.NoError(err, "error creating challenge")
sk, pk, err = test.GenerateKeyPair() sk, pk, err = test.GenerateKeyPair()
assert.NoErrorf(err, "error generating keypair") assert.NoError(err, "error generating keypair")
sig, err = test.Sign(sk, lnAuth.K1) sig, err = test.Sign(sk, lnAuth.K1)
assert.NoErrorf(err, "error signing k1") assert.NoError(err, "error signing k1")
key = hex.EncodeToString(pk.SerializeCompressed()) key = hex.EncodeToString(pk.SerializeCompressed())
sc = context.Context{Db: db} sc = context.Context{Db: db}
e, req, rec = test.HTTPMocks("GET", e, req, rec = test.HTTPMocks(
fmt.Sprintf("/api/login?tag=login&k1=%s&key=%s&sig=%s&action=%s", lnAuth.K1, key, sig, "register"), nil) "GET",
fmt.Sprintf("/api/lnauth?tag=login&k1=%s&key=%s&sig=%s&action=%s", lnAuth.K1, key, sig, "register"),
nil)
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
err = handler.HandleLnAuthCallback(sc)(c) handler.HandleLnAuthCallback(sc)(c)
assert.NoErrorf(err, "handler returned error") assert.Equal(http.StatusOK, rec.Code, "wrong status code")
// user created // user created
err = db.QueryRow("SELECT id FROM users WHERE ln_pubkey = $1", key).Scan(&userId) err = db.QueryRow("SELECT id FROM users WHERE ln_pubkey = $1", key).Scan(&userId)
assert.NoErrorf(err, "error fetching user") assert.NoError(err, "error fetching user")
// session created // session created
err = db.QueryRow("SELECT COUNT(1) FROM sessions WHERE id = $1 AND user_id = $2", sessionId, userId).Scan(&count) err = db.QueryRow("SELECT COUNT(1) FROM sessions WHERE id = $1 AND user_id = $2", sessionId, userId).Scan(&count)
assert.NoErrorf(err, "error fetching session") assert.NoError(err, "error fetching session")
assert.Equalf(1, count, "invalid session count") assert.Equal(1, count, "invalid session count")
// challenge deleted // challenge deleted
err = db.QueryRow("SELECT COUNT(1) FROM lnauth WHERE k1 = $1", lnAuth.K1).Scan(&count) err = db.QueryRow("SELECT COUNT(1) FROM lnauth WHERE k1 = $1", lnAuth.K1).Scan(&count)
assert.NoErrorf(err, "error fetching challenge") assert.NoError(err, "error fetching challenge")
assert.Equalf(count, 0, "challenge not deleted") assert.Equal(count, 0, "challenge not deleted")
} }
func TestLnAuthSignupCallbackUserExists(t *testing.T) { func TestLnAuthSignupCallbackUserExists(t *testing.T) {
@ -135,28 +137,30 @@ func TestLnAuthSignupCallbackUserExists(t *testing.T) {
err = db.QueryRow( err = db.QueryRow(
"INSERT INTO lnauth(k1, lnurl) VALUES($1, $2) RETURNING session_id", "INSERT INTO lnauth(k1, lnurl) VALUES($1, $2) RETURNING session_id",
lnAuth.K1, lnAuth.LNURL).Scan(&sessionId) lnAuth.K1, lnAuth.LNURL).Scan(&sessionId)
assert.NoErrorf(err, "error creating challenge") assert.NoError(err, "error creating challenge")
sk, pk, err = test.GenerateKeyPair() sk, pk, err = test.GenerateKeyPair()
assert.NoErrorf(err, "error generating keypair") assert.NoError(err, "error generating keypair")
sig, err = test.Sign(sk, lnAuth.K1) sig, err = test.Sign(sk, lnAuth.K1)
assert.NoErrorf(err, "error signing k1") assert.NoError(err, "error signing k1")
key = hex.EncodeToString(pk.SerializeCompressed()) key = hex.EncodeToString(pk.SerializeCompressed())
// create user such that signup must fail // create user before signup
_, err = db.Exec("INSERT INTO users(ln_pubkey) VALUES($1) RETURNING id", key) _, err = db.Exec("INSERT INTO users(ln_pubkey) VALUES($1) RETURNING id", key)
assert.NoErrorf(err, "error creating user") assert.NoError(err, "error creating user")
sc = context.Context{Db: db} sc = context.Context{Db: db}
e, req, rec = test.HTTPMocks("GET", e, req, rec = test.HTTPMocks(
fmt.Sprintf("/api/login?tag=login&k1=%s&key=%s&sig=%s&action=%s", lnAuth.K1, key, sig, "register"), nil) "GET",
fmt.Sprintf("/api/lnauth?tag=login&k1=%s&key=%s&sig=%s&action=%s", lnAuth.K1, key, sig, "register"),
nil)
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
// must throw error because user already exists // does not throw an error for UX reasons
err = handler.HandleLnAuthCallback(sc)(c) handler.HandleLnAuthCallback(sc)(c)
assert.ErrorContains(err, "user already exists", "user check failed") assert.Equal(http.StatusOK, rec.Code, "wrong status code")
} }
func TestLnAuthLogin(t *testing.T) { func TestLnAuthLogin(t *testing.T) {
@ -179,18 +183,18 @@ func TestLnAuthLogin(t *testing.T) {
c.SetParamValues("lightning") c.SetParamValues("lightning")
err = handler.HandleAuth(sc, "login")(c) err = handler.HandleAuth(sc, "login")(c)
assert.NoErrorf(err, "handler returned error") assert.NoError(err, "handler returned error")
// Set-Cookie header present // Set-Cookie header present
cookies = rec.Result().Cookies() cookies = rec.Result().Cookies()
assert.Equalf(len(cookies), 1, "wrong number of Set-Cookie headers") assert.Equal(len(cookies), 1, "wrong number of Set-Cookie headers")
assert.Equalf("session", cookies[0].Name, "wrong cookie name") assert.Equal("session", cookies[0].Name, "wrong cookie name")
// new challenge inserted which matches cookie value // new challenge inserted which matches cookie value
sessionId = cookies[0].Value sessionId = cookies[0].Value
err = db.QueryRow("SELECT session_id FROM lnauth WHERE session_id = $1", sessionId).Scan(&dbSessionId) err = db.QueryRow("SELECT session_id FROM lnauth WHERE session_id = $1", sessionId).Scan(&dbSessionId)
assert.NoError(err) assert.NoError(err)
assert.Equalf(sessionId, dbSessionId, "wrong session id") assert.Equal(sessionId, dbSessionId, "wrong session id")
} }
func TestLnAuthLoginCallbackUserNotExists(t *testing.T) { func TestLnAuthLoginCallbackUserNotExists(t *testing.T) {
@ -216,24 +220,27 @@ func TestLnAuthLoginCallbackUserNotExists(t *testing.T) {
err = db.QueryRow( err = db.QueryRow(
"INSERT INTO lnauth(k1, lnurl) VALUES($1, $2) RETURNING session_id", "INSERT INTO lnauth(k1, lnurl) VALUES($1, $2) RETURNING session_id",
lnAuth.K1, lnAuth.LNURL).Scan(&sessionId) lnAuth.K1, lnAuth.LNURL).Scan(&sessionId)
assert.NoErrorf(err, "error creating challenge") assert.NoError(err, "error creating challenge")
sk, pk, err = test.GenerateKeyPair() sk, pk, err = test.GenerateKeyPair()
assert.NoErrorf(err, "error generating keypair") assert.NoError(err, "error generating keypair")
sig, err = test.Sign(sk, lnAuth.K1) sig, err = test.Sign(sk, lnAuth.K1)
assert.NoErrorf(err, "error signing k1") assert.NoError(err, "error signing k1")
key = hex.EncodeToString(pk.SerializeCompressed()) key = hex.EncodeToString(pk.SerializeCompressed())
sc = context.Context{Db: db} sc = context.Context{Db: db}
e, req, rec = test.HTTPMocks("GET", e, req, rec = test.HTTPMocks(
fmt.Sprintf("/api/login?tag=login&k1=%s&key=%s&sig=%s&action=%s", lnAuth.K1, key, sig, "login"), nil) "GET",
fmt.Sprintf("/api/lnauth?tag=login&k1=%s&key=%s&sig=%s&action=%s", lnAuth.K1, key, sig, "login"),
nil)
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
// must throw error because user does not exist // must throw error because user does not exist
err = handler.HandleLnAuthCallback(sc)(c) handler.HandleLnAuthCallback(sc)(c)
assert.ErrorContains(err, "user not found", "user check failed") assert.Equal(http.StatusNotFound, rec.Code, "wrong status code")
assert.Contains(rec.Body.String(), "\"reason\":\"user not found\"", "user check failed")
} }
func TestLnAuthLoginCallbackUserExists(t *testing.T) { func TestLnAuthLoginCallbackUserExists(t *testing.T) {
@ -261,35 +268,37 @@ func TestLnAuthLoginCallbackUserExists(t *testing.T) {
err = db.QueryRow( err = db.QueryRow(
"INSERT INTO lnauth(k1, lnurl) VALUES($1, $2) RETURNING session_id", "INSERT INTO lnauth(k1, lnurl) VALUES($1, $2) RETURNING session_id",
lnAuth.K1, lnAuth.LNURL).Scan(&sessionId) lnAuth.K1, lnAuth.LNURL).Scan(&sessionId)
assert.NoErrorf(err, "error creating challenge") assert.NoError(err, "error creating challenge")
sk, pk, err = test.GenerateKeyPair() sk, pk, err = test.GenerateKeyPair()
assert.NoErrorf(err, "error generating keypair") assert.NoError(err, "error generating keypair")
sig, err = test.Sign(sk, lnAuth.K1) sig, err = test.Sign(sk, lnAuth.K1)
assert.NoErrorf(err, "error signing k1") assert.NoError(err, "error signing k1")
key = hex.EncodeToString(pk.SerializeCompressed()) key = hex.EncodeToString(pk.SerializeCompressed())
// create user such that login does not fail // create user such that login does not fail
err = db.QueryRow("INSERT INTO users(ln_pubkey) VALUES($1) RETURNING id", key).Scan(&userId) err = db.QueryRow("INSERT INTO users(ln_pubkey) VALUES($1) RETURNING id", key).Scan(&userId)
assert.NoErrorf(err, "error creating user") assert.NoError(err, "error creating user")
sc = context.Context{Db: db} sc = context.Context{Db: db}
e, req, rec = test.HTTPMocks("GET", e, req, rec = test.HTTPMocks(
fmt.Sprintf("/api/login?tag=login&k1=%s&key=%s&sig=%s&action=%s", lnAuth.K1, key, sig, "login"), nil) "GET",
fmt.Sprintf("/api/lnauth?tag=login&k1=%s&key=%s&sig=%s&action=%s", lnAuth.K1, key, sig, "login"),
nil)
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
err = handler.HandleLnAuthCallback(sc)(c) handler.HandleLnAuthCallback(sc)(c)
assert.NoErrorf(err, "handler returned error") assert.Equal(http.StatusOK, rec.Code, "wrong status code")
// session created // session created
err = db.QueryRow("SELECT COUNT(1) FROM sessions WHERE id = $1 AND user_id = $2", sessionId, userId).Scan(&count) err = db.QueryRow("SELECT COUNT(1) FROM sessions WHERE id = $1 AND user_id = $2", sessionId, userId).Scan(&count)
assert.NoErrorf(err, "error fetching session") assert.NoError(err, "error fetching session")
assert.Equalf(1, count, "invalid session count") assert.Equal(1, count, "invalid session count")
// challenge deleted // challenge deleted
err = db.QueryRow("SELECT COUNT(1) FROM lnauth WHERE k1 = $1", lnAuth.K1).Scan(&count) err = db.QueryRow("SELECT COUNT(1) FROM lnauth WHERE k1 = $1", lnAuth.K1).Scan(&count)
assert.NoErrorf(err, "error fetching challenge") assert.NoError(err, "error fetching challenge")
assert.Equalf(count, 0, "challenge not deleted") assert.Equal(count, 0, "challenge not deleted")
} }

View File

@ -1,33 +1,41 @@
package middleware package middleware
import ( import (
"database/sql"
"net/http" "net/http"
"git.ekzyis.com/ekzyis/delphi.market/server/router/context" "git.ekzyis.com/ekzyis/delphi.market/server/router/context"
"git.ekzyis.com/ekzyis/delphi.market/types"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
func Session(sc context.Context) echo.MiddlewareFunc { func Session(sc context.Context) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
// TODO: implement session middleware var (
// var ( db = sc.Db
// cookie *http.Cookie ctx = c.Request().Context()
// err error cookie *http.Cookie
// s *db.Session err error
// u *db.User u = types.User{}
// ) )
// if cookie, err = c.Cookie("session"); err != nil { if cookie, err = c.Cookie("session"); err != nil {
// // cookie not found // cookie not found
// return next(c) return next(c)
// } }
// s = &db.Session{SessionId: cookie.Value} if err = db.QueryRowContext(
// if err = sc.Db.FetchSession(s); err == nil { ctx,
// // session found ""+
// c.Set("session", *u) "SELECT u.id, u.created_at, COALESCE(u.ln_pubkey, ''), COALESCE(u.nostr_pubkey, ''), u.msats "+
// } else if err != sql.ErrNoRows { "FROM sessions s LEFT JOIN users u ON u.id = s.user_id "+
// return err "WHERE s.id = $1",
// } cookie.Value).
Scan(&u.Id, &u.CreatedAt, &u.LnPubkey, &u.NostrPubkey, &u.Msats); err == nil {
// session found
c.Set("session", u)
} else if err != sql.ErrNoRows {
return err
}
return next(c) return next(c)
} }
} }

View File

@ -5,12 +5,13 @@ import (
"git.ekzyis.com/ekzyis/delphi.market/server/router/context" "git.ekzyis.com/ekzyis/delphi.market/server/router/context"
"git.ekzyis.com/ekzyis/delphi.market/server/router/handler" "git.ekzyis.com/ekzyis/delphi.market/server/router/handler"
"git.ekzyis.com/ekzyis/delphi.market/server/router/middleware"
) )
type Context = context.Context type Context = context.Context
func Init(e *echo.Echo, sc Context) { func Init(e *echo.Echo, sc Context) {
// e.Use(middleware.Session(sc)) e.Use(middleware.Session(sc))
e.GET("/", handler.HandleIndex(sc)) e.GET("/", handler.HandleIndex(sc))
e.GET("/about", handler.HandleAbout(sc)) e.GET("/about", handler.HandleAbout(sc))

11
types/types.go Normal file
View File

@ -0,0 +1,11 @@
package types
import "time"
type User struct {
Id int
CreatedAt time.Time
LnPubkey string
NostrPubkey string
Msats int64
}