Add -header flag to attack sub command

This commit is contained in:
Tomás Senart
2013-09-21 18:28:59 +02:00
parent e75a244b61
commit 172e7492f2
2 changed files with 73 additions and 14 deletions

View File

@@ -1,10 +1,13 @@
package main
import (
"bytes"
"flag"
"fmt"
vegeta "github.com/tsenart/vegeta/lib"
"log"
"net/http"
"strings"
"time"
)
@@ -15,16 +18,18 @@ func attackCmd(args []string) command {
ordering := fs.String("ordering", "random", "Attack ordering [sequential, random]")
duration := fs.Duration("duration", 10*time.Second, "Duration of the test")
output := fs.String("output", "stdout", "Output file")
hdrs := headers{Header: make(http.Header)}
fs.Var(hdrs, "header", "Targets request header")
fs.Parse(args)
return func() error {
return attack(*rate, *duration, *targetsf, *ordering, *output)
return attack(*rate, *duration, *targetsf, *ordering, *output, hdrs.Header)
}
}
// attack validates the attack arguments, sets up the
// required resources, launches the attack and writes the results
func attack(rate uint64, duration time.Duration, targetsf, ordering, output string) error {
func attack(rate uint64, duration time.Duration, targetsf, ordering, output string, header http.Header) error {
if rate == 0 {
return fmt.Errorf(errRatePrefix + "can't be zero")
}
@@ -42,6 +47,7 @@ func attack(rate uint64, duration time.Duration, targetsf, ordering, output stri
if err != nil {
return fmt.Errorf(errTargetsFilePrefix+"(%s): %s", targetsf, err)
}
targets.SetHeader(header)
switch ordering {
case "random":
@@ -76,3 +82,30 @@ const (
errOrderingPrefix = "Ordering: "
errReportingPrefix = "Reporting: "
)
// headers is the http.Header used in each target request
// it is defined here to implement the flag.Value interface
// in order to support multiple identical flags for request header
// specification
type headers struct{ http.Header }
func (h headers) String() string {
buf := &bytes.Buffer{}
if err := h.Write(buf); err != nil {
return ""
}
return buf.String()
}
func (h headers) Set(value string) error {
parts := strings.Split(value, ":")
if len(parts) != 2 {
return fmt.Errorf("Header '%s' has a wrong format", value)
}
key, val := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
if key == "" || val == "" {
return fmt.Errorf("Header '%s' has a wrong format", value)
}
h.Add(key, val)
return nil
}

View File

@@ -1,8 +1,10 @@
package main
import (
"flag"
"io/ioutil"
"log"
"net/http"
"strings"
"testing"
"time"
@@ -14,48 +16,48 @@ func init() {
}
func TestRateValidation(t *testing.T) {
rate, duration, targetsf, ordering, output := defaultArguments()
rate, duration, targetsf, ordering, output, header := defaultArguments()
rate = 0
err := attack(rate, duration, targetsf, ordering, output)
err := attack(rate, duration, targetsf, ordering, output, header)
if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errRatePrefix)) {
t.Errorf("Rate 0 shouldn't be valid: %s", err)
}
}
func TestDurationValidation(t *testing.T) {
rate, duration, targetsf, ordering, output := defaultArguments()
rate, duration, targetsf, ordering, output, header := defaultArguments()
duration = 0
err := attack(rate, duration, targetsf, ordering, output)
err := attack(rate, duration, targetsf, ordering, output, header)
if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errDurationPrefix)) {
t.Errorf("Duration 0 shouldn't be valid: %s", err)
}
}
func TestTargetsValidation(t *testing.T) {
rate, duration, goodFile, ordering, output := defaultArguments()
rate, duration, goodFile, ordering, output, header := defaultArguments()
// Good case
err := attack(rate, duration, goodFile, ordering, output)
err := attack(rate, duration, goodFile, ordering, output, header)
if err != nil {
t.Errorf("Targets file `%s` should be valid: %s", goodFile, err)
}
// Bad case
badFile := "randomInexistingFile12345.txt"
err = attack(rate, duration, badFile, ordering, output)
err = attack(rate, duration, badFile, ordering, output, header)
if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errTargetsFilePrefix)) {
t.Errorf("Targets file `%s` shouldn't be valid: %s", badFile, err)
}
}
func TestOrderingValidation(t *testing.T) {
rate, duration, targetsf, _, output := defaultArguments()
rate, duration, targetsf, _, output, header := defaultArguments()
// Good cases
for _, ordering := range []string{"random", "sequential"} {
err := attack(rate, duration, targetsf, ordering, output)
err := attack(rate, duration, targetsf, ordering, output, header)
if err != nil {
t.Errorf("Ordering `%s` should be valid: %s", ordering, err)
}
@@ -63,12 +65,36 @@ func TestOrderingValidation(t *testing.T) {
// Bad case
badOrdering := "lolcat"
err := attack(rate, duration, targetsf, badOrdering, output)
err := attack(rate, duration, targetsf, badOrdering, output, header)
if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errOrderingPrefix)) {
t.Errorf("Ordering `%s` shouldn't be valid: %s", badOrdering, err)
}
}
func defaultArguments() (uint64, time.Duration, string, string, string) {
return uint64(1000), 5 * time.Millisecond, ".targets.txt", "random", "/dev/null"
func TestHeadersParsing(t *testing.T) {
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.SetOutput(ioutil.Discard)
hdrs := headers{Header: make(http.Header)}
fs.Var(hdrs, "H", "Header")
// Good case
good := []string{"-H", "Host: lolcathost"}
if err := fs.Parse(good); err != nil {
t.Errorf("%v should be a valid header", good[1])
}
// Bad cases
bad := [][]string{[]string{"-H", "Host:"}, []string{"-H", "Host"}}
for _, args := range bad {
if err := fs.Parse(args); err == nil {
t.Errorf("%v should not be a valid header", args[1])
}
}
}
func defaultArguments() (uint64, time.Duration, string, string, string, http.Header) {
return uint64(1000),
5 * time.Millisecond,
".targets.txt",
"random",
"/dev/null",
http.Header{}
}