Use HODL invoices

This commit is contained in:
ekzyis 2023-09-09 22:52:51 +02:00
parent 2d2ceec140
commit 057690d5a9
4 changed files with 122 additions and 41 deletions

View File

@ -129,9 +129,22 @@ func (db *DB) CreateInvoice(invoice *Invoice) error {
return nil return nil
} }
func (db *DB) FetchInvoice(invoiceId string, invoice *Invoice) error { type FetchInvoiceWhere struct {
if err := db.QueryRow(""+ Id string
"SELECT id, pubkey, msats, preimage, hash, bolt11, created_at, expires_at, confirmed_at, held_since FROM invoices WHERE id = $1", invoiceId).Scan(&invoice.Id, &invoice.Pubkey, &invoice.Msats, &invoice.Preimage, &invoice.PaymentHash, &invoice.PaymentRequest, &invoice.CreatedAt, &invoice.ExpiresAt, &invoice.ConfirmedAt, &invoice.HeldSince); err != nil { Hash string
}
func (db *DB) FetchInvoice(where *FetchInvoiceWhere, invoice *Invoice) error {
query := "SELECT id, pubkey, msats, preimage, hash, bolt11, created_at, expires_at, confirmed_at, held_since FROM invoices "
var args []any
if where.Id != "" {
query += "WHERE id = $1"
args = append(args, where.Id)
} else if where.Hash != "" {
query += "WHERE hash = $1"
args = append(args, where.Hash)
}
if err := db.QueryRow(query, args...).Scan(&invoice.Id, &invoice.Pubkey, &invoice.Msats, &invoice.Preimage, &invoice.PaymentHash, &invoice.PaymentRequest, &invoice.CreatedAt, &invoice.ExpiresAt, &invoice.ConfirmedAt, &invoice.HeldSince); err != nil {
return err return err
} }
return nil return nil

View File

@ -2,17 +2,19 @@ package main
import ( import (
"context" "context"
"crypto/rand"
"database/sql" "database/sql"
"encoding/hex" "io"
"log" "log"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/joho/godotenv" "github.com/joho/godotenv"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/lightninglabs/lndclient" "github.com/lightninglabs/lndclient"
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/namsral/flag" "github.com/namsral/flag"
) )
@ -25,7 +27,7 @@ var (
) )
type LndClient struct { type LndClient struct {
lnrpc.LightningClient lndclient.GrpcLndServices
} }
func init() { func init() {
@ -38,20 +40,21 @@ func init() {
flag.StringVar(&LndHost, "LND_HOST", "localhost:10001", "LND gRPC server address") flag.StringVar(&LndHost, "LND_HOST", "localhost:10001", "LND gRPC server address")
flag.Parse() flag.Parse()
lndEnabled = false lndEnabled = false
rpcClient, err := lndclient.NewBasicClient(LndHost, LndCert, LndMacaroonDir, "regtest") rpcLndServices, err := lndclient.NewLndServices(&lndclient.LndServicesConfig{
LndAddress: LndHost,
MacaroonDir: LndMacaroonDir,
TLSPath: LndCert,
Network: lndclient.NetworkRegtest,
})
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return return
} }
lnd = &LndClient{LightningClient: rpcClient} lnd = &LndClient{GrpcLndServices: *rpcLndServices}
if info, err := lnd.GetInfo(context.TODO(), &lnrpc.GetInfoRequest{}); err != nil { ver := lnd.Version
log.Printf("LND connection error: %v\n", err) log.Printf("Connected to %s running LND v%s", LndHost, ver.Version)
return
} else {
version := strings.Split(info.Version, " ")[0]
log.Printf("Connected to %s running LND v%s", LndHost, version)
lndEnabled = true lndEnabled = true
}
} }
func lndGuard(next echo.HandlerFunc) echo.HandlerFunc { func lndGuard(next echo.HandlerFunc) echo.HandlerFunc {
@ -63,26 +66,46 @@ func lndGuard(next echo.HandlerFunc) echo.HandlerFunc {
} }
} }
func (lnd *LndClient) GenerateNewPreimage() (lntypes.Preimage, error) {
randomBytes := make([]byte, 32)
_, err := io.ReadFull(rand.Reader, randomBytes)
if err != nil {
return lntypes.Preimage{}, err
}
preimage, err := lntypes.MakePreimage(randomBytes)
if err != nil {
return lntypes.Preimage{}, err
}
return preimage, nil
}
func (lnd *LndClient) CreateInvoice(pubkey string, msats int) (*Invoice, error) { func (lnd *LndClient) CreateInvoice(pubkey string, msats int) (*Invoice, error) {
addInvoiceResponse, err := lnd.AddInvoice(context.TODO(), &lnrpc.Invoice{ expiry := time.Hour
ValueMsat: int64(msats), preimage, err := lnd.GenerateNewPreimage()
Expiry: 3600, if err != nil {
return nil, err
}
hash := preimage.Hash()
paymentRequest, err := lnd.Invoices.AddHoldInvoice(context.TODO(), &invoicesrpc.AddInvoiceData{
Hash: &hash,
Value: lnwire.MilliSatoshi(msats),
Expiry: int64(expiry),
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
lnInvoice, err := lnd.LookupInvoice(context.TODO(), &lnrpc.PaymentHash{RHash: addInvoiceResponse.RHash}) lnInvoice, err := lnd.Client.LookupInvoice(context.TODO(), hash)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dbInvoice := Invoice{ dbInvoice := Invoice{
Session: Session{pubkey}, Session: Session{pubkey},
Msats: msats, Msats: msats,
Preimage: hex.EncodeToString(lnInvoice.RPreimage), Preimage: preimage.String(),
PaymentRequest: lnInvoice.PaymentRequest, PaymentRequest: paymentRequest,
PaymentHash: hex.EncodeToString(lnInvoice.RHash), PaymentHash: hash.String(),
CreatedAt: time.Unix(lnInvoice.CreationDate, 0), CreatedAt: lnInvoice.CreationDate,
ExpiresAt: time.Unix(lnInvoice.CreationDate+lnInvoice.Expiry, 0), ExpiresAt: lnInvoice.CreationDate.Add(expiry),
} }
if err := db.CreateInvoice(&dbInvoice); err != nil { if err := db.CreateInvoice(&dbInvoice); err != nil {
return nil, err return nil, err
@ -90,40 +113,65 @@ func (lnd *LndClient) CreateInvoice(pubkey string, msats int) (*Invoice, error)
return &dbInvoice, nil return &dbInvoice, nil
} }
func (lnd *LndClient) CheckInvoice(hash string) { func (lnd *LndClient) CheckInvoice(hash lntypes.Hash) {
if !lndEnabled { if !lndEnabled {
log.Printf("LND disabled, skipping checking invoice: hash=%s", hash) log.Printf("LND disabled, skipping checking invoice: hash=%s", hash)
return return
} }
var invoice Invoice
if err := db.FetchInvoice(&FetchInvoiceWhere{Hash: hash.String()}, &invoice); err != nil {
log.Println(err)
return
}
loopPause := 5 * time.Second
handleLoopError := func(err error) {
log.Println(err)
time.Sleep(loopPause)
}
for { for {
log.Printf("lookup invoice: hash=%s", hash) log.Printf("lookup invoice: hash=%s", hash)
invoice, err := lnd.LookupInvoice(context.TODO(), &lnrpc.PaymentHash{RHashStr: hash}) lnInvoice, err := lnd.Client.LookupInvoice(context.TODO(), hash)
if err != nil { if err != nil {
log.Println(err) handleLoopError(err)
time.Sleep(5 * time.Second) continue
}
if time.Now().After(invoice.ExpiresAt) {
if err := lnd.Invoices.CancelInvoice(context.TODO(), hash); err != nil {
handleLoopError(err)
continue continue
} }
if time.Now().After(time.Unix(invoice.CreationDate+invoice.Expiry, 0)) {
log.Printf("invoice expired: hash=%s", hash) log.Printf("invoice expired: hash=%s", hash)
break break
} }
if invoice.SettleDate != 0 && invoice.AmtPaidMsat > 0 { if lnInvoice.AmountPaid > 0 {
if err := db.ConfirmInvoice(hash, time.Unix(invoice.SettleDate, 0), int(invoice.AmtPaidMsat)); err != nil { preimage, err := lntypes.MakePreimageFromStr(invoice.Preimage)
log.Println(err) if err != nil {
time.Sleep(5 * time.Second) handleLoopError(err)
continue
}
// TODO settle invoice after matching order was found
if err := lnd.Invoices.SettleInvoice(context.TODO(), preimage); err != nil {
handleLoopError(err)
continue
}
if err := db.ConfirmInvoice(hash.String(), time.Now(), int(lnInvoice.AmountPaid)); err != nil {
handleLoopError(err)
continue continue
} }
log.Printf("invoice confirmed: hash=%s", hash) log.Printf("invoice confirmed: hash=%s", hash)
break break
} }
time.Sleep(5 * time.Second) time.Sleep(loopPause)
} }
} }
func invoice(c echo.Context) error { func invoice(c echo.Context) error {
invoiceId := c.Param("id") invoiceId := c.Param("id")
var invoice Invoice var invoice Invoice
if err := db.FetchInvoice(invoiceId, &invoice); err == sql.ErrNoRows { if err := db.FetchInvoice(&FetchInvoiceWhere{Id: invoiceId}, &invoice); err == sql.ErrNoRows {
return echo.NewHTTPError(http.StatusNotFound, "Not Found") return echo.NewHTTPError(http.StatusNotFound, "Not Found")
} else if err != nil { } else if err != nil {
return err return err
@ -132,7 +180,11 @@ func invoice(c echo.Context) error {
if invoice.Pubkey != session.Pubkey { if invoice.Pubkey != session.Pubkey {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
go lnd.CheckInvoice(invoice.PaymentHash) hash, err := lntypes.MakeHashFromStr(invoice.PaymentHash)
if err != nil {
return err
}
go lnd.CheckInvoice(hash)
qr, err := ToQR(invoice.PaymentRequest) qr, err := ToQR(invoice.PaymentRequest)
if err != nil { if err != nil {
return err return err
@ -156,7 +208,7 @@ func invoice(c echo.Context) error {
func invoiceStatus(c echo.Context) error { func invoiceStatus(c echo.Context) error {
invoiceId := c.Param("id") invoiceId := c.Param("id")
var invoice Invoice var invoice Invoice
if err := db.FetchInvoice(invoiceId, &invoice); err == sql.ErrNoRows { if err := db.FetchInvoice(&FetchInvoiceWhere{Id: invoiceId}, &invoice); err == sql.ErrNoRows {
return echo.NewHTTPError(http.StatusNotFound, "Not Found") return echo.NewHTTPError(http.StatusNotFound, "Not Found")
} else if err != nil { } else if err != nil {
return err return err

View File

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/lightningnetwork/lnd/lntypes"
"gopkg.in/guregu/null.v4" "gopkg.in/guregu/null.v4"
) )
@ -90,7 +91,11 @@ func order(c echo.Context) error {
if err != nil { if err != nil {
return err return err
} }
go lnd.CheckInvoice(invoice.PaymentHash) hash, err := lntypes.MakeHashFromStr(invoice.PaymentHash)
if err != nil {
return err
}
go lnd.CheckInvoice(hash)
data := map[string]any{ data := map[string]any{
"session": c.Get("session"), "session": c.Get("session"),
"ENV": ENV, "ENV": ENV,

View File

@ -1,12 +1,23 @@
package main package main
import (
"log"
"github.com/lightningnetwork/lnd/lntypes"
)
func RunJobs() error { func RunJobs() error {
var invoices []Invoice var invoices []Invoice
if err := db.FetchInvoices(&FetchInvoicesWhere{Expired: false}, &invoices); err != nil { if err := db.FetchInvoices(&FetchInvoicesWhere{Expired: false}, &invoices); err != nil {
return err return err
} }
for _, inv := range invoices { for _, inv := range invoices {
go lnd.CheckInvoice(inv.PaymentHash) hash, err := lntypes.MakeHashFromStr(inv.PaymentHash)
if err != nil {
log.Println(err)
continue
}
go lnd.CheckInvoice(hash)
} }
return nil return nil
} }