Extract main logic into attack function and test it
This commit is contained in:
62
main.go
62
main.go
@@ -30,14 +30,6 @@ func main() {
|
|||||||
if *rate == 0 {
|
if *rate == 0 {
|
||||||
log.Fatal("rate can't be zero")
|
log.Fatal("rate can't be zero")
|
||||||
}
|
}
|
||||||
// Magic formula that assumes each client can
|
|
||||||
// sustain 200 RPS under normal circumstances
|
|
||||||
clients := make([]*Client, int(math.Ceil(float64(*rate)/200.0)))
|
|
||||||
ratePerClient := *rate / uint(len(clients))
|
|
||||||
for i := 0; i < len(clients); i++ {
|
|
||||||
clients[i] = NewClient(ratePerClient)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse targets file
|
// Parse targets file
|
||||||
targets, err := NewTargetsFromFile(*targetsf)
|
targets, err := NewTargetsFromFile(*targetsf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -45,10 +37,8 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse ordering argument
|
// Parse ordering argument
|
||||||
random := false
|
|
||||||
if *ordering == "random" {
|
if *ordering == "random" {
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
random = true
|
|
||||||
} else if *ordering != "sequential" {
|
} else if *ordering != "sequential" {
|
||||||
log.Fatalf("Unknown ordering %s", *ordering)
|
log.Fatalf("Unknown ordering %s", *ordering)
|
||||||
}
|
}
|
||||||
@@ -58,34 +48,46 @@ func main() {
|
|||||||
log.Fatal("Duration provided is invalid")
|
log.Fatal("Duration provided is invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
hits := make(chan *http.Request, *rate*uint((*duration).Seconds()))
|
// Parse reporter
|
||||||
for i, idxs := 0, targets.Iter(random); i < cap(hits); i++ {
|
|
||||||
hits <- targets[idxs[i%len(idxs)]]
|
|
||||||
}
|
|
||||||
// Attack!
|
|
||||||
responses := make(chan *Response, cap(hits))
|
|
||||||
for _, client := range clients {
|
|
||||||
go client.Drill(hits, responses)
|
|
||||||
}
|
|
||||||
log.Printf("Vegeta is attacking ")
|
|
||||||
log.Printf("%d targets in %s order for %s with %d clients.\n", len(targets), *ordering, duration, len(clients))
|
|
||||||
|
|
||||||
var rep Reporter
|
var rep Reporter
|
||||||
switch *reporter {
|
switch *reporter {
|
||||||
case "text":
|
case "text":
|
||||||
rep = NewTextReporter(len(responses))
|
rep = NewTextReporter()
|
||||||
default:
|
default:
|
||||||
log.Println("reporter provided is not supported. using text")
|
log.Println("reporter provided is not supported. using text")
|
||||||
rep = NewTextReporter(len(responses))
|
rep = NewTextReporter()
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Vegeta is attacking %d targets in %s order for %s\n", len(targets), *ordering, *duration)
|
||||||
|
attack(targets, *ordering, *rate, *duration, rep)
|
||||||
|
|
||||||
|
// Report results!
|
||||||
|
if rep.Report(os.Stdout) != nil {
|
||||||
|
log.Fatal("Failed to report!")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func attack(targets Targets, ordering string, rate uint, duration time.Duration, rep Reporter) {
|
||||||
|
// Magic formula that assumes each client can
|
||||||
|
// sustain 200 RPS under normal circumstances
|
||||||
|
clients := make([]*Client, int(math.Ceil(float64(rate)/200.0)))
|
||||||
|
ratePerClient := rate / uint(len(clients))
|
||||||
|
for i := 0; i < len(clients); i++ {
|
||||||
|
clients[i] = NewClient(ratePerClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
hits := make(chan *http.Request, rate*uint((duration).Seconds()))
|
||||||
|
defer close(hits)
|
||||||
|
for i, idxs := 0, targets.Iter(ordering); i < cap(hits); i++ {
|
||||||
|
hits <- targets[idxs[i%len(idxs)]]
|
||||||
|
}
|
||||||
|
responses := make(chan *Response, cap(hits))
|
||||||
|
defer close(responses)
|
||||||
|
for _, client := range clients {
|
||||||
|
go client.Drill(hits, responses) // Attack!
|
||||||
}
|
}
|
||||||
// Wait for all requests to finish
|
// Wait for all requests to finish
|
||||||
for i := 0; i < cap(responses); i++ {
|
for i := 0; i < cap(responses); i++ {
|
||||||
rep.Add(<-responses)
|
rep.Add(<-responses)
|
||||||
}
|
}
|
||||||
close(hits)
|
|
||||||
close(responses)
|
|
||||||
|
|
||||||
if rep.Report(os.Stdout) != nil {
|
|
||||||
log.Fatal("Failed to report!")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
27
main_test.go
Normal file
27
main_test.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAttackRate(t *testing.T) {
|
||||||
|
hitCount := uint64(0)
|
||||||
|
server := httptest.NewServer(
|
||||||
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddUint64(&hitCount, 1)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
targets, err := NewTargets(bytes.NewBufferString("GET " + server.URL + "\n"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
attack(targets, "random", 50, 1*time.Second, NewTextReporter())
|
||||||
|
if hits := atomic.LoadUint64(&hitCount); hits != 50 {
|
||||||
|
t.Fatalf("Wrong number of hits: want %d, got %d\n", 50, hits)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -17,8 +17,8 @@ type TextReporter struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewTextReporter initializes a TextReporter with n responses
|
// NewTextReporter initializes a TextReporter with n responses
|
||||||
func NewTextReporter(n int) *TextReporter {
|
func NewTextReporter() *TextReporter {
|
||||||
return &TextReporter{responses: make([]*Response, n)}
|
return &TextReporter{responses: make([]*Response, 0)}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add adds a response to be used in the report
|
// Add adds a response to be used in the report
|
||||||
|
|||||||
@@ -49,8 +49,8 @@ func NewTargets(source io.Reader) (Targets, error) {
|
|||||||
return targets, nil
|
return targets, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Targets) Iter(random bool) []int {
|
func (t Targets) Iter(ordering string) []int {
|
||||||
if random {
|
if ordering == "random" {
|
||||||
return rand.Perm(len(t))
|
return rand.Perm(len(t))
|
||||||
}
|
}
|
||||||
iter := make([]int, len(t))
|
iter := make([]int, len(t))
|
||||||
|
|||||||
Reference in New Issue
Block a user