Files
old-svijetlastrana/server/serverconfig/cache.go
2018-04-25 13:16:36 +02:00

205 lines
4.8 KiB
Go

package serverconfig
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"net/http"
"regexp"
"strings"
"time"
"bitbucket.org/nemt/nemt-portal-api/domain"
"bitbucket.org/nemt/nemt-portal-api/infra/cache"
"bitbucket.org/nemt/nemt-portal-api/infra/config"
"bitbucket.org/nemt/nemt-portal-api/server/router/routeutils"
"github.com/labstack/echo"
"github.com/labstack/echo/middleware"
)
// bodyResponseWriter implements the http.ResponseWriter interface
type bodyResponseWriter struct {
io.Writer
http.ResponseWriter
}
func (w *bodyResponseWriter) WriteHeader(code int) {
w.ResponseWriter.WriteHeader(code)
}
func (w *bodyResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}
func (w *bodyResponseWriter) Flush() {
w.ResponseWriter.(http.Flusher).Flush()
}
func (w *bodyResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack()
}
func (w *bodyResponseWriter) CloseNotify() <-chan bool {
return w.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
// CacheConfig defines the config for Cache middleware.
type CacheConfig struct {
// Skipper determines if the request should skip this middleware
Skipper middleware.Skipper
// Cache expiration/lifetime
Expiration time.Duration
// VaryByQuery contains a list of query parameters to include in cache key
VaryByQuery []string
}
// DefaultCacheConfig is the default Cache middleware config.
var DefaultCacheConfig = CacheConfig{
Skipper: middleware.DefaultSkipper,
}
// CacheMiddleware returns a middleware that protects requests agains Cache attacks.
func CacheMiddleware(cfg *config.Config) echo.MiddlewareFunc {
config := DefaultCacheConfig
config.Expiration = cfg.Cache.DefaultExpiration
return CacheMiddlewareWithConfig(cfg, config)
}
// CacheMiddlewareWithConfig returns a Cache middleware with config.
// See: `CacheMiddleware()`.
func CacheMiddlewareWithConfig(cfg *config.Config, config CacheConfig) echo.MiddlewareFunc {
if config.Skipper == nil {
config.Skipper = DefaultCacheConfig.Skipper
}
if config.Expiration < 0 {
config.Expiration = cfg.Cache.DefaultExpiration
}
cache := cache.Instance(cfg)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
req := c.Request()
skip := config.Skipper(c) || req.Method != echo.GET
if !skip {
for _, val := range req.Header["Cache-Control"] {
if val == "no-cache" {
skip = true
break
}
}
}
if skip {
return next(c)
}
res := c.Response()
var cacheKey = getCacheKey(c, config)
var contentTypeCacheKey = fmt.Sprintf("%s-content-type", cacheKey)
var contentType = "application/json"
output, err := cache.GetItem(cacheKey)
if err == nil {
var responseStatus = http.StatusOK
if len(output) == 0 {
responseStatus = http.StatusNoContent
}
contentType, err = cache.GetString(contentTypeCacheKey)
if err != nil {
return routeutils.HandleAPIError(c, err)
}
expiration, err := cache.GetExpiration(cacheKey)
if err != nil {
return routeutils.HandleAPIError(c, err)
}
if expiration > 0 {
res.Header()["Cache-Control"] = []string{fmt.Sprintf("max-age=%v", expiration.Seconds())}
}
return c.Blob(responseStatus, contentType, []byte(output))
} else if err == domain.ErrCacheMiss {
resBody := new(bytes.Buffer)
mw := io.MultiWriter(res.Writer, resBody)
writer := &bodyResponseWriter{Writer: mw, ResponseWriter: res.Writer}
res.Writer = writer
err = next(c)
if err != nil {
return err
}
headers := writer.Header()
cache.SetExpiration(cacheKey, config.Expiration)
if config.Expiration != 0 {
res.Header()["Cache-Control"] = []string{fmt.Sprintf("max-age=%v", config.Expiration.Seconds())}
}
contentTypeHeader, ok := headers[echo.HeaderContentType]
if ok {
contentType = contentTypeHeader[0]
}
err = cache.SetItem(cacheKey, resBody.Bytes())
if err != nil {
return routeutils.HandleAPIError(c, err)
}
err = cache.SetString(contentTypeCacheKey, contentType)
if err != nil {
return routeutils.HandleAPIError(c, err)
}
} else if err != nil {
return routeutils.HandleAPIError(c, err)
}
return nil
}
}
}
func getCacheKey(c echo.Context, config CacheConfig) string {
var req = c.Request()
var re = regexp.MustCompile("(?i)[^a-z0-9_]+")
var key = req.URL.Path
if len(config.VaryByQuery) > 0 {
var query = "q"
for _, queryKey := range config.VaryByQuery {
for k, v := range req.URL.Query() {
if k == queryKey {
query = fmt.Sprintf("%s-%s-%s", query, k, v)
break
}
}
}
query = strings.Trim(re.ReplaceAllString(query, "-"), "-")
if strings.TrimSpace(query) != "" {
key = fmt.Sprintf("%s-%v", key, query)
}
}
return strings.Trim(re.ReplaceAllString(key, "-"), "-")
}