package controllers import ( "html/template" "log" "net/http" "net/url" "os" "path/filepath" "risklet/db" ) func Signup(w http.ResponseWriter, r *http.Request) { if r.Method == "GET" { handleGet(w, r) } else if r.Method == "POST" { handlePost(w, r) } else { http.Error(w, "Method not allowed.", http.StatusMethodNotAllowed) return } } func handlePost(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { log.Println("Error processing form: ", err) handleGet(w, r) } company := createCompany(r.PostForm) companyId, err := db.InsertCompany(company) if err != nil { log.Println("Error inserting company into database ", err) handleGet(w, r) } basicProfile := createBasicProfile(companyId, r.PostForm) _, err = db.InsertBasicProfile(basicProfile) if err != nil { log.Println("Error inserting into database ", err) handleGet(w, r) } } func handleGet(w http.ResponseWriter, r *http.Request) { lp := filepath.Join("application", "layouts", "main.html") fp := filepath.Join("application", "views", "signup.html") log.Println("Hitting Signup") // Return a 404 if the template doesn't exist info, err := os.Stat(fp) if err != nil { if os.IsNotExist(err) { http.NotFound(w, r) return } } // Return a 404 if the request is for a directory if info.IsDir() { http.NotFound(w, r) return } tmpl, err := template.ParseFiles(lp, fp) if err != nil { // Log the detailed error log.Print(err.Error()) // Return a generic "Internal Server Error" message http.Error(w, http.StatusText(500), 500) return } err = tmpl.ExecuteTemplate(w, "main.html", nil) if err != nil { log.Print(err.Error()) http.Error(w, http.StatusText(500), 500) } } func createBasicProfile(companyId int, f url.Values) db.BasicProfile { return db.BasicProfile{ CompanyId: companyId, Employees: f.Get("Employees"), Revenue: f.Get("Revenue"), Applications: f.Get("Applications"), Compliance: f.Get("Compliance"), Industry: f.Get("Industry"), ITDependency: f.Get("ITDependency"), DataSensitivity: f.Get("DataSensitivity"), DataVolume: f.Get("DataVolume"), NetworkSegmentation: f.Get("NetworkSegmentation"), LegacySystems: f.Get("LegacySystems"), IoTIntegration: f.Get("IoTIntegration"), RemoteWork: f.Get("RemoteWork"), BYOD: f.Get("BYOD"), VPN: f.Get("VPN"), API: f.Get("API"), VendorAccess: f.Get("VendorAccess"), InternalDev: f.Get("InternalDev"), } } func createCompany(f url.Values) db.Company { return db.Company{ UUID: db.GenerateRandomString(), Name: f.Get("Name"), Email: f.Get("Email"), TaxId: f.Get("TaxId"), Password: db.GenerateRandomString(), } }