From 738d511f014efa70a494dcc20d1b413c38894d9f Mon Sep 17 00:00:00 2001 From: ekzyis Date: Sun, 14 Jul 2024 12:08:13 +0200 Subject: [PATCH] Call tx.Rollback() in bail --- server/router/handler/auth.go | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/server/router/handler/auth.go b/server/router/handler/auth.go index 0af5646..b96676d 100644 --- a/server/router/handler/auth.go +++ b/server/router/handler/auth.go @@ -78,6 +78,10 @@ func HandleLnAuthCallback(sc context.Context) echo.HandlerFunc { ) bail := func(code int, reason string) error { + if tx != nil { + // manual rollback is only required for tests afaik + tx.Rollback() + } return c.JSON(code, map[string]string{"status": "ERROR", "reason": reason}) } @@ -93,19 +97,15 @@ func HandleLnAuthCallback(sc context.Context) echo.HandlerFunc { err = tx.QueryRow("SELECT session_id FROM lnauth WHERE k1 = $1 LIMIT 1", query.K1).Scan(&sessionId) if err == sql.ErrNoRows { - tx.Rollback() return bail(http.StatusNotFound, "session not found") } else if err != nil { - tx.Rollback() return bail(http.StatusInternalServerError, err.Error()) } ok, err = auth.VerifyLNAuth(&query) if err != nil { - tx.Rollback() return bail(http.StatusInternalServerError, err.Error()) } else if !ok { - tx.Rollback() return bail(http.StatusBadRequest, "bad signature") } @@ -115,7 +115,6 @@ func HandleLnAuthCallback(sc context.Context) echo.HandlerFunc { "ON CONFLICT(ln_pubkey) DO UPDATE SET ln_pubkey = $1 "+ "RETURNING id", query.Key).Scan(&userId) if err != nil { - tx.Rollback() pqErr, ok = err.(*pq.Error) if ok && pqErr.Code == "23505" { return bail(http.StatusBadRequest, "user already exists") @@ -125,10 +124,8 @@ func HandleLnAuthCallback(sc context.Context) echo.HandlerFunc { } else if query.Action == "login" { err = tx.QueryRow("SELECT id FROM users WHERE ln_pubkey = $1", query.Key).Scan(&userId) if err == sql.ErrNoRows { - tx.Rollback() return bail(http.StatusNotFound, "user not found") } else if err != nil { - tx.Rollback() return bail(http.StatusInternalServerError, err.Error()) } } else { @@ -136,17 +133,14 @@ func HandleLnAuthCallback(sc context.Context) echo.HandlerFunc { } if _, err = tx.Exec("INSERT INTO sessions(id, user_id) VALUES($1, $2)", sessionId, userId); err != nil { - tx.Rollback() return bail(http.StatusInternalServerError, err.Error()) } if _, err = tx.Exec("DELETE FROM lnauth WHERE k1 = $1", query.K1); err != nil { - tx.Rollback() return bail(http.StatusInternalServerError, err.Error()) } if err = tx.Commit(); err != nil { - tx.Rollback() return bail(http.StatusInternalServerError, err.Error()) }