Files
old-riskletpy/application/controllers/signup.go
2024-11-17 19:41:30 +01:00

118 lines
2.9 KiB
Go

package controllers
import (
"html/template"
"log"
"net/http"
"net/url"
"os"
"path/filepath"
"risklet/db"
)
func Signup(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
handleGet(w, r)
} else if r.Method == "POST" {
handlePost(w, r)
} else {
http.Error(w, "Method not allowed.", http.StatusMethodNotAllowed)
return
}
}
func handlePost(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
log.Println("Error processing form: ", err)
handleGet(w, r)
}
company := createCompany(r.PostForm)
companyId, err := db.InsertCompany(company)
if err != nil {
log.Println("Error inserting company into database ", err)
handleGet(w, r)
}
basicProfile := createBasicProfile(companyId, r.PostForm)
_, err = db.InsertBasicProfile(basicProfile)
if err != nil {
log.Println("Error inserting into database ", err)
handleGet(w, r)
}
http.Redirect(w, r, "/thankyou", http.StatusSeeOther)
}
func handleGet(w http.ResponseWriter, r *http.Request) {
lp := filepath.Join("application", "layouts", "main.html")
fp := filepath.Join("application", "views", "signup.html")
// add a CSP header to allow only same-origin scripts
w.Header().Set("Content-Security-Policy", "script-src 'unsafe-eval' 'self'")
// Return a 404 if the template doesn't exist
info, err := os.Stat(fp)
if err != nil {
if os.IsNotExist(err) {
http.NotFound(w, r)
return
}
}
// Return a 404 if the request is for a directory
if info.IsDir() {
http.NotFound(w, r)
return
}
tmpl, err := template.ParseFiles(lp, fp)
if err != nil {
// Log the detailed error
log.Print(err.Error())
// Return a generic "Internal Server Error" message
http.Error(w, http.StatusText(500), 500)
return
}
err = tmpl.ExecuteTemplate(w, "main.html", nil)
if err != nil {
log.Print(err.Error())
http.Error(w, http.StatusText(500), 500)
}
}
func createBasicProfile(companyId int, f url.Values) db.BasicProfile {
return db.BasicProfile{
CompanyId: companyId,
Employees: f.Get("Employees"),
Revenue: f.Get("Revenue"),
Applications: f.Get("Applications"),
Compliance: f.Get("Compliance"),
Industry: f.Get("Industry"),
ITDependency: f.Get("ITDependency"),
DataSensitivity: f.Get("DataSensitivity"),
DataVolume: f.Get("DataVolume"),
NetworkSegmentation: f.Get("NetworkSegmentation"),
LegacySystems: f.Get("LegacySystems"),
IoTIntegration: f.Get("IoTIntegration"),
RemoteWork: f.Get("RemoteWork"),
BYOD: f.Get("BYOD"),
VPN: f.Get("VPN"),
API: f.Get("API"),
VendorAccess: f.Get("VendorAccess"),
InternalDev: f.Get("InternalDev"),
}
}
func createCompany(f url.Values) db.Company {
return db.Company{
UUID: db.GenerateRandomString(),
Name: f.Get("Name"),
Email: f.Get("Email"),
TaxId: f.Get("TaxId"),
Password: db.GenerateRandomString(),
}
}