diff --git a/attack.go b/attack.go index 8d7e83b..49ed04e 100644 --- a/attack.go +++ b/attack.go @@ -1,10 +1,13 @@ package main import ( + "bytes" "flag" "fmt" vegeta "github.com/tsenart/vegeta/lib" "log" + "net/http" + "strings" "time" ) @@ -15,16 +18,18 @@ func attackCmd(args []string) command { ordering := fs.String("ordering", "random", "Attack ordering [sequential, random]") duration := fs.Duration("duration", 10*time.Second, "Duration of the test") output := fs.String("output", "stdout", "Output file") + hdrs := headers{Header: make(http.Header)} + fs.Var(hdrs, "header", "Targets request header") fs.Parse(args) return func() error { - return attack(*rate, *duration, *targetsf, *ordering, *output) + return attack(*rate, *duration, *targetsf, *ordering, *output, hdrs.Header) } } // attack validates the attack arguments, sets up the // required resources, launches the attack and writes the results -func attack(rate uint64, duration time.Duration, targetsf, ordering, output string) error { +func attack(rate uint64, duration time.Duration, targetsf, ordering, output string, header http.Header) error { if rate == 0 { return fmt.Errorf(errRatePrefix + "can't be zero") } @@ -42,6 +47,7 @@ func attack(rate uint64, duration time.Duration, targetsf, ordering, output stri if err != nil { return fmt.Errorf(errTargetsFilePrefix+"(%s): %s", targetsf, err) } + targets.SetHeader(header) switch ordering { case "random": @@ -76,3 +82,30 @@ const ( errOrderingPrefix = "Ordering: " errReportingPrefix = "Reporting: " ) + +// headers is the http.Header used in each target request +// it is defined here to implement the flag.Value interface +// in order to support multiple identical flags for request header +// specification +type headers struct{ http.Header } + +func (h headers) String() string { + buf := &bytes.Buffer{} + if err := h.Write(buf); err != nil { + return "" + } + return buf.String() +} + +func (h headers) Set(value string) error { + parts := strings.Split(value, ":") + if len(parts) != 2 { + return fmt.Errorf("Header '%s' has a wrong format", value) + } + key, val := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) + if key == "" || val == "" { + return fmt.Errorf("Header '%s' has a wrong format", value) + } + h.Add(key, val) + return nil +} diff --git a/attack_test.go b/attack_test.go index f9b1216..16e69c0 100644 --- a/attack_test.go +++ b/attack_test.go @@ -1,8 +1,10 @@ package main import ( + "flag" "io/ioutil" "log" + "net/http" "strings" "testing" "time" @@ -14,48 +16,48 @@ func init() { } func TestRateValidation(t *testing.T) { - rate, duration, targetsf, ordering, output := defaultArguments() + rate, duration, targetsf, ordering, output, header := defaultArguments() rate = 0 - err := attack(rate, duration, targetsf, ordering, output) + err := attack(rate, duration, targetsf, ordering, output, header) if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errRatePrefix)) { t.Errorf("Rate 0 shouldn't be valid: %s", err) } } func TestDurationValidation(t *testing.T) { - rate, duration, targetsf, ordering, output := defaultArguments() + rate, duration, targetsf, ordering, output, header := defaultArguments() duration = 0 - err := attack(rate, duration, targetsf, ordering, output) + err := attack(rate, duration, targetsf, ordering, output, header) if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errDurationPrefix)) { t.Errorf("Duration 0 shouldn't be valid: %s", err) } } func TestTargetsValidation(t *testing.T) { - rate, duration, goodFile, ordering, output := defaultArguments() + rate, duration, goodFile, ordering, output, header := defaultArguments() // Good case - err := attack(rate, duration, goodFile, ordering, output) + err := attack(rate, duration, goodFile, ordering, output, header) if err != nil { t.Errorf("Targets file `%s` should be valid: %s", goodFile, err) } // Bad case badFile := "randomInexistingFile12345.txt" - err = attack(rate, duration, badFile, ordering, output) + err = attack(rate, duration, badFile, ordering, output, header) if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errTargetsFilePrefix)) { t.Errorf("Targets file `%s` shouldn't be valid: %s", badFile, err) } } func TestOrderingValidation(t *testing.T) { - rate, duration, targetsf, _, output := defaultArguments() + rate, duration, targetsf, _, output, header := defaultArguments() // Good cases for _, ordering := range []string{"random", "sequential"} { - err := attack(rate, duration, targetsf, ordering, output) + err := attack(rate, duration, targetsf, ordering, output, header) if err != nil { t.Errorf("Ordering `%s` should be valid: %s", ordering, err) } @@ -63,12 +65,36 @@ func TestOrderingValidation(t *testing.T) { // Bad case badOrdering := "lolcat" - err := attack(rate, duration, targetsf, badOrdering, output) + err := attack(rate, duration, targetsf, badOrdering, output, header) if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errOrderingPrefix)) { t.Errorf("Ordering `%s` shouldn't be valid: %s", badOrdering, err) } } -func defaultArguments() (uint64, time.Duration, string, string, string) { - return uint64(1000), 5 * time.Millisecond, ".targets.txt", "random", "/dev/null" +func TestHeadersParsing(t *testing.T) { + fs := flag.NewFlagSet("test", flag.ContinueOnError) + fs.SetOutput(ioutil.Discard) + hdrs := headers{Header: make(http.Header)} + fs.Var(hdrs, "H", "Header") + // Good case + good := []string{"-H", "Host: lolcathost"} + if err := fs.Parse(good); err != nil { + t.Errorf("%v should be a valid header", good[1]) + } + // Bad cases + bad := [][]string{[]string{"-H", "Host:"}, []string{"-H", "Host"}} + for _, args := range bad { + if err := fs.Parse(args); err == nil { + t.Errorf("%v should not be a valid header", args[1]) + } + } +} + +func defaultArguments() (uint64, time.Duration, string, string, string, http.Header) { + return uint64(1000), + 5 * time.Millisecond, + ".targets.txt", + "random", + "/dev/null", + http.Header{} }