Skip to content

Commit 6e3ff5d

Browse files
committed
refactoring of main
1 parent 13e64fb commit 6e3ff5d

File tree

3 files changed

+107
-84
lines changed

3 files changed

+107
-84
lines changed

cmd/web/app.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package main
22

33
import (
4+
"net/http"
5+
46
"github.com/alexedwards/scs"
57
"snippetbox.org/pkg/models"
68
)
@@ -11,6 +13,9 @@ type App struct {
1113
staticDir string
1214
databaseFile string
1315
secret string
16+
tlsCert string
17+
tlsKey string
18+
server *http.Server
1419
sessions *scs.Manager
1520
database *models.Database
1621
}

cmd/web/main.go

Lines changed: 13 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,18 @@
11
package main
22

33
import (
4-
"context"
5-
"database/sql"
64
"flag"
7-
"fmt"
85
"log"
9-
"net/http"
106
"os"
11-
"os/signal"
127
"time"
138

149
"github.com/alexedwards/scs"
1510

1611
_ "github.com/mattn/go-sqlite3"
17-
"snippetbox.org/pkg/models"
1812
)
1913

20-
func existDir(path string) bool {
21-
if _, err := os.Stat(path); os.IsNotExist(err) {
14+
func existDir(path *string) bool {
15+
if _, err := os.Stat(*path); os.IsNotExist(err) {
2216
return false
2317
}
2418
return true
@@ -33,105 +27,40 @@ func main() {
3327
flag.StringVar(&app.htmlDir, "html-dir", "./ui/html", "Path to html templates")
3428
flag.StringVar(&app.databaseFile, "db-file", "./info.db", "Path to database file")
3529
flag.StringVar(&app.secret, "secret", "8sB9ozuKkqWtN3b6lEiInd1dSISxPWogpaGV5HG4wKs=", "Secret key for cookies encryption")
36-
tlsCert := flag.String("tls-cert", "./tls/cert.pem", "TLS certificate")
37-
tlsKey := flag.String("tls-key", "./tls/key.pem", "TLS private-key")
30+
flag.StringVar(&app.tlsCert, "tls-cert", "./tls/cert.pem", "TLS certificate")
31+
flag.StringVar(&app.tlsKey, "tls-key", "./tls/key.pem", "TLS private-key")
3832
flag.Parse()
3933

40-
if !existDir(app.staticDir) {
34+
if !existDir(&app.staticDir) {
4135
log.Fatal("Folder for static-dir was not found")
4236
}
4337

44-
if !existDir(app.htmlDir) {
38+
if !existDir(&app.htmlDir) {
4539
log.Fatal("Folder for html-dir was not found")
4640
}
4741

48-
if !existDir(*tlsCert) {
42+
if !existDir(&app.tlsCert) {
4943
log.Fatal("TLS certificate was not found")
5044
}
5145

52-
if !existDir(*tlsKey) {
46+
if !existDir(&app.tlsKey) {
5347
log.Fatal("TLS key was not found")
5448
}
5549

56-
if err := app.connectDb(); err != nil {
50+
if err := app.ConnectDb(); err != nil {
5751
log.Fatal("Failed to establish database connection")
5852
}
59-
defer func() {
60-
log.Println("Closing database connection")
61-
app.closeDB()
62-
}()
53+
defer app.CloseDB()
6354

6455
sessionManager := scs.NewCookieManager(app.secret)
6556
sessionManager.Lifetime(12 * time.Hour)
6657
sessionManager.Persist(true)
6758
sessionManager.Secure(true)
6859
app.sessions = sessionManager
6960

70-
server := &http.Server{
71-
Addr: app.addr,
72-
Handler: app.Routes(),
73-
WriteTimeout: 15 * time.Second,
74-
ReadTimeout: 15 * time.Second,
75-
IdleTimeout: 60 * time.Second,
76-
}
77-
78-
c := make(chan os.Signal, 1)
79-
signal.Notify(c, os.Interrupt)
80-
go func() {
81-
for sig := range c {
82-
log.Printf("Terminating (signal caught - %s)\n", sig)
83-
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
84-
defer func() {
85-
log.Println("Closing context")
86-
cancel()
87-
}()
88-
server.Shutdown(ctx)
89-
}
90-
}()
91-
92-
log.Printf("Listening on %s\n", app.addr)
93-
94-
if err := server.ListenAndServeTLS(*tlsCert, *tlsKey); err != nil {
95-
if err != http.ErrServerClosed {
96-
log.Println("The error below raised after shutdown:")
97-
log.Println(err)
98-
}
99-
}
100-
}
101-
102-
func (app *App) connectDb() error {
103-
initDb := !existDir(app.databaseFile)
104-
105-
dsn := fmt.Sprintf("file:%s?cache=shared&_loc=auto", app.databaseFile)
106-
db, err := sql.Open("sqlite3", dsn)
107-
if err != nil {
108-
return err
109-
}
61+
app.InitServer()
11062

111-
defer func() {
112-
if err != nil {
113-
db.Close()
114-
}
115-
}()
63+
app.MonitorInterrupts()
11664

117-
if err = db.Ping(); err != nil {
118-
return err
119-
}
120-
121-
app.database = &models.Database{DB: db}
122-
123-
if initDb {
124-
log.Println("Initializing database...")
125-
if err := app.database.InitializeDb(); err != nil {
126-
return err
127-
}
128-
}
129-
130-
return nil
131-
}
132-
133-
func (app *App) closeDB() {
134-
if app.database != nil {
135-
app.database.Close()
136-
}
65+
app.RunServer()
13766
}

cmd/web/server.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"fmt"
7+
"log"
8+
"net/http"
9+
"os"
10+
"os/signal"
11+
"time"
12+
13+
"snippetbox.org/pkg/models"
14+
)
15+
16+
func (app *App) MonitorInterrupts() {
17+
c := make(chan os.Signal, 1)
18+
signal.Notify(c, os.Interrupt)
19+
go func() {
20+
for sig := range c {
21+
log.Printf("Terminating (signal caught - %s)\n", sig)
22+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
23+
defer func() {
24+
log.Println("Closing context")
25+
cancel()
26+
}()
27+
app.server.Shutdown(ctx)
28+
}
29+
}()
30+
}
31+
32+
func (app *App) ConnectDb() error {
33+
initDb := !existDir(&app.databaseFile)
34+
35+
dsn := fmt.Sprintf("file:%s?cache=shared&_loc=auto", app.databaseFile)
36+
db, err := sql.Open("sqlite3", dsn)
37+
if err != nil {
38+
return err
39+
}
40+
41+
defer func() {
42+
if err != nil {
43+
db.Close()
44+
}
45+
}()
46+
47+
if err = db.Ping(); err != nil {
48+
return err
49+
}
50+
51+
app.database = &models.Database{DB: db}
52+
53+
if initDb {
54+
log.Println("Initializing database...")
55+
if err := app.database.InitializeDb(); err != nil {
56+
return err
57+
}
58+
}
59+
60+
return nil
61+
}
62+
63+
func (app *App) CloseDB() {
64+
log.Println("Closing database connection")
65+
if app.database != nil {
66+
app.database.Close()
67+
}
68+
}
69+
70+
func (app *App) InitServer() {
71+
app.server = &http.Server{
72+
Addr: app.addr,
73+
Handler: app.Routes(),
74+
WriteTimeout: 15 * time.Second,
75+
ReadTimeout: 15 * time.Second,
76+
IdleTimeout: 60 * time.Second,
77+
}
78+
}
79+
80+
func (app *App) RunServer() {
81+
log.Printf("Listening on %s\n", app.addr)
82+
83+
if err := app.server.ListenAndServeTLS(app.tlsCert, app.tlsKey); err != nil {
84+
if err != http.ErrServerClosed {
85+
log.Println("The error below raised after shutdown:")
86+
log.Println(err)
87+
}
88+
}
89+
}

0 commit comments

Comments
 (0)