From 499fa82d96e971e61f8a7f68406d616e0e13f339 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Rudowicz?= Date: Tue, 7 Jan 2025 00:10:21 +0100 Subject: [PATCH] Collect improvements --- filters_sync.go | 7 ++++--- filters_test.go | 5 +++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/filters_sync.go b/filters_sync.go index 0ae8aac..c8a3f5e 100644 --- a/filters_sync.go +++ b/filters_sync.go @@ -25,16 +25,17 @@ func (impl *SyncFilterImpl[MsgType]) CallNext(msg MsgType) { } } -type CollectFromChannel struct{ SyncFilterImpl[GenericMessage] } +type CollectFromChannel[MsgType any] struct{ SyncFilterImpl[MsgType] } -func (collect *CollectFromChannel) Call(msg GenericMessage) { +func (collect *CollectFromChannel[MsgType]) Call(msg MsgType) { collect.CallNext(msg) } -func (collect CollectFromChannel) Collect(events <-chan GenericMessage, wg *sync.WaitGroup) { +func (collect CollectFromChannel[MsgType]) Collect(events <-chan MsgType, wg *sync.WaitGroup, onClose func()) { wg.Add(1) go func() { defer wg.Done() + defer onClose() for e := range events { collect.Call(e) diff --git a/filters_test.go b/filters_test.go index cc3e48d..54cd0c1 100644 --- a/filters_test.go +++ b/filters_test.go @@ -345,13 +345,14 @@ func TestSyncCollect(t *testing.T) { testEvents := make(chan GenericMessage) wg := sync.WaitGroup{} - tested := CollectFromChannel{} + tested := CollectFromChannel[GenericMessage]{} mock := &SyncMockFilter{} mock2 := &SyncMockFilter{} tested.Then(mock).Then(mock2) - tested.Collect(testEvents, &wg) + wg.Add(1) + tested.Collect(testEvents, &wg, func() { wg.Done() }) testEvents <- makeGenericMessage(satel.ArmedPartition, 1, true) testEvents <- makeGenericMessage(satel.DoorOpened, 2, true)