diff --git a/client.go b/client.go index 53ccb4e..e4805cd 100644 --- a/client.go +++ b/client.go @@ -9,10 +9,7 @@ import ( // Client is an http.Client with rate limiting // TODO: Add timeouts -type Client struct { - http.Client - rate uint64 -} +type Client struct{ http.Client } // Response represents the metrics we want out of an http.Response type Response struct { @@ -24,15 +21,10 @@ type Response struct { err error } -// NewClient returns an initialized Client -func NewClient(rate uint64) *Client { - return &Client{http.Client{}, rate} -} - // Drill loops over the passed reqs channel and executes each request. -// It is throttled to the rate specified in the initializer -func (c *Client) Drill(reqs chan *http.Request, res chan *Response) { - throttle := time.Tick(time.Duration(1e9 / c.rate)) +// It is throttled to the rate specified +func (c *Client) Drill(rate uint64, reqs chan *http.Request, res chan *Response) { + throttle := time.Tick(time.Duration(1e9 / rate)) for req := range reqs { <-throttle go c.Do(req, res) diff --git a/main.go b/main.go index 0b89248..0b608fe 100644 --- a/main.go +++ b/main.go @@ -89,8 +89,8 @@ func attack(targets Targets, ordering string, rate uint64, duration time.Duratio defer close(hits) responses := make(chan *Response, cap(hits)) defer close(responses) - client := NewClient(rate) - go client.Drill(hits, responses) // Attack! + client := Client{} + go client.Drill(rate, hits, responses) // Attack! for i := 0; i < cap(hits); i++ { hits <- targets[i%len(targets)] } diff --git a/main_test.go b/main_test.go index 50736b4..f64e124 100644 --- a/main_test.go +++ b/main_test.go @@ -21,10 +21,10 @@ func TestAttackRate(t *testing.T) { if err != nil { t.Fatal(err) } - rate := uint(5000) + rate := uint64(5000) rep := NewTextReporter() attack(targets, "random", rate, 1*time.Second, rep) - if hits := atomic.LoadUint64(&hitCount); uint(hits) != rate { + if hits := atomic.LoadUint64(&hitCount); hits != rate { rep.Report(os.Stdout) t.Fatalf("Wrong number of hits: want %d, got %d\n", rate, hits) }