diff --git a/README.md b/README.md index 7396b14..e425a59 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,3 @@ # golang-webapp -Step by step tour of writing a web application in Go. - -[Video Playlist](https://www.youtube.com/playlist?list=PLmxT2pVYo5LDMV0epL4z4CUbxvIw6umg_) +This is the source code for the [Write a Web App in Go](https://www.youtube.com/playlist?list=PLmxT2pVYo5LDMV0epL4z4CUbxvIw6umg_) video series. diff --git a/main.go b/main.go index 3c951d1..c18741e 100644 --- a/main.go +++ b/main.go @@ -2,117 +2,15 @@ package main import ( "net/http" - "github.com/gorilla/mux" - "github.com/gorilla/sessions" - "github.com/go-redis/redis" - "golang.org/x/crypto/bcrypt" - "html/template" + "./routes" + "./models" + "./utils" ) -var client *redis.Client -var store = sessions.NewCookieStore([]byte("t0p-s3cr3t")) -var templates *template.Template - func main() { - client = redis.NewClient(&redis.Options{ - Addr: "localhost:6379", - }) - templates = template.Must(template.ParseGlob("templates/*.html")) - r := mux.NewRouter() - r.HandleFunc("/", AuthRequired(indexGetHandler)).Methods("GET") - r.HandleFunc("/", AuthRequired(indexPostHandler)).Methods("POST") - r.HandleFunc("/login", loginGetHandler).Methods("GET") - r.HandleFunc("/login", loginPostHandler).Methods("POST") - r.HandleFunc("/register", registerGetHandler).Methods("GET") - r.HandleFunc("/register", registerPostHandler).Methods("POST") - fs := http.FileServer(http.Dir("./static/")) - r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", fs)) + models.Init() + utils.LoadTemplates("templates/*.html") + r := routes.NewRouter() http.Handle("/", r) http.ListenAndServe(":8080", nil) } - -func AuthRequired(handler http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - session, _ := store.Get(r, "session") - _, ok := session.Values["username"] - if !ok { - http.Redirect(w, r, "/login", 302) - return - } - handler.ServeHTTP(w, r) - } -} - -func indexGetHandler(w http.ResponseWriter, r *http.Request) { - comments, err := client.LRange("comments", 0, 10).Result() - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal server error")) - return - } - templates.ExecuteTemplate(w, "index.html", comments) -} - -func indexPostHandler(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - comment := r.PostForm.Get("comment") - err := client.LPush("comments", comment).Err() - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal server error")) - return - } - http.Redirect(w, r, "/", 302) -} - -func loginGetHandler(w http.ResponseWriter, r *http.Request) { - templates.ExecuteTemplate(w, "login.html", nil) -} - -func loginPostHandler(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - username := r.PostForm.Get("username") - password := r.PostForm.Get("password") - hash, err := client.Get("user:" + username).Bytes() - if err == redis.Nil { - templates.ExecuteTemplate(w, "login.html", "unknown user") - return - } else if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal server error")) - return - } - err = bcrypt.CompareHashAndPassword(hash, []byte(password)) - if err != nil { - templates.ExecuteTemplate(w, "login.html", "invalid login") - return - } - session, _ := store.Get(r, "session") - session.Values["username"] = username - session.Save(r, w) - http.Redirect(w, r, "/", 302) -} - -func registerGetHandler(w http.ResponseWriter, r *http.Request) { - templates.ExecuteTemplate(w, "register.html", nil) -} - -func registerPostHandler(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - username := r.PostForm.Get("username") - password := r.PostForm.Get("password") - cost := bcrypt.DefaultCost - hash, err := bcrypt.GenerateFromPassword([]byte(password), cost) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal server error")) - return - } - err = client.Set("user:" + username, hash, 0).Err() - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal server error")) - return - } - http.Redirect(w, r, "/login", 302) -} diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000..e8af467 --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,18 @@ +package middleware + +import ( + "net/http" + "../sessions" +) + +func AuthRequired(handler http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + session, _ := sessions.Store.Get(r, "session") + _, ok := session.Values["user_id"] + if !ok { + http.Redirect(w, r, "/login", 302) + return + } + handler.ServeHTTP(w, r) + } +} diff --git a/models/db.go b/models/db.go new file mode 100644 index 0000000..f95a7d7 --- /dev/null +++ b/models/db.go @@ -0,0 +1,13 @@ +package models + +import ( + "github.com/go-redis/redis" +) + +var client *redis.Client + +func Init() { + client = redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + }) +} \ No newline at end of file diff --git a/models/update.go b/models/update.go new file mode 100644 index 0000000..9f37c8f --- /dev/null +++ b/models/update.go @@ -0,0 +1,73 @@ +package models + +import ( + "fmt" + "strconv" +) + +type Update struct { + id int64 +} + +func NewUpdate(userId int64, body string) (*Update, error) { + id, err := client.Incr("update:next-id").Result() + if err != nil { + return nil, err + } + key := fmt.Sprintf("update:%d", id) + pipe := client.Pipeline() + pipe.HSet(key, "id", id) + pipe.HSet(key, "user_id", userId) + pipe.HSet(key, "body", body) + pipe.LPush("updates", id) + pipe.LPush(fmt.Sprintf("user:%d:updates", userId), id) + _, err = pipe.Exec() + if err != nil { + return nil, err + } + return &Update{id}, nil +} + +func (update *Update) GetBody() (string, error) { + key := fmt.Sprintf("update:%d", update.id) + return client.HGet(key, "body").Result() +} + +func (update *Update) GetUser() (*User, error) { + key := fmt.Sprintf("update:%d", update.id) + userId, err := client.HGet(key, "user_id").Int64() + if err != nil { + return nil, err + } + return GetUserById(userId) +} + +func queryUpdates(key string) ([]*Update, error) { + updateIds, err := client.LRange(key, 0, 10).Result() + if err != nil { + return nil, err + } + updates := make([]*Update, len(updateIds)) + for i, strId := range updateIds { + id, err := strconv.Atoi(strId) + if err != nil { + return nil, err + } + updates[i] = &Update{int64(id)} + } + return updates, nil +} + +func GetAllUpdates() ([]*Update, error) { + return queryUpdates("updates") +} + +func GetUpdates(userId int64) ([]*Update, error) { + key := fmt.Sprintf("user:%d:updates", userId) + return queryUpdates(key) +} + +func PostUpdate(userId int64, body string) error { + _, err := NewUpdate(userId, body) + return err +} \ No newline at end of file diff --git a/models/user.go b/models/user.go new file mode 100644 index 0000000..a5c595a --- /dev/null +++ b/models/user.go @@ -0,0 +1,98 @@ +package models + +import ( + "fmt" + "errors" + "github.com/go-redis/redis" + "golang.org/x/crypto/bcrypt" +) + +var ( + ErrUserNotFound = errors.New("user not found") + ErrInvalidLogin = errors.New("invalid login") + ErrUsernameTaken = errors.New("username taken") +) + +type User struct { + id int64 +} + +func NewUser(username string, hash []byte) (*User, error) { + exists, err := client.HExists("user:by-username", username).Result() + if exists { + return nil, ErrUsernameTaken + } + id, err := client.Incr("user:next-id").Result() + if err != nil { + return nil, err + } + key := fmt.Sprintf("user:%d", id) + pipe := client.Pipeline() + pipe.HSet(key, "id", id) + pipe.HSet(key, "username", username) + pipe.HSet(key, "hash", hash) + pipe.HSet("user:by-username", username, id) + _, err = pipe.Exec() + if err != nil { + return nil, err + } + return &User{id}, nil +} + +func (user *User) GetId() (int64, error) { + return user.id, nil +} + +func (user *User) GetUsername() (string, error) { + key := fmt.Sprintf("user:%d", user.id) + return client.HGet(key, "username").Result() +} + +func (user *User) GetHash() ([]byte, error) { + key := fmt.Sprintf("user:%d", user.id) + return client.HGet(key, "hash").Bytes() +} + +func (user *User) Authenticate(password string) error { + hash, err := user.GetHash() + if err != nil { + return err + } + err = bcrypt.CompareHashAndPassword(hash, []byte(password)) + if err == bcrypt.ErrMismatchedHashAndPassword { + return ErrInvalidLogin + } + return err +} + +func GetUserById(id int64) (*User, error) { + return &User{id}, nil +} + +func GetUserByUsername(username string) (*User, error) { + id, err := client.HGet("user:by-username", username).Int64() + if err == redis.Nil { + return nil, ErrUserNotFound + } else if err != nil { + return nil, err + } + return GetUserById(id) +} + +func AuthenticateUser(username, password string) (*User, error) { + user, err := GetUserByUsername(username) + if err != nil { + return nil, err + } + return user, user.Authenticate(password) +} + +func RegisterUser(username, password string) error { + cost := bcrypt.DefaultCost + hash, err := bcrypt.GenerateFromPassword([]byte(password), cost) + if err != nil { + return err + } + _, err = NewUser(username, hash) + return err +} \ No newline at end of file diff --git a/routes/routes.go b/routes/routes.go new file mode 100644 index 0000000..e434842 --- /dev/null +++ b/routes/routes.go @@ -0,0 +1,154 @@ +package routes + +import ( + "net/http" + "github.com/gorilla/mux" + "../middleware" + "../models" + "../sessions" + "../utils" +) + +func NewRouter() *mux.Router { + r := mux.NewRouter() + r.HandleFunc("/", middleware.AuthRequired(indexGetHandler)).Methods("GET") + r.HandleFunc("/", middleware.AuthRequired(indexPostHandler)).Methods("POST") + r.HandleFunc("/login", loginGetHandler).Methods("GET") + r.HandleFunc("/login", loginPostHandler).Methods("POST") + r.HandleFunc("/logout", logoutGetHandler).Methods("GET") + r.HandleFunc("/register", registerGetHandler).Methods("GET") + r.HandleFunc("/register", registerPostHandler).Methods("POST") + fs := http.FileServer(http.Dir("./static/")) + r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", fs)) + r.HandleFunc("/{username}", + middleware.AuthRequired(userGetHandler)).Methods("GET") + return r +} + +func indexGetHandler(w http.ResponseWriter, r *http.Request) { + updates, err := models.GetAllUpdates() + if err != nil { + utils.InternalServerError(w) + return + } + utils.ExecuteTemplate(w, "index.html", struct { + Title string + Updates []*models.Update + DisplayForm bool + } { + Title: "All updates", + Updates: updates, + DisplayForm: true, + }) +} + +func indexPostHandler(w http.ResponseWriter, r *http.Request) { + session, _ := sessions.Store.Get(r, "session") + untypedUserId := session.Values["user_id"] + userId, ok := untypedUserId.(int64) + if !ok { + utils.InternalServerError(w) + return + } + r.ParseForm() + body := r.PostForm.Get("update") + err := models.PostUpdate(userId, body) + if err != nil { + utils.InternalServerError(w) + return + } + http.Redirect(w, r, "/", 302) +} + +func userGetHandler(w http.ResponseWriter, r *http.Request) { + session, _ := sessions.Store.Get(r, "session") + untypedUserId := session.Values["user_id"] + currentUserId, ok := untypedUserId.(int64) + if !ok { + utils.InternalServerError(w) + return + } + vars := mux.Vars(r) + username := vars["username"] + user, err := models.GetUserByUsername(username) + if err != nil { + utils.InternalServerError(w) + return + } + userId, err := user.GetId() + if err != nil { + utils.InternalServerError(w) + return + } + updates, err := models.GetUpdates(userId) + if err != nil { + utils.InternalServerError(w) + return + } + utils.ExecuteTemplate(w, "index.html", struct { + Title string + Updates []*models.Update + DisplayForm bool + } { + Title: username, + Updates: updates, + DisplayForm: currentUserId == userId, + }) +} + +func loginGetHandler(w http.ResponseWriter, r *http.Request) { + utils.ExecuteTemplate(w, "login.html", nil) +} + +func loginPostHandler(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + username := r.PostForm.Get("username") + password := r.PostForm.Get("password") + user, err := models.AuthenticateUser(username, password) + if err != nil { + switch err { + case models.ErrUserNotFound: + utils.ExecuteTemplate(w, "login.html", "unknown user") + case models.ErrInvalidLogin: + utils.ExecuteTemplate(w, "login.html", "invalid login") + default: + utils.InternalServerError(w) + } + return + } + userId, err := user.GetId() + if err != nil { + utils.InternalServerError(w) + return + } + session, _ := sessions.Store.Get(r, "session") + session.Values["user_id"] = userId + session.Save(r, w) + http.Redirect(w, r, "/", 302) +} + +func logoutGetHandler(w http.ResponseWriter, r *http.Request) { + session, _ := sessions.Store.Get(r, "session") + delete(session.Values, "user_id") + session.Save(r, w) + http.Redirect(w, r, "/login", 302) +} + +func registerGetHandler(w http.ResponseWriter, r *http.Request) { + utils.ExecuteTemplate(w, "register.html", nil) +} + +func registerPostHandler(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + username := r.PostForm.Get("username") + password := r.PostForm.Get("password") + err := models.RegisterUser(username, password) + if err == models.ErrUsernameTaken { + utils.ExecuteTemplate(w, "register.html", "username taken") + return + } else if err != nil { + utils.InternalServerError(w) + return + } + http.Redirect(w, r, "/login", 302) +} diff --git a/sessions/sessions.go b/sessions/sessions.go new file mode 100644 index 0000000..55521e4 --- /dev/null +++ b/sessions/sessions.go @@ -0,0 +1,7 @@ +package sessions + +import ( + "github.com/gorilla/sessions" +) + +var Store = sessions.NewCookieStore([]byte("t0p-s3cr3t")) diff --git a/static/index.css b/static/index.css index 7e0429c..bf74c61 100644 --- a/static/index.css +++ b/static/index.css @@ -1,7 +1,38 @@ -body > div { +* { + margin: 0; + padding: 0; +} + +body { + background: #f0f0f0; +} + +nav { padding: 0.5em; - width: 200px; - margin: 1em 0em; - background: #ccc; + background: #fff; + border-bottom: 1px solid #aaa; + text-align: right; +} + +main { + margin: 0 auto; + max-width: 640px; +} + +main > .update { + padding: 0.5em; + background: #fff; border: 1px solid #aaa; -} \ No newline at end of file + margin-bottom: 1em; +} + +#update-form { + text-align: right; + margin-bottom: 1em; +} + +#update-form textarea { + width: 100%; + margin-bottom: 0.5em; + resize: vertical; +} diff --git a/templates/index.html b/templates/index.html index 528e29f..0dcc0f4 100644 --- a/templates/index.html +++ b/templates/index.html @@ -1,18 +1,32 @@
-