diff --git a/config.go b/config.go index 557d858..b14f93a 100644 --- a/config.go +++ b/config.go @@ -10,17 +10,80 @@ import ( "time" "git.sr.ht/~michalr/go-satel" + "gopkg.in/yaml.v3" ) +const ( + ConfigFilePath = "hswro-alarm-bot.yml" +) + +type OwnDuration struct { + duration time.Duration +} + +type SatelChangeType struct { + changeType satel.ChangeType +} + type AppConfig struct { - SatelAddr string - ChatIds []int64 - AllowedTypes []satel.ChangeType - AllowedIndexes []int - PoolInterval time.Duration - ArmCallbackUrls []string - DisarmCallbackUrls []string - AlarmCallbackUrls []string + SatelAddr string `yaml:"satel-addr"` + ChatIds []int64 `yaml:"tg-chat-ids"` + AllowedTypes []SatelChangeType `yaml:"allowed-types"` + AllowedIndexes []int `yaml:"allowed-indexes"` + PoolInterval OwnDuration `yaml:"pool-interval"` + ArmCallbackUrls []string `yaml:"arm-callback-urls"` + DisarmCallbackUrls []string `yaml:"disarm-callback-urls"` + AlarmCallbackUrls []string `yaml:"alarm-callback-urls"` +} + +func (m *SatelChangeType) UnmarshalYAML(unmarshal func(interface{}) error) error { + var inputStr string + err := unmarshal(&inputStr) + if err != nil { + return err + } + ct, err := StringToSatelChangeType(inputStr) + if err != nil { + return err + } + *m = SatelChangeType{ct} + return nil +} + +func (m *OwnDuration) UnmarshalYAML(unmarshal func(interface{}) error) error { + var inputStr string + err := unmarshal(&inputStr) + if err != nil { + return err + } + duration, err := time.ParseDuration(inputStr) + if err != nil { + return err + } + *m = OwnDuration{duration} + return nil +} + +func (self SatelChangeType) GetChangeType() satel.ChangeType { return self.changeType } +func (self OwnDuration) GetDuration() time.Duration { return self.duration } + +func loadConfigFromFile(filePath string, logger *log.Logger) AppConfig { + f, err := os.ReadFile(filePath) + if err != nil { + logger.Print("Error opening config file: ", err, ". Trying to continue without it") + return AppConfig{} + } + return parseConfigFromFile(f, logger) +} + +func parseConfigFromFile(contents []byte, logger *log.Logger) AppConfig { + var config AppConfig + err := yaml.Unmarshal(contents, &config) + if err != nil { + logger.Print("Error while parsing config file: ", err, ". Trying to continue without it") + return AppConfig{} + } + return config } func getCmdLineParams(config *AppConfig, logger *log.Logger) { @@ -45,7 +108,7 @@ func getCmdLineParams(config *AppConfig, logger *log.Logger) { chatIds = append(chatIds, chatId) } allowedTypesStrings := strings.Split(*allowedTypesRaw, ",") - var allowedTypes []satel.ChangeType + var allowedTypes []SatelChangeType for _, allowedTypeStr := range allowedTypesStrings { if len(allowedTypeStr) == 0 { continue @@ -54,7 +117,7 @@ func getCmdLineParams(config *AppConfig, logger *log.Logger) { if err != nil { logger.Fatalf("Error trying to understand an allowed type: %s.", err) } - allowedTypes = append(allowedTypes, allowedType) + allowedTypes = append(allowedTypes, SatelChangeType{allowedType}) } allowedIndexesStrings := strings.Split(*allowedIndexesRaw, ",") var allowedIndexes []int @@ -75,11 +138,11 @@ func getCmdLineParams(config *AppConfig, logger *log.Logger) { config.ChatIds = chatIds config.AllowedTypes = allowedTypes config.AllowedIndexes = allowedIndexes - config.PoolInterval = *satelPoolInterval + config.PoolInterval = OwnDuration{*satelPoolInterval} } func MakeConfig(logger *log.Logger) AppConfig { - config := AppConfig{} + config := loadConfigFromFile(ConfigFilePath, logger) config.ArmCallbackUrls = []string{} config.DisarmCallbackUrls = []string{} config.AlarmCallbackUrls = []string{} diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..5007a99 --- /dev/null +++ b/config_test.go @@ -0,0 +1,50 @@ +package main + +import ( + "log" + "os" + "testing" + "time" + + "git.sr.ht/~michalr/go-satel" + "github.com/stretchr/testify/assert" +) + +const data = ` +satel-addr: "test satel address" +tg-chat-ids: + - 1234 + - 5678 + - 9876 +allowed-types: + - "zone-isolate" + - "zone-alarm" +allowed-indexes: + - 5678 + - 1337 +pool-interval: 5m +arm-callback-urls: + - "test arm callback url" + - "second test arm callback url" +disarm-callback-urls: + - "test disarm callback url" + - "second test disarm callback url" +alarm-callback-urls: + - "test alarm callback url" + - "second test alarm callback url" +` + +func TestParseYamlConfig(t *testing.T) { + a := assert.New(t) + + actualConfig := parseConfigFromFile([]byte(data), log.New(os.Stderr, "", log.Ltime)) + + a.Equal("test satel address", actualConfig.SatelAddr) + a.ElementsMatch([]int64{1234, 5678, 9876}, actualConfig.ChatIds) + a.ElementsMatch([]int{5678, 1337}, actualConfig.AllowedIndexes) + a.ElementsMatch([]SatelChangeType{{satel.ZoneIsolate}, {satel.ZoneAlarm}}, actualConfig.AllowedTypes) + a.Equal(5*time.Minute, actualConfig.PoolInterval.GetDuration()) + a.ElementsMatch([]string{"test arm callback url", "second test arm callback url"}, actualConfig.ArmCallbackUrls) + a.ElementsMatch([]string{"test disarm callback url", "second test disarm callback url"}, actualConfig.DisarmCallbackUrls) + a.ElementsMatch([]string{"test alarm callback url", "second test alarm callback url"}, actualConfig.AlarmCallbackUrls) +} diff --git a/filters.go b/filters.go index a262852..4d2afe6 100644 --- a/filters.go +++ b/filters.go @@ -7,9 +7,9 @@ import ( "git.sr.ht/~michalr/go-satel" ) -func isBasicEventElementOkay(basicEventElement satel.BasicEventElement, allowedTypes []satel.ChangeType, allowedIndexes []int) bool { +func isBasicEventElementOkay(basicEventElement satel.BasicEventElement, allowedTypes []SatelChangeType, allowedIndexes []int) bool { for _, allowedType := range allowedTypes { - if allowedType == basicEventElement.Type { + if allowedType.GetChangeType() == basicEventElement.Type { return true } } @@ -21,7 +21,7 @@ func isBasicEventElementOkay(basicEventElement satel.BasicEventElement, allowedT return false } -func FilterByTypeOrIndex(ev <-chan satel.Event, wg *sync.WaitGroup, allowedTypes []satel.ChangeType, allowedIndexes []int) <-chan satel.Event { +func FilterByTypeOrIndex(ev <-chan satel.Event, wg *sync.WaitGroup, allowedTypes []SatelChangeType, allowedIndexes []int) <-chan satel.Event { returnChan := make(chan satel.Event) if (len(allowedTypes) == 0) && (len(allowedIndexes) == 0) { diff --git a/filters_test.go b/filters_test.go index 8580c0c..978cbc4 100644 --- a/filters_test.go +++ b/filters_test.go @@ -18,7 +18,7 @@ func TestSatelEventTypeFiltering(t *testing.T) { go func() { wg.Add(1) - for e := range FilterByTypeOrIndex(testEvents, &wg, []satel.ChangeType{satel.ArmedPartition, satel.PartitionFireAlarm}, []int{}) { + for e := range FilterByTypeOrIndex(testEvents, &wg, []SatelChangeType{{satel.ArmedPartition}, {satel.PartitionFireAlarm}}, []int{}) { receivedEvents = append(receivedEvents, e) } wg.Done() @@ -46,7 +46,7 @@ func TestSatelEventTypeFiltering_NoAllowedEventTypesMeansAllAreAllowed(t *testin go func() { wg.Add(1) - for e := range FilterByTypeOrIndex(testEvents, &wg, []satel.ChangeType{}, []int{}) { + for e := range FilterByTypeOrIndex(testEvents, &wg, []SatelChangeType{}, []int{}) { receivedEvents = append(receivedEvents, e) } wg.Done() @@ -72,7 +72,7 @@ func TestSatelIndexFiltering(t *testing.T) { go func() { wg.Add(1) - for e := range FilterByTypeOrIndex(testEvents, &wg, []satel.ChangeType{}, []int{1, 3}) { + for e := range FilterByTypeOrIndex(testEvents, &wg, []SatelChangeType{}, []int{1, 3}) { receivedEvents = append(receivedEvents, e) } wg.Done() @@ -101,7 +101,7 @@ func TestSatelIndexFiltering_NoAllowedEventTypesMeansAllAreAllowed(t *testing.T) go func() { wg.Add(1) - for e := range FilterByTypeOrIndex(testEvents, &wg, []satel.ChangeType{}, []int{}) { + for e := range FilterByTypeOrIndex(testEvents, &wg, []SatelChangeType{}, []int{}) { receivedEvents = append(receivedEvents, e) } wg.Done() diff --git a/main.go b/main.go index 9003652..4cdee58 100644 --- a/main.go +++ b/main.go @@ -58,7 +58,7 @@ func main() { stopRequested.Store(false) config := MakeConfig(logger) - s := makeSatel(config.SatelAddr, config.PoolInterval) + s := makeSatel(config.SatelAddr, config.PoolInterval.GetDuration()) logger.Printf("Connected to Satel: %s", config.SatelAddr) bot, err := tgbotapi.NewBotAPI(os.Getenv("TELEGRAM_APITOKEN"))