package db import ( "database/sql" "fmt" "io/ioutil" _ "github.com/lib/pq" ) type DB struct { *sql.DB } var ( schemaPath = "./db/schema.sql" ) func New(dbUrl string) (*DB, error) { var ( db_ *sql.DB db *DB err error ) if db_, err = sql.Open("postgres", dbUrl); err != nil { return nil, err } // test connection if _, err = db_.Exec("SELECT 1"); err != nil { return nil, err } // TODO: run migrations db = &DB{DB: db_} return db, nil } func (db *DB) Reset(dbName string) error { var ( f []byte err error ) if err = db.Clear(dbName); err != nil { return err } if f, err = ioutil.ReadFile(schemaPath); err != nil { return err } if _, err = db.Exec(string(f)); err != nil { return err } return nil } func (db *DB) Clear(dbName string) error { var ( tables = []string{"lnauth", "users", "sessions", "invoices", "markets", "shares", "invoices", "order_side", "orders", "withdrawals"} sql []string err error ) for _, t := range tables { sql = append(sql, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", t)) } sql = append(sql, "DROP EXTENSION IF EXISTS \"uuid-ossp\"") sql = append(sql, "DROP TYPE IF EXISTS order_side") for _, s := range sql { if _, err = db.Exec(s); err != nil { return err } } return nil }