package api import ( "context" "fmt" "net/http" "os" "strings" "arimelody-web/controller" "arimelody-web/model" ) func Handler(app *model.AppState) http.Handler { mux := http.NewServeMux() // TODO: generate API keys on the frontend // ARTIST ENDPOINTS mux.Handle("/v1/artist/{id}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var artistID = r.PathValue("id") artist, err := controller.GetArtist(app.DB, artistID) if err != nil { if strings.Contains(err.Error(), "no rows") { http.NotFound(w, r) return } fmt.Printf("WARN: Error while retrieving artist %s: %s\n", artistID, err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } switch r.Method { case http.MethodGet: // GET /api/v1/artist/{id} ServeArtist(app, artist).ServeHTTP(w, r) case http.MethodPut: // PUT /api/v1/artist/{id} (admin) requireAccount(UpdateArtist(app, artist)).ServeHTTP(w, r) case http.MethodDelete: // DELETE /api/v1/artist/{id} (admin) requireAccount(DeleteArtist(app, artist)).ServeHTTP(w, r) default: http.NotFound(w, r) } })) artistIndexHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: // GET /api/v1/artist ServeAllArtists(app).ServeHTTP(w, r) case http.MethodPost: // POST /api/v1/artist (admin) requireAccount(CreateArtist(app)).ServeHTTP(w, r) default: http.NotFound(w, r) } }) mux.Handle("/v1/artist/", artistIndexHandler) mux.Handle("/v1/artist", artistIndexHandler) // RELEASE ENDPOINTS mux.Handle("/v1/music/{id}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var releaseID = r.PathValue("id") release, err := controller.GetRelease(app.DB, releaseID, true) if err != nil { if strings.Contains(err.Error(), "no rows") { http.NotFound(w, r) return } fmt.Printf("WARN: Error while retrieving release %s: %s\n", releaseID, err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } switch r.Method { case http.MethodGet: // GET /api/v1/music/{id} ServeRelease(app, release).ServeHTTP(w, r) case http.MethodPut: // PUT /api/v1/music/{id} (admin) requireAccount(UpdateRelease(app, release)).ServeHTTP(w, r) case http.MethodDelete: // DELETE /api/v1/music/{id} (admin) requireAccount(DeleteRelease(app, release)).ServeHTTP(w, r) default: http.NotFound(w, r) } })) musicIndexHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: // GET /api/v1/music ServeCatalog(app).ServeHTTP(w, r) case http.MethodPost: // POST /api/v1/music (admin) requireAccount(CreateRelease(app)).ServeHTTP(w, r) default: http.NotFound(w, r) } }) mux.Handle("/v1/music/", musicIndexHandler) mux.Handle("/v1/music", musicIndexHandler) // TRACK ENDPOINTS mux.Handle("/v1/track/{id}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var trackID = r.PathValue("id") track, err := controller.GetTrack(app.DB, trackID) if err != nil { if strings.Contains(err.Error(), "no rows") { http.NotFound(w, r) return } fmt.Printf("WARN: Error while retrieving track %s: %s\n", trackID, err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } switch r.Method { case http.MethodGet: // GET /api/v1/track/{id} (admin) requireAccount(ServeTrack(app, track)).ServeHTTP(w, r) case http.MethodPut: // PUT /api/v1/track/{id} (admin) requireAccount(UpdateTrack(app, track)).ServeHTTP(w, r) case http.MethodDelete: // DELETE /api/v1/track/{id} (admin) requireAccount(DeleteTrack(app, track)).ServeHTTP(w, r) default: http.NotFound(w, r) } })) trackIndexHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: // GET /api/v1/track (admin) requireAccount(ServeAllTracks(app)).ServeHTTP(w, r) case http.MethodPost: // POST /api/v1/track (admin) requireAccount(CreateTrack(app)).ServeHTTP(w, r) default: http.NotFound(w, r) } }) mux.Handle("/v1/track/", trackIndexHandler) mux.Handle("/v1/track", trackIndexHandler) // BLOG ENDPOINTS mux.Handle("GET /v1/blog/{id}", ServeBlog(app)) mux.Handle("PUT /v1/blog/{id}", requireAccount(UpdateBlog(app))) mux.Handle("DELETE /v1/blog/{id}", requireAccount(DeleteBlog(app))) mux.Handle("GET /v1/blog", ServeAllBlogs(app)) mux.Handle("POST /v1/blog", requireAccount(CreateBlog(app))) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session, err := getSession(app, r) if err != nil { fmt.Fprintf(os.Stderr, "WARN: Failed to get session: %v\n", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } ctx := context.WithValue(r.Context(), "session", session) mux.ServeHTTP(w, r.WithContext(ctx)) }) } func requireAccount(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session := r.Context().Value("session").(*model.Session) if session == nil || session.Account == nil { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } ctx := context.WithValue(r.Context(), "session", session) next.ServeHTTP(w, r.WithContext(ctx)) }) } func getSession(app *model.AppState, r *http.Request) (*model.Session, error) { var token string // check cookies first sessionCookie, err := r.Cookie(model.COOKIE_TOKEN) if err != nil && err != http.ErrNoCookie { return nil, fmt.Errorf("Failed to retrieve session cookie: %v\n", err) } if sessionCookie != nil { token = sessionCookie.Value } else { // check Authorization header token = strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") } if token == "" { return nil, nil } // fetch existing session session, err := controller.GetSession(app.DB, token) if err != nil && !strings.Contains(err.Error(), "no rows") { return nil, fmt.Errorf("Failed to retrieve session: %v\n", err) } if session != nil { // TODO: consider running security checks here (i.e. user agent mismatches) } return session, nil }