diff --git a/db/invoice.go b/db/invoice.go index e2c9bae..161acc9 100644 --- a/db/invoice.go +++ b/db/invoice.go @@ -1,12 +1,13 @@ package db import ( + "context" "database/sql" "time" ) -func (db *DB) CreateInvoice(invoice *Invoice) error { - if err := db.QueryRow(""+ +func (db *DB) CreateInvoice(tx *sql.Tx, ctx context.Context, invoice *Invoice) error { + if err := tx.QueryRowContext(ctx, ""+ "INSERT INTO invoices(pubkey, msats, preimage, hash, bolt11, created_at, expires_at, description) "+ "VALUES($1, $2, $3, $4, $5, $6, $7, $8) "+ "RETURNING id", diff --git a/db/market.go b/db/market.go index 0ca6ad3..d2ae5e2 100644 --- a/db/market.go +++ b/db/market.go @@ -1,6 +1,9 @@ package db -import "database/sql" +import ( + "context" + "database/sql" +) type FetchOrdersWhere struct { MarketId int @@ -8,8 +11,8 @@ type FetchOrdersWhere struct { Confirmed bool } -func (db *DB) CreateMarket(market *Market) error { - if err := db.QueryRow(""+ +func (db *DB) CreateMarket(tx *sql.Tx, ctx context.Context, market *Market) error { + if err := tx.QueryRowContext(ctx, ""+ "INSERT INTO markets(description, end_date, invoice_id) "+ "VALUES($1, $2, $3) "+ "RETURNING id", market.Description, market.EndDate, market.InvoiceId).Scan(&market.Id); err != nil { @@ -62,8 +65,8 @@ func (db *DB) FetchShares(marketId int, shares *[]Share) error { return nil } -func (db *DB) FetchShare(shareId string, share *Share) error { - return db.QueryRow("SELECT id, market_id, description FROM shares WHERE id = $1", shareId).Scan(&share.Id, &share.MarketId, &share.Description) +func (db *DB) FetchShare(tx *sql.Tx, ctx context.Context, shareId string, share *Share) error { + return tx.QueryRowContext(ctx, "SELECT id, market_id, description FROM shares WHERE id = $1", shareId).Scan(&share.Id, &share.MarketId, &share.Description) } func (db *DB) FetchOrders(where *FetchOrdersWhere, orders *[]Order) error { @@ -99,8 +102,8 @@ func (db *DB) FetchOrders(where *FetchOrdersWhere, orders *[]Order) error { return nil } -func (db *DB) CreateOrder(order *Order) error { - if _, err := db.Exec(""+ +func (db *DB) CreateOrder(tx *sql.Tx, ctx context.Context, order *Order) error { + if _, err := tx.ExecContext(ctx, ""+ "INSERT INTO orders(share_id, pubkey, side, quantity, price, invoice_id) "+ "VALUES ($1, $2, $3, $4, $5, $6)", order.ShareId, order.Pubkey, order.Side, order.Quantity, order.Price, order.InvoiceId); err != nil { diff --git a/lnd/invoice.go b/lnd/invoice.go index 337fa8d..84d9163 100644 --- a/lnd/invoice.go +++ b/lnd/invoice.go @@ -2,6 +2,7 @@ package lnd import ( "context" + "database/sql" "log" "time" @@ -12,7 +13,7 @@ import ( "github.com/lightningnetwork/lnd/lnwire" ) -func (lnd *LNDClient) CreateInvoice(d *db.DB, pubkey string, msats int64, description string) (*db.Invoice, error) { +func (lnd *LNDClient) CreateInvoice(tx *sql.Tx, ctx context.Context, d *db.DB, pubkey string, msats int64, description string) (*db.Invoice, error) { var ( expiry time.Duration = time.Hour preimage lntypes.Preimage @@ -26,14 +27,14 @@ func (lnd *LNDClient) CreateInvoice(d *db.DB, pubkey string, msats int64, descri return nil, err } hash = preimage.Hash() - if paymentRequest, err = lnd.Invoices.AddHoldInvoice(context.TODO(), &invoicesrpc.AddInvoiceData{ + if paymentRequest, err = lnd.Invoices.AddHoldInvoice(ctx, &invoicesrpc.AddInvoiceData{ Hash: &hash, Value: lnwire.MilliSatoshi(msats), Expiry: int64(expiry / time.Millisecond), }); err != nil { return nil, err } - if lnInvoice, err = lnd.Client.LookupInvoice(context.TODO(), hash); err != nil { + if lnInvoice, err = lnd.Client.LookupInvoice(ctx, hash); err != nil { return nil, err } dbInvoice = &db.Invoice{ @@ -46,7 +47,7 @@ func (lnd *LNDClient) CreateInvoice(d *db.DB, pubkey string, msats int64, descri ExpiresAt: lnInvoice.CreationDate.Add(expiry), Description: description, } - if err := d.CreateInvoice(dbInvoice); err != nil { + if err := d.CreateInvoice(tx, ctx, dbInvoice); err != nil { return nil, err } return dbInvoice, nil diff --git a/server/router/handler/market.go b/server/router/handler/market.go index 3c0b917..2728aee 100644 --- a/server/router/handler/market.go +++ b/server/router/handler/market.go @@ -1,11 +1,13 @@ package handler import ( + context_ "context" "database/sql" "fmt" "net/http" "strconv" "strings" + "time" "git.ekzyis.com/ekzyis/delphi.market/db" "git.ekzyis.com/ekzyis/delphi.market/lib" @@ -50,6 +52,7 @@ func HandleMarket(sc context.ServerContext) echo.HandlerFunc { func HandleCreateMarket(sc context.ServerContext) echo.HandlerFunc { return func(c echo.Context) error { var ( + tx *sql.Tx u db.User m db.Market invoice *db.Invoice @@ -64,11 +67,20 @@ func HandleCreateMarket(sc context.ServerContext) echo.HandlerFunc { return echo.NewHTTPError(http.StatusBadRequest) } + // transaction start + ctx, cancel := context_.WithTimeout(c.Request().Context(), 5*time.Second) + defer cancel() + if tx, err = sc.Db.BeginTx(ctx, nil); err != nil { + tx.Rollback() + return err + } + defer tx.Commit() + u = c.Get("session").(db.User) msats = 1000 // TODO: add [market:] for redirect after payment invDescription = fmt.Sprintf("create market \"%s\" (%s)", m.Description, m.EndDate) - if invoice, err = sc.Lnd.CreateInvoice(sc.Db, u.Pubkey, msats, invDescription); err != nil { + if invoice, err = sc.Lnd.CreateInvoice(tx, ctx, sc.Db, u.Pubkey, msats, invDescription); err != nil { return err } if qr, err = lib.ToQR(invoice.PaymentRequest); err != nil { @@ -80,7 +92,7 @@ func HandleCreateMarket(sc context.ServerContext) echo.HandlerFunc { go sc.Lnd.CheckInvoice(sc.Db, hash) m.InvoiceId = invoice.Id - if err := sc.Db.CreateMarket(&m); err != nil { + if err := sc.Db.CreateMarket(tx, ctx, &m); err != nil { return err } @@ -97,6 +109,7 @@ func HandleCreateMarket(sc context.ServerContext) echo.HandlerFunc { func HandleOrder(sc context.ServerContext) echo.HandlerFunc { return func(c echo.Context) error { var ( + tx *sql.Tx u db.User o db.Order s db.Share @@ -122,7 +135,18 @@ func HandleOrder(sc context.ServerContext) echo.HandlerFunc { u = c.Get("session").(db.User) o.Pubkey = u.Pubkey msats = o.Quantity * o.Price * 1000 - if err = sc.Db.FetchShare(o.ShareId, &s); err != nil { + + // transaction start + ctx, cancel := context_.WithTimeout(c.Request().Context(), 5*time.Second) + defer cancel() + if tx, err = sc.Db.BeginTx(ctx, nil); err != nil { + tx.Rollback() + return err + } + defer tx.Commit() + + if err = sc.Db.FetchShare(tx, ctx, o.ShareId, &s); err != nil { + tx.Rollback() return err } description = fmt.Sprintf("%s %d %s shares @ %d sats [market:%d]", strings.ToUpper(o.Side), o.Quantity, s.Description, o.Price, s.MarketId) @@ -130,20 +154,24 @@ func HandleOrder(sc context.ServerContext) echo.HandlerFunc { // TODO: if SELL order, check share balance of user // Create HODL invoice - if invoice, err = sc.Lnd.CreateInvoice(sc.Db, o.Pubkey, msats, description); err != nil { + if invoice, err = sc.Lnd.CreateInvoice(tx, ctx, sc.Db, o.Pubkey, msats, description); err != nil { + tx.Rollback() return err } // Create QR code to pay HODL invoice if qr, err = lib.ToQR(invoice.PaymentRequest); err != nil { + tx.Rollback() return err } if hash, err = lntypes.MakeHashFromStr(invoice.Hash); err != nil { + tx.Rollback() return err } // Create (unconfirmed) order o.InvoiceId = invoice.Id - if err := sc.Db.CreateOrder(&o); err != nil { + if err := sc.Db.CreateOrder(tx, ctx, &o); err != nil { + tx.Rollback() return err }