diff --git a/.gitignore b/.gitignore index 782015b..41e6063 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ _testmain.go vegeta vegeta.test +targets.txt diff --git a/.targets.txt b/.targets.txt new file mode 100755 index 0000000..ff4c478 --- /dev/null +++ b/.targets.txt @@ -0,0 +1,2 @@ +HEAD http://lolcathost:9194 + diff --git a/main.go b/main.go index de69f52..e540acf 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "flag" + "fmt" vegeta "github.com/tsenart/vegeta/lib" "io" "log" @@ -30,58 +31,76 @@ func main() { return } - if *rate == 0 { - log.Fatal("rate can't be zero") - } - - targets, err := vegeta.NewTargetsFromFile(*targetsf) - if err != nil { + if err := run(*rate, *duration, *targetsf, *ordering, *reporter, *output); err != nil { log.Fatal(err) } +} - switch *ordering { - case "random": - targets.Shuffle(time.Now().UnixNano()) - case "sequential": - break - default: - log.Fatalf("Unknown ordering %s", *ordering) +var ( + errRatePrefix = "Rate: " + errDurationPrefix = "Duration: " + errOutputFilePrefix = "Output file: " + errTargetsFilePrefix = "Targets file: " + errOrderingPrefix = "Ordering: " + errReportingPrefix = "Reporting: " +) + +// attack is an utility function that validates the attack arguments, sets up the +// required resources, launches the attack and reports the results +func run(rate uint64, duration time.Duration, targetsf, ordering, reporter, output string) error { + if rate == 0 { + return fmt.Errorf(errRatePrefix + "can't be zero") } - if *duration == 0 { - log.Fatal("Duration provided is invalid") - } - - var rep vegeta.Reporter - switch *reporter { - case "text": - rep = vegeta.NewTextReporter() - case "plot:timings": - rep = vegeta.NewTimingsPlotReporter() - default: - log.Println("Reporter provided is not supported. using text") - rep = vegeta.NewTextReporter() + if duration == 0 { + return fmt.Errorf(errDurationPrefix + "can't be zero") } var out io.Writer - switch *output { + switch output { case "stdout": out = os.Stdout default: - file, err := os.Create(*output) + file, err := os.Create(output) if err != nil { - log.Fatalf("Couldn't open `%s` for writing report: %s", *output, err) + return fmt.Errorf(errOutputFilePrefix+"(%s): %s", output, err) } defer file.Close() out = file } - log.Printf("Vegeta is attacking %d targets in %s order for %s...\n", len(targets), *ordering, *duration) - vegeta.Attack(targets, *rate, *duration, rep) + var rep vegeta.Reporter + switch reporter { + case "text": + rep = vegeta.NewTextReporter() + case "plot:timings": + rep = vegeta.NewTimingsPlotReporter() + default: + log.Println("Reporter provided is not supported. Using text") + rep = vegeta.NewTextReporter() + } + + targets, err := vegeta.NewTargetsFromFile(targetsf) + if err != nil { + return fmt.Errorf(errTargetsFilePrefix+"(%s): %s", targetsf, err) + } + + switch ordering { + case "random": + targets.Shuffle(time.Now().UnixNano()) + case "sequential": + break + default: + return fmt.Errorf(errOrderingPrefix+"`%s` is invalid", ordering) + } + + log.Printf("Vegeta is attacking %d targets in %s order for %s...\n", len(targets), ordering, duration) + vegeta.Attack(targets, rate, duration, rep) log.Println("Done!") - log.Printf("Writing report to '%s'...", *output) + log.Printf("Writing report to '%s'...", output) if err = rep.Report(out); err != nil { - log.Printf("Failed to report: %s", err) + return fmt.Errorf(errReportingPrefix+"%s", err) } + return nil } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..24ec69a --- /dev/null +++ b/main_test.go @@ -0,0 +1,102 @@ +package main + +import ( + "io/ioutil" + "log" + "strings" + "testing" + "time" +) + +func init() { + // Discard default log output + log.SetOutput(ioutil.Discard) +} + +func TestRateValidation(t *testing.T) { + rate, duration, targetsf, ordering, reporter, output := defaultArguments() + rate = 0 + + err := run(rate, duration, targetsf, ordering, reporter, output) + if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errRatePrefix)) { + t.Errorf("Rate 0 shouldn't be valid: %s", err) + } +} + +func TestDurationValidation(t *testing.T) { + rate, duration, targetsf, ordering, reporter, output := defaultArguments() + duration = 0 + + err := run(rate, duration, targetsf, ordering, reporter, output) + if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errDurationPrefix)) { + t.Errorf("Duration 0 shouldn't be valid: %s", err) + } +} + +func TestOutputValidation(t *testing.T) { + rate, duration, targetsf, ordering, reporter, _ := defaultArguments() + + // Good cases + for _, output := range []string{"stdout", "/dev/null"} { + err := run(rate, duration, targetsf, ordering, reporter, output) + if err != nil { + t.Errorf("Output file `%s` should be valid: %s", output, err) + } + } + + // Bad case + badOutput := "" + err := run(rate, duration, targetsf, ordering, reporter, badOutput) + if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errOutputFilePrefix)) { + t.Errorf("Output file `%s` shouldn't be valid: %s", badOutput, err) + } +} + +func TestReporter(t *testing.T) { + rate, duration, targetsf, ordering, reporter, output := defaultArguments() + + err := run(rate, duration, targetsf, ordering, reporter, output) + if err != nil { + t.Errorf("Reporter shouldn't return an error: %s", err) + } +} + +func TestTargetsValidation(t *testing.T) { + rate, duration, goodFile, ordering, reporter, output := defaultArguments() + + // Good case + err := run(rate, duration, goodFile, ordering, reporter, output) + if err != nil { + t.Errorf("Targets file `%s` should be valid: %s", goodFile, err) + } + + // Bad case + badFile := "randomInexistingFile12345.txt" + err = run(rate, duration, badFile, ordering, reporter, output) + if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errTargetsFilePrefix)) { + t.Errorf("Targets file `%s` shouldn't be valid: %s", badFile, err) + } +} + +func TestOrderingValidation(t *testing.T) { + rate, duration, targetsf, _, reporter, output := defaultArguments() + + // Good cases + for _, ordering := range []string{"random", "sequential"} { + err := run(rate, duration, targetsf, ordering, reporter, output) + if err != nil { + t.Errorf("Ordering `%s` should be valid: %s", ordering, err) + } + } + + // Bad case + badOrdering := "lolcat" + err := run(rate, duration, targetsf, badOrdering, reporter, output) + if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errOrderingPrefix)) { + t.Errorf("Ordering `%s` shouldn't be valid: %s", badOrdering, err) + } +} + +func defaultArguments() (uint64, time.Duration, string, string, string, string) { + return uint64(1000), 5 * time.Millisecond, ".targets.txt", "random", "text", "/dev/null" +}