magicwallet/server/router/middleware/session.go

125 lines
3.2 KiB
Go

package middleware
import (
"encoding/hex"
"fmt"
"net/http"
"time"
"github.com/ekzyis/magicwallet/db/models"
"github.com/ekzyis/magicwallet/nostr/nwc"
"github.com/ekzyis/magicwallet/server/router/context"
"github.com/labstack/echo/v4"
)
func Session(hc context.Context) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(ec echo.Context) error {
cookie, err := ec.Cookie("session")
if err != nil {
err = newSession(hc, ec)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err)
}
return next(ec)
}
err = findSession(hc, ec, cookie.Value)
if err != nil {
return echo.NewHTTPError(http.StatusNotFound, err)
}
return next(ec)
}
}
}
func newSession(hc context.Context, ec echo.Context) error {
var (
db = hc.Db
ctx = ec.Request().Context()
session string
u models.User
w models.Wallet
err error
)
err = db.QueryRowContext(ctx, ""+
"INSERT INTO users DEFAULT VALUES "+
"RETURNING id, name, created_at, COALESCE(ln_pubkey, ''), COALESCE(nostr_pubkey, '')").
Scan(&u.Id, &u.Name, &u.CreatedAt, &u.LnPubkey, &u.NostrPubkey)
if err != nil {
return fmt.Errorf("failed to insert new user: %v", err)
}
err = db.QueryRowContext(ctx, ""+
"INSERT INTO sessions(user_id) VALUES($1) RETURNING id",
u.Id).Scan(&session)
if err != nil {
return fmt.Errorf("failed to insert new session: %v", err)
}
nwc, err := nwc.NewConnection()
if err != nil {
return fmt.Errorf("failed to create nwc connection uri: %v", err)
}
err = db.QueryRowContext(ctx, ""+
"INSERT INTO wallets(wallet_pubkey, secret, user_id) "+
"VALUES($1, $2, $3) "+
"RETURNING id, wallet_pubkey, secret, msats, user_id",
hex.EncodeToString(nwc.WalletPubkey.Serialize()),
hex.EncodeToString(nwc.Secret.Serialize()),
u.Id).
Scan(&w.Id, &w.WalletPubkey, &w.Secret, &w.Msats, &w.UserId)
if err != nil {
return fmt.Errorf("failed to insert new wallet: %v", err)
}
ec.SetCookie(&http.Cookie{
Name: "session",
HttpOnly: true,
Path: "/",
Value: session,
Secure: true,
// TODO: refresh session
// if session expires and user did not register, they will lose access to their wallet
Expires: time.Now().Add(60 * 60 * 24 * 30 * time.Second), // 30d
})
ec.Set("user", u)
ec.Set("wallet", w)
return nil
}
func findSession(hc context.Context, ec echo.Context, sid string) error {
var (
db = hc.Db
ctx = ec.Request().Context()
u models.User
w models.Wallet
err error
)
err = db.QueryRowContext(ctx, ""+
"SELECT u.id, u.name, u.created_at, COALESCE(u.ln_pubkey, ''), COALESCE(u.nostr_pubkey, '') "+
"FROM users u "+
"JOIN sessions s ON s.user_id = u.id "+
"WHERE s.id = $1", sid).
Scan(&u.Id, &u.Name, &u.CreatedAt, &u.LnPubkey, &u.NostrPubkey)
if err != nil {
return fmt.Errorf("session not found: %v", err)
}
err = db.QueryRowContext(ctx, ""+
"SELECT w.id, w.wallet_pubkey, w.secret, w.msats, w.user_id "+
"FROM wallets w "+
"JOIN users u ON u.id = w.user_id "+
"WHERE u.id = $1 "+
"LIMIT 1", u.Id).Scan(&w.Id, &w.WalletPubkey, &w.Secret, &w.Msats, &w.UserId)
ec.Set("user", u)
ec.Set("wallet", w)
return nil
}