diff --git a/filters_sync.go b/filters_sync.go index 6b709a2..0ae8aac 100644 --- a/filters_sync.go +++ b/filters_sync.go @@ -1,6 +1,9 @@ package main -import "sync" +import ( + "log" + "sync" +) type SyncFilter[MsgType any] interface { Then(what SyncFilter[MsgType]) SyncFilter[MsgType] @@ -38,3 +41,54 @@ func (collect CollectFromChannel) Collect(events <-chan GenericMessage, wg *sync } }() } + +type ThrottleSync struct { + SyncFilterImpl[GenericMessage] + + events chan GenericMessage +} + +func MakeThrottleSync(sleeper Sleeper, logger *log.Logger, wg *sync.WaitGroup) *ThrottleSync { + events := make(chan GenericMessage) + + throttle := ThrottleSync{SyncFilterImpl[GenericMessage]{}, events} + + wg.Add(1) + go func() { + timeoutEvents := make(chan interface{}) + var currentEvent *GenericMessage = nil + loop: + for { + select { + case ev, ok := <-events: + if !ok { + break loop + } + if currentEvent == nil { + logger.Print("Waiting for more messages to arrive before sending...") + sleeper.Sleep(timeoutEvents) + } + currentEvent = appendToGenericMessage(currentEvent, &ev) + case <-timeoutEvents: + logger.Print("Time's up, sending all messages we've got for now.") + throttle.CallNext(*currentEvent) + currentEvent = nil + } + } + + // If anything is left to be sent, send it now + if currentEvent != nil { + throttle.CallNext(*currentEvent) + } + wg.Done() + logger.Print("Throttling goroutine finishing") + }() + + return &throttle +} + +func (throttle *ThrottleSync) Close() { close(throttle.events) } + +func (throttle *ThrottleSync) Call(msg GenericMessage) { + throttle.events <- msg +} diff --git a/filters_test.go b/filters_test.go index 1ba83d4..cc3e48d 100644 --- a/filters_test.go +++ b/filters_test.go @@ -379,3 +379,40 @@ func TestSyncCollect(t *testing.T) { assert.Contains(t, mock2.collected, makeGenericMessage(satel.TroublePart1, 5, true)) assert.Contains(t, mock2.collected, makeGenericMessage(satel.ZoneTamper, 6, true)) } + +func TestThrottleSync(t *testing.T) { + wg := sync.WaitGroup{} + fakeLog := log.New(io.Discard, "", log.Ltime) + mockSleeper := MockSleeper{nil, 0} + + var ( + tplMessageTest1 = satel.BasicEventElement{Type: satel.ArmedPartition, Index: 1, Value: true} + tplMessageTest2 = satel.BasicEventElement{Type: satel.ZoneViolation, Index: 2, Value: true} + tplMessageTest3 = satel.BasicEventElement{Type: satel.ArmedPartition, Index: 1, Value: false} + tplMessageTest4 = satel.BasicEventElement{Type: satel.ZoneViolation, Index: 2, Value: false} + ) + + tested := MakeThrottleSync(&mockSleeper, fakeLog, &wg) + mock := &SyncMockFilter{} + + tested.Then(mock) + + tested.Call(GenericMessage{[]satel.BasicEventElement{tplMessageTest1}}) + tested.Call(GenericMessage{[]satel.BasicEventElement{tplMessageTest2}}) + tested.Call(GenericMessage{[]satel.BasicEventElement{tplMessageTest3}}) + *mockSleeper.ch <- nil + + tested.Call(GenericMessage{[]satel.BasicEventElement{tplMessageTest4}}) + + tested.Close() + wg.Wait() + + assert.Equal(t, 2, mockSleeper.callCount) + assert.Len(t, mock.collected, 2) + assert.Contains(t, mock.collected[0].Messages, tplMessageTest2) + assert.Contains(t, mock.collected[0].Messages, tplMessageTest3) + assert.Len(t, mock.collected[0].Messages, 2) + + assert.Contains(t, mock.collected[1].Messages, tplMessageTest4) + assert.Len(t, mock.collected[1].Messages, 1) +}