package main

import (
	"fmt"
	"sort"
	"sync"
	"time"

	"github.com/nats-io/go-nats"
)

const (
	NatsServer = "nats://127.0.0.1:4222"
	//NatsServer    = "nats://demo.nats.io:4222"
	MeasureTimes  = 10000
	TargetPubRate = 1000 // pubs per second, -1 == none
)

func main() {
	c1, err1 := nats.Connect(NatsServer)
	c2, err2 := nats.Connect(NatsServer)

	if err1 != nil || err2 != nil {
		panic("uh oh")
	}

	// Duration tracking
	durations := make([]time.Duration, 0, MeasureTimes)

	// Wait for all messages to be received.
	var wg sync.WaitGroup
	wg.Add(1)

	received := 0

	// Async Subscriber (Runs in own Go routine)
	c1.Subscribe("foo", func(msg *nats.Msg) {
		var t time.Time
		t.UnmarshalBinary(msg.Data)
		durations = append(durations, time.Since(t))
		received++
		if received >= MeasureTimes {
			wg.Done()
		}
	})
	// Make sure interest is set for subscribe before publish since a different connection.
	c1.Flush()

	// For publish throttle
	delay := time.Second / TargetPubRate
	if delay < time.Microsecond {
		delay = 0
	}

	start := time.Now()

	// Now publish
	for i := 0; i < MeasureTimes; i++ {
		now := time.Now()
		msg, _ := now.MarshalBinary()
		c2.Publish("foo", msg)

		// Throttle logic, crude I know.
		if delay > 0 {
			r := rps(i+1, time.Since(start))
			adj := delay / 20 // 5%
			if r < TargetPubRate {
				delay -= adj
			} else if r > TargetPubRate {
				delay += adj
			}
			time.Sleep(delay)
		}
	}
	pubDur := time.Since(start)
	wg.Wait()

	// Print results
	fmt.Printf("Total time: %v\n", time.Since(start))
	fmt.Printf("Time to publish: %v\n", pubDur)
	fmt.Printf("Publish rate (desired vs actual): %d, %d\n", TargetPubRate, rps(MeasureTimes, pubDur))
	sort.Slice(durations, func(i, j int) bool { return durations[i] < durations[j] })
	fmt.Printf("Receive Latency (min, median, max): %v, %v, %v\n",
		time.Duration(durations[0]),
		time.Duration(durations[len(durations)/2]),
		time.Duration(durations[len(durations)-1]))
}

const fsecs = float64(time.Second)

func rps(count int, elapsed time.Duration) int {
	return int(float64(count) / (float64(elapsed) / fsecs))
}