Call tx.Rollback() in bail

This commit is contained in:
ekzyis 2024-07-14 12:08:13 +02:00
parent 23ab67e8fc
commit 738d511f01
1 changed files with 4 additions and 10 deletions

View File

@ -78,6 +78,10 @@ func HandleLnAuthCallback(sc context.Context) echo.HandlerFunc {
) )
bail := func(code int, reason string) error { bail := func(code int, reason string) error {
if tx != nil {
// manual rollback is only required for tests afaik
tx.Rollback()
}
return c.JSON(code, map[string]string{"status": "ERROR", "reason": reason}) return c.JSON(code, map[string]string{"status": "ERROR", "reason": reason})
} }
@ -93,19 +97,15 @@ func HandleLnAuthCallback(sc context.Context) echo.HandlerFunc {
err = tx.QueryRow("SELECT session_id FROM lnauth WHERE k1 = $1 LIMIT 1", query.K1).Scan(&sessionId) err = tx.QueryRow("SELECT session_id FROM lnauth WHERE k1 = $1 LIMIT 1", query.K1).Scan(&sessionId)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
tx.Rollback()
return bail(http.StatusNotFound, "session not found") return bail(http.StatusNotFound, "session not found")
} else if err != nil { } else if err != nil {
tx.Rollback()
return bail(http.StatusInternalServerError, err.Error()) return bail(http.StatusInternalServerError, err.Error())
} }
ok, err = auth.VerifyLNAuth(&query) ok, err = auth.VerifyLNAuth(&query)
if err != nil { if err != nil {
tx.Rollback()
return bail(http.StatusInternalServerError, err.Error()) return bail(http.StatusInternalServerError, err.Error())
} else if !ok { } else if !ok {
tx.Rollback()
return bail(http.StatusBadRequest, "bad signature") return bail(http.StatusBadRequest, "bad signature")
} }
@ -115,7 +115,6 @@ func HandleLnAuthCallback(sc context.Context) echo.HandlerFunc {
"ON CONFLICT(ln_pubkey) DO UPDATE SET ln_pubkey = $1 "+ "ON CONFLICT(ln_pubkey) DO UPDATE SET ln_pubkey = $1 "+
"RETURNING id", query.Key).Scan(&userId) "RETURNING id", query.Key).Scan(&userId)
if err != nil { if err != nil {
tx.Rollback()
pqErr, ok = err.(*pq.Error) pqErr, ok = err.(*pq.Error)
if ok && pqErr.Code == "23505" { if ok && pqErr.Code == "23505" {
return bail(http.StatusBadRequest, "user already exists") return bail(http.StatusBadRequest, "user already exists")
@ -125,10 +124,8 @@ func HandleLnAuthCallback(sc context.Context) echo.HandlerFunc {
} else if query.Action == "login" { } else if query.Action == "login" {
err = tx.QueryRow("SELECT id FROM users WHERE ln_pubkey = $1", query.Key).Scan(&userId) err = tx.QueryRow("SELECT id FROM users WHERE ln_pubkey = $1", query.Key).Scan(&userId)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
tx.Rollback()
return bail(http.StatusNotFound, "user not found") return bail(http.StatusNotFound, "user not found")
} else if err != nil { } else if err != nil {
tx.Rollback()
return bail(http.StatusInternalServerError, err.Error()) return bail(http.StatusInternalServerError, err.Error())
} }
} else { } else {
@ -136,17 +133,14 @@ func HandleLnAuthCallback(sc context.Context) echo.HandlerFunc {
} }
if _, err = tx.Exec("INSERT INTO sessions(id, user_id) VALUES($1, $2)", sessionId, userId); err != nil { if _, err = tx.Exec("INSERT INTO sessions(id, user_id) VALUES($1, $2)", sessionId, userId); err != nil {
tx.Rollback()
return bail(http.StatusInternalServerError, err.Error()) return bail(http.StatusInternalServerError, err.Error())
} }
if _, err = tx.Exec("DELETE FROM lnauth WHERE k1 = $1", query.K1); err != nil { if _, err = tx.Exec("DELETE FROM lnauth WHERE k1 = $1", query.K1); err != nil {
tx.Rollback()
return bail(http.StatusInternalServerError, err.Error()) return bail(http.StatusInternalServerError, err.Error())
} }
if err = tx.Commit(); err != nil { if err = tx.Commit(); err != nil {
tx.Rollback()
return bail(http.StatusInternalServerError, err.Error()) return bail(http.StatusInternalServerError, err.Error())
} }