Fix orphaned invoices with Tx and Context usage
txs are now rolled back on order error
This commit is contained in:
		
							parent
							
								
									22c9bbdee3
								
							
						
					
					
						commit
						73162155d5
					
				@ -1,12 +1,13 @@
 | 
			
		||||
package db
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (db *DB) CreateInvoice(invoice *Invoice) error {
 | 
			
		||||
	if err := db.QueryRow(""+
 | 
			
		||||
func (db *DB) CreateInvoice(tx *sql.Tx, ctx context.Context, invoice *Invoice) error {
 | 
			
		||||
	if err := tx.QueryRowContext(ctx, ""+
 | 
			
		||||
		"INSERT INTO invoices(pubkey, msats, preimage, hash, bolt11, created_at, expires_at, description) "+
 | 
			
		||||
		"VALUES($1, $2, $3, $4, $5, $6, $7, $8) "+
 | 
			
		||||
		"RETURNING id",
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										17
									
								
								db/market.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								db/market.go
									
									
									
									
									
								
							@ -1,6 +1,9 @@
 | 
			
		||||
package db
 | 
			
		||||
 | 
			
		||||
import "database/sql"
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type FetchOrdersWhere struct {
 | 
			
		||||
	MarketId  int
 | 
			
		||||
@ -8,8 +11,8 @@ type FetchOrdersWhere struct {
 | 
			
		||||
	Confirmed bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *DB) CreateMarket(market *Market) error {
 | 
			
		||||
	if err := db.QueryRow(""+
 | 
			
		||||
func (db *DB) CreateMarket(tx *sql.Tx, ctx context.Context, market *Market) error {
 | 
			
		||||
	if err := tx.QueryRowContext(ctx, ""+
 | 
			
		||||
		"INSERT INTO markets(description, end_date, invoice_id) "+
 | 
			
		||||
		"VALUES($1, $2, $3) "+
 | 
			
		||||
		"RETURNING id", market.Description, market.EndDate, market.InvoiceId).Scan(&market.Id); err != nil {
 | 
			
		||||
@ -62,8 +65,8 @@ func (db *DB) FetchShares(marketId int, shares *[]Share) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *DB) FetchShare(shareId string, share *Share) error {
 | 
			
		||||
	return db.QueryRow("SELECT id, market_id, description FROM shares WHERE id = $1", shareId).Scan(&share.Id, &share.MarketId, &share.Description)
 | 
			
		||||
func (db *DB) FetchShare(tx *sql.Tx, ctx context.Context, shareId string, share *Share) error {
 | 
			
		||||
	return tx.QueryRowContext(ctx, "SELECT id, market_id, description FROM shares WHERE id = $1", shareId).Scan(&share.Id, &share.MarketId, &share.Description)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *DB) FetchOrders(where *FetchOrdersWhere, orders *[]Order) error {
 | 
			
		||||
@ -99,8 +102,8 @@ func (db *DB) FetchOrders(where *FetchOrdersWhere, orders *[]Order) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *DB) CreateOrder(order *Order) error {
 | 
			
		||||
	if _, err := db.Exec(""+
 | 
			
		||||
func (db *DB) CreateOrder(tx *sql.Tx, ctx context.Context, order *Order) error {
 | 
			
		||||
	if _, err := tx.ExecContext(ctx, ""+
 | 
			
		||||
		"INSERT INTO orders(share_id, pubkey, side, quantity, price, invoice_id) "+
 | 
			
		||||
		"VALUES ($1, $2, $3, $4, $5, $6)",
 | 
			
		||||
		order.ShareId, order.Pubkey, order.Side, order.Quantity, order.Price, order.InvoiceId); err != nil {
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@ package lnd
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"log"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
@ -12,7 +13,7 @@ import (
 | 
			
		||||
	"github.com/lightningnetwork/lnd/lnwire"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (lnd *LNDClient) CreateInvoice(d *db.DB, pubkey string, msats int64, description string) (*db.Invoice, error) {
 | 
			
		||||
func (lnd *LNDClient) CreateInvoice(tx *sql.Tx, ctx context.Context, d *db.DB, pubkey string, msats int64, description string) (*db.Invoice, error) {
 | 
			
		||||
	var (
 | 
			
		||||
		expiry         time.Duration = time.Hour
 | 
			
		||||
		preimage       lntypes.Preimage
 | 
			
		||||
@ -26,14 +27,14 @@ func (lnd *LNDClient) CreateInvoice(d *db.DB, pubkey string, msats int64, descri
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	hash = preimage.Hash()
 | 
			
		||||
	if paymentRequest, err = lnd.Invoices.AddHoldInvoice(context.TODO(), &invoicesrpc.AddInvoiceData{
 | 
			
		||||
	if paymentRequest, err = lnd.Invoices.AddHoldInvoice(ctx, &invoicesrpc.AddInvoiceData{
 | 
			
		||||
		Hash:   &hash,
 | 
			
		||||
		Value:  lnwire.MilliSatoshi(msats),
 | 
			
		||||
		Expiry: int64(expiry / time.Millisecond),
 | 
			
		||||
	}); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if lnInvoice, err = lnd.Client.LookupInvoice(context.TODO(), hash); err != nil {
 | 
			
		||||
	if lnInvoice, err = lnd.Client.LookupInvoice(ctx, hash); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	dbInvoice = &db.Invoice{
 | 
			
		||||
@ -46,7 +47,7 @@ func (lnd *LNDClient) CreateInvoice(d *db.DB, pubkey string, msats int64, descri
 | 
			
		||||
		ExpiresAt:      lnInvoice.CreationDate.Add(expiry),
 | 
			
		||||
		Description:    description,
 | 
			
		||||
	}
 | 
			
		||||
	if err := d.CreateInvoice(dbInvoice); err != nil {
 | 
			
		||||
	if err := d.CreateInvoice(tx, ctx, dbInvoice); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return dbInvoice, nil
 | 
			
		||||
 | 
			
		||||
@ -1,11 +1,13 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	context_ "context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.ekzyis.com/ekzyis/delphi.market/db"
 | 
			
		||||
	"git.ekzyis.com/ekzyis/delphi.market/lib"
 | 
			
		||||
@ -50,6 +52,7 @@ func HandleMarket(sc context.ServerContext) echo.HandlerFunc {
 | 
			
		||||
func HandleCreateMarket(sc context.ServerContext) echo.HandlerFunc {
 | 
			
		||||
	return func(c echo.Context) error {
 | 
			
		||||
		var (
 | 
			
		||||
			tx             *sql.Tx
 | 
			
		||||
			u              db.User
 | 
			
		||||
			m              db.Market
 | 
			
		||||
			invoice        *db.Invoice
 | 
			
		||||
@ -64,11 +67,20 @@ func HandleCreateMarket(sc context.ServerContext) echo.HandlerFunc {
 | 
			
		||||
			return echo.NewHTTPError(http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// transaction start
 | 
			
		||||
		ctx, cancel := context_.WithTimeout(c.Request().Context(), 5*time.Second)
 | 
			
		||||
		defer cancel()
 | 
			
		||||
		if tx, err = sc.Db.BeginTx(ctx, nil); err != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		defer tx.Commit()
 | 
			
		||||
 | 
			
		||||
		u = c.Get("session").(db.User)
 | 
			
		||||
		msats = 1000
 | 
			
		||||
		// TODO: add [market:<id>] for redirect after payment
 | 
			
		||||
		invDescription = fmt.Sprintf("create market \"%s\" (%s)", m.Description, m.EndDate)
 | 
			
		||||
		if invoice, err = sc.Lnd.CreateInvoice(sc.Db, u.Pubkey, msats, invDescription); err != nil {
 | 
			
		||||
		if invoice, err = sc.Lnd.CreateInvoice(tx, ctx, sc.Db, u.Pubkey, msats, invDescription); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		if qr, err = lib.ToQR(invoice.PaymentRequest); err != nil {
 | 
			
		||||
@ -80,7 +92,7 @@ func HandleCreateMarket(sc context.ServerContext) echo.HandlerFunc {
 | 
			
		||||
		go sc.Lnd.CheckInvoice(sc.Db, hash)
 | 
			
		||||
 | 
			
		||||
		m.InvoiceId = invoice.Id
 | 
			
		||||
		if err := sc.Db.CreateMarket(&m); err != nil {
 | 
			
		||||
		if err := sc.Db.CreateMarket(tx, ctx, &m); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@ -97,6 +109,7 @@ func HandleCreateMarket(sc context.ServerContext) echo.HandlerFunc {
 | 
			
		||||
func HandleOrder(sc context.ServerContext) echo.HandlerFunc {
 | 
			
		||||
	return func(c echo.Context) error {
 | 
			
		||||
		var (
 | 
			
		||||
			tx          *sql.Tx
 | 
			
		||||
			u           db.User
 | 
			
		||||
			o           db.Order
 | 
			
		||||
			s           db.Share
 | 
			
		||||
@ -122,7 +135,18 @@ func HandleOrder(sc context.ServerContext) echo.HandlerFunc {
 | 
			
		||||
		u = c.Get("session").(db.User)
 | 
			
		||||
		o.Pubkey = u.Pubkey
 | 
			
		||||
		msats = o.Quantity * o.Price * 1000
 | 
			
		||||
		if err = sc.Db.FetchShare(o.ShareId, &s); err != nil {
 | 
			
		||||
 | 
			
		||||
		// transaction start
 | 
			
		||||
		ctx, cancel := context_.WithTimeout(c.Request().Context(), 5*time.Second)
 | 
			
		||||
		defer cancel()
 | 
			
		||||
		if tx, err = sc.Db.BeginTx(ctx, nil); err != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		defer tx.Commit()
 | 
			
		||||
 | 
			
		||||
		if err = sc.Db.FetchShare(tx, ctx, o.ShareId, &s); err != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		description = fmt.Sprintf("%s %d %s shares @ %d sats [market:%d]", strings.ToUpper(o.Side), o.Quantity, s.Description, o.Price, s.MarketId)
 | 
			
		||||
@ -130,20 +154,24 @@ func HandleOrder(sc context.ServerContext) echo.HandlerFunc {
 | 
			
		||||
		// TODO: if SELL order, check share balance of user
 | 
			
		||||
 | 
			
		||||
		// Create HODL invoice
 | 
			
		||||
		if invoice, err = sc.Lnd.CreateInvoice(sc.Db, o.Pubkey, msats, description); err != nil {
 | 
			
		||||
		if invoice, err = sc.Lnd.CreateInvoice(tx, ctx, sc.Db, o.Pubkey, msats, description); err != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		// Create QR code to pay HODL invoice
 | 
			
		||||
		if qr, err = lib.ToQR(invoice.PaymentRequest); err != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		if hash, err = lntypes.MakeHashFromStr(invoice.Hash); err != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Create (unconfirmed) order
 | 
			
		||||
		o.InvoiceId = invoice.Id
 | 
			
		||||
		if err := sc.Db.CreateOrder(&o); err != nil {
 | 
			
		||||
		if err := sc.Db.CreateOrder(tx, ctx, &o); err != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user