From d54b6433782f90dfec5adb7cccb18a1613817328 Mon Sep 17 00:00:00 2001 From: Nedim Date: Tue, 21 Nov 2023 18:28:30 +0100 Subject: [PATCH] Added uuid middleware --- controllers/contracts_controller.go | 9 +++-- controllers/devices_controller.go | 16 +++++--- database/contract/contract.go | 30 ++++++++++----- database/device/device.go | 20 +++++++--- go.mod | 2 +- go.sum | 4 ++ main.go | 14 +++++++ middlewares/jwt.go | 57 +++++++++++++++++++++++++++++ models/contract.go | 1 + routes/public_routes.go | 6 +-- shared/database.go | 1 + 11 files changed, 133 insertions(+), 27 deletions(-) diff --git a/controllers/contracts_controller.go b/controllers/contracts_controller.go index b6fbb7e..a043bbe 100644 --- a/controllers/contracts_controller.go +++ b/controllers/contracts_controller.go @@ -9,6 +9,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/google/uuid" "gitlab.com/pactual1/backend/database/contract" "gitlab.com/pactual1/backend/models" "gitlab.com/pactual1/backend/shared" @@ -36,6 +37,7 @@ func GetLatestContracts(c *gin.Context) { iDsStr := c.QueryArray("ids[]") company := c.GetInt("companyID") + uuid := c.Query("uuid") // Convert limit and offset to int limit, err := strconv.Atoi(limitStr) @@ -95,7 +97,7 @@ func GetLatestContracts(c *gin.Context) { } // Fetch contracts - contracts, total, st, err := contract.GetContracts(status, companyName, companyAddress, companyEmail, companyPhone, &startTime, &endTime, contractName, deviceIDs, contractIDs, nil, company, limit, offset) + contracts, total, st, err := contract.GetContracts(status, companyName, companyAddress, companyEmail, companyPhone, &startTime, &endTime, contractName, deviceIDs, contractIDs, nil, uuid, company, limit, offset) if err != nil { c.JSON(st, gin.H{"error": err.Error()}) @@ -178,7 +180,7 @@ func GetBuyerContracts(c *gin.Context) { } // Fetch contracts - contracts, total, st, err := contract.GetContracts(status, "", "", "", "", startTime, endTime, qStr, nil, contractIDs, &dateCreated, company, limit, offset) + contracts, total, st, err := contract.GetContracts(status, "", "", "", "", startTime, endTime, qStr, nil, contractIDs, &dateCreated, "", company, limit, offset) if err != nil { c.JSON(st, gin.H{"error": err.Error()}) return @@ -207,6 +209,7 @@ func CreateContract(c *gin.Context) { db := shared.GetDb() newContract := models.Contract{ + UUID: uuid.New().String(), BuyerID: payload.BuyerID, SellerID: payload.SellerID, Description: payload.Description, @@ -256,7 +259,7 @@ func GetContractByID(c *gin.Context) { } // Fetch contract - contract, st, err := contract.GetContractByID(uint(contractID)) + contract, st, err := contract.GetContractByID(uint(contractID), "") if err != nil { c.JSON(st, gin.H{"error": err.Error()}) return diff --git a/controllers/devices_controller.go b/controllers/devices_controller.go index 8750fb2..f8376f8 100644 --- a/controllers/devices_controller.go +++ b/controllers/devices_controller.go @@ -41,7 +41,7 @@ func SaveDeviceInfo(c *gin.Context) { } if currentDevice.CurrentContractID != nil { - deviceContract, _, err := contract.GetContractByID(*currentDevice.CurrentContractID) + deviceContract, _, err := contract.GetContractByID(*currentDevice.CurrentContractID, "") if err != nil { log.Printf("SaveDeviceInfo - GetContractByID error : %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Could not fetch device contract"}) @@ -94,6 +94,8 @@ func GetDeviceData(c *gin.Context) { deviceIDStr := c.DefaultQuery("device_id", "") contractIDStr := c.DefaultQuery("contract_id", "") + uuid := c.DefaultQuery("uuid", "") + if deviceIDStr == "" { log.Printf("GetDeviceData Error: Device ID is required") c.JSON(http.StatusBadRequest, gin.H{"error": "Device ID is required"}) @@ -114,7 +116,7 @@ func GetDeviceData(c *gin.Context) { return } - contract, st, err := contract.GetContractByID(uint(contractID)) + contract, st, err := contract.GetContractByID(uint(contractID), uuid) // Update this line to pass UUID if err != nil { c.JSON(st, gin.H{"error": err.Error()}) return @@ -145,13 +147,16 @@ func GetDeviceData(c *gin.Context) { } func GetDevicesByContract(c *gin.Context) { - // Get the contract ID from query parameter + // Get the contract ID and UUID from query parameters contractIDStr := c.DefaultQuery("contract_id", "") + uuid := c.DefaultQuery("uuid", "") // Add this line to get the UUID + if contractIDStr == "" { log.Printf("GetDevicesByContract Error: Contract ID is required") c.JSON(http.StatusBadRequest, gin.H{"error": "Contract ID is required"}) return } + companyID := c.GetInt("companyID") // Convert string to uint @@ -162,13 +167,14 @@ func GetDevicesByContract(c *gin.Context) { return } - log.Printf("This is the ID: %v", contractID) - devices, st, err := device.GetDevicesForContract(contractID, companyID) + log.Printf("This is the Contract ID: %v, UUID: %s", contractID, uuid) + devices, st, err := device.GetDevicesForContract(contractID, uuid, companyID) // Pass UUID here if err != nil { c.JSON(st, gin.H{"error": err.Error()}) return } + // Respond with the devices c.JSON(http.StatusOK, gin.H{"data": models.ConvertDeviceToResponse(devices)}) } diff --git a/database/contract/contract.go b/database/contract/contract.go index d176fcb..42a7278 100644 --- a/database/contract/contract.go +++ b/database/contract/contract.go @@ -19,9 +19,9 @@ import ( "gitlab.com/pactual1/backend/shared" ) -func GetContracts(status []string, companyName string, companyAddress string, - companyEmail string, companyPhone string, startTime *time.Time, endTime *time.Time, - contractName string, deviceIDs []int64, contractIDs []int64, dateCreated *time.Time, company, limit, offset int) ([]models.Contract, int64, int, error) { +func GetContracts(status []string, companyName string, companyAddress string, companyEmail string, companyPhone string, + startTime *time.Time, endTime *time.Time, contractName string, deviceIDs []int64, contractIDs []int64, dateCreated *time.Time, + uuid string, company, limit, offset int) ([]models.Contract, int64, int, error) { var contracts []models.Contract db := shared.GetDb() @@ -69,6 +69,12 @@ func GetContracts(status []string, companyName string, companyAddress string, countDb = countDb.Where("lower(contracts.name) LIKE ?", "%"+strings.ToLower(contractName)+"%") } + // Check if uuid is present + if uuid != "" { + db = db.Where("contracts.uuid = ?", uuid) + countDb = countDb.Where("contracts.uuid = ?", uuid) + } + // Search by Start Time and End Time if startTime != nil && !startTime.IsZero() { db = db.Where("start_time >= ?", startTime) @@ -139,7 +145,7 @@ func UpdateContract(contract models.Contract) (models.Contract, int, error) { } // get old contract to compare updates - oldContract, status, err := GetContractByID(contract.ID) + oldContract, status, err := GetContractByID(contract.ID, "") if err != nil { return contract, status, err } @@ -167,7 +173,7 @@ func UpdateContract(contract models.Contract) (models.Contract, int, error) { return contract, http.StatusInternalServerError, err } - contract, status, err = GetContractByID(contract.ID) + contract, status, err = GetContractByID(contract.ID, "") if err != nil { return contract, status, err } @@ -218,7 +224,7 @@ func validateContractDevices(contractID uint, devices []models.Device) (int, err for _, device := range devices { if device.CurrentContractID != nil && *device.CurrentContractID != contractID { - currentDeviceContract, status, err := GetContractByID(*device.CurrentContractID) + currentDeviceContract, status, err := GetContractByID(*device.CurrentContractID, "") if err != nil { return status, err } @@ -233,11 +239,17 @@ func validateContractDevices(contractID uint, devices []models.Device) (int, err return http.StatusOK, nil } -func GetContractByID(contractID uint) (models.Contract, int, error) { +func GetContractByID(contractID uint, uuid string) (models.Contract, int, error) { - // Fetch the contract creation date based on contractID var contract models.Contract - if err := shared.GetDb().Unscoped().Where("id = ?", contractID).First(&contract).Error; err != nil { + db := shared.GetDb().Unscoped().Where("id = ?", contractID) + + // Include UUID in the query if provided + if uuid != "" { + db = db.Where("uuid = ?", uuid) + } + + if err := db.First(&contract).Error; err != nil { log.Printf("GetContractByID Error: Could not fetch contract: %v", err) return contract, http.StatusInternalServerError, err } diff --git a/database/device/device.go b/database/device/device.go index 89a4596..c35e3e2 100644 --- a/database/device/device.go +++ b/database/device/device.go @@ -15,19 +15,27 @@ import ( "gitlab.com/pactual1/backend/shared" ) -func GetDevicesForContract(contractID uint64, companyID int) ([]models.Device, int, error) { - // Fetch the contract from the database +func GetDevicesForContract(contractID uint64, uuid string, companyID int) ([]models.Device, int, error) { + // Fetch the contract from the database using both contractID and UUID var contract models.Contract - if err := shared.GetDb().Where("id = ?", contractID).First(&contract).Error; err != nil { + query := shared.GetDb().Where("id = ?", contractID) + + // If UUID is provided, include it in the query + if uuid != "" { + query = query.Where("uuid = ?", uuid) + } + + if err := query.First(&contract).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - log.Printf("GetDevicesByContract Error: No contract found: %v", err) + log.Printf("GetDevicesForContract Error: No contract found: %v", err) return nil, http.StatusNotFound, err } else { - log.Printf("GetDevicesByContract Error: Database error: %v", err) + log.Printf("GetDevicesForContract Error: Database error: %v", err) return nil, http.StatusInternalServerError, err } } - log.Printf("This is the device IDS ID: %v", contract.DeviceIDs) + + log.Printf("This is the device IDs: %v", contract.DeviceIDs) return GetDevicesByID(contract.DeviceIDs) } diff --git a/go.mod b/go.mod index 79785f3..330f530 100644 --- a/go.mod +++ b/go.mod @@ -58,7 +58,7 @@ require ( github.com/go-stack/stack v1.8.1 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/golang-jwt/jwt/v4 v4.5.0 // indirect - github.com/google/uuid v1.3.1 // indirect + github.com/google/uuid v1.4.0 // indirect github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/css v1.0.0 // indirect github.com/gorilla/securecookie v1.1.1 // indirect diff --git a/go.sum b/go.sum index df6c0d4..201f65e 100644 --- a/go.sum +++ b/go.sum @@ -138,6 +138,8 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= +github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= @@ -334,6 +336,8 @@ github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFA github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= +github.com/twinj/uuid v1.0.0 h1:fzz7COZnDrXGTAOHGuUGYd6sG+JMq+AoE7+Jlu0przk= +github.com/twinj/uuid v1.0.0/go.mod h1:mMgcE1RHFUFqe5AfiwlINXisXfDGro23fWdPUfOMjRY= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/tyler-smith/go-bip39 v1.1.0 h1:5eUemwrMargf3BSLRRCalXT93Ns6pQJIjYQN2nyfOP8= diff --git a/main.go b/main.go index 492180f..96a9fe2 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( "gitlab.com/pactual1/backend/shared" "github.com/gin-gonic/gin" + "github.com/google/uuid" "github.com/jinzhu/gorm" "github.com/qor/admin" ) @@ -80,6 +81,19 @@ func main() { http.ListenAndServe(":"+port, mux) }() + // Temp fix to add UUID to existing contracts + var contracts []models.Contract + result := shared.GetDb().Where("uuid IS NULL").Find(&contracts) + if result.Error != nil { + log.Printf("Error fetching contracts: %v", result.Error) + return + } + + for _, contract := range contracts { + contract.UUID = uuid.New().String() + shared.GetDb().Save(&contract) + } + // Initialize channels // messagingChannel := make(chan string) erpChannel := make(chan string) diff --git a/middlewares/jwt.go b/middlewares/jwt.go index 3eb7c16..c497b1f 100644 --- a/middlewares/jwt.go +++ b/middlewares/jwt.go @@ -6,6 +6,7 @@ package middlewares import ( "errors" + "log" "net/http" "strings" "time" @@ -113,6 +114,15 @@ func ValidateToken(tokenString string, key string) (*jwt.Token, error) { // AuthMiddleware checks the session token and validates it func AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { + + // Check if contractCheckPassed is set in the context + if passed, exists := c.Get("contractCheckPassed"); exists && passed.(bool) { + log.Printf("checjedpass auth %v", true) + // Skip further checks and proceed to the next middleware + c.Next() + return + } + log.Printf("checjedpass auth%v", false) var jwtKey = []byte(config.AppConfig.Service.JwtSecretKey) tokenString := c.GetHeader("Authorization") @@ -156,3 +166,50 @@ func AuthMiddleware() gin.HandlerFunc { c.Next() } } + +func ContractCheckMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + db := shared.GetDb() + var contractID uint + var uuid string + + // Handling for POST requests + if c.Request.Method == "POST" { + var payload struct { + UUID string `json:"uuid"` + } + if err := c.ShouldBindJSON(&payload); err == nil { + uuid = payload.UUID + } + } + + // Handling for GET requests + if c.Request.Method == "GET" { + uuid = c.Query("uuid") + } + + log.Printf("uuid %v", uuid) + log.Printf("contractID %v", contractID) + + // Perform the check only if both contractID and uuid are provided + if uuid != "" { + var contract models.Contract + result := db.Where("uuid = ?", uuid).First(&contract) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + c.JSON(http.StatusUnauthorized, gin.H{"message": "Invalid contract ID or UUID"}) + c.Abort() + return + } else if result.Error != nil { + c.JSON(http.StatusInternalServerError, gin.H{"message": "Internal server error"}) + c.Abort() + return + } + + // Set a flag in the context to indicate a successful contract check + c.Set("contractCheckPassed", true) + log.Printf("checjedpass %v", true) + } + + c.Next() + } +} diff --git a/models/contract.go b/models/contract.go index 1a18929..686e98f 100644 --- a/models/contract.go +++ b/models/contract.go @@ -11,6 +11,7 @@ import ( type Contract struct { BaseModel Name string `json:"name"` + UUID string `json:"uuid" gorm:"type:uuid;"` DeviceIDs pq.Int64Array `json:"deviceIds" gorm:"type:integer[]"` BuyerID uint `json:"buyerId"` SellerID uint `json:"sellerId"` diff --git a/routes/public_routes.go b/routes/public_routes.go index 822ea76..e394f28 100644 --- a/routes/public_routes.go +++ b/routes/public_routes.go @@ -13,9 +13,9 @@ func RegisterPublicRoutes(r *gin.Engine) { r.GET("/health", controllers.HealthCheck) // Map dashboard - r.GET("/dashboard/map/contract/devices", middlewares.AuthMiddleware(), controllers.GetDevicesByContract) - r.GET("/dashboard/map/contracts", middlewares.AuthMiddleware(), controllers.GetLatestContracts) - r.GET("/dashboard/map/device_data", middlewares.AuthMiddleware(), controllers.GetDeviceData) + r.GET("/dashboard/map/contract/devices", middlewares.ContractCheckMiddleware(), middlewares.AuthMiddleware(), controllers.GetDevicesByContract) + r.GET("/dashboard/map/contracts", middlewares.ContractCheckMiddleware(), middlewares.AuthMiddleware(), controllers.GetLatestContracts) + r.GET("/dashboard/map/device_data", middlewares.ContractCheckMiddleware(), middlewares.AuthMiddleware(), controllers.GetDeviceData) // Invoices r.GET("/invoices", middlewares.AuthMiddleware(), controllers.GetInvoices) diff --git a/shared/database.go b/shared/database.go index 1ddf02d..12a66cc 100644 --- a/shared/database.go +++ b/shared/database.go @@ -32,6 +32,7 @@ func Init() error { log.Println("Error initializing the database: ", err) return err } + //TODO AUTOMIGRATE models once we have them db.AutoMigrate(&models.User{}, &models.Company{}, &models.Device{}, &models.DeviceInfo{}, &models.Contract{}, &models.ContractInfo{},