package contract import ( "context" "errors" "fmt" "log" "net/http" "strings" "time" "github.com/jinzhu/gorm" "github.com/lib/pq" "gitlab.com/pactual1/backend/config" "gitlab.com/pactual1/backend/database/device" "gitlab.com/pactual1/backend/models" "gitlab.com/pactual1/backend/services/blockchain" "gitlab.com/pactual1/backend/shared" ) func GetContracts(status []string, companyName string, companyAddress string, companyEmail string, companyPhone string, startTime *time.Time, endTime *time.Time, contractName string, deviceIDs []int64, contractIDs []int64, dateCreated *time.Time, limit, offset int) ([]models.Contract, int64, int, error) { var contracts []models.Contract db := shared.GetDb() countDb := db // Define custom fields to be selected, varies based on joined tables customFields := "distinct contracts.*, array_length(contracts.device_ids, 1) as number_of_devices" // Search by Statuses if len(status) > 0 { db = db.Where("contracts.status IN (?)", status) countDb = countDb.Where("contracts.status IN (?)", status) } // Search by IDs if len(contractIDs) > 0 { db = db.Where("contracts.id IN (?)", contractIDs) countDb = countDb.Where("contracts.id IN (?)", contractIDs) } // Search by Company Fields db = db.Joins("left join companies on companies.id = contracts.buyer_id") countDb = countDb.Joins("left join companies on companies.id = contracts.buyer_id") customFields += ", companies.name as buyer_name" if companyName != "" { db = db.Where("companies.name LIKE ?", "%"+companyName+"%") countDb = countDb.Where("companies.name LIKE ?", "%"+companyName+"%") } if companyAddress != "" { db = db.Where("companies.address LIKE ?", "%"+companyAddress+"%") countDb = countDb.Where("companies.address LIKE ?", "%"+companyAddress+"%") } if companyEmail != "" { db = db.Where("companies.email LIKE ?", "%"+companyEmail+"%") countDb = countDb.Where("companies.email LIKE ?", "%"+companyEmail+"%") } if companyPhone != "" { db = db.Where("companies.phone LIKE ?", "%"+companyPhone+"%") countDb = countDb.Where("companies.phone LIKE ?", "%"+companyPhone+"%") } // Search by Contract Name if contractName != "" { db = db.Where("lower(contracts.name) LIKE ?", "%"+strings.ToLower(contractName)+"%") countDb = countDb.Where("lower(contracts.name) LIKE ?", "%"+strings.ToLower(contractName)+"%") } // Search by Start Time and End Time if startTime != nil && !startTime.IsZero() { db = db.Where("start_time >= ?", startTime) countDb = countDb.Where("start_time >= ?", startTime) } if endTime != nil && !startTime.IsZero() { db = db.Where("end_time <= ?", endTime) countDb = countDb.Where("end_time <= ?", endTime) } if dateCreated != nil && !dateCreated.IsZero() { db = db.Where("contracts.created_at = ?", dateCreated) countDb = countDb.Where("contracts.created_at = ?", dateCreated) } // Search by Device IDs if len(deviceIDs) > 0 { db = db.Where("device_ids && ?", pq.Array(deviceIDs)) countDb = countDb.Where("device_ids && ?", pq.Array(deviceIDs)) } // Fetch total count of filtered records var total int64 if err := countDb.Model(&models.Contract{}).Count(&total).Error; err != nil { log.Printf("GetContracts Error: Database error: %v", err) return contracts, total, http.StatusInternalServerError, err } // Fetch contracts with custom fields if err := db.Select(customFields). Order("created_at desc"). Limit(limit). Offset(offset). Find(&contracts).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Printf("GetContracts Error: No contracts found: %v", err) return contracts, total, http.StatusNotFound, err } else { log.Printf("GetContracts Error: Database error: %v", err) return contracts, total, http.StatusInternalServerError, err } } return contracts, total, http.StatusOK, nil } func UpdateContract(contract models.Contract) (models.Contract, int, error) { var devices []models.Device var status int var err error if contract.DeviceIDs != nil { devices, status, err = device.GetDevicesByID(contract.DeviceIDs) if err != nil { return contract, status, err } status, err = validateContractDevices(contract.ID, devices) if err != nil { log.Printf("UpdateContract Error: Invalid Device ID: %v", err) return contract, status, err } } // get old contract to compare updates oldContract, status, err := GetContractByID(contract.ID) if err != nil { return contract, status, err } err = shared.GetDb().Transaction(func(tx *gorm.DB) error { // Update contract if err := tx.Model(contract).Updates(contract).Error; err != nil { log.Printf("UpdateContract Error: Could not update contract: %v", err) return err } if devices != nil { // Update devices if err := tx.Model(models.Device{}).Where("id IN (?)", []int64(contract.DeviceIDs)).Updates(models.Device{CurrentContractID: &contract.ID}).Error; err != nil { log.Printf("UpdateContract Error: Could not update devices: %v", err) return err } } // return nil will commit the whole transaction return nil }) if err != nil { log.Printf("UpdateContract Error: Could not update contract: %v", err) return contract, http.StatusInternalServerError, err } contract, status, err = GetContractByID(contract.ID) if err != nil { return contract, status, err } // Create contract in blockchain only when it is signed if oldContract.Status != contract.Status && contract.Status == models.ContractStatusSigned { err = blockchain.NewService(config.AppConfig.Blockchain).CreateContract(context.Background(), shared.CovertUintToByte32(contract.ID)) if err != nil { log.Printf("UpdateContract Error: Could not create contract in blockchain: %v", err) return contract, http.StatusInternalServerError, err } // Register devices in blockchain when contract is signed for _, device := range devices { err = blockchain.NewService(config.AppConfig.Blockchain).RegisterNewDeviceID(context.Background(), shared.CovertUintToByte32(contract.ID), shared.CovertUintToByte32(device.ID)) if err != nil { log.Printf("UpdateContract Error: Could not register contract device in blockchain: %v", err) return contract, http.StatusInternalServerError, err } } } return contract, status, err } func validateContractDevices(contractID uint, devices []models.Device) (int, error) { for _, device := range devices { if device.CurrentContractID != nil && *device.CurrentContractID != contractID { currentDeviceContract, status, err := GetContractByID(*device.CurrentContractID) if err != nil { return status, err } if currentDeviceContract.Status != models.ContractStatusExecuted && currentDeviceContract.Status == models.ContractStatusRevoked { return http.StatusBadRequest, fmt.Errorf("device id %d is linked to contract id - %d name %s", device.ID, currentDeviceContract.ID, currentDeviceContract.Name) } } } return http.StatusOK, nil } func GetContractByID(contractID uint) (models.Contract, int, error) { // Fetch the contract creation date based on contractID var contract models.Contract if err := shared.GetDb().Unscoped().Where("id = ?", contractID).First(&contract).Error; err != nil { log.Printf("GetContractByID Error: Could not fetch contract: %v", err) return contract, http.StatusInternalServerError, err } // Fetch the product var product models.ProductTemplate if err := shared.GetDb().Unscoped().Where("id = ?", contract.ProductID).First(&product).Error; err != nil { log.Printf("GetContractByID Error: Could not fetch product: %v", err) return contract, http.StatusInternalServerError, err } contract.ProductName = product.Name // Fetch the seller var seller models.Company if err := shared.GetDb().Unscoped().Where("id = ?", contract.SellerID).First(&seller).Error; err != nil { log.Printf("GetContractByID Error: Could not fetch seller: %v", err) return contract, http.StatusInternalServerError, err } contract.SellerName = seller.Name // Fetch the buyer var buyer models.Company if err := shared.GetDb().Unscoped().Where("id = ?", contract.BuyerID).First(&buyer).Error; err != nil { log.Printf("GetContractByID Error: Could not fetch buyer: %v", err) return contract, http.StatusInternalServerError, err } contract.BuyerName = buyer.Name return contract, http.StatusOK, nil }