Fix orphaned invoices with Tx and Context usage
txs are now rolled back on order error
This commit is contained in:
parent
22c9bbdee3
commit
73162155d5
|
@ -1,12 +1,13 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (db *DB) CreateInvoice(invoice *Invoice) error {
|
func (db *DB) CreateInvoice(tx *sql.Tx, ctx context.Context, invoice *Invoice) error {
|
||||||
if err := db.QueryRow(""+
|
if err := tx.QueryRowContext(ctx, ""+
|
||||||
"INSERT INTO invoices(pubkey, msats, preimage, hash, bolt11, created_at, expires_at, description) "+
|
"INSERT INTO invoices(pubkey, msats, preimage, hash, bolt11, created_at, expires_at, description) "+
|
||||||
"VALUES($1, $2, $3, $4, $5, $6, $7, $8) "+
|
"VALUES($1, $2, $3, $4, $5, $6, $7, $8) "+
|
||||||
"RETURNING id",
|
"RETURNING id",
|
||||||
|
|
17
db/market.go
17
db/market.go
|
@ -1,6 +1,9 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import "database/sql"
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
type FetchOrdersWhere struct {
|
type FetchOrdersWhere struct {
|
||||||
MarketId int
|
MarketId int
|
||||||
|
@ -8,8 +11,8 @@ type FetchOrdersWhere struct {
|
||||||
Confirmed bool
|
Confirmed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) CreateMarket(market *Market) error {
|
func (db *DB) CreateMarket(tx *sql.Tx, ctx context.Context, market *Market) error {
|
||||||
if err := db.QueryRow(""+
|
if err := tx.QueryRowContext(ctx, ""+
|
||||||
"INSERT INTO markets(description, end_date, invoice_id) "+
|
"INSERT INTO markets(description, end_date, invoice_id) "+
|
||||||
"VALUES($1, $2, $3) "+
|
"VALUES($1, $2, $3) "+
|
||||||
"RETURNING id", market.Description, market.EndDate, market.InvoiceId).Scan(&market.Id); err != nil {
|
"RETURNING id", market.Description, market.EndDate, market.InvoiceId).Scan(&market.Id); err != nil {
|
||||||
|
@ -62,8 +65,8 @@ func (db *DB) FetchShares(marketId int, shares *[]Share) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) FetchShare(shareId string, share *Share) error {
|
func (db *DB) FetchShare(tx *sql.Tx, ctx context.Context, shareId string, share *Share) error {
|
||||||
return db.QueryRow("SELECT id, market_id, description FROM shares WHERE id = $1", shareId).Scan(&share.Id, &share.MarketId, &share.Description)
|
return tx.QueryRowContext(ctx, "SELECT id, market_id, description FROM shares WHERE id = $1", shareId).Scan(&share.Id, &share.MarketId, &share.Description)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) FetchOrders(where *FetchOrdersWhere, orders *[]Order) error {
|
func (db *DB) FetchOrders(where *FetchOrdersWhere, orders *[]Order) error {
|
||||||
|
@ -99,8 +102,8 @@ func (db *DB) FetchOrders(where *FetchOrdersWhere, orders *[]Order) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) CreateOrder(order *Order) error {
|
func (db *DB) CreateOrder(tx *sql.Tx, ctx context.Context, order *Order) error {
|
||||||
if _, err := db.Exec(""+
|
if _, err := tx.ExecContext(ctx, ""+
|
||||||
"INSERT INTO orders(share_id, pubkey, side, quantity, price, invoice_id) "+
|
"INSERT INTO orders(share_id, pubkey, side, quantity, price, invoice_id) "+
|
||||||
"VALUES ($1, $2, $3, $4, $5, $6)",
|
"VALUES ($1, $2, $3, $4, $5, $6)",
|
||||||
order.ShareId, order.Pubkey, order.Side, order.Quantity, order.Price, order.InvoiceId); err != nil {
|
order.ShareId, order.Pubkey, order.Side, order.Quantity, order.Price, order.InvoiceId); err != nil {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package lnd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"log"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -12,7 +13,7 @@ import (
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (lnd *LNDClient) CreateInvoice(d *db.DB, pubkey string, msats int64, description string) (*db.Invoice, error) {
|
func (lnd *LNDClient) CreateInvoice(tx *sql.Tx, ctx context.Context, d *db.DB, pubkey string, msats int64, description string) (*db.Invoice, error) {
|
||||||
var (
|
var (
|
||||||
expiry time.Duration = time.Hour
|
expiry time.Duration = time.Hour
|
||||||
preimage lntypes.Preimage
|
preimage lntypes.Preimage
|
||||||
|
@ -26,14 +27,14 @@ func (lnd *LNDClient) CreateInvoice(d *db.DB, pubkey string, msats int64, descri
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
hash = preimage.Hash()
|
hash = preimage.Hash()
|
||||||
if paymentRequest, err = lnd.Invoices.AddHoldInvoice(context.TODO(), &invoicesrpc.AddInvoiceData{
|
if paymentRequest, err = lnd.Invoices.AddHoldInvoice(ctx, &invoicesrpc.AddInvoiceData{
|
||||||
Hash: &hash,
|
Hash: &hash,
|
||||||
Value: lnwire.MilliSatoshi(msats),
|
Value: lnwire.MilliSatoshi(msats),
|
||||||
Expiry: int64(expiry / time.Millisecond),
|
Expiry: int64(expiry / time.Millisecond),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if lnInvoice, err = lnd.Client.LookupInvoice(context.TODO(), hash); err != nil {
|
if lnInvoice, err = lnd.Client.LookupInvoice(ctx, hash); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dbInvoice = &db.Invoice{
|
dbInvoice = &db.Invoice{
|
||||||
|
@ -46,7 +47,7 @@ func (lnd *LNDClient) CreateInvoice(d *db.DB, pubkey string, msats int64, descri
|
||||||
ExpiresAt: lnInvoice.CreationDate.Add(expiry),
|
ExpiresAt: lnInvoice.CreationDate.Add(expiry),
|
||||||
Description: description,
|
Description: description,
|
||||||
}
|
}
|
||||||
if err := d.CreateInvoice(dbInvoice); err != nil {
|
if err := d.CreateInvoice(tx, ctx, dbInvoice); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return dbInvoice, nil
|
return dbInvoice, nil
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
context_ "context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.ekzyis.com/ekzyis/delphi.market/db"
|
"git.ekzyis.com/ekzyis/delphi.market/db"
|
||||||
"git.ekzyis.com/ekzyis/delphi.market/lib"
|
"git.ekzyis.com/ekzyis/delphi.market/lib"
|
||||||
|
@ -50,6 +52,7 @@ func HandleMarket(sc context.ServerContext) echo.HandlerFunc {
|
||||||
func HandleCreateMarket(sc context.ServerContext) echo.HandlerFunc {
|
func HandleCreateMarket(sc context.ServerContext) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
var (
|
var (
|
||||||
|
tx *sql.Tx
|
||||||
u db.User
|
u db.User
|
||||||
m db.Market
|
m db.Market
|
||||||
invoice *db.Invoice
|
invoice *db.Invoice
|
||||||
|
@ -64,11 +67,20 @@ func HandleCreateMarket(sc context.ServerContext) echo.HandlerFunc {
|
||||||
return echo.NewHTTPError(http.StatusBadRequest)
|
return echo.NewHTTPError(http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// transaction start
|
||||||
|
ctx, cancel := context_.WithTimeout(c.Request().Context(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if tx, err = sc.Db.BeginTx(ctx, nil); err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Commit()
|
||||||
|
|
||||||
u = c.Get("session").(db.User)
|
u = c.Get("session").(db.User)
|
||||||
msats = 1000
|
msats = 1000
|
||||||
// TODO: add [market:<id>] for redirect after payment
|
// TODO: add [market:<id>] for redirect after payment
|
||||||
invDescription = fmt.Sprintf("create market \"%s\" (%s)", m.Description, m.EndDate)
|
invDescription = fmt.Sprintf("create market \"%s\" (%s)", m.Description, m.EndDate)
|
||||||
if invoice, err = sc.Lnd.CreateInvoice(sc.Db, u.Pubkey, msats, invDescription); err != nil {
|
if invoice, err = sc.Lnd.CreateInvoice(tx, ctx, sc.Db, u.Pubkey, msats, invDescription); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if qr, err = lib.ToQR(invoice.PaymentRequest); err != nil {
|
if qr, err = lib.ToQR(invoice.PaymentRequest); err != nil {
|
||||||
|
@ -80,7 +92,7 @@ func HandleCreateMarket(sc context.ServerContext) echo.HandlerFunc {
|
||||||
go sc.Lnd.CheckInvoice(sc.Db, hash)
|
go sc.Lnd.CheckInvoice(sc.Db, hash)
|
||||||
|
|
||||||
m.InvoiceId = invoice.Id
|
m.InvoiceId = invoice.Id
|
||||||
if err := sc.Db.CreateMarket(&m); err != nil {
|
if err := sc.Db.CreateMarket(tx, ctx, &m); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,6 +109,7 @@ func HandleCreateMarket(sc context.ServerContext) echo.HandlerFunc {
|
||||||
func HandleOrder(sc context.ServerContext) echo.HandlerFunc {
|
func HandleOrder(sc context.ServerContext) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
var (
|
var (
|
||||||
|
tx *sql.Tx
|
||||||
u db.User
|
u db.User
|
||||||
o db.Order
|
o db.Order
|
||||||
s db.Share
|
s db.Share
|
||||||
|
@ -122,7 +135,18 @@ func HandleOrder(sc context.ServerContext) echo.HandlerFunc {
|
||||||
u = c.Get("session").(db.User)
|
u = c.Get("session").(db.User)
|
||||||
o.Pubkey = u.Pubkey
|
o.Pubkey = u.Pubkey
|
||||||
msats = o.Quantity * o.Price * 1000
|
msats = o.Quantity * o.Price * 1000
|
||||||
if err = sc.Db.FetchShare(o.ShareId, &s); err != nil {
|
|
||||||
|
// transaction start
|
||||||
|
ctx, cancel := context_.WithTimeout(c.Request().Context(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if tx, err = sc.Db.BeginTx(ctx, nil); err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Commit()
|
||||||
|
|
||||||
|
if err = sc.Db.FetchShare(tx, ctx, o.ShareId, &s); err != nil {
|
||||||
|
tx.Rollback()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
description = fmt.Sprintf("%s %d %s shares @ %d sats [market:%d]", strings.ToUpper(o.Side), o.Quantity, s.Description, o.Price, s.MarketId)
|
description = fmt.Sprintf("%s %d %s shares @ %d sats [market:%d]", strings.ToUpper(o.Side), o.Quantity, s.Description, o.Price, s.MarketId)
|
||||||
|
@ -130,20 +154,24 @@ func HandleOrder(sc context.ServerContext) echo.HandlerFunc {
|
||||||
// TODO: if SELL order, check share balance of user
|
// TODO: if SELL order, check share balance of user
|
||||||
|
|
||||||
// Create HODL invoice
|
// Create HODL invoice
|
||||||
if invoice, err = sc.Lnd.CreateInvoice(sc.Db, o.Pubkey, msats, description); err != nil {
|
if invoice, err = sc.Lnd.CreateInvoice(tx, ctx, sc.Db, o.Pubkey, msats, description); err != nil {
|
||||||
|
tx.Rollback()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Create QR code to pay HODL invoice
|
// Create QR code to pay HODL invoice
|
||||||
if qr, err = lib.ToQR(invoice.PaymentRequest); err != nil {
|
if qr, err = lib.ToQR(invoice.PaymentRequest); err != nil {
|
||||||
|
tx.Rollback()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if hash, err = lntypes.MakeHashFromStr(invoice.Hash); err != nil {
|
if hash, err = lntypes.MakeHashFromStr(invoice.Hash); err != nil {
|
||||||
|
tx.Rollback()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create (unconfirmed) order
|
// Create (unconfirmed) order
|
||||||
o.InvoiceId = invoice.Id
|
o.InvoiceId = invoice.Id
|
||||||
if err := sc.Db.CreateOrder(&o); err != nil {
|
if err := sc.Db.CreateOrder(tx, ctx, &o); err != nil {
|
||||||
|
tx.Rollback()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue