diff --git a/filters_sync.go b/filters_sync.go index 0ae8aac..c8a3f5e 100644 --- a/filters_sync.go +++ b/filters_sync.go @@ -25,16 +25,17 @@ func (impl *SyncFilterImpl[MsgType]) CallNext(msg MsgType) { } } -type CollectFromChannel struct{ SyncFilterImpl[GenericMessage] } +type CollectFromChannel[MsgType any] struct{ SyncFilterImpl[MsgType] } -func (collect *CollectFromChannel) Call(msg GenericMessage) { +func (collect *CollectFromChannel[MsgType]) Call(msg MsgType) { collect.CallNext(msg) } -func (collect CollectFromChannel) Collect(events <-chan GenericMessage, wg *sync.WaitGroup) { +func (collect CollectFromChannel[MsgType]) Collect(events <-chan MsgType, wg *sync.WaitGroup, onClose func()) { wg.Add(1) go func() { defer wg.Done() + defer onClose() for e := range events { collect.Call(e) diff --git a/filters_test.go b/filters_test.go index cc3e48d..54cd0c1 100644 --- a/filters_test.go +++ b/filters_test.go @@ -345,13 +345,14 @@ func TestSyncCollect(t *testing.T) { testEvents := make(chan GenericMessage) wg := sync.WaitGroup{} - tested := CollectFromChannel{} + tested := CollectFromChannel[GenericMessage]{} mock := &SyncMockFilter{} mock2 := &SyncMockFilter{} tested.Then(mock).Then(mock2) - tested.Collect(testEvents, &wg) + wg.Add(1) + tested.Collect(testEvents, &wg, func() { wg.Done() }) testEvents <- makeGenericMessage(satel.ArmedPartition, 1, true) testEvents <- makeGenericMessage(satel.DoorOpened, 2, true)