Protected routes

This commit is contained in:
Nedim
2023-11-10 17:32:17 +01:00
parent 367b5d51f2
commit 99b9df5066
14 changed files with 172 additions and 100 deletions

View File

@@ -26,6 +26,8 @@ func Load() error {
Port: getEnv("NOVATECH_SERVICE_PORT", "9000"),
Environment: getEnv("NOVATECH_SERVICE_ENVIRONMENT", "DEV"),
MapboxAccessToken: getEnv("NOVATECH_SERVICE_MAPBOX_ACCESS_TOKEN", ""),
JwtSecretKey: getEnv("JWT_SECRET_KEY", "MDQsCiJwYWNrZXRWZXJzaW9uIjogMSwKImhhcm"),
JwtSecretKeyExpiryHours: getEnv("JWT_SECRET_KEY_EXPIRY_HOURS", "24"),
},
AdminService: Service{
// 8080 DEFAULT FOR DEV ENVIRONMENT
@@ -45,7 +47,7 @@ func Load() error {
WalletAddress: getEnv("NOVATECH_BLOCKCHAIN_WALLET_ADDRESS", ""),
WalletPrivateKey: getEnv("NOVATECH_BLOCKCHAIN_WALLET_PRIVATE_KEY", ""),
},
AWS: AWS {
AWS: AWS{
AccessKey: getEnv("AWS_ACCESS_KEY_ID", ""),
SecretKey: getEnv("AWS_SECRET_ACCESS_KEY", ""),
},

View File

@@ -15,6 +15,8 @@ type Service struct {
Environment string
WebPageURL string
MapboxAccessToken string
JwtSecretKey string
JwtSecretKeyExpiryHours string
}
// Blockchain contains configuration for blockchain

View File

@@ -35,6 +35,8 @@ func GetLatestContracts(c *gin.Context) {
deviceIDsStr := c.QueryArray("deviceIDs[]")
iDsStr := c.QueryArray("ids[]")
company := c.GetInt("companyID")
// Convert limit and offset to int
limit, err := strconv.Atoi(limitStr)
if err != nil {
@@ -93,7 +95,7 @@ func GetLatestContracts(c *gin.Context) {
}
// Fetch contracts
contracts, total, st, err := contract.GetContracts(status, companyName, companyAddress, companyEmail, companyPhone, &startTime, &endTime, contractName, deviceIDs, contractIDs, nil, limit, offset)
contracts, total, st, err := contract.GetContracts(status, companyName, companyAddress, companyEmail, companyPhone, &startTime, &endTime, contractName, deviceIDs, contractIDs, nil, company, limit, offset)
if err != nil {
c.JSON(st, gin.H{"error": err.Error()})
@@ -114,6 +116,7 @@ func GetBuyerContracts(c *gin.Context) {
dateCreatedStr := c.Query("date_created")
startTimeStr := c.Query("start_time")
endTimeStr := c.Query("end_time")
company := c.GetInt("companyID")
// Convert limit and offset to int
limit, err := strconv.Atoi(limitStr)
@@ -175,7 +178,7 @@ func GetBuyerContracts(c *gin.Context) {
}
// Fetch contracts
contracts, total, st, err := contract.GetContracts(status, "", "", "", "", startTime, endTime, qStr, nil, contractIDs, &dateCreated, limit, offset)
contracts, total, st, err := contract.GetContracts(status, "", "", "", "", startTime, endTime, qStr, nil, contractIDs, &dateCreated, company, limit, offset)
if err != nil {
c.JSON(st, gin.H{"error": err.Error()})
return
@@ -354,7 +357,7 @@ func GetContractCountByStatus(c *gin.Context) {
return
}
c.JSON(http.StatusOK, gin.H{"data": models.ActiveContractsResponse{ActiveCount: activeCount, ExecutedCount : executedCount, MonthlyContracts: monthly}})
c.JSON(http.StatusOK, gin.H{"data": models.ActiveContractsResponse{ActiveCount: activeCount, ExecutedCount: executedCount, MonthlyContracts: monthly}})
}
@@ -388,5 +391,5 @@ func GetTotalContractCount(c *gin.Context) {
return
}
c.JSON(http.StatusOK, gin.H{"data" : totalCount})
c.JSON(http.StatusOK, gin.H{"data": totalCount})
}

View File

@@ -85,7 +85,6 @@ func SaveDeviceInfo(c *gin.Context) {
}
}
log.Printf("Successfully received and saved device info: %v", deviceInfo)
c.JSON(http.StatusOK, gin.H{"message": "Successfully received and saved device info", "data": deviceInfo})
}
@@ -153,6 +152,7 @@ func GetDevicesByContract(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Contract ID is required"})
return
}
companyID := c.GetInt("companyID")
// Convert string to uint
contractID, err := strconv.ParseUint(contractIDStr, 10, 32)
@@ -163,7 +163,7 @@ func GetDevicesByContract(c *gin.Context) {
}
log.Printf("This is the ID: %v", contractID)
devices, st, err := device.GetDevicesForContract(contractID)
devices, st, err := device.GetDevicesForContract(contractID, companyID)
if err != nil {
c.JSON(st, gin.H{"error": err.Error()})
@@ -239,7 +239,6 @@ func GetCompanyRelatedDeviceInfoCountWithTempRange(c *gin.Context) {
return
}
// Get the counts
inRangeCount, outOfRangeCount, monthlyCount, err := device.CountDeviceBreachedAndNormalDevicesByCompany(uint(companyID), startTime, endTime)
if err != nil {
@@ -253,7 +252,6 @@ func GetCompanyRelatedDeviceInfoCountWithTempRange(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"data": data})
}
func GetContractsMatchingDeviceLocation(c *gin.Context) {
startTimeStr := c.DefaultQuery("start_time", "")
endTimeStr := c.DefaultQuery("end_time", "")

View File

@@ -17,6 +17,7 @@ func GetInvoices(c *gin.Context) {
sortBy := c.Query("sort_by")
iDsStr := c.QueryArray("ids[]")
status := c.QueryArray("status")
companyID := c.GetInt("companyID")
limit, err := strconv.Atoi(limitStr)
if err != nil {
@@ -41,7 +42,7 @@ func GetInvoices(c *gin.Context) {
invoiceIDs = append(invoiceIDs, id)
}
invoices, total, err := invoice.GetInvoices(buyerName, sortBy, limit, offset, invoiceIDs, status)
invoices, total, err := invoice.GetInvoices(buyerName, sortBy, limit, offset, invoiceIDs, status, companyID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -61,7 +62,9 @@ func GetInvoiceByID(c *gin.Context) {
return
}
invoices, _, err := invoice.GetInvoices("", "", 1, 0, []int64{int64(id)}, nil)
companyID := c.GetInt("companyID")
invoices, _, err := invoice.GetInvoices("", "", 1, 0, []int64{int64(id)}, nil, companyID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -106,7 +109,6 @@ func convertToResponseModel(invoices []models.Invoice) []models.ListInvoiceRespo
return listInvoiceResponses
}
func GetInvoiceCountByStatus(c *gin.Context) {
companyID := c.DefaultQuery("company_id", "0")
startTimeStr := c.DefaultQuery("start_time", "")
@@ -137,6 +139,6 @@ func GetInvoiceCountByStatus(c *gin.Context) {
return
}
c.JSON(http.StatusOK, gin.H{"data": models.ActiveInvoiceResponse{Claimed: activeCount, Issued : executedCount, MonthlyInvoices: monthly}})
c.JSON(http.StatusOK, gin.H{"data": models.ActiveInvoiceResponse{Claimed: activeCount, Issued: executedCount, MonthlyInvoices: monthly}})
}

View File

@@ -130,7 +130,7 @@ func Login(c *gin.Context) {
if usr.CheckPassword(user.Password, req.Password) {
if user.IsActive && user.LoginAttempts < 10 {
// Proceed with creating JWT token and resetting login attempts
token, err := usr.CreateSessionToken(user.ID)
token, err := usr.CreateSessionToken(user.ID, user.CompanyID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Could not create JWT token"})
return

View File

@@ -21,7 +21,7 @@ import (
func GetContracts(status []string, companyName string, companyAddress string,
companyEmail string, companyPhone string, startTime *time.Time, endTime *time.Time,
contractName string, deviceIDs []int64, contractIDs []int64, dateCreated *time.Time, limit, offset int) ([]models.Contract, int64, int, error) {
contractName string, deviceIDs []int64, contractIDs []int64, dateCreated *time.Time, company, limit, offset int) ([]models.Contract, int64, int, error) {
var contracts []models.Contract
db := shared.GetDb()

View File

@@ -15,7 +15,7 @@ import (
"gitlab.com/pactual1/backend/shared"
)
func GetDevicesForContract(contractID uint64) ([]models.Device, int, error) {
func GetDevicesForContract(contractID uint64, companyID int) ([]models.Device, int, error) {
// Fetch the contract from the database
var contract models.Contract
if err := shared.GetDb().Where("id = ?", contractID).First(&contract).Error; err != nil {

View File

@@ -9,7 +9,7 @@ import (
"gitlab.com/pactual1/backend/shared"
)
func GetInvoices(buyerName string, sortBy string, limit int, offset int, ids []int64, status []string) ([]models.Invoice, int64, error) {
func GetInvoices(buyerName string, sortBy string, limit int, offset int, ids []int64, status []string, company int) ([]models.Invoice, int64, error) {
var invoices []models.Invoice
// Default sort by InvoiceDate DESC

View File

@@ -3,10 +3,12 @@ package user
import (
"errors"
"fmt"
"strconv"
"time"
"github.com/golang-jwt/jwt"
"github.com/jinzhu/gorm"
"gitlab.com/pactual1/backend/config"
"gitlab.com/pactual1/backend/models"
"gitlab.com/pactual1/backend/shared"
"golang.org/x/crypto/bcrypt"
@@ -50,7 +52,7 @@ func CheckPassword(hashedPassword, password string) bool {
return err == nil
}
func CreateSessionToken(userID uint) (string, error) {
func CreateSessionToken(userID, companyID uint) (string, error) {
// Generate JWT token
tokenString, err := CreateJWTToken(userID)
if err != nil {
@@ -61,6 +63,7 @@ func CreateSessionToken(userID uint) (string, error) {
sessionToken := models.SessionToken{
UserID: userID,
Token: tokenString,
CompanyID: companyID,
IsActive: true,
}
if result := shared.GetDb().Create(&sessionToken); result.Error != nil {
@@ -84,10 +87,15 @@ func IncrementLoginAttempts(user models.User) {
shared.GetDb().Save(&user)
}
var jwtKey = []byte("MDQsCiJwYWNrZXRWZXJzaW9uIjogMSwKImhhcm")
func CreateJWTToken(userID uint) (string, error) {
expirationTime := time.Now().Add(24 * time.Hour)
var jwtKey = []byte(config.AppConfig.Service.JwtSecretKey)
expiryHours, err := strconv.Atoi(config.AppConfig.Service.JwtSecretKeyExpiryHours)
if err != nil {
return "", err
}
expirationTime := time.Now().Add(time.Duration(expiryHours) * time.Hour)
claims := &jwt.StandardClaims{
Subject: fmt.Sprint(userID),
ExpiresAt: expirationTime.Unix(),

View File

@@ -15,3 +15,4 @@ NOVATECH_BLOCKCHAIN_WALLET_PRIVATE_KEY=PRIVATE_KEY
NOVATECH_SERVICE_MAPBOX_ACCESS_TOKEN=pk.ey
AWS_ACCESS_KEY_ID:access
AWS_SECRET_ACCESS_KEY:secret
JWT_SECRET_KEY:key

View File

@@ -5,10 +5,17 @@
package middlewares
import (
"github.com/dgrijalva/jwt-go"
"github.com/gin-gonic/gin"
"errors"
"net/http"
"strings"
"time"
"github.com/dgrijalva/jwt-go"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"gitlab.com/pactual1/backend/config"
"gitlab.com/pactual1/backend/models"
"gitlab.com/pactual1/backend/shared"
)
var (
@@ -102,3 +109,50 @@ func ValidateToken(tokenString string, key string) (*jwt.Token, error) {
return token, err
}
// AuthMiddleware checks the session token and validates it
func AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var jwtKey = []byte(config.AppConfig.Service.JwtSecretKey)
tokenString := c.GetHeader("Authorization")
// Check if token is in the correct format (Bearer token)
if len(tokenString) > 7 && strings.ToUpper(tokenString[0:7]) == "BEARER " {
tokenString = tokenString[7:]
} else {
c.JSON(http.StatusForbidden, gin.H{"message": "Your request is not authorized"})
c.Abort()
return
}
// Parse and validate the token
claims := &jwt.StandardClaims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
return jwtKey, nil
})
if err != nil || !token.Valid {
c.JSON(http.StatusForbidden, gin.H{"message": "Invalid authorization token"})
c.Abort()
return
}
// Check if the token is present and active in the SessionToken table
var sessionToken models.SessionToken
result := shared.GetDb().Where("token = ? AND is_active = ?", tokenString, true).First(&sessionToken)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
c.JSON(http.StatusForbidden, gin.H{"message": "Invalid session token"})
c.Abort()
return
} else if result.Error != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Internal server error"})
c.Abort()
return
}
// Set user ID in the Gin context
c.Set("userID", sessionToken.UserID)
c.Set("companyID", sessionToken.CompanyID)
c.Next()
}
}

View File

@@ -5,6 +5,7 @@ type SessionToken struct {
UserID uint `json:"userId"`
Token string `json:"token"`
IsActive bool `json:"isActive"`
CompanyID uint `json:"userId"`
}
func (SessionToken) Update() (bool, error) {

View File

@@ -2,6 +2,7 @@ package routes
import (
"gitlab.com/pactual1/backend/controllers"
"gitlab.com/pactual1/backend/middlewares"
"github.com/gin-gonic/gin"
)
@@ -12,42 +13,42 @@ func RegisterPublicRoutes(r *gin.Engine) {
r.GET("/health", controllers.HealthCheck)
// Map dashboard
r.GET("/dashboard/map/contract/devices", controllers.GetDevicesByContract)
r.GET("/dashboard/map/contracts", controllers.GetLatestContracts)
r.GET("/dashboard/map/device_data", controllers.GetDeviceData)
r.GET("/dashboard/map/contract/devices", middlewares.AuthMiddleware(), controllers.GetDevicesByContract)
r.GET("/dashboard/map/contracts", middlewares.AuthMiddleware(), controllers.GetLatestContracts)
r.GET("/dashboard/map/device_data", middlewares.AuthMiddleware(), controllers.GetDeviceData)
// Invoices
r.GET("/invoices", controllers.GetInvoices)
r.GET("/invoices/:id", controllers.GetInvoiceByID)
r.GET("/invoices", middlewares.AuthMiddleware(), controllers.GetInvoices)
r.GET("/invoices/:id", middlewares.AuthMiddleware(), controllers.GetInvoiceByID)
r.POST("/device_data/save", controllers.SaveDeviceInfo)
r.GET("/buyers", controllers.ListCompanies)
r.GET("/products", controllers.ListProductTemplates)
r.GET("/templates", controllers.ListTextTemplates)
r.POST("/templates/save", controllers.CreateTextTemplate)
r.GET("/products/:template_id", controllers.GetProductTemplate)
r.GET("/buyers", middlewares.AuthMiddleware(), controllers.ListCompanies)
r.GET("/products", middlewares.AuthMiddleware(), middlewares.AuthMiddleware(), controllers.ListProductTemplates)
r.GET("/templates", middlewares.AuthMiddleware(), controllers.ListTextTemplates)
r.POST("/templates/save", middlewares.AuthMiddleware(), controllers.CreateTextTemplate)
r.GET("/products/:template_id", middlewares.AuthMiddleware(), controllers.GetProductTemplate)
// Contracts
r.GET("/contracts/statuses", controllers.GetContractStatuses)
r.GET("/contracts", controllers.GetBuyerContracts)
r.POST("/contracts/create", controllers.CreateContract)
r.GET("/contracts/statuses", middlewares.AuthMiddleware(), controllers.GetContractStatuses)
r.GET("/contracts", middlewares.AuthMiddleware(), controllers.GetBuyerContracts)
r.POST("/contracts/create", middlewares.AuthMiddleware(), controllers.CreateContract)
r.GET("/contracts/:contract_id", controllers.GetContractByID)
r.PATCH("/contracts/:contract_id", controllers.UpdateContract)
r.GET("/contracts/:contract_id", middlewares.AuthMiddleware(), controllers.GetContractByID)
r.PATCH("/contracts/:contract_id", middlewares.AuthMiddleware(), controllers.UpdateContract)
// Locations
r.GET("/locations", controllers.SearchPlace)
r.GET("/locations", middlewares.AuthMiddleware(), controllers.SearchPlace)
// Notifications
r.GET("/notifications", controllers.GetNotifications)
r.GET("/notifications", middlewares.AuthMiddleware(), controllers.GetNotifications)
// Stats
r.GET("/stats/measurements", controllers.GetCompanyRelatedDeviceInfoCount)
r.GET("/stats/devices", controllers.GetCompanyRelatedDeviceInfoCountWithTempRange)
r.GET("/stats/contracts", controllers.GetContractCountByStatus)
r.GET("/stats/contracts/total", controllers.GetTotalContractCount)
r.GET("/stats/invoices", controllers.GetInvoiceCountByStatus)
r.GET("/stats/milestones", controllers.GetContractsMatchingDeviceLocation)
r.GET("/stats/measurements", middlewares.AuthMiddleware(), controllers.GetCompanyRelatedDeviceInfoCount)
r.GET("/stats/devices", middlewares.AuthMiddleware(), controllers.GetCompanyRelatedDeviceInfoCountWithTempRange)
r.GET("/stats/contracts", middlewares.AuthMiddleware(), controllers.GetContractCountByStatus)
r.GET("/stats/contracts/total", middlewares.AuthMiddleware(), controllers.GetTotalContractCount)
r.GET("/stats/invoices", middlewares.AuthMiddleware(), controllers.GetInvoiceCountByStatus)
r.GET("/stats/milestones", middlewares.AuthMiddleware(), controllers.GetContractsMatchingDeviceLocation)
//Users
r.POST("/user/reset/password", controllers.ResetPassword)