package serverconfig import ( "bufio" "bytes" "fmt" "io" "net" "net/http" "regexp" "strings" "time" "bitbucket.org/nemt/nemt-portal-api/domain" "bitbucket.org/nemt/nemt-portal-api/infra/cache" "bitbucket.org/nemt/nemt-portal-api/infra/config" "bitbucket.org/nemt/nemt-portal-api/server/router/routeutils" "github.com/labstack/echo" "github.com/labstack/echo/middleware" ) // bodyResponseWriter implements the http.ResponseWriter interface type bodyResponseWriter struct { io.Writer http.ResponseWriter } func (w *bodyResponseWriter) WriteHeader(code int) { w.ResponseWriter.WriteHeader(code) } func (w *bodyResponseWriter) Write(b []byte) (int, error) { return w.Writer.Write(b) } func (w *bodyResponseWriter) Flush() { w.ResponseWriter.(http.Flusher).Flush() } func (w *bodyResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return w.ResponseWriter.(http.Hijacker).Hijack() } func (w *bodyResponseWriter) CloseNotify() <-chan bool { return w.ResponseWriter.(http.CloseNotifier).CloseNotify() } // CacheConfig defines the config for Cache middleware. type CacheConfig struct { // Skipper determines if the request should skip this middleware Skipper middleware.Skipper // Cache expiration/lifetime Expiration time.Duration // VaryByQuery contains a list of query parameters to include in cache key VaryByQuery []string } // DefaultCacheConfig is the default Cache middleware config. var DefaultCacheConfig = CacheConfig{ Skipper: middleware.DefaultSkipper, } // CacheMiddleware returns a middleware that protects requests agains Cache attacks. func CacheMiddleware(cfg *config.Config) echo.MiddlewareFunc { config := DefaultCacheConfig config.Expiration = cfg.Cache.DefaultExpiration return CacheMiddlewareWithConfig(cfg, config) } // CacheMiddlewareWithConfig returns a Cache middleware with config. // See: `CacheMiddleware()`. func CacheMiddlewareWithConfig(cfg *config.Config, config CacheConfig) echo.MiddlewareFunc { if config.Skipper == nil { config.Skipper = DefaultCacheConfig.Skipper } if config.Expiration < 0 { config.Expiration = cfg.Cache.DefaultExpiration } cache := cache.Instance(cfg) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { req := c.Request() skip := config.Skipper(c) || req.Method != echo.GET if !skip { for _, val := range req.Header["Cache-Control"] { if val == "no-cache" { skip = true break } } } if skip { return next(c) } res := c.Response() var cacheKey = getCacheKey(c, config) var contentTypeCacheKey = fmt.Sprintf("%s-content-type", cacheKey) var contentType = "application/json" output, err := cache.GetItem(cacheKey) if err == nil { var responseStatus = http.StatusOK if len(output) == 0 { responseStatus = http.StatusNoContent } contentType, err = cache.GetString(contentTypeCacheKey) if err != nil { return routeutils.HandleAPIError(c, err) } expiration, err := cache.GetExpiration(cacheKey) if err != nil { return routeutils.HandleAPIError(c, err) } if expiration > 0 { res.Header()["Cache-Control"] = []string{fmt.Sprintf("max-age=%v", expiration.Seconds())} } return c.Blob(responseStatus, contentType, []byte(output)) } else if err == domain.ErrCacheMiss { resBody := new(bytes.Buffer) mw := io.MultiWriter(res.Writer, resBody) writer := &bodyResponseWriter{Writer: mw, ResponseWriter: res.Writer} res.Writer = writer err = next(c) if err != nil { return err } headers := writer.Header() cache.SetExpiration(cacheKey, config.Expiration) if config.Expiration != 0 { res.Header()["Cache-Control"] = []string{fmt.Sprintf("max-age=%v", config.Expiration.Seconds())} } contentTypeHeader, ok := headers[echo.HeaderContentType] if ok { contentType = contentTypeHeader[0] } err = cache.SetItem(cacheKey, resBody.Bytes()) if err != nil { return routeutils.HandleAPIError(c, err) } err = cache.SetString(contentTypeCacheKey, contentType) if err != nil { return routeutils.HandleAPIError(c, err) } } else if err != nil { return routeutils.HandleAPIError(c, err) } return nil } } } func getCacheKey(c echo.Context, config CacheConfig) string { var req = c.Request() var re = regexp.MustCompile("(?i)[^a-z0-9_]+") var key = req.URL.Path if len(config.VaryByQuery) > 0 { var query = "q" for _, queryKey := range config.VaryByQuery { for k, v := range req.URL.Query() { if k == queryKey { query = fmt.Sprintf("%s-%s-%s", query, k, v) break } } } query = strings.Trim(re.ReplaceAllString(query, "-"), "-") if strings.TrimSpace(query) != "" { key = fmt.Sprintf("%s-%v", key, query) } } return strings.Trim(re.ReplaceAllString(key, "-"), "-") }