package contract import ( "errors" "fmt" "log" "net/http" "time" "github.com/jinzhu/gorm" "github.com/lib/pq" "gitlab.com/pactual1/backend/database/device" "gitlab.com/pactual1/backend/models" "gitlab.com/pactual1/backend/shared" ) func GetContracts(status []string, companyName string, companyAddress string, companyEmail string, companyPhone string, startTime *time.Time, endTime *time.Time, contractName string, deviceIDs []int64, contractIDs []int64, dateCreated *time.Time, limit, offset int) ([]models.Contract, int64, int, error) { var contracts []models.Contract db := shared.GetDb() countDb := db // Define custom fields to be selected, varies based on joined tables customFields := "distinct contracts.*, array_length(contracts.device_ids, 1) as number_of_devices" // Search by Statuses if len(status) > 0 { db = db.Where("contracts.status IN (?)", status) countDb = countDb.Where("contracts.status IN (?)", status) } // Search by IDs if len(contractIDs) > 0 { db = db.Where("contracts.id IN (?)", contractIDs) countDb = countDb.Where("contracts.id IN (?)", contractIDs) } // Search by Company Fields db = db.Joins("left join companies on companies.id = contracts.buyer_id") countDb = countDb.Joins("left join companies on companies.id = contracts.buyer_id") customFields += ", companies.name as buyer_name" if companyName != "" { db = db.Where("companies.name LIKE ?", "%"+companyName+"%") countDb = countDb.Where("companies.name LIKE ?", "%"+companyName+"%") } if companyAddress != "" { db = db.Where("companies.address LIKE ?", "%"+companyAddress+"%") countDb = countDb.Where("companies.address LIKE ?", "%"+companyAddress+"%") } if companyEmail != "" { db = db.Where("companies.email LIKE ?", "%"+companyEmail+"%") countDb = countDb.Where("companies.email LIKE ?", "%"+companyEmail+"%") } if companyPhone != "" { db = db.Where("companies.phone LIKE ?", "%"+companyPhone+"%") countDb = countDb.Where("companies.phone LIKE ?", "%"+companyPhone+"%") } // Search by Contract Name if contractName != "" { db = db.Where("contracts.name LIKE ?", "%"+contractName+"%") countDb = countDb.Where("contracts.name LIKE ?", "%"+contractName+"%") } // Search by Start Time and End Time if startTime != nil && !startTime.IsZero() { db = db.Where("start_time >= ?", startTime) countDb = countDb.Where("start_time >= ?", startTime) } if endTime != nil && !startTime.IsZero() { db = db.Where("end_time <= ?", endTime) countDb = countDb.Where("end_time <= ?", endTime) } if dateCreated != nil && !dateCreated.IsZero() { db = db.Where("contracts.created_at = ?", dateCreated) countDb = countDb.Where("contracts.created_at = ?", dateCreated) } // Search by Device IDs if len(deviceIDs) > 0 { db = db.Where("device_ids && ?", pq.Array(deviceIDs)) countDb = countDb.Where("device_ids && ?", pq.Array(deviceIDs)) } // Fetch total count of filtered records var total int64 if err := countDb.Model(&models.Contract{}).Count(&total).Error; err != nil { log.Printf("GetContracts Error: Database error: %v", err) return contracts, total, http.StatusInternalServerError, err } // Fetch contracts with custom fields if err := db.Select(customFields). Order("created_at desc"). Limit(limit). Offset(offset). Find(&contracts).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Printf("GetContracts Error: No contracts found: %v", err) return contracts, total, http.StatusNotFound, err } else { log.Printf("GetContracts Error: Database error: %v", err) return contracts, total, http.StatusInternalServerError, err } } return contracts, total, http.StatusOK, nil } func UpdateContract(contract models.Contract) (models.Contract, int, error) { var devices []models.Device var status int var err error if contract.DeviceIDs != nil { devices, status, err = device.GetDevicesByID(contract.DeviceIDs) if err != nil { return contract, status, err } status, err = validateContractDevices(contract.ID, devices) if err != nil { log.Printf("UpdateContract Error: Invalid Device ID: %v", err) return contract, status, err } } err = shared.GetDb().Transaction(func(tx *gorm.DB) error { // Update contract if err := tx.Model(contract).Updates(contract).Error; err != nil { log.Printf("UpdateContract Error: Could not update contract: %v", err) return err } if devices != nil { // Update devices if err := tx.Model(models.Device{}).Where("id IN (?)", contract.DeviceIDs).Updates(models.Device{CurrentContractID: &contract.ID}).Error; err != nil { log.Printf("UpdateContract Error: Could not update devices: %v", err) return err } } // return nil will commit the whole transaction return nil }) if err != nil { log.Printf("UpdateContract Error: Could not update contract: %v", err) return contract, http.StatusInternalServerError, err } return GetContractByID(uint64(contract.ID)) } func validateContractDevices(contractID uint, devices []models.Device) (int, error) { for _, device := range devices { if device.CurrentContractID != nil && *device.CurrentContractID != contractID { currentDeviceContract, status, err := GetContractByID(uint64(*device.CurrentContractID)) if err != nil { return status, err } if currentDeviceContract.Status != models.ContractStatusExecuted && currentDeviceContract.Status == models.ContractStatusRevoked { return http.StatusBadRequest, fmt.Errorf("device id %d is linked to contract id - %d name %s", device.ID, currentDeviceContract.ID, currentDeviceContract.Name) } } } return http.StatusOK, nil } func GetContractByID(contractID uint64) (models.Contract, int, error) { // Fetch the contract creation date based on contractID var contract models.Contract if err := shared.GetDb().Unscoped().Where("id = ?", contractID).First(&contract).Error; err != nil { log.Printf("GetContractByID Error: Could not fetch contract: %v", err) return contract, http.StatusInternalServerError, err } return contract, http.StatusOK, nil }