Skip to content

Commit

Permalink
refactor(agent): ♻️ simplify MQTT worker handling
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuar committed Jan 11, 2025
1 parent 9055ce0 commit 7791538
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 51 deletions.
29 changes: 8 additions & 21 deletions internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ func newAgent(ctx context.Context) *Agent {

// Run is invoked when Go Hass Agent is run with the `run` command-line option
// (i.e., `go-hass-agent run`).
//
//nolint:funlen
//revive:disable:function-length
func Run(ctx context.Context, dataCh chan any) error {
var (
wg sync.WaitGroup
Expand Down Expand Up @@ -111,7 +108,8 @@ func Run(ctx context.Context, dataCh chan any) error {
// Initialize and add the script worker.
scriptsWorkers, err := scripts.NewScriptsWorker(ctx)
if err != nil {
agent.logger.Warn("Could not init scripts workers.", slog.Any("error", err))
agent.logger.Warn("Could not init scripts workers.",
slog.Any("error", err))
} else {
sensorWorkers = append(sensorWorkers, scriptsWorkers)
}
Expand All @@ -130,23 +128,12 @@ func Run(ctx context.Context, dataCh chan any) error {
processWorkers(ctx, dataCh, eventWorkers...)
}()

// If MQTT is enabled, init MQTT workers and process them.
if preferences.MQTTEnabled() {
if mqttPrefs, err := preferences.GetMQTTPreferences(); err != nil {
agent.logger.Warn("Could not init mqtt workers.",
slog.Any("error", err))
} else {
ctx = MQTTPrefsToCtx(ctx, mqttPrefs)
mqttWorkers := setupMQTT(ctx)

wg.Add(1)

go func() {
defer wg.Done()
processMQTTWorkers(ctx, mqttWorkers...)
}()
}
}
wg.Add(1)
// Process MQTT workers.
go func() {
defer wg.Done()
processMQTTWorkers(ctx)
}()

wg.Add(1)
// Listen for notifications from Home Assistant.
Expand Down
85 changes: 55 additions & 30 deletions internal/agent/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,24 @@ type mqttEntities struct {
cameras []*mqtthass.CameraEntity
}

// setupMQTT will create a slice of MQTTWorker from the custom commands
// configuration and any OS-specific MQTT workers.
func setupMQTT(ctx context.Context) []MQTTWorker {
var workers []MQTTWorker

// Create an MQTT device, used to configure MQTT functionality for some
// controllers.
// setupMQTT will load a context with MQTT preferences and device configuration.
func setupMQTT(ctx context.Context) (context.Context, error) {
// Get the MQTT preferences.
prefs, err := preferences.GetMQTTPreferences()
if err != nil {
return ctx, fmt.Errorf("could not get MQTT preferences: %w", err)
}
// Add MQTT preferences to context.
ctx = MQTTPrefsToCtx(ctx, prefs)
// Get MQTT device and add to context.
ctx = MQTTDeviceToCtx(ctx, preferences.GetMQTTDevice())

return ctx, nil
}

// createMQTTWorkers creates the MQTT workers.
func createMQTTWorkers(ctx context.Context) []MQTTWorker {
var workers []MQTTWorker
// Set up custom MQTT commands worker.
customCommandsWorker, err := commands.NewCommandsWorker(ctx, MQTTDeviceFromFromCtx(ctx))
if err != nil {
Expand All @@ -78,36 +87,47 @@ func setupMQTT(ctx context.Context) []MQTTWorker {
} else {
workers = append(workers, customCommandsWorker)
}

osWorker := setupOSMQTTWorker(ctx)
workers = append(workers, osWorker)
// Set up OS MQTT worker.
workers = append(workers, setupOSMQTTWorker(ctx))

return workers
}

// processMQTTWorkers will connect to MQTT, publish configs and subscriptions and
// listen for any messages from all MQTT workers defined by the passed in
// MQTT controllers.
func processMQTTWorkers(ctx context.Context, controllers ...MQTTWorker) {
// listen for any messages from all MQTT workers.
func processMQTTWorkers(ctx context.Context) {
var ( //nolint:prealloc
subscriptions []*mqttapi.Subscription
configs []*mqttapi.Msg
msgCh []<-chan *mqttapi.Msg
err error
)

// Add the subscriptions and configs from the controllers.
for _, controller := range controllers {
subscriptions = append(subscriptions, controller.Subscriptions()...)
configs = append(configs, controller.Configs()...)
msgCh = append(msgCh, controller.Msgs())
if !preferences.MQTTEnabled() {
return
}

// Create a new connection to the MQTT broker. This will also publish the
// device subscriptions.
// Get the MQTT preferences and device.
ctx, err = setupMQTT(ctx)
if err != nil {
logging.FromContext(ctx).Error("Could not set-up MQTT.",
slog.Any("error", err))
return
}
// Create the workers.
workers := createMQTTWorkers(ctx)
// Add the subscriptions and configs from the workers.
for _, worker := range workers {
subscriptions = append(subscriptions, worker.Subscriptions()...)
configs = append(configs, worker.Configs()...)
msgCh = append(msgCh, worker.Msgs())
}
// Create a new connection to the MQTT broker, publish subscriptions and
// configs.
client, err := mqttapi.NewClient(ctx, MQTTPrefsFromFromCtx(ctx), subscriptions, configs)
if err != nil {
logging.FromContext(ctx).Error("Could not connect to MQTT.", slog.Any("error", err))
logging.FromContext(ctx).Error("Could not connect to MQTT.",
slog.Any("error", err))
return
}

Expand All @@ -128,21 +148,26 @@ func processMQTTWorkers(ctx context.Context, controllers ...MQTTWorker) {
}
}

// resetMQTTWorkers will unpublish configs for all defined MQTTWorkers.
// resetMQTTWorkers will unpublish configs for all defined MQTT workers.
func resetMQTTWorkers(ctx context.Context) error {
var configs []*mqttapi.Msg
var (
configs []*mqttapi.Msg
err error
)

workers := setupMQTT(ctx)
for _, worker := range workers {
configs = append(configs, worker.Configs()...)
// Get the MQTT preferences and device.
ctx, err = setupMQTT(ctx)
if err != nil {
return errors.New("could not reset MQTT: set-up failed")
}
// Create the workers.
workers := createMQTTWorkers(ctx)

mqttPrefs, err := preferences.GetMQTTPreferences()
if err != nil {
return fmt.Errorf("could reset MQTT: %w", err)
for _, worker := range workers {
configs = append(configs, worker.Configs()...)
}

client, err := mqttapi.NewClient(ctx, mqttPrefs, nil, nil)
client, err := mqttapi.NewClient(ctx, MQTTPrefsFromFromCtx(ctx), nil, nil)
if err != nil {
return fmt.Errorf("could not connect to MQTT: %w", err)
}
Expand Down

0 comments on commit 7791538

Please sign in to comment.