205 lines
4.8 KiB
Go
205 lines
4.8 KiB
Go
package serverconfig
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"time"
|
|
|
|
"bitbucket.org/nemt/nemt-portal-api/domain"
|
|
"bitbucket.org/nemt/nemt-portal-api/infra/cache"
|
|
"bitbucket.org/nemt/nemt-portal-api/infra/config"
|
|
"bitbucket.org/nemt/nemt-portal-api/server/router/routeutils"
|
|
"github.com/labstack/echo"
|
|
"github.com/labstack/echo/middleware"
|
|
)
|
|
|
|
// bodyResponseWriter implements the http.ResponseWriter interface
|
|
type bodyResponseWriter struct {
|
|
io.Writer
|
|
http.ResponseWriter
|
|
}
|
|
|
|
func (w *bodyResponseWriter) WriteHeader(code int) {
|
|
w.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func (w *bodyResponseWriter) Write(b []byte) (int, error) {
|
|
return w.Writer.Write(b)
|
|
}
|
|
|
|
func (w *bodyResponseWriter) Flush() {
|
|
w.ResponseWriter.(http.Flusher).Flush()
|
|
}
|
|
|
|
func (w *bodyResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
return w.ResponseWriter.(http.Hijacker).Hijack()
|
|
}
|
|
|
|
func (w *bodyResponseWriter) CloseNotify() <-chan bool {
|
|
return w.ResponseWriter.(http.CloseNotifier).CloseNotify()
|
|
}
|
|
|
|
// CacheConfig defines the config for Cache middleware.
|
|
type CacheConfig struct {
|
|
// Skipper determines if the request should skip this middleware
|
|
Skipper middleware.Skipper
|
|
// Cache expiration/lifetime
|
|
Expiration time.Duration
|
|
// VaryByQuery contains a list of query parameters to include in cache key
|
|
VaryByQuery []string
|
|
}
|
|
|
|
// DefaultCacheConfig is the default Cache middleware config.
|
|
var DefaultCacheConfig = CacheConfig{
|
|
Skipper: middleware.DefaultSkipper,
|
|
}
|
|
|
|
// CacheMiddleware returns a middleware that protects requests agains Cache attacks.
|
|
func CacheMiddleware(cfg *config.Config) echo.MiddlewareFunc {
|
|
config := DefaultCacheConfig
|
|
config.Expiration = cfg.Cache.DefaultExpiration
|
|
|
|
return CacheMiddlewareWithConfig(cfg, config)
|
|
}
|
|
|
|
// CacheMiddlewareWithConfig returns a Cache middleware with config.
|
|
// See: `CacheMiddleware()`.
|
|
func CacheMiddlewareWithConfig(cfg *config.Config, config CacheConfig) echo.MiddlewareFunc {
|
|
if config.Skipper == nil {
|
|
config.Skipper = DefaultCacheConfig.Skipper
|
|
}
|
|
|
|
if config.Expiration < 0 {
|
|
config.Expiration = cfg.Cache.DefaultExpiration
|
|
}
|
|
|
|
cache := cache.Instance(cfg)
|
|
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
req := c.Request()
|
|
|
|
skip := config.Skipper(c) || req.Method != echo.GET
|
|
|
|
if !skip {
|
|
for _, val := range req.Header["Cache-Control"] {
|
|
if val == "no-cache" {
|
|
skip = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if skip {
|
|
return next(c)
|
|
}
|
|
|
|
res := c.Response()
|
|
|
|
var cacheKey = getCacheKey(c, config)
|
|
var contentTypeCacheKey = fmt.Sprintf("%s-content-type", cacheKey)
|
|
var contentType = "application/json"
|
|
|
|
output, err := cache.GetItem(cacheKey)
|
|
if err == nil {
|
|
var responseStatus = http.StatusOK
|
|
|
|
if len(output) == 0 {
|
|
responseStatus = http.StatusNoContent
|
|
}
|
|
|
|
contentType, err = cache.GetString(contentTypeCacheKey)
|
|
if err != nil {
|
|
return routeutils.HandleAPIError(c, err)
|
|
}
|
|
|
|
expiration, err := cache.GetExpiration(cacheKey)
|
|
if err != nil {
|
|
return routeutils.HandleAPIError(c, err)
|
|
}
|
|
|
|
if expiration > 0 {
|
|
res.Header()["Cache-Control"] = []string{fmt.Sprintf("max-age=%v", expiration.Seconds())}
|
|
}
|
|
|
|
return c.Blob(responseStatus, contentType, []byte(output))
|
|
} else if err == domain.ErrCacheMiss {
|
|
resBody := new(bytes.Buffer)
|
|
|
|
mw := io.MultiWriter(res.Writer, resBody)
|
|
|
|
writer := &bodyResponseWriter{Writer: mw, ResponseWriter: res.Writer}
|
|
|
|
res.Writer = writer
|
|
|
|
err = next(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
headers := writer.Header()
|
|
|
|
cache.SetExpiration(cacheKey, config.Expiration)
|
|
|
|
if config.Expiration != 0 {
|
|
res.Header()["Cache-Control"] = []string{fmt.Sprintf("max-age=%v", config.Expiration.Seconds())}
|
|
}
|
|
|
|
contentTypeHeader, ok := headers[echo.HeaderContentType]
|
|
if ok {
|
|
contentType = contentTypeHeader[0]
|
|
}
|
|
|
|
err = cache.SetItem(cacheKey, resBody.Bytes())
|
|
if err != nil {
|
|
return routeutils.HandleAPIError(c, err)
|
|
}
|
|
err = cache.SetString(contentTypeCacheKey, contentType)
|
|
if err != nil {
|
|
return routeutils.HandleAPIError(c, err)
|
|
}
|
|
} else if err != nil {
|
|
return routeutils.HandleAPIError(c, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func getCacheKey(c echo.Context, config CacheConfig) string {
|
|
var req = c.Request()
|
|
|
|
var re = regexp.MustCompile("(?i)[^a-z0-9_]+")
|
|
|
|
var key = req.URL.Path
|
|
|
|
if len(config.VaryByQuery) > 0 {
|
|
var query = "q"
|
|
|
|
for _, queryKey := range config.VaryByQuery {
|
|
for k, v := range req.URL.Query() {
|
|
if k == queryKey {
|
|
query = fmt.Sprintf("%s-%s-%s", query, k, v)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
query = strings.Trim(re.ReplaceAllString(query, "-"), "-")
|
|
|
|
if strings.TrimSpace(query) != "" {
|
|
key = fmt.Sprintf("%s-%v", key, query)
|
|
}
|
|
}
|
|
|
|
return strings.Trim(re.ReplaceAllString(key, "-"), "-")
|
|
}
|