From b0d87a2d64ebfa9ca93fa368e33eed980ed5ea02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Rudowicz?= Date: Wed, 8 Jan 2025 21:16:18 +0100 Subject: [PATCH] Synchronous filter for type conversion in stream --- filters_sync.go | 22 ++++++++++++++++++++++ filters_test.go | 40 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/filters_sync.go b/filters_sync.go index 747c95f..0232569 100644 --- a/filters_sync.go +++ b/filters_sync.go @@ -115,3 +115,25 @@ func (throttle *ThrottleSync) Close() { close(throttle.events) } func (throttle *ThrottleSync) Call(msg GenericMessage) { throttle.events <- msg } + +type Convert[InMsgType any] struct { + SyncFilterImpl[InMsgType] + out SyncFilter[GenericMessage] + convert func(InMsgType) GenericMessage +} + +func MakeConvert[InMsgType any](convertFunc func(InMsgType) GenericMessage) *Convert[InMsgType] { + return &Convert[InMsgType]{SyncFilterImpl[InMsgType]{}, nil, convertFunc} +} + +func (convert *Convert[InMsgType]) Call(msg InMsgType) { + convert.out.Call(convert.convert(msg)) +} + +func (convert *Convert[InMsgType]) ConvertTo(out SyncFilter[GenericMessage]) { + convert.out = out +} + +func (convert *Convert[InMsgType]) Then(_ SyncFilter[InMsgType]) { + panic("Use ConvertTo() with Convert object") +} diff --git a/filters_test.go b/filters_test.go index feb1ddb..0c1d82f 100644 --- a/filters_test.go +++ b/filters_test.go @@ -236,16 +236,18 @@ func (self *MockSleeper) Sleep(ch chan<- interface{}) { self.callCount += 1 } -type SyncMockFilter struct { - SyncFilterImpl[GenericMessage] - collected []GenericMessage +type GenericSyncMockFilter[T any] struct { + SyncFilterImpl[T] + collected []T } -func (self *SyncMockFilter) Call(msg GenericMessage) { +func (self *GenericSyncMockFilter[T]) Call(msg T) { self.collected = append(self.collected, msg) self.CallNext(msg) } +type SyncMockFilter = GenericSyncMockFilter[GenericMessage] + func TestSyncCollect(t *testing.T) { testEvents := make(chan GenericMessage) wg := sync.WaitGroup{} @@ -322,3 +324,33 @@ func TestThrottleSync(t *testing.T) { assert.Contains(t, mock.collected[1].Messages, tplMessageTest4) assert.Len(t, mock.collected[1].Messages, 1) } + +func TestConvert_failsWhenNotConverting(t *testing.T) { + a := assert.New(t) + tested := MakeConvert[int](func(in int) GenericMessage { + a.Equal(in, 1) + return GenericMessage{} + }) + mock := &GenericSyncMockFilter[int]{} + + a.Panics(func() { + tested.Then(mock) + tested.Call(1) + }) +} + +func TestConvert(t *testing.T) { + a := assert.New(t) + numCalled := 0 + tested := MakeConvert[int](func(in int) GenericMessage { + a.Equal(in, 1) + numCalled += 1 + return GenericMessage{} + }) + mock := &SyncMockFilter{} + + tested.ConvertTo(mock) + tested.Call(1) + + a.Equal(numCalled, 1) +}