Files
old-riskletpy/backend/core/utils.py

183 lines
7.3 KiB
Python

from openai import OpenAI
from django.conf import settings
from .models import Risk, Control
import time
def extract_organization_details(organization):
excluded_fields = {"name", "email"}
risk_data = {}
for field in organization._meta.get_fields():
if field.name not in excluded_fields and hasattr(organization, field.name):
value = getattr(organization, field.name)
if value:
help_text = getattr(field, 'help_text', '').strip()
key = help_text if help_text else field.name
risk_data[key] = value
return risk_data
def get_top_risk(organization):
client = OpenAI(api_key=settings.OPENAI_API_KEY)
all_risks = Risk.objects.all()
risk_list = []
for risk in all_risks:
risk_list.append(f"""
Risk ID: {risk.risk_id}
Category: {risk.category}
Name: {risk.risk_name}
Primary Impact: {risk.primary_impact}
Secondary Impact: {risk.secondary_impact}
Tertiary Impact: {risk.tretiary_impact}
Detection Difficulty: {risk.detection_difficulty}
Recovery Complexity: {risk.recovery_complexity}
Business Impact Severity: {risk.businnes_impact_severity}
""")
organization_details = extract_organization_details(organization)
prompt = f"""
You are an AI risk assessor. Based on the following company details and list of known risks,
identify the 10 most critical risks for this company. Respond only with risk IDs.
Company Details:
{organization_details}
List of Risks:
{risk_list}
Provide only the 10 most critical risk IDs in a simple comma-separated format, e.g "1,3,7,12,..."
"""
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "system", "content": prompt}]
)
risk_ids = response.choices[0].message.content.strip().split(",")
print(f"Risks: {risk_ids}")
return [int(risk_id) for risk_id in risk_ids if risk_id.isdigit()]
def get_controls_for_risk(risk, organization):
client = OpenAI(api_key=settings.OPENAI_API_KEY)
all_controls = Control.objects.all()
control_list = []
organization_details = extract_organization_details(organization)
valid_control_ids = {control.id for control in all_controls}
control_map = {control.id: control.name for control in all_controls}
for control in all_controls:
control_list.append(f"Control ID: {control.id}, Control Name: {control.name}")
prompt = f"""
You are an expert in cybersecurity risk management. Given the risk "{risk.risk_name}" and its associated organization details "{organization_details}",
your task is to select **exactly 10 unique controls** from the provided list that best mitigate this risk. Each control should be assigned a weight between **1 and 10** based on its effectiveness in reducing the risk.
### Rules:
1. **Each control ID must be unique** (no duplicates).
2. **Only return control IDs and weights** in the exact format below.
3. **Weights must be between 1 and 10** (1 = low impact, 10 = high impact).
4. **Do NOT add explanations, descriptions, or extra text.**
5. **Ensure that control IDs are randomly distributed and diverse across different categories.**
### Available Controls:
{control_list}
### Expected Response Format (STRICTLY FOLLOW THIS FORMAT):
```
<control_id> : <weight>
<control_id> : <weight>
```
### Example Correct Response (NO DUPLICATES):
```
12 : 8
45 : 7
```
⚠️ **If you provide duplicate control IDs, your response will be rejected. Ensure all control IDs are unique.**
⚠️ **Follow the response format exactly. Any deviation will be considered invalid.**
"""
for attempt in range(10):
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "system", "content": prompt}]
)
result = response.choices[0].message.content.strip()
print(f"AI Response (Attempt {attempt+1}):\n{result}")
selected_controls = []
valid = True
control_ids_seen = set()
for line in result.split("\n"):
line = line.strip()
parts = line.split(":")
if len(parts) == 2:
control_id_str = parts[0].replace("ID:", "").replace("id:", "").replace("Id:", "").strip()
weight_str = parts[1].strip().replace("Weight:", "").replace("weight:", "").strip()
control_id_str = ''.join(filter(str.isdigit, control_id_str))
weight_str = ''.join(filter(str.isdigit, weight_str))
if control_id_str and weight_str:
control_id = int(control_id_str)
weight = int(weight_str)
if control_id in valid_control_ids and 1 <= weight <= 10:
if control_id in control_ids_seen:
valid = False
break
selected_controls.append((control_id, weight))
control_ids_seen.add(control_id)
else:
valid = False
break
if valid and len(selected_controls) == 10:
return selected_controls
print(f"Recived {len(selected_controls)} controls. Retrying for missing ones...\n")
remaining_controls = valid_control_ids - control_ids_seen
missing_count = 10 - len(selected_controls)
if missing_count > 0:
remaining_controls_list = [f"Control ID:{cid}, Control Name: {control_map[cid]}" for cid in remaining_controls]
prompt = f"""
You are an expert in cybersecurity risk management. Previously, you selected {len(selected_controls)} controls for the risk "{risk.risk_name}"
and its associated organization details "{organization_details}".
Now, your task is to select **exactly {missing_count} additional unique controls** from the remaining list that best mitigate this risk.
Each selected control should be assigned a weight between **1 and 10**, based on its effectiveness in reducing the risk.
### Rules:
1. **Each control ID must be unique** (no duplicates with the previously selected controls).
2. **Only return control IDs and weights** in the exact format below.
3. **Weights must be between 1 and 10** (1 = low impact, 10 = high impact).
4. **Do NOT add explanations, descriptions, or extra text.**
5. **Ensure that control IDs are randomly distributed and diverse across different categories.**
### Remaining Available Controls:
{remaining_controls_list}
### Expected Response Format (STRICTLY FOLLOW THIS FORMAT):
<control_id> : <weight>
<control_id> : <weight>
### Example Correct Response (NO DUPLICATES):
23 : 7,
45 : 3,
**If you provide duplicate control IDs or controls outside the remaining list, your response will be rejected.**
**Follow the response format exactly. Any deviation will be considered invalid.**
"""
time.sleep(2)
continue
print("Failed to get a valid response after multiple attempts.")
return []