diff --git a/backend/__init__.py b/backend/__init__.py index e69de29..9e0d95f 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -0,0 +1,3 @@ +from .celery import app as celery_app + +__all__ = ('celery_app',) \ No newline at end of file diff --git a/backend/accounts/tasks.py b/backend/accounts/tasks.py new file mode 100644 index 0000000..56cf39a --- /dev/null +++ b/backend/accounts/tasks.py @@ -0,0 +1,52 @@ +from celery import shared_task +from backend.core.models import Organization, Document, Risk, Control, DocumentRiskControl +from backend.core.utils import get_top_risk, get_controls_for_risk +from django.shortcuts import get_object_or_404, render + + +@shared_task +def create_document_for_organization(confirmation_email): + + organization = get_object_or_404(Organization, email=confirmation_email) + + top_risk_ids = get_top_risk(organization) + top_risks = Risk.objects.filter(risk_id__in=top_risk_ids) + organization.risks.set(top_risks) + + document = Document.objects.create(organization=organization) + document.add_segment('h1', "Top 10 Risks Identified") + + risk_content = "\n\n".join([ + f"Risk: {risk.risk_id} - {risk.risk_name} \n" + f"Category: {risk.category}\n" + f"Primary Impact: {risk.primary_impact} \n" + f"Secondary Impact: {risk.secondary_impact}\n" + f"Tertiary Impact: {risk.tretiary_impact} \n" + f"Detection Difficulty: {risk.detection_difficulty} \n" + f"Recovery Complexity: {risk.recovery_complexity} \n" + f"Business Impact Severity: {risk.businnes_impact_severity}\n" + for risk in top_risks + ]) + document.add_segment('body', f"Identified Risks: \n\n{risk_content}") + + controls_content = "Mitigation Controls:\n\n" + + for risk in top_risks: + controls_content += f"Risk: {risk.risk_id} - {risk.risk_name}\n" + + selected_controls = get_controls_for_risk(risk ,organization=organization) + + for control_id, weight in selected_controls: + control = Control.objects.filter(id=control_id).first() + if control: + DocumentRiskControl.objects.create( + document=document, + risk=risk, + control=control, + weight=weight + ) + controls_content += f" - Control: {control.name} (Impact Weight: {weight}/10)\n" + + controls_content += "\n" + + document.add_segment('body', controls_content) \ No newline at end of file diff --git a/backend/accounts/views.py b/backend/accounts/views.py index ecf9c75..c644ca5 100644 --- a/backend/accounts/views.py +++ b/backend/accounts/views.py @@ -5,7 +5,7 @@ from .models import EmailConfirmation from django.shortcuts import get_object_or_404, render from django.http import HttpResponse from backend.accounts.utils import send_confirmation_email - +from .tasks import create_document_for_organization class SignUpView(CreateView): form_class = SignupForm @@ -13,12 +13,16 @@ class SignUpView(CreateView): template_name = 'accounts/signup.html' -def confirm_email(request,uuid): +def confirm_email(request, uuid): confirmation = get_object_or_404(EmailConfirmation, uuid=uuid) if confirmation.is_expired(): - return render(request,'confirmation_expired.html', {'email': confirmation.email}) + return render(request, 'confirmation_expired.html', {'email': confirmation.email}) + task = create_document_for_organization.delay(confirmation.email) + print(f"Task ID: {task.id}") + + return HttpResponse("Email is confirmed") def resend_confirmation(request,email): diff --git a/backend/celery.py b/backend/celery.py new file mode 100644 index 0000000..4b45946 --- /dev/null +++ b/backend/celery.py @@ -0,0 +1,15 @@ +# backend/celery.py +import os +from celery import Celery + +# Set the default Django settings module +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings') + +# Create the Celery app +app = Celery('backend') + +# Load configuration from Django settings +app.config_from_object('django.conf:settings', namespace='CELERY') + +# Discover tasks in all installed apps +app.autodiscover_tasks() \ No newline at end of file diff --git a/backend/core/admin.py b/backend/core/admin.py index ead9a41..87beba0 100644 --- a/backend/core/admin.py +++ b/backend/core/admin.py @@ -1,5 +1,5 @@ from django.contrib import admin -from .models import Document, DocumentSegment, Organization, Risk, Control, DocumentTemplate +from .models import Document, DocumentSegment, Organization, Risk, Control, DocumentTemplate, DocumentRiskControl from django.urls import reverse from django.utils.html import format_html @@ -38,6 +38,8 @@ class RiskAdmin(admin.ModelAdmin): class ControlAdmin(admin.ModelAdmin): list_display = ('id', 'name') +class DocumentRiskControlAdmin(admin.ModelAdmin): + list_display = ('document', 'risk', 'control', 'weight') admin.site.register(Document, DocumentAdmin) @@ -45,3 +47,4 @@ admin.site.register(Organization, OrganizationAdmin) admin.site.register(Risk ,RiskAdmin) admin.site.register(Control, ControlAdmin) admin.site.register(DocumentTemplate, DocumentTemplateAdmin) +admin.site.register(DocumentRiskControl, DocumentRiskControlAdmin) diff --git a/backend/core/migrations/0008_documentriskcontrol.py b/backend/core/migrations/0008_documentriskcontrol.py new file mode 100644 index 0000000..b993d7e --- /dev/null +++ b/backend/core/migrations/0008_documentriskcontrol.py @@ -0,0 +1,27 @@ +# Generated by Django 5.1.3 on 2025-02-14 13:09 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0007_rename_safeguard_control_name_and_more'), + ] + + operations = [ + migrations.CreateModel( + name='DocumentRiskControl', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('weight', models.IntegerField()), + ('control', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='core.control')), + ('document', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='core.document')), + ('risk', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='core.risk')), + ], + options={ + 'unique_together': {('document', 'risk', 'control')}, + }, + ), + ] diff --git a/backend/core/models.py b/backend/core/models.py index dd89a01..6dc5587 100644 --- a/backend/core/models.py +++ b/backend/core/models.py @@ -44,28 +44,28 @@ class CreatedBy(models.Model): class Organization(models.Model): - name = models.CharField(max_length=255) - email = models.EmailField() - employee_headcount = models.CharField(max_length=20) - annual_revenue = models.CharField(max_length=20) - critical_applications = models.CharField(max_length=20) - compliance_frameworks = models.JSONField() # Stores selected compliance frameworks as a list - industry_sector = models.CharField(max_length=255) - it_dependency = models.IntegerField() - data_sensitivity = models.CharField(max_length=20) - network_infrastructure = models.CharField(max_length=20) - remote_workforce_percentage = models.CharField(max_length=20) - third_party_vendor_access = models.CharField(max_length=20) - internal_software_development = models.CharField(max_length=20) - geographic_scope = models.CharField(max_length=20, null=True, blank=True) - customer_base = models.CharField(max_length=20, null=True, blank=True) - customer_type = models.CharField(max_length=20, null=True, blank=True) - product_portfolio = models.CharField(max_length=20, null=True, blank=True) - supplier_base = models.CharField(max_length=20, null=True, blank=True) - it_infrastructure = models.JSONField(null=True, blank=True) # Stores selected IT infrastructure types as a list - intellectual_property = models.JSONField(null=True, blank=True) # Stores selected IP protection types as a list - sensitive_data = models.JSONField(null=True, blank=True) # Stores selected sensitive data types as a list - integration_level = models.CharField(max_length=20, null=True, blank=True) + name = models.CharField(max_length=255, help_text="What is the name of your organization?") + email = models.EmailField(help_text="What is your email?") + employee_headcount = models.CharField(max_length=20, help_text="What is your organization's current employee headcount?") + annual_revenue = models.CharField(max_length=20, help_text="What is your organization's annual revenue range?") + critical_applications = models.CharField(max_length=20, help_text="How many critical business applications do your employees use daily?") + compliance_frameworks = models.JSONField(help_text="Which regulatory frameworks is your organization required to comply with?") # Stores selected compliance frameworks as a list + industry_sector = models.CharField(max_length=255,help_text="What is your primary industry sector?") + it_dependency = models.IntegerField(help_text="On a scale from 1-10, how dependent is your business operations on technology?") + data_sensitivity = models.CharField(max_length=20, help_text="What level of sensitive data does your organization process?") + network_infrastructure = models.CharField(max_length=20, help_text="What best describes your organization's network infrastructure model?") + remote_workforce_percentage = models.CharField(max_length=20, help_text="What percentage of your workforce operates remotely?") + third_party_vendor_access = models.CharField(max_length=20, help_text="How many third-party vendors have access to your systems?") + internal_software_development = models.CharField(max_length=20, help_text="What is the extent of your internal software development activities?") + geographic_scope = models.CharField(max_length=20, null=True, blank=True, help_text="What is your organization's geographic operational scope?") + customer_base = models.CharField(max_length=20, null=True, blank=True, help_text="How would you characterize your customer base distribution?") + customer_type = models.CharField(max_length=20, null=True, blank=True, help_text="What is your primary customer type?") + product_portfolio = models.CharField(max_length=20, null=True, blank=True, help_text="How diversified is your product/service portfolio?") + supplier_base = models.CharField(max_length=20, null=True, blank=True, help_text="What is your supplier base structure?") + it_infrastructure = models.JSONField(null=True, blank=True, help_text="What is your primary IT infrastructure model?") # Stores selected IT infrastructure types as a list + intellectual_property = models.JSONField(null=True, blank=True, help_text="How does your organization protect and manage intellectual property?") # Stores selected IP protection types as a list + sensitive_data = models.JSONField(null=True, blank=True, help_text="What type of sensitive data does your organization handle?") # Stores selected sensitive data types as a list + integration_level = models.CharField(max_length=20, null=True, blank=True, help_text="How integrated are your critical business systems?") risks = models.ManyToManyField('Risk', related_name='organizations', blank=True) @@ -157,4 +157,13 @@ class Control(models.Model): name = models.CharField(max_length=255) def __str__(self): - return f"{self.id} ({self.name})" \ No newline at end of file + return f"{self.id} ({self.name})" + +class DocumentRiskControl(models.Model): + document = models.ForeignKey(Document, on_delete=models.CASCADE) + risk = models.ForeignKey(Risk, on_delete=models.CASCADE) + control = models.ForeignKey(Control, on_delete=models.CASCADE) + weight = models.IntegerField() + + class Meta: + unique_together = ('document', 'risk', 'control') \ No newline at end of file diff --git a/backend/core/utils.py b/backend/core/utils.py index a572c89..7e164ca 100644 --- a/backend/core/utils.py +++ b/backend/core/utils.py @@ -1,16 +1,19 @@ from openai import OpenAI from django.conf import settings -from .models import Risk +from .models import Risk, Control +import time -def extract_risk_factors(organization): - excluded_fields={"name","email"} +def extract_organization_details(organization): + excluded_fields = {"name", "email"} risk_data = {} for field in organization._meta.get_fields(): if field.name not in excluded_fields and hasattr(organization, field.name): value = getattr(organization, field.name) if value: - risk_data[field.name] = value + help_text = getattr(field, 'help_text', '').strip() + key = help_text if help_text else field.name + risk_data[key] = value return risk_data def get_top_risk(organization): @@ -25,16 +28,21 @@ def get_top_risk(organization): Category: {risk.category} Name: {risk.risk_name} Primary Impact: {risk.primary_impact} + Secondary Impact: {risk.secondary_impact} + Tertiary Impact: {risk.tretiary_impact} + Detection Difficulty: {risk.detection_difficulty} + Recovery Complexity: {risk.recovery_complexity} + Business Impact Severity: {risk.businnes_impact_severity} """) - risk_factors = extract_risk_factors(organization) + organization_details = extract_organization_details(organization) prompt = f""" You are an AI risk assessor. Based on the following company details and list of known risks, identify the 10 most critical risks for this company. Respond only with risk IDs. Company Details: - {risk_factors} + {organization_details} List of Risks: {risk_list} @@ -43,10 +51,134 @@ def get_top_risk(organization): """ response = client.chat.completions.create( - model="gpt-4", + model="gpt-4o-mini", messages=[{"role": "system", "content": prompt}] ) risk_ids = response.choices[0].message.content.strip().split(",") + print(f"Risks: {risk_ids}") return [int(risk_id) for risk_id in risk_ids if risk_id.isdigit()] + +def get_controls_for_risk(risk, organization): + client = OpenAI(api_key=settings.OPENAI_API_KEY) + all_controls = Control.objects.all() + organization_details = extract_organization_details(organization) + control_list = [f"Control ID: {control.id}, Control Name: {control.name}" for control in all_controls] + valid_control_ids = {control.id for control in all_controls} + control_map = {control.id: control.name for control in all_controls} + + def fetch_controls(prompt): + response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "system", "content": prompt}] + ) + return response.choices[0].message.content.strip() + + prompt = f""" + You are an expert in cybersecurity risk management. Given the risk "{risk.risk_name}" and its associated organization details "{organization_details}", + your task is to select **exactly 10 unique controls** from the provided list that best mitigate this risk. Each control should be assigned a weight between **1 and 10** based on its effectiveness in reducing the risk. + ### Rules: + 1. **Each control ID must be unique** (no duplicates). + 2. **Only return control IDs and weights** in the exact format below. + 3. **Weights must be between 1 and 10** (1 = low impact, 10 = high impact). + 4. **Do NOT add explanations, descriptions, or extra text.** + 5. **Ensure that control IDs are randomly distributed and diverse across different categories.** + ### Available Controls: + {control_list} + + ### Expected Response Format (STRICTLY FOLLOW THIS FORMAT): + : + : + ### Example Correct Response (NO DUPLICATES): + 12 : 8 + 45 : 7 + ⚠️ **If you provide duplicate control IDs, your response will be rejected. Ensure all control IDs are unique.** + ⚠️ **Follow the response format exactly. Any deviation will be considered invalid.** + """ + + selected_controls = [] + control_ids_seen = set() + + result = fetch_controls(prompt) + + for line in result.split("\n"): + line = line.strip() + parts = line.split(":") + if len(parts) == 2: + control_id_str = parts[0].replace("ID:", "").replace("id:", "").replace("Id:", "").strip() + weight_str = parts[1].strip().replace("Weight:", "").replace("weight:", "").strip() + print(f"Control:{control_id_str} Weight:{weight_str}") + print(f"ControlType: {type(control_id_str)} WeightType: {type(weight_str)}") + + control_id_str = ''.join(filter(str.isdigit, control_id_str)) + weight_str = ''.join(filter(str.isdigit, weight_str)) + + if control_id_str and weight_str: + try: + control_id = int(control_id_str) + weight = int(weight_str) + + if control_id in valid_control_ids and 1 <= weight <= 10 and control_id not in control_ids_seen: + selected_controls.append((control_id, weight)) + control_ids_seen.add(control_id) + except ValueError: + continue + + if len(selected_controls) == 10: + return selected_controls + + while len(selected_controls) < 10: + missing_count = 10 - len(selected_controls) + remaining_controls = valid_control_ids - control_ids_seen + remaining_controls_list = [f"Control ID: {cid}, Control Name: {control_map[cid]}" for cid in remaining_controls] + + retry_prompt = f""" + You are an expert in cybersecurity risk management. Given the risk "{risk.risk_name}" and its associated organization details "{organization_details}", + your task is to select **exactly {missing_count} unique controls** from the provided list that best mitigate this risk. Each control should be assigned a weight between **1 and 10** based on its effectiveness in reducing the risk. + ### Rules: + 1. **Each control ID must be unique** (no duplicates). + 2. **Only return control IDs and weights** in the exact format below. + 3. **Weights must be between 1 and 10** (1 = low impact, 10 = high impact). + 4. **Do NOT add explanations, descriptions, or extra text.** + 5. **Ensure that control IDs are randomly distributed and diverse across different categories.** + ### Available Controls: + {remaining_controls_list} + + ### Expected Response Format (STRICTLY FOLLOW THIS FORMAT): + : + : + ### Example Correct Response (NO DUPLICATES): + 12 : 8 + 45 : 7 + ⚠️ **If you provide duplicate control IDs, your response will be rejected. Ensure all control IDs are unique.** + ⚠️ **Follow the response format exactly. Any deviation will be considered invalid.** + """ + + result = fetch_controls(retry_prompt) + for line in result.split("\n"): + line = line.strip() + parts = line.split(":") + if len(parts) == 2: + control_id_str = parts[0].replace("ID:", "").replace("id:", "").replace("Id:", "").strip() + weight_str = parts[1].strip().replace("Weight:", "").replace("weight:", "").strip() + print(f"Control:{control_id} Weight:{weight_str}") + print(f"ControlType: {type(control_id)} WeightType: {type(weight_str)}") + + control_id_str = ''.join(filter(str.isdigit, control_id_str)) + weight_str = ''.join(filter(str.isdigit, weight_str)) + + if control_id_str and weight_str: + try: + control_id = int(control_id_str) + weight = int(weight_str) + + if control_id in valid_control_ids and 1 <= weight <= 10 and control_id not in control_ids_seen: + selected_controls.append((control_id, weight)) + control_ids_seen.add(control_id) + except ValueError: + continue + + if not remaining_controls: + break + return selected_controls if len(selected_controls) == 10 else [] diff --git a/backend/core/views.py b/backend/core/views.py index 7b4ecc7..1eed1e7 100644 --- a/backend/core/views.py +++ b/backend/core/views.py @@ -24,34 +24,11 @@ def signup(request): if request.method == 'POST': form = OrganizationForm(request.POST) if form.is_valid(): - organization = form.save() - top_risk_ids = get_top_risk(organization) - top_risks = Risk.objects.filter(risk_id__in = top_risk_ids) - - organization.risks.set(top_risks) - - document = Document.objects.create(organization=organization) - document.add_segment('h1', "Top 10 Risk Identified") - - risk_content = "\n\n".join([ - f"Risk: {risk.risk_id} : {risk.risk_name} \n" - f"Category: {risk.category}\n" - f"Primary Impaact: {risk.primary_impact} \n" - f"Secondary Impact: {risk.secondary_impact}\n" - f"Tertiary Impact: {risk.tretiary_impact} \n" - f"Detection Difficulty: {risk.detection_difficulty} \n" - f"Recovery Complexity: {risk.recovery_complexity} \n" - f"Business Impact Severity: {risk.businnes_impact_severity}\n" - for risk in top_risks - ]) - - document.add_segment('body',f"Identified Risks: \n\n{risk_content}") - + form.save() send_confirmation_email(form.data['email']) return render(request, 'thankyou.html', { 'email': form.data['email'], - 'document_link': reverse('core:document', args=[str(document.id)]) }) else: logging.error(form.errors) diff --git a/backend/settings.py b/backend/settings.py index f01cf3c..fee0daf 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -51,6 +51,8 @@ INSTALLED_APPS = [ 'django_extensions', 'widget_tweaks', 'django_seed', + 'django_celery_results', + # my apps 'backend.accounts.apps.AccountsConfig', 'backend.core.apps.CoreConfig', @@ -156,3 +158,11 @@ DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' LOGIN_URL = '/admin/login/' LOGIN_REDIRECT_URL = 'core:index' # LOGOUT_REDIRECT_URL = 'core:index' + +# Celery Configuration +CELERY_BROKER_URL = 'redis://localhost:6380/0' +CELERY_RESULT_BACKEND = 'django-db' # Store task results in the Django database +CELERY_ACCEPT_CONTENT = ['json'] +CELERY_TASK_SERIALIZER = 'json' +CELERY_TIMEZONE = 'UTC' +