From 64f4930b4eec6cea915e6d9b07658654a68c3de5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Rudowicz?= Date: Thu, 2 Jan 2025 13:40:25 +0100 Subject: [PATCH] Avoid possible race condition with waitgroups Looks like calling wg.Add(1) from a goroutine can cause a race condition with wg.Wait() - can be easily avoided by calling Add() before the subroutine is created. --- debug_utils.go | 3 ++- filters.go | 19 ++++++++++--------- filters_test.go | 44 ++++++++++++++++++++++++++------------------ sender_worker.go | 10 ++++++---- 4 files changed, 44 insertions(+), 32 deletions(-) diff --git a/debug_utils.go b/debug_utils.go index 45034f8..b4891c0 100644 --- a/debug_utils.go +++ b/debug_utils.go @@ -21,9 +21,10 @@ func dumpMemoryProfile(log *log.Logger) { } func WriteMemoryProfilePeriodically(wg *sync.WaitGroup, log *log.Logger, close <-chan interface{}) { + wg.Add(1) go func() { - wg.Add(1) defer wg.Done() + memoryProfileTicker := time.NewTicker(24 * time.Hour) defer memoryProfileTicker.Stop() select { diff --git a/filters.go b/filters.go index 4d2afe6..3ffe4ce 100644 --- a/filters.go +++ b/filters.go @@ -26,19 +26,20 @@ func FilterByTypeOrIndex(ev <-chan satel.Event, wg *sync.WaitGroup, allowedTypes if (len(allowedTypes) == 0) && (len(allowedIndexes) == 0) { // no allowed types == all types are allowed + wg.Add(1) go func() { - wg.Add(1) defer wg.Done() + defer close(returnChan) for e := range ev { returnChan <- e } - close(returnChan) }() } else { + wg.Add(1) go func() { - wg.Add(1) defer wg.Done() + defer close(returnChan) for e := range ev { retEv := satel.Event{BasicEvents: make([]satel.BasicEventElement, 0)} @@ -51,7 +52,6 @@ func FilterByTypeOrIndex(ev <-chan satel.Event, wg *sync.WaitGroup, allowedTypes returnChan <- retEv } } - close(returnChan) }() } @@ -61,9 +61,10 @@ func FilterByTypeOrIndex(ev <-chan satel.Event, wg *sync.WaitGroup, allowedTypes func FilterByLastSeen(ev <-chan satel.Event, wg *sync.WaitGroup, dataStore *DataStore, logger *log.Logger) <-chan satel.Event { returnChan := make(chan satel.Event) + wg.Add(1) go func() { - wg.Add(1) defer wg.Done() + defer close(returnChan) for e := range ev { retEv := satel.Event{BasicEvents: make([]satel.BasicEventElement, 0)} @@ -81,7 +82,6 @@ func FilterByLastSeen(ev <-chan satel.Event, wg *sync.WaitGroup, dataStore *Data } } logger.Print("Satel disconnected.") - close(returnChan) }() return returnChan @@ -111,8 +111,11 @@ func Throttle(inputEvents <-chan GenericMessage, wg *sync.WaitGroup, sleeper Sle returnChan := make(chan GenericMessage) timeoutEvents := make(chan interface{}) + wg.Add(1) go func() { - wg.Add(1) + defer wg.Done() + defer close(returnChan) + var currentEvent *GenericMessage = nil loop: for { @@ -137,8 +140,6 @@ func Throttle(inputEvents <-chan GenericMessage, wg *sync.WaitGroup, sleeper Sle if currentEvent != nil { returnChan <- *currentEvent } - close(returnChan) - wg.Done() }() return returnChan diff --git a/filters_test.go b/filters_test.go index 978cbc4..2cc1751 100644 --- a/filters_test.go +++ b/filters_test.go @@ -16,12 +16,13 @@ func TestSatelEventTypeFiltering(t *testing.T) { receivedEvents := make([]satel.Event, 0) wg := sync.WaitGroup{} + wg.Add(1) go func() { - wg.Add(1) + defer wg.Done() + for e := range FilterByTypeOrIndex(testEvents, &wg, []SatelChangeType{{satel.ArmedPartition}, {satel.PartitionFireAlarm}}, []int{}) { receivedEvents = append(receivedEvents, e) } - wg.Done() }() testEvents <- makeTestSatelEvent(satel.ArmedPartition, 1, true) @@ -44,12 +45,13 @@ func TestSatelEventTypeFiltering_NoAllowedEventTypesMeansAllAreAllowed(t *testin receivedEvents := make([]satel.Event, 0) wg := sync.WaitGroup{} + wg.Add(1) go func() { - wg.Add(1) + defer wg.Done() + for e := range FilterByTypeOrIndex(testEvents, &wg, []SatelChangeType{}, []int{}) { receivedEvents = append(receivedEvents, e) } - wg.Done() }() for index, ct := range SUPPORTED_CHANGE_TYPES { @@ -70,12 +72,13 @@ func TestSatelIndexFiltering(t *testing.T) { receivedEvents := make([]satel.Event, 0) wg := sync.WaitGroup{} + wg.Add(1) go func() { - wg.Add(1) + defer wg.Done() + for e := range FilterByTypeOrIndex(testEvents, &wg, []SatelChangeType{}, []int{1, 3}) { receivedEvents = append(receivedEvents, e) } - wg.Done() }() testEvents <- makeTestSatelEvent(satel.ArmedPartition, 1, true) @@ -99,12 +102,13 @@ func TestSatelIndexFiltering_NoAllowedEventTypesMeansAllAreAllowed(t *testing.T) wg := sync.WaitGroup{} myReasonableMaxIndex := 100 // I wanted to use math.MaxInt at first, but it's kind of a waste of time here + wg.Add(1) go func() { - wg.Add(1) + defer wg.Done() + for e := range FilterByTypeOrIndex(testEvents, &wg, []SatelChangeType{}, []int{}) { receivedEvents = append(receivedEvents, e) } - wg.Done() }() for i := 0; i < myReasonableMaxIndex; i++ { @@ -132,12 +136,13 @@ func TestSatelLastSeenFiltering(t *testing.T) { fakeLog := log.New(io.Discard, "", log.Ltime) ds := MakeDataStore(fakeLog, tempFileName) + wg.Add(1) go func() { - wg.Add(1) + defer wg.Done() + for e := range FilterByLastSeen(testEvents, &wg, &ds, fakeLog) { receivedEvents = append(receivedEvents, e) } - wg.Done() }() testEvents <- makeTestSatelEvent(satel.ArmedPartition, 1, true) @@ -168,12 +173,13 @@ func TestSatelLastSeenFilteringWithPersistence(t *testing.T) { fakeLog := log.New(io.Discard, "", log.Ltime) ds := MakeDataStore(fakeLog, tempFileName) + wg.Add(1) go func() { - wg.Add(1) + defer wg.Done() + for e := range FilterByLastSeen(testEvents, &wg, &ds, fakeLog) { receivedEvents = append(receivedEvents, e) } - wg.Done() }() testEvents <- makeTestSatelEvent(satel.ArmedPartition, 1, true) @@ -194,12 +200,13 @@ func TestSatelLastSeenFilteringWithPersistence(t *testing.T) { testEvents = make(chan satel.Event) receivedEvents = make([]satel.Event, 0) ds = MakeDataStore(fakeLog, tempFileName) + wg.Add(1) go func() { - wg.Add(1) + defer wg.Done() + for e := range FilterByLastSeen(testEvents, &wg, &ds, fakeLog) { receivedEvents = append(receivedEvents, e) } - wg.Done() }() receivedEvents = make([]satel.Event, 0) @@ -243,12 +250,13 @@ func TestThrottle(t *testing.T) { tplMessageTest4 = satel.BasicEventElement{Type: satel.ZoneViolation, Index: 2, Value: false} ) + wg.Add(1) go func() { - wg.Add(1) + defer wg.Done() + for e := range Throttle(testEvents, &wg, &mockSleeper, fakeLog) { receivedEvents = append(receivedEvents, e) } - wg.Done() }() testEvents <- GenericMessage{[]satel.BasicEventElement{tplMessageTest1}} @@ -295,12 +303,12 @@ func TestThrottle_ManyMessagesInOneEvent(t *testing.T) { tplMessageTest4 = satel.BasicEventElement{Type: satel.ZoneViolation, Index: 2, Value: false} ) + wg.Add(1) go func() { - wg.Add(1) + defer wg.Done() for e := range Throttle(testEvents, &wg, &mockSleeper, fakeLog) { receivedEvents = append(receivedEvents, e) } - wg.Done() }() testEvents <- makeMassiveEvent(tplMessageTest1, 100) diff --git a/sender_worker.go b/sender_worker.go index db940e2..0578b85 100644 --- a/sender_worker.go +++ b/sender_worker.go @@ -31,9 +31,11 @@ func Consume(events <-chan GenericMessage) { func SendToTg(events <-chan GenericMessage, s Sender, wg *sync.WaitGroup, logger *log.Logger, tpl *template.Template) <-chan GenericMessage { returnEvents := make(chan GenericMessage) + wg.Add(1) go func() { - wg.Add(1) defer wg.Done() + defer close(returnEvents) + for e := range events { returnEvents <- e err := s.Send(e, tpl) @@ -42,7 +44,6 @@ func SendToTg(events <-chan GenericMessage, s Sender, wg *sync.WaitGroup, logger panic(err) } } - close(returnEvents) }() return returnEvents @@ -74,9 +75,11 @@ func notifyAllHttp(urls []string, logger *log.Logger, wg *sync.WaitGroup) { func NotifyViaHTTP(events <-chan GenericMessage, config AppConfig, wg *sync.WaitGroup, logger *log.Logger) <-chan GenericMessage { returnEvents := make(chan GenericMessage) + wg.Add(1) go func() { - wg.Add(1) defer wg.Done() + defer close(returnEvents) + for e := range events { returnEvents <- e inner_arm: @@ -101,7 +104,6 @@ func NotifyViaHTTP(events <-chan GenericMessage, config AppConfig, wg *sync.Wait } } - close(returnEvents) }() return returnEvents