Merge pull request #23 from tsenart/headers
Ability to set request headers on attack sub command
This commit is contained in:
37
attack.go
37
attack.go
@@ -1,10 +1,13 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
vegeta "github.com/tsenart/vegeta/lib"
|
vegeta "github.com/tsenart/vegeta/lib"
|
||||||
"log"
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,16 +18,18 @@ func attackCmd(args []string) command {
|
|||||||
ordering := fs.String("ordering", "random", "Attack ordering [sequential, random]")
|
ordering := fs.String("ordering", "random", "Attack ordering [sequential, random]")
|
||||||
duration := fs.Duration("duration", 10*time.Second, "Duration of the test")
|
duration := fs.Duration("duration", 10*time.Second, "Duration of the test")
|
||||||
output := fs.String("output", "stdout", "Output file")
|
output := fs.String("output", "stdout", "Output file")
|
||||||
|
hdrs := headers{Header: make(http.Header)}
|
||||||
|
fs.Var(hdrs, "header", "Targets request header")
|
||||||
fs.Parse(args)
|
fs.Parse(args)
|
||||||
|
|
||||||
return func() error {
|
return func() error {
|
||||||
return attack(*rate, *duration, *targetsf, *ordering, *output)
|
return attack(*rate, *duration, *targetsf, *ordering, *output, hdrs.Header)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// attack validates the attack arguments, sets up the
|
// attack validates the attack arguments, sets up the
|
||||||
// required resources, launches the attack and writes the results
|
// required resources, launches the attack and writes the results
|
||||||
func attack(rate uint64, duration time.Duration, targetsf, ordering, output string) error {
|
func attack(rate uint64, duration time.Duration, targetsf, ordering, output string, header http.Header) error {
|
||||||
if rate == 0 {
|
if rate == 0 {
|
||||||
return fmt.Errorf(errRatePrefix + "can't be zero")
|
return fmt.Errorf(errRatePrefix + "can't be zero")
|
||||||
}
|
}
|
||||||
@@ -42,6 +47,7 @@ func attack(rate uint64, duration time.Duration, targetsf, ordering, output stri
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(errTargetsFilePrefix+"(%s): %s", targetsf, err)
|
return fmt.Errorf(errTargetsFilePrefix+"(%s): %s", targetsf, err)
|
||||||
}
|
}
|
||||||
|
targets.SetHeader(header)
|
||||||
|
|
||||||
switch ordering {
|
switch ordering {
|
||||||
case "random":
|
case "random":
|
||||||
@@ -76,3 +82,30 @@ const (
|
|||||||
errOrderingPrefix = "Ordering: "
|
errOrderingPrefix = "Ordering: "
|
||||||
errReportingPrefix = "Reporting: "
|
errReportingPrefix = "Reporting: "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// headers is the http.Header used in each target request
|
||||||
|
// it is defined here to implement the flag.Value interface
|
||||||
|
// in order to support multiple identical flags for request header
|
||||||
|
// specification
|
||||||
|
type headers struct{ http.Header }
|
||||||
|
|
||||||
|
func (h headers) String() string {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
if err := h.Write(buf); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h headers) Set(value string) error {
|
||||||
|
parts := strings.Split(value, ":")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return fmt.Errorf("Header '%s' has a wrong format", value)
|
||||||
|
}
|
||||||
|
key, val := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
|
||||||
|
if key == "" || val == "" {
|
||||||
|
return fmt.Errorf("Header '%s' has a wrong format", value)
|
||||||
|
}
|
||||||
|
h.Add(key, val)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"flag"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -14,48 +16,48 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRateValidation(t *testing.T) {
|
func TestRateValidation(t *testing.T) {
|
||||||
rate, duration, targetsf, ordering, output := defaultArguments()
|
rate, duration, targetsf, ordering, output, header := defaultArguments()
|
||||||
rate = 0
|
rate = 0
|
||||||
|
|
||||||
err := attack(rate, duration, targetsf, ordering, output)
|
err := attack(rate, duration, targetsf, ordering, output, header)
|
||||||
if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errRatePrefix)) {
|
if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errRatePrefix)) {
|
||||||
t.Errorf("Rate 0 shouldn't be valid: %s", err)
|
t.Errorf("Rate 0 shouldn't be valid: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDurationValidation(t *testing.T) {
|
func TestDurationValidation(t *testing.T) {
|
||||||
rate, duration, targetsf, ordering, output := defaultArguments()
|
rate, duration, targetsf, ordering, output, header := defaultArguments()
|
||||||
duration = 0
|
duration = 0
|
||||||
|
|
||||||
err := attack(rate, duration, targetsf, ordering, output)
|
err := attack(rate, duration, targetsf, ordering, output, header)
|
||||||
if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errDurationPrefix)) {
|
if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errDurationPrefix)) {
|
||||||
t.Errorf("Duration 0 shouldn't be valid: %s", err)
|
t.Errorf("Duration 0 shouldn't be valid: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTargetsValidation(t *testing.T) {
|
func TestTargetsValidation(t *testing.T) {
|
||||||
rate, duration, goodFile, ordering, output := defaultArguments()
|
rate, duration, goodFile, ordering, output, header := defaultArguments()
|
||||||
|
|
||||||
// Good case
|
// Good case
|
||||||
err := attack(rate, duration, goodFile, ordering, output)
|
err := attack(rate, duration, goodFile, ordering, output, header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Targets file `%s` should be valid: %s", goodFile, err)
|
t.Errorf("Targets file `%s` should be valid: %s", goodFile, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bad case
|
// Bad case
|
||||||
badFile := "randomInexistingFile12345.txt"
|
badFile := "randomInexistingFile12345.txt"
|
||||||
err = attack(rate, duration, badFile, ordering, output)
|
err = attack(rate, duration, badFile, ordering, output, header)
|
||||||
if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errTargetsFilePrefix)) {
|
if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errTargetsFilePrefix)) {
|
||||||
t.Errorf("Targets file `%s` shouldn't be valid: %s", badFile, err)
|
t.Errorf("Targets file `%s` shouldn't be valid: %s", badFile, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOrderingValidation(t *testing.T) {
|
func TestOrderingValidation(t *testing.T) {
|
||||||
rate, duration, targetsf, _, output := defaultArguments()
|
rate, duration, targetsf, _, output, header := defaultArguments()
|
||||||
|
|
||||||
// Good cases
|
// Good cases
|
||||||
for _, ordering := range []string{"random", "sequential"} {
|
for _, ordering := range []string{"random", "sequential"} {
|
||||||
err := attack(rate, duration, targetsf, ordering, output)
|
err := attack(rate, duration, targetsf, ordering, output, header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Ordering `%s` should be valid: %s", ordering, err)
|
t.Errorf("Ordering `%s` should be valid: %s", ordering, err)
|
||||||
}
|
}
|
||||||
@@ -63,12 +65,36 @@ func TestOrderingValidation(t *testing.T) {
|
|||||||
|
|
||||||
// Bad case
|
// Bad case
|
||||||
badOrdering := "lolcat"
|
badOrdering := "lolcat"
|
||||||
err := attack(rate, duration, targetsf, badOrdering, output)
|
err := attack(rate, duration, targetsf, badOrdering, output, header)
|
||||||
if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errOrderingPrefix)) {
|
if err == nil || (err != nil && !strings.HasPrefix(err.Error(), errOrderingPrefix)) {
|
||||||
t.Errorf("Ordering `%s` shouldn't be valid: %s", badOrdering, err)
|
t.Errorf("Ordering `%s` shouldn't be valid: %s", badOrdering, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultArguments() (uint64, time.Duration, string, string, string) {
|
func TestHeadersParsing(t *testing.T) {
|
||||||
return uint64(1000), 5 * time.Millisecond, ".targets.txt", "random", "/dev/null"
|
fs := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||||
|
fs.SetOutput(ioutil.Discard)
|
||||||
|
hdrs := headers{Header: make(http.Header)}
|
||||||
|
fs.Var(hdrs, "H", "Header")
|
||||||
|
// Good case
|
||||||
|
good := []string{"-H", "Host: lolcathost"}
|
||||||
|
if err := fs.Parse(good); err != nil {
|
||||||
|
t.Errorf("%v should be a valid header", good[1])
|
||||||
|
}
|
||||||
|
// Bad cases
|
||||||
|
bad := [][]string{[]string{"-H", "Host:"}, []string{"-H", "Host"}}
|
||||||
|
for _, args := range bad {
|
||||||
|
if err := fs.Parse(args); err == nil {
|
||||||
|
t.Errorf("%v should not be a valid header", args[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultArguments() (uint64, time.Duration, string, string, string, http.Header) {
|
||||||
|
return uint64(1000),
|
||||||
|
5 * time.Millisecond,
|
||||||
|
".targets.txt",
|
||||||
|
"random",
|
||||||
|
"/dev/null",
|
||||||
|
http.Header{}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,3 +56,10 @@ func (t Targets) Shuffle(seed int64) {
|
|||||||
t[i], t[rnd] = t[rnd], t[i]
|
t[i], t[rnd] = t[rnd], t[i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetHeader sets the passed request header in all Targets
|
||||||
|
func (t Targets) SetHeader(header http.Header) {
|
||||||
|
for _, target := range t {
|
||||||
|
target.Header = header
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -52,3 +52,14 @@ func TestShuffle(t *testing.T) {
|
|||||||
}
|
}
|
||||||
t.Fatal("Targets were not shuffled correctly")
|
t.Fatal("Targets were not shuffled correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetHeader(t *testing.T) {
|
||||||
|
targets, _ := NewTargets([]string{"GET http://lolcathost:9999/", "HEAD http://lolcathost:9999/"})
|
||||||
|
want := "lolcathost.com"
|
||||||
|
targets.SetHeader(http.Header{"Host": []string{want}})
|
||||||
|
for _, target := range targets {
|
||||||
|
if got := target.Header.Get("Host"); got != want {
|
||||||
|
t.Errorf("Want: %s, Got: %s", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user