diff --git a/db/invoice.go b/db/invoice.go index 161acc9..964e671 100644 --- a/db/invoice.go +++ b/db/invoice.go @@ -98,8 +98,8 @@ func (db *DB) FetchUserInvoices(pubkey string, invoices *[]Invoice) error { return nil } -func (db *DB) ConfirmInvoice(hash string, confirmedAt time.Time, msatsReceived int) error { - if _, err := db.Exec("UPDATE invoices SET confirmed_at = $2, msats_received = $3 WHERE hash = $1", hash, confirmedAt, msatsReceived); err != nil { +func (db *DB) ConfirmInvoice(tx *sql.Tx, c context.Context, hash string, confirmedAt time.Time, msatsReceived int) error { + if _, err := tx.ExecContext(c, "UPDATE invoices SET confirmed_at = $2, msats_received = $3 WHERE hash = $1", hash, confirmedAt, msatsReceived); err != nil { return err } return nil diff --git a/lnd/invoice.go b/lnd/invoice.go index 83c833d..917214e 100644 --- a/lnd/invoice.go +++ b/lnd/invoice.go @@ -67,24 +67,34 @@ func (lnd *LNDClient) CheckInvoice(d *db.DB, hash lntypes.Hash) { return } - handleLoopError := func(err error) { - log.Println(err) - time.Sleep(pollInterval) - } - for { + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + var tx *sql.Tx + if tx, err = d.BeginTx(ctx, nil); err != nil { + cancel() + continue + } + + handleLoopError := func(err error) { + log.Println(err) + tx.Rollback() + cancel() + time.Sleep(pollInterval) + } + log.Printf("lookup invoice: hash=%s", hash) - if lnInvoice, err = lnd.Client.LookupInvoice(context.TODO(), hash); err != nil { + if lnInvoice, err = lnd.Client.LookupInvoice(ctx, hash); err != nil { handleLoopError(err) continue } if time.Now().After(invoice.ExpiresAt) { // cancel invoices after expiration if no matching order found yet - if err = lnd.Invoices.CancelInvoice(context.TODO(), hash); err != nil { + if err = lnd.Invoices.CancelInvoice(ctx, hash); err != nil { handleLoopError(err) continue } log.Printf("invoice expired: hash=%s", hash) + tx.Commit() break } if lnInvoice.AmountPaid == lnInvoice.Amount { @@ -93,15 +103,16 @@ func (lnd *LNDClient) CheckInvoice(d *db.DB, hash lntypes.Hash) { continue } // TODO settle invoice after matching order was found - if err = lnd.Invoices.SettleInvoice(context.TODO(), preimage); err != nil { + if err = lnd.Invoices.SettleInvoice(ctx, preimage); err != nil { handleLoopError(err) continue } - if err = d.ConfirmInvoice(hash.String(), time.Now(), int(lnInvoice.AmountPaid)); err != nil { + if err = d.ConfirmInvoice(tx, ctx, hash.String(), time.Now(), int(lnInvoice.AmountPaid)); err != nil { handleLoopError(err) continue } log.Printf("invoice confirmed: hash=%s", hash) + tx.Commit() break } time.Sleep(pollInterval)