diff --git a/golang/cmd/frontend/main.go b/golang/cmd/frontend/main.go index 2aabbe7..930649b 100644 --- a/golang/cmd/frontend/main.go +++ b/golang/cmd/frontend/main.go @@ -2,11 +2,8 @@ package main import ( "crowsnest/internal/model/database" - "database/sql" - "fmt" "log" "net/http" - "os" _ "github.com/lib/pq" ) @@ -16,35 +13,9 @@ type App struct { } func main() { - // collect environement variables - dbPass := os.Getenv("DB_PASS") - if dbPass == "" { - log.Fatal("empty DB_PASS") - } - dbHost := os.Getenv("DB_HOST") - if dbHost == "" { - log.Fatal("empty DB_HOST") - } - dbPort := os.Getenv("DB_PORT") - if dbPort == "" { - dbPort = "5432" - } - dbUser := os.Getenv("DB_USER") - if dbUser == "" { - dbUser = "postgres" - } - dbName := os.Getenv("DB_NAME") - - // connect to database - databaseURL := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%s sslmode=disable", - dbUser, dbPass, dbName, dbHost, dbPort) - db, err := sql.Open("postgres", databaseURL) + db, err := database.DbConnection() if err != nil { - log.Fatal(err) - } - defer db.Close() - if err = db.Ping(); err != nil { - log.Fatal(err) + log.Fatal("failed to connect to database due to", err.Error()) } // define app @@ -54,10 +25,10 @@ func main() { // start web server server := http.Server{ - Addr: ":8080", + Addr: ":80", Handler: app.routes(), } - log.Println("server started, listening on :8080") + log.Println("server started, listening on :80") server.ListenAndServe() } diff --git a/golang/internal/model/database/connect.go b/golang/internal/model/database/connect.go new file mode 100644 index 0000000..9f63057 --- /dev/null +++ b/golang/internal/model/database/connect.go @@ -0,0 +1,45 @@ +package database + +import ( + "database/sql" + "errors" + "fmt" + "os" +) + +// Will try to connect to the database defined by the env variables. If the +// connection fails a error will be returned. Ensure that the returned sql.DB +// object is closed after use. +func DbConnection() (*sql.DB, error) { + // collect environement variables + dbPass := os.Getenv("DB_PASS") + if dbPass == "" { + return nil, errors.New("empty env. variable DB_PASS") + } + dbHost := os.Getenv("DB_HOST") + if dbHost == "" { + return nil, errors.New("empty env. variable DB_HOST") + } + dbPort := os.Getenv("DB_PORT") + if dbPort == "" { + dbPort = "5432" + } + dbUser := os.Getenv("DB_USER") + if dbUser == "" { + dbUser = "postgres" + } + dbName := os.Getenv("DB_NAME") + + // connect to database + databaseURL := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%s sslmode=disable", + dbUser, dbPass, dbName, dbHost, dbPort) + db, err := sql.Open("postgres", databaseURL) + if err != nil { + return nil, err + } + if err = db.Ping(); err != nil { + return nil, err + } + + return db, nil +}