From b2fe9b03d7255b275041dfb7b11ee061cb936ca1 Mon Sep 17 00:00:00 2001 From: ekzyis Date: Sat, 9 Sep 2023 22:52:51 +0200 Subject: [PATCH] Use 405 Method Not Allowed if no LND connection exists --- src/lnd.go | 22 ++++++++++++++++++++-- src/router.go | 15 ++++++++------- src/server.go | 6 +++--- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/lnd.go b/src/lnd.go index d07e59f..d236066 100644 --- a/src/lnd.go +++ b/src/lnd.go @@ -21,6 +21,7 @@ var ( LndMacaroonDir string LndHost string lnd *LndClient + lndEnabled bool ) type LndClient struct { @@ -36,16 +37,29 @@ func init() { flag.StringVar(&LndMacaroonDir, "LND_MACAROON_DIR", "", "LND macaroon directory") flag.StringVar(&LndHost, "LND_HOST", "localhost:10001", "LND gRPC server address") flag.Parse() + lndEnabled = false rpcClient, err := lndclient.NewBasicClient(LndHost, LndCert, LndMacaroonDir, "regtest") if err != nil { - panic(err) + log.Println(err) + return } lnd = &LndClient{LightningClient: rpcClient} if info, err := lnd.GetInfo(context.TODO(), &lnrpc.GetInfoRequest{}); err != nil { - panic(err) + log.Printf("LND connection error: %v\n", err) + return } else { version := strings.Split(info.Version, " ")[0] log.Printf("Connected to %s running LND v%s", LndHost, version) + lndEnabled = true + } +} + +func lndGuard(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if lndEnabled { + return next(c) + } + return serveError(c, 405) } } @@ -77,6 +91,10 @@ func (lnd *LndClient) CreateInvoice(pubkey string, msats int) (*Invoice, error) } func (lnd *LndClient) CheckInvoice(hash string) { + if !lndEnabled { + log.Printf("LND disabled, skipping checking invoice: hash=%s", hash) + return + } for { log.Printf("lookup invoice: hash=%s", hash) invoice, err := lnd.LookupInvoice(context.TODO(), &lnrpc.PaymentHash{RHashStr: hash}) diff --git a/src/router.go b/src/router.go index 85dbab3..82e00dd 100644 --- a/src/router.go +++ b/src/router.go @@ -71,17 +71,18 @@ func index(c echo.Context) error { return c.Render(http.StatusOK, "index.html", data) } -func serve500(c echo.Context) { - f, err := os.Open("public/500.html") +func serveError(c echo.Context, code int) error { + f, err := os.Open(fmt.Sprintf("public/%d.html", code)) if err != nil { c.Logger().Error(err) - return + return err } - err = c.Stream(500, "text/html", f) + err = c.Stream(code, "text/html", f) if err != nil { c.Logger().Error(err) - return + return err } + return nil } func httpErrorHandler(err error, c echo.Context) { @@ -95,13 +96,13 @@ func httpErrorHandler(err error, c echo.Context) { f, err := os.Open(filePath) if err != nil { c.Logger().Error(err) - serve500(c) + serveError(c, 500) return } err = c.Stream(code, "text/html", f) if err != nil { c.Logger().Error(err) - serve500(c) + serveError(c, 500) return } } diff --git a/src/server.go b/src/server.go index fea6638..4d60c3d 100644 --- a/src/server.go +++ b/src/server.go @@ -65,9 +65,9 @@ func main() { e.POST("/logout", logout) e.GET("/user", sessionGuard(user)) e.GET("/market/:id", sessionGuard(market)) - e.POST("/market/:id/order", sessionGuard(order)) - e.GET("/invoice/:id", sessionGuard(invoice)) - e.GET("/api/invoice/:id", sessionGuard(invoiceStatus)) + e.POST("/market/:id/order", sessionGuard(lndGuard(order))) + e.GET("/invoice/:id", sessionGuard(lndGuard(invoice))) + e.GET("/api/invoice/:id", sessionGuard(lndGuard(invoiceStatus))) e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ Format: "${time_custom} ${method} ${uri} ${status}\n", CustomTimeFormat: "2006-01-02 15:04:05.00000-0700",