diff --git a/models/update.go b/models/update.go index 857dcb6..9f37c8f 100644 --- a/models/update.go +++ b/models/update.go @@ -2,10 +2,11 @@ package models import ( "fmt" + "strconv" ) type Update struct { - key string + id int64 } func NewUpdate(userId int64, body string) (*Update, error) { @@ -24,46 +25,46 @@ func NewUpdate(userId int64, body string) (*Update, error) { if err != nil { return nil, err } - return &Update{key}, nil + return &Update{id}, nil } func (update *Update) GetBody() (string, error) { - return client.HGet(update.key, "body").Result() + key := fmt.Sprintf("update:%d", update.id) + return client.HGet(key, "body").Result() } func (update *Update) GetUser() (*User, error) { - userId, err := client.HGet(update.key, "user_id").Int64() + key := fmt.Sprintf("update:%d", update.id) + userId, err := client.HGet(key, "user_id").Int64() if err != nil { return nil, err } return GetUserById(userId) } -func GetAllUpdates() ([]*Update, error) { - updateIds, err := client.LRange("updates", 0, 10).Result() +func queryUpdates(key string) ([]*Update, error) { + updateIds, err := client.LRange(key, 0, 10).Result() if err != nil { return nil, err } updates := make([]*Update, len(updateIds)) - for i, id := range updateIds { - key := "update:" + id - updates[i] = &Update{key} + for i, strId := range updateIds { + id, err := strconv.Atoi(strId) + if err != nil { + return nil, err + } + updates[i] = &Update{int64(id)} } return updates, nil } +func GetAllUpdates() ([]*Update, error) { + return queryUpdates("updates") +} + func GetUpdates(userId int64) ([]*Update, error) { key := fmt.Sprintf("user:%d:updates", userId) - updateIds, err := client.LRange(key, 0, 10).Result() - if err != nil { - return nil, err - } - updates := make([]*Update, len(updateIds)) - for i, id := range updateIds { - key := "update:" + id - updates[i] = &Update{key} - } - return updates, nil + return queryUpdates(key) } func PostUpdate(userId int64, body string) error { diff --git a/models/user.go b/models/user.go index 5574d88..a5c595a 100644 --- a/models/user.go +++ b/models/user.go @@ -10,13 +10,18 @@ import ( var ( ErrUserNotFound = errors.New("user not found") ErrInvalidLogin = errors.New("invalid login") + ErrUsernameTaken = errors.New("username taken") ) type User struct { - key string + id int64 } func NewUser(username string, hash []byte) (*User, error) { + exists, err := client.HExists("user:by-username", username).Result() + if exists { + return nil, ErrUsernameTaken + } id, err := client.Incr("user:next-id").Result() if err != nil { return nil, err @@ -31,19 +36,21 @@ func NewUser(username string, hash []byte) (*User, error) { if err != nil { return nil, err } - return &User{key}, nil + return &User{id}, nil } func (user *User) GetId() (int64, error) { - return client.HGet(user.key, "id").Int64() + return user.id, nil } func (user *User) GetUsername() (string, error) { - return client.HGet(user.key, "username").Result() + key := fmt.Sprintf("user:%d", user.id) + return client.HGet(key, "username").Result() } func (user *User) GetHash() ([]byte, error) { - return client.HGet(user.key, "hash").Bytes() + key := fmt.Sprintf("user:%d", user.id) + return client.HGet(key, "hash").Bytes() } func (user *User) Authenticate(password string) error { @@ -59,8 +66,7 @@ func (user *User) Authenticate(password string) error { } func GetUserById(id int64) (*User, error) { - key := fmt.Sprintf("user:%d", id) - return &User{key}, nil + return &User{id}, nil } func GetUserByUsername(username string) (*User, error) { diff --git a/routes/routes.go b/routes/routes.go index b743e8f..e434842 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -15,6 +15,7 @@ func NewRouter() *mux.Router { r.HandleFunc("/", middleware.AuthRequired(indexPostHandler)).Methods("POST") r.HandleFunc("/login", loginGetHandler).Methods("GET") r.HandleFunc("/login", loginPostHandler).Methods("POST") + r.HandleFunc("/logout", logoutGetHandler).Methods("GET") r.HandleFunc("/register", registerGetHandler).Methods("GET") r.HandleFunc("/register", registerPostHandler).Methods("POST") fs := http.FileServer(http.Dir("./static/")) @@ -27,16 +28,17 @@ func NewRouter() *mux.Router { func indexGetHandler(w http.ResponseWriter, r *http.Request) { updates, err := models.GetAllUpdates() if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal server error")) + utils.InternalServerError(w) return } utils.ExecuteTemplate(w, "index.html", struct { Title string Updates []*models.Update + DisplayForm bool } { Title: "All updates", Updates: updates, + DisplayForm: true, }) } @@ -45,48 +47,52 @@ func indexPostHandler(w http.ResponseWriter, r *http.Request) { untypedUserId := session.Values["user_id"] userId, ok := untypedUserId.(int64) if !ok { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal server error")) + utils.InternalServerError(w) return } r.ParseForm() body := r.PostForm.Get("update") err := models.PostUpdate(userId, body) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal server error")) + utils.InternalServerError(w) return } http.Redirect(w, r, "/", 302) } func userGetHandler(w http.ResponseWriter, r *http.Request) { + session, _ := sessions.Store.Get(r, "session") + untypedUserId := session.Values["user_id"] + currentUserId, ok := untypedUserId.(int64) + if !ok { + utils.InternalServerError(w) + return + } vars := mux.Vars(r) username := vars["username"] user, err := models.GetUserByUsername(username) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal server error")) + utils.InternalServerError(w) return } userId, err := user.GetId() if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal server error")) + utils.InternalServerError(w) return } updates, err := models.GetUpdates(userId) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal server error")) + utils.InternalServerError(w) return } utils.ExecuteTemplate(w, "index.html", struct { Title string Updates []*models.Update - } { + DisplayForm bool + } { Title: username, Updates: updates, + DisplayForm: currentUserId == userId, }) } @@ -106,15 +112,13 @@ func loginPostHandler(w http.ResponseWriter, r *http.Request) { case models.ErrInvalidLogin: utils.ExecuteTemplate(w, "login.html", "invalid login") default: - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal server error")) + utils.InternalServerError(w) } return } userId, err := user.GetId() if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal server error")) + utils.InternalServerError(w) return } session, _ := sessions.Store.Get(r, "session") @@ -123,6 +127,13 @@ func loginPostHandler(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/", 302) } +func logoutGetHandler(w http.ResponseWriter, r *http.Request) { + session, _ := sessions.Store.Get(r, "session") + delete(session.Values, "user_id") + session.Save(r, w) + http.Redirect(w, r, "/login", 302) +} + func registerGetHandler(w http.ResponseWriter, r *http.Request) { utils.ExecuteTemplate(w, "register.html", nil) } @@ -132,9 +143,11 @@ func registerPostHandler(w http.ResponseWriter, r *http.Request) { username := r.PostForm.Get("username") password := r.PostForm.Get("password") err := models.RegisterUser(username, password) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal server error")) + if err == models.ErrUsernameTaken { + utils.ExecuteTemplate(w, "register.html", "username taken") + return + } else if err != nil { + utils.InternalServerError(w) return } http.Redirect(w, r, "/login", 302) diff --git a/static/index.css b/static/index.css index 7e0429c..bf74c61 100644 --- a/static/index.css +++ b/static/index.css @@ -1,7 +1,38 @@ -body > div { +* { + margin: 0; + padding: 0; +} + +body { + background: #f0f0f0; +} + +nav { padding: 0.5em; - width: 200px; - margin: 1em 0em; - background: #ccc; + background: #fff; + border-bottom: 1px solid #aaa; + text-align: right; +} + +main { + margin: 0 auto; + max-width: 640px; +} + +main > .update { + padding: 0.5em; + background: #fff; border: 1px solid #aaa; -} \ No newline at end of file + margin-bottom: 1em; +} + +#update-form { + text-align: right; + margin-bottom: 1em; +} + +#update-form textarea { + width: 100%; + margin-bottom: 0.5em; + resize: vertical; +} diff --git a/templates/index.html b/templates/index.html index a7f0c0a..0dcc0f4 100644 --- a/templates/index.html +++ b/templates/index.html @@ -4,20 +4,29 @@ -

{{ .Title }}

-
- -
- + +
+

{{ .Title }}

+ {{ if .DisplayForm }} +
+ + +
+ +
+
- - {{ range .Updates }} -
-
- {{ .GetUser.GetUsername }} wrote: + {{ end }} + {{ range .Updates }} +
+ +
{{ .GetBody }}
-
{{ .GetBody }}
-
- {{ end }} + {{ end }} +
\ No newline at end of file diff --git a/templates/register.html b/templates/register.html index e05a413..97403fc 100644 --- a/templates/register.html +++ b/templates/register.html @@ -3,6 +3,9 @@ Register + {{ if . }} +
{{ . }}
+ {{ end }}
Username:
Password:
diff --git a/utils/errors.go b/utils/errors.go new file mode 100644 index 0000000..1b528e6 --- /dev/null +++ b/utils/errors.go @@ -0,0 +1,10 @@ +package utils + +import ( + "net/http" +) + +func InternalServerError(w http.ResponseWriter) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal server error")) +} \ No newline at end of file