118 lines
2.9 KiB
Go
118 lines
2.9 KiB
Go
package controllers
|
|
|
|
import (
|
|
"html/template"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"risklet/db"
|
|
)
|
|
|
|
func Signup(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == "GET" {
|
|
handleGet(w, r)
|
|
} else if r.Method == "POST" {
|
|
handlePost(w, r)
|
|
} else {
|
|
http.Error(w, "Method not allowed.", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
}
|
|
|
|
func handlePost(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
log.Println("Error processing form: ", err)
|
|
handleGet(w, r)
|
|
}
|
|
company := createCompany(r.PostForm)
|
|
companyId, err := db.InsertCompany(company)
|
|
if err != nil {
|
|
log.Println("Error inserting company into database ", err)
|
|
handleGet(w, r)
|
|
}
|
|
|
|
basicProfile := createBasicProfile(companyId, r.PostForm)
|
|
|
|
_, err = db.InsertBasicProfile(basicProfile)
|
|
if err != nil {
|
|
log.Println("Error inserting into database ", err)
|
|
handleGet(w, r)
|
|
}
|
|
|
|
http.Redirect(w, r, "/thankyou", http.StatusSeeOther)
|
|
}
|
|
|
|
func handleGet(w http.ResponseWriter, r *http.Request) {
|
|
lp := filepath.Join("application", "layouts", "main.html")
|
|
fp := filepath.Join("application", "views", "signup.html")
|
|
|
|
// add a CSP header to allow only same-origin scripts
|
|
w.Header().Set("Content-Security-Policy", "script-src 'unsafe-eval' 'self'")
|
|
|
|
// Return a 404 if the template doesn't exist
|
|
info, err := os.Stat(fp)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Return a 404 if the request is for a directory
|
|
if info.IsDir() {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
|
|
tmpl, err := template.ParseFiles(lp, fp)
|
|
if err != nil {
|
|
// Log the detailed error
|
|
log.Print(err.Error())
|
|
// Return a generic "Internal Server Error" message
|
|
http.Error(w, http.StatusText(500), 500)
|
|
return
|
|
}
|
|
|
|
err = tmpl.ExecuteTemplate(w, "main.html", nil)
|
|
if err != nil {
|
|
log.Print(err.Error())
|
|
http.Error(w, http.StatusText(500), 500)
|
|
}
|
|
}
|
|
|
|
func createBasicProfile(companyId int, f url.Values) db.BasicProfile {
|
|
return db.BasicProfile{
|
|
CompanyId: companyId,
|
|
Employees: f.Get("Employees"),
|
|
Revenue: f.Get("Revenue"),
|
|
Applications: f.Get("Applications"),
|
|
Compliance: f.Get("Compliance"),
|
|
Industry: f.Get("Industry"),
|
|
ITDependency: f.Get("ITDependency"),
|
|
DataSensitivity: f.Get("DataSensitivity"),
|
|
DataVolume: f.Get("DataVolume"),
|
|
NetworkSegmentation: f.Get("NetworkSegmentation"),
|
|
LegacySystems: f.Get("LegacySystems"),
|
|
IoTIntegration: f.Get("IoTIntegration"),
|
|
RemoteWork: f.Get("RemoteWork"),
|
|
BYOD: f.Get("BYOD"),
|
|
VPN: f.Get("VPN"),
|
|
API: f.Get("API"),
|
|
VendorAccess: f.Get("VendorAccess"),
|
|
InternalDev: f.Get("InternalDev"),
|
|
}
|
|
}
|
|
|
|
func createCompany(f url.Values) db.Company {
|
|
return db.Company{
|
|
UUID: db.GenerateRandomString(),
|
|
Name: f.Get("Name"),
|
|
Email: f.Get("Email"),
|
|
TaxId: f.Get("TaxId"),
|
|
Password: db.GenerateRandomString(),
|
|
}
|
|
|
|
}
|