From 057690d5a996d3d436b7413818192a0cda2f5ebc Mon Sep 17 00:00:00 2001 From: ekzyis Date: Sat, 9 Sep 2023 22:52:51 +0200 Subject: [PATCH] Use HODL invoices --- src/db.go | 19 ++++++-- src/lnd.go | 124 +++++++++++++++++++++++++++++++++++--------------- src/market.go | 7 ++- src/worker.go | 13 +++++- 4 files changed, 122 insertions(+), 41 deletions(-) diff --git a/src/db.go b/src/db.go index 4cf3228..b142dfb 100644 --- a/src/db.go +++ b/src/db.go @@ -129,9 +129,22 @@ func (db *DB) CreateInvoice(invoice *Invoice) error { return nil } -func (db *DB) FetchInvoice(invoiceId string, invoice *Invoice) error { - if err := db.QueryRow(""+ - "SELECT id, pubkey, msats, preimage, hash, bolt11, created_at, expires_at, confirmed_at, held_since FROM invoices WHERE id = $1", invoiceId).Scan(&invoice.Id, &invoice.Pubkey, &invoice.Msats, &invoice.Preimage, &invoice.PaymentHash, &invoice.PaymentRequest, &invoice.CreatedAt, &invoice.ExpiresAt, &invoice.ConfirmedAt, &invoice.HeldSince); err != nil { +type FetchInvoiceWhere struct { + Id string + Hash string +} + +func (db *DB) FetchInvoice(where *FetchInvoiceWhere, invoice *Invoice) error { + query := "SELECT id, pubkey, msats, preimage, hash, bolt11, created_at, expires_at, confirmed_at, held_since FROM invoices " + var args []any + if where.Id != "" { + query += "WHERE id = $1" + args = append(args, where.Id) + } else if where.Hash != "" { + query += "WHERE hash = $1" + args = append(args, where.Hash) + } + if err := db.QueryRow(query, args...).Scan(&invoice.Id, &invoice.Pubkey, &invoice.Msats, &invoice.Preimage, &invoice.PaymentHash, &invoice.PaymentRequest, &invoice.CreatedAt, &invoice.ExpiresAt, &invoice.ConfirmedAt, &invoice.HeldSince); err != nil { return err } return nil diff --git a/src/lnd.go b/src/lnd.go index d236066..d37aa02 100644 --- a/src/lnd.go +++ b/src/lnd.go @@ -2,17 +2,19 @@ package main import ( "context" + "crypto/rand" "database/sql" - "encoding/hex" + "io" "log" "net/http" - "strings" "time" "github.com/joho/godotenv" "github.com/labstack/echo/v4" "github.com/lightninglabs/lndclient" - "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" "github.com/namsral/flag" ) @@ -25,7 +27,7 @@ var ( ) type LndClient struct { - lnrpc.LightningClient + lndclient.GrpcLndServices } func init() { @@ -38,20 +40,21 @@ func init() { flag.StringVar(&LndHost, "LND_HOST", "localhost:10001", "LND gRPC server address") flag.Parse() lndEnabled = false - rpcClient, err := lndclient.NewBasicClient(LndHost, LndCert, LndMacaroonDir, "regtest") + rpcLndServices, err := lndclient.NewLndServices(&lndclient.LndServicesConfig{ + LndAddress: LndHost, + MacaroonDir: LndMacaroonDir, + TLSPath: LndCert, + Network: lndclient.NetworkRegtest, + }) if err != nil { log.Println(err) return } - lnd = &LndClient{LightningClient: rpcClient} - if info, err := lnd.GetInfo(context.TODO(), &lnrpc.GetInfoRequest{}); err != nil { - log.Printf("LND connection error: %v\n", err) - return - } else { - version := strings.Split(info.Version, " ")[0] - log.Printf("Connected to %s running LND v%s", LndHost, version) - lndEnabled = true - } + lnd = &LndClient{GrpcLndServices: *rpcLndServices} + ver := lnd.Version + log.Printf("Connected to %s running LND v%s", LndHost, ver.Version) + lndEnabled = true + } func lndGuard(next echo.HandlerFunc) echo.HandlerFunc { @@ -63,26 +66,46 @@ func lndGuard(next echo.HandlerFunc) echo.HandlerFunc { } } +func (lnd *LndClient) GenerateNewPreimage() (lntypes.Preimage, error) { + randomBytes := make([]byte, 32) + _, err := io.ReadFull(rand.Reader, randomBytes) + if err != nil { + return lntypes.Preimage{}, err + } + preimage, err := lntypes.MakePreimage(randomBytes) + if err != nil { + return lntypes.Preimage{}, err + } + return preimage, nil +} + func (lnd *LndClient) CreateInvoice(pubkey string, msats int) (*Invoice, error) { - addInvoiceResponse, err := lnd.AddInvoice(context.TODO(), &lnrpc.Invoice{ - ValueMsat: int64(msats), - Expiry: 3600, + expiry := time.Hour + preimage, err := lnd.GenerateNewPreimage() + if err != nil { + return nil, err + } + hash := preimage.Hash() + paymentRequest, err := lnd.Invoices.AddHoldInvoice(context.TODO(), &invoicesrpc.AddInvoiceData{ + Hash: &hash, + Value: lnwire.MilliSatoshi(msats), + Expiry: int64(expiry), }) if err != nil { return nil, err } - lnInvoice, err := lnd.LookupInvoice(context.TODO(), &lnrpc.PaymentHash{RHash: addInvoiceResponse.RHash}) + lnInvoice, err := lnd.Client.LookupInvoice(context.TODO(), hash) if err != nil { return nil, err } dbInvoice := Invoice{ Session: Session{pubkey}, Msats: msats, - Preimage: hex.EncodeToString(lnInvoice.RPreimage), - PaymentRequest: lnInvoice.PaymentRequest, - PaymentHash: hex.EncodeToString(lnInvoice.RHash), - CreatedAt: time.Unix(lnInvoice.CreationDate, 0), - ExpiresAt: time.Unix(lnInvoice.CreationDate+lnInvoice.Expiry, 0), + Preimage: preimage.String(), + PaymentRequest: paymentRequest, + PaymentHash: hash.String(), + CreatedAt: lnInvoice.CreationDate, + ExpiresAt: lnInvoice.CreationDate.Add(expiry), } if err := db.CreateInvoice(&dbInvoice); err != nil { return nil, err @@ -90,40 +113,65 @@ func (lnd *LndClient) CreateInvoice(pubkey string, msats int) (*Invoice, error) return &dbInvoice, nil } -func (lnd *LndClient) CheckInvoice(hash string) { +func (lnd *LndClient) CheckInvoice(hash lntypes.Hash) { if !lndEnabled { log.Printf("LND disabled, skipping checking invoice: hash=%s", hash) return } + + var invoice Invoice + if err := db.FetchInvoice(&FetchInvoiceWhere{Hash: hash.String()}, &invoice); err != nil { + log.Println(err) + return + } + + loopPause := 5 * time.Second + handleLoopError := func(err error) { + log.Println(err) + time.Sleep(loopPause) + } + for { log.Printf("lookup invoice: hash=%s", hash) - invoice, err := lnd.LookupInvoice(context.TODO(), &lnrpc.PaymentHash{RHashStr: hash}) + lnInvoice, err := lnd.Client.LookupInvoice(context.TODO(), hash) if err != nil { - log.Println(err) - time.Sleep(5 * time.Second) + handleLoopError(err) continue } - if time.Now().After(time.Unix(invoice.CreationDate+invoice.Expiry, 0)) { + if time.Now().After(invoice.ExpiresAt) { + if err := lnd.Invoices.CancelInvoice(context.TODO(), hash); err != nil { + handleLoopError(err) + continue + } log.Printf("invoice expired: hash=%s", hash) break } - if invoice.SettleDate != 0 && invoice.AmtPaidMsat > 0 { - if err := db.ConfirmInvoice(hash, time.Unix(invoice.SettleDate, 0), int(invoice.AmtPaidMsat)); err != nil { - log.Println(err) - time.Sleep(5 * time.Second) + if lnInvoice.AmountPaid > 0 { + preimage, err := lntypes.MakePreimageFromStr(invoice.Preimage) + if err != nil { + handleLoopError(err) + continue + } + // TODO settle invoice after matching order was found + if err := lnd.Invoices.SettleInvoice(context.TODO(), preimage); err != nil { + handleLoopError(err) + continue + } + if err := db.ConfirmInvoice(hash.String(), time.Now(), int(lnInvoice.AmountPaid)); err != nil { + handleLoopError(err) continue } log.Printf("invoice confirmed: hash=%s", hash) break } - time.Sleep(5 * time.Second) + time.Sleep(loopPause) } } func invoice(c echo.Context) error { invoiceId := c.Param("id") var invoice Invoice - if err := db.FetchInvoice(invoiceId, &invoice); err == sql.ErrNoRows { + if err := db.FetchInvoice(&FetchInvoiceWhere{Id: invoiceId}, &invoice); err == sql.ErrNoRows { return echo.NewHTTPError(http.StatusNotFound, "Not Found") } else if err != nil { return err @@ -132,7 +180,11 @@ func invoice(c echo.Context) error { if invoice.Pubkey != session.Pubkey { return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") } - go lnd.CheckInvoice(invoice.PaymentHash) + hash, err := lntypes.MakeHashFromStr(invoice.PaymentHash) + if err != nil { + return err + } + go lnd.CheckInvoice(hash) qr, err := ToQR(invoice.PaymentRequest) if err != nil { return err @@ -156,7 +208,7 @@ func invoice(c echo.Context) error { func invoiceStatus(c echo.Context) error { invoiceId := c.Param("id") var invoice Invoice - if err := db.FetchInvoice(invoiceId, &invoice); err == sql.ErrNoRows { + if err := db.FetchInvoice(&FetchInvoiceWhere{Id: invoiceId}, &invoice); err == sql.ErrNoRows { return echo.NewHTTPError(http.StatusNotFound, "Not Found") } else if err != nil { return err diff --git a/src/market.go b/src/market.go index a7663ae..477a912 100644 --- a/src/market.go +++ b/src/market.go @@ -9,6 +9,7 @@ import ( "time" "github.com/labstack/echo/v4" + "github.com/lightningnetwork/lnd/lntypes" "gopkg.in/guregu/null.v4" ) @@ -90,7 +91,11 @@ func order(c echo.Context) error { if err != nil { return err } - go lnd.CheckInvoice(invoice.PaymentHash) + hash, err := lntypes.MakeHashFromStr(invoice.PaymentHash) + if err != nil { + return err + } + go lnd.CheckInvoice(hash) data := map[string]any{ "session": c.Get("session"), "ENV": ENV, diff --git a/src/worker.go b/src/worker.go index d9e1175..2e01877 100644 --- a/src/worker.go +++ b/src/worker.go @@ -1,12 +1,23 @@ package main +import ( + "log" + + "github.com/lightningnetwork/lnd/lntypes" +) + func RunJobs() error { var invoices []Invoice if err := db.FetchInvoices(&FetchInvoicesWhere{Expired: false}, &invoices); err != nil { return err } for _, inv := range invoices { - go lnd.CheckInvoice(inv.PaymentHash) + hash, err := lntypes.MakeHashFromStr(inv.PaymentHash) + if err != nil { + log.Println(err) + continue + } + go lnd.CheckInvoice(hash) } return nil }