diff --git a/workers/entity/controller.go b/workers/entity/controller.go index 3ad521086..5110fc0cb 100644 --- a/workers/entity/controller.go +++ b/workers/entity/controller.go @@ -41,7 +41,6 @@ func NewController(ctx context.Context, store dbCommon.Store, providers map[stri ctx: ctx, store: store, providers: providers, - Entities: make(map[string]*Worker), }, nil } @@ -53,7 +52,8 @@ type Controller struct { store dbCommon.Store providers map[string]common.Provider - Entities map[string]*Worker + // sync.Map[string]*Worker + Entities sync.Map running bool quit chan struct{} @@ -62,8 +62,6 @@ type Controller struct { } func (c *Controller) loadAllRepositories() error { - c.mux.Lock() - defer c.mux.Unlock() repos, err := c.store.ListRepositories(c.ctx, params.RepositoryFilter{}) if err != nil { return fmt.Errorf("fetching repositories: %w", err) @@ -83,7 +81,7 @@ func (c *Controller) loadAllRepositories() error { if err := worker.Start(); err != nil { return fmt.Errorf("starting worker: %w", err) } - c.Entities[entity.ID] = worker + c.Entities.Store(entity.ID, worker) return nil }) } @@ -94,8 +92,6 @@ func (c *Controller) loadAllRepositories() error { } func (c *Controller) loadAllOrganizations() error { - c.mux.Lock() - defer c.mux.Unlock() orgs, err := c.store.ListOrganizations(c.ctx, params.OrganizationFilter{}) if err != nil { return fmt.Errorf("fetching organizations: %w", err) @@ -115,7 +111,7 @@ func (c *Controller) loadAllOrganizations() error { if err := worker.Start(); err != nil { return fmt.Errorf("starting worker: %w", err) } - c.Entities[entity.ID] = worker + c.Entities.Store(entity.ID, worker) return nil }) } @@ -126,8 +122,6 @@ func (c *Controller) loadAllOrganizations() error { } func (c *Controller) loadAllEnterprises() error { - c.mux.Lock() - defer c.mux.Unlock() enterprises, err := c.store.ListEnterprises(c.ctx, params.EnterpriseFilter{}) if err != nil { return fmt.Errorf("fetching enterprises: %w", err) @@ -148,7 +142,7 @@ func (c *Controller) loadAllEnterprises() error { if err := worker.Start(); err != nil { return fmt.Errorf("starting worker: %w", err) } - c.Entities[entity.ID] = worker + c.Entities.Store(entity.ID, worker) return nil }) } @@ -220,11 +214,14 @@ func (c *Controller) Stop() error { } slog.DebugContext(c.ctx, "stopping entity controller") - for entityID, worker := range c.Entities { + c.Entities.Range(func(key, value any) bool { + entityID := key.(string) + worker := value.(*Worker) if err := worker.Stop(); err != nil { slog.ErrorContext(c.ctx, "stopping worker for entity", "entity_id", entityID, "error", err) } - } + return true + }) c.running = false close(c.quit) diff --git a/workers/entity/controller_watcher.go b/workers/entity/controller_watcher.go index d907d25a7..3b5802c75 100644 --- a/workers/entity/controller_watcher.go +++ b/workers/entity/controller_watcher.go @@ -76,14 +76,12 @@ func (c *Controller) handleWatcherEvent(event dbCommon.ChangePayload) { } func (c *Controller) handleWatcherUpdateOperation(entity params.ForgeEntity) { - c.mux.Lock() - defer c.mux.Unlock() - - worker, ok := c.Entities[entity.ID] + val, ok := c.Entities.Load(entity.ID) if !ok { slog.InfoContext(c.ctx, "entity not found in worker list", "entity_id", entity.ID) return } + worker := val.(*Worker) if worker.IsRunning() { // The worker is running. It watches for updates to its own entity. We only care about updates @@ -104,9 +102,6 @@ func (c *Controller) handleWatcherUpdateOperation(entity params.ForgeEntity) { } func (c *Controller) handleWatcherCreateOperation(entity params.ForgeEntity) { - c.mux.Lock() - defer c.mux.Unlock() - worker, err := NewWorker(c.ctx, c.store, entity, c.providers) if err != nil { slog.ErrorContext(c.ctx, "creating worker from repository", "entity_type", entity.EntityType, "error", err) @@ -119,22 +114,20 @@ func (c *Controller) handleWatcherCreateOperation(entity params.ForgeEntity) { return } - c.Entities[entity.ID] = worker + c.Entities.Store(entity.ID, worker) } func (c *Controller) handleWatcherDeleteOperation(entity params.ForgeEntity) { - c.mux.Lock() - defer c.mux.Unlock() - - worker, ok := c.Entities[entity.ID] + val, ok := c.Entities.Load(entity.ID) if !ok { slog.InfoContext(c.ctx, "entity not found in worker list", "entity_id", entity.ID) return } + worker := val.(*Worker) slog.InfoContext(c.ctx, "stopping entity worker", "entity_id", entity.ID, "entity_type", entity.EntityType) if err := worker.Stop(); err != nil { slog.ErrorContext(c.ctx, "stopping worker", "entity_id", entity.ID, "error", err) return } - delete(c.Entities, entity.ID) + c.Entities.Delete(entity.ID) } diff --git a/workers/provider/instance_manager.go b/workers/provider/instance_manager.go index 2c2a5b748..fd001c6ab 100644 --- a/workers/provider/instance_manager.go +++ b/workers/provider/instance_manager.go @@ -19,6 +19,7 @@ import ( "fmt" "log/slog" "sync" + "sync/atomic" "time" runnerErrors "github.com/cloudbase/garm-provider-common/errors" @@ -71,7 +72,7 @@ type instanceManager struct { updates chan dbCommon.ChangePayload mux sync.Mutex - running bool + running atomic.Bool quit chan struct{} } @@ -80,11 +81,11 @@ func (i *instanceManager) Start() error { defer i.mux.Unlock() slog.DebugContext(i.ctx, "starting instance manager", "instance", i.instance.Name) - if i.running { + if i.running.Load() { return nil } - i.running = true + i.running.Store(true) i.quit = make(chan struct{}) i.updates = make(chan dbCommon.ChangePayload) @@ -97,11 +98,11 @@ func (i *instanceManager) Stop() error { i.mux.Lock() defer i.mux.Unlock() - if !i.running { + if !i.running.Load() { return nil } - i.running = false + i.running.Store(false) close(i.quit) close(i.updates) return nil @@ -292,7 +293,7 @@ func (i *instanceManager) consolidateState() error { i.mux.Lock() defer i.mux.Unlock() - if !i.running { + if !i.running.Load() { return nil } @@ -365,7 +366,7 @@ func (i *instanceManager) handleUpdate(update dbCommon.ChangePayload) error { // We need a better way to handle instance state. Database updates may fail, and we // end up with an inconsistent state between what we know about the instance and what // is reflected in the database. - if !i.running { + if !i.running.Load() { return nil } @@ -374,12 +375,14 @@ func (i *instanceManager) handleUpdate(update dbCommon.ChangePayload) error { return runnerErrors.NewBadRequestError("invalid payload type") } + i.mux.Lock() i.instance = instance + i.mux.Unlock() return nil } func (i *instanceManager) Update(instance dbCommon.ChangePayload) error { - if !i.running { + if !i.running.Load() { return runnerErrors.NewBadRequestError("instance manager is not running") } diff --git a/workers/provider/provider.go b/workers/provider/provider.go index 124775acc..8c6e0a1be 100644 --- a/workers/provider/provider.go +++ b/workers/provider/provider.go @@ -41,8 +41,6 @@ func NewWorker(ctx context.Context, store dbCommon.Store, providers map[string]c consumerID: consumerID, providers: providers, tokenGetter: tokenGetter, - scaleSets: make(map[uint]params.ScaleSet), - runners: make(map[string]*instanceManager), }, nil } @@ -62,8 +60,10 @@ type Provider struct { providers map[string]common.Provider // A cache of all scale sets kept updated by the watcher. // This helps us avoid a bunch of queries to the database. - scaleSets map[uint]params.ScaleSet - runners map[string]*instanceManager + // sync.Map[uint]params.ScaleSet + scaleSets sync.Map + // sync.Map[string]*instanceManager + runners sync.Map mux sync.Mutex running bool @@ -77,7 +77,7 @@ func (p *Provider) loadAllScaleSets() error { } for _, scaleSet := range scaleSets { - p.scaleSets[scaleSet.ID] = scaleSet + p.scaleSets.Store(scaleSet.ID, scaleSet) } return nil @@ -111,11 +111,12 @@ func (p *Provider) loadAllRunners() error { continue } - scaleSet, ok := p.scaleSets[runner.ScaleSetID] + val, ok := p.scaleSets.Load(runner.ScaleSetID) if !ok { slog.ErrorContext(p.ctx, "scale set not found", "scale_set_id", runner.ScaleSetID) continue } + scaleSet := val.(params.ScaleSet) provider, ok := p.providers[scaleSet.ProviderName] if !ok { slog.ErrorContext(p.ctx, "provider not found", "provider_name", runner.ProviderName) @@ -130,7 +131,7 @@ func (p *Provider) loadAllRunners() error { return fmt.Errorf("starting instance manager: %w", err) } - p.runners[runner.Name] = instanceManager + p.runners.Store(runner.Name, instanceManager) } return nil @@ -211,9 +212,6 @@ func (p *Provider) handleWatcherEvent(payload dbCommon.ChangePayload) { } func (p *Provider) handleScaleSetEvent(event dbCommon.ChangePayload) { - p.mux.Lock() - defer p.mux.Unlock() - scaleSet, ok := event.Payload.(params.ScaleSet) if !ok { slog.ErrorContext(p.ctx, "invalid payload type", "payload_type", fmt.Sprintf("%T", event.Payload)) @@ -223,10 +221,10 @@ func (p *Provider) handleScaleSetEvent(event dbCommon.ChangePayload) { switch event.Operation { case dbCommon.CreateOperation, dbCommon.UpdateOperation: slog.DebugContext(p.ctx, "got create/update operation") - p.scaleSets[scaleSet.ID] = scaleSet + p.scaleSets.Store(scaleSet.ID, scaleSet) case dbCommon.DeleteOperation: slog.DebugContext(p.ctx, "got delete operation") - delete(p.scaleSets, scaleSet.ID) + p.scaleSets.Delete(scaleSet.ID) default: slog.ErrorContext(p.ctx, "invalid operation type", "operation_type", event.Operation) return @@ -234,10 +232,11 @@ func (p *Provider) handleScaleSetEvent(event dbCommon.ChangePayload) { } func (p *Provider) handleInstanceAdded(instance params.Instance) error { - scaleSet, ok := p.scaleSets[instance.ScaleSetID] + val, ok := p.scaleSets.Load(instance.ScaleSetID) if !ok { return fmt.Errorf("scale set not found for instance %s", instance.Name) } + scaleSet := val.(params.ScaleSet) instanceManager, err := newInstanceManager( p.ctx, instance, scaleSet, p.providers[instance.ProviderName], p) if err != nil { @@ -246,7 +245,7 @@ func (p *Provider) handleInstanceAdded(instance params.Instance) error { if err := instanceManager.Start(); err != nil { return fmt.Errorf("starting instance manager: %w", err) } - p.runners[instance.Name] = instanceManager + p.runners.Store(instance.Name, instanceManager) return nil } @@ -254,20 +253,19 @@ func (p *Provider) stopAndDeleteInstance(instance params.Instance) error { if instance.Status != commonParams.InstanceDeleted { return nil } - existingInstance, ok := p.runners[instance.Name] - if ok { - if err := existingInstance.Stop(); err != nil { - return fmt.Errorf("failed to stop instance manager: %w", err) - } - delete(p.runners, instance.Name) + val, ok := p.runners.Load(instance.Name) + if !ok { + return nil } + existingInstance := val.(*instanceManager) + if err := existingInstance.Stop(); err != nil { + return fmt.Errorf("failed to stop instance manager: %w", err) + } + p.runners.Delete(instance.Name) return nil } func (p *Provider) handleInstanceEvent(event dbCommon.ChangePayload) { - p.mux.Lock() - defer p.mux.Unlock() - instance, ok := event.Payload.(params.Instance) if !ok { slog.ErrorContext(p.ctx, "invalid payload type", "payload_type", fmt.Sprintf("%T", event.Payload)) @@ -289,19 +287,25 @@ func (p *Provider) handleInstanceEvent(event dbCommon.ChangePayload) { } case dbCommon.UpdateOperation: slog.DebugContext(p.ctx, "got update operation") - existingInstance, ok := p.runners[instance.Name] + + val, ok := p.runners.Load(instance.Name) + if !ok { + if instance.Status == commonParams.InstanceDeleted { + // No manager running for this instance and it's already deleted. + return + } slog.DebugContext(p.ctx, "instance not found, creating new instance", "instance_name", instance.Name) if err := p.handleInstanceAdded(instance); err != nil { slog.ErrorContext(p.ctx, "failed to handle instance added", "error", err) return } } else { + existingInstance := val.(*instanceManager) slog.DebugContext(p.ctx, "updating instance", "instance_name", instance.Name) if instance.Status == commonParams.InstanceDeleted { if err := p.stopAndDeleteInstance(instance); err != nil { slog.ErrorContext(p.ctx, "failed to clean up instance manager", "error", err) - return } return } diff --git a/workers/provider/provider_helper.go b/workers/provider/provider_helper.go index 7ea23d1bd..a5f309e79 100644 --- a/workers/provider/provider_helper.go +++ b/workers/provider/provider_helper.go @@ -48,9 +48,6 @@ func (p *Provider) updateArgsFromProviderInstance(instanceName string, providerI } func (p *Provider) GetControllerInfo() (params.ControllerInfo, error) { - p.mux.Lock() - defer p.mux.Unlock() - info, err := p.store.ControllerInfo() if err != nil { return params.ControllerInfo{}, fmt.Errorf("getting controller info: %w", err) @@ -60,10 +57,7 @@ func (p *Provider) GetControllerInfo() (params.ControllerInfo, error) { } func (p *Provider) SetInstanceStatus(instanceName string, status commonParams.InstanceStatus, providerFault []byte, force bool) error { - p.mux.Lock() - defer p.mux.Unlock() - - if _, ok := p.runners[instanceName]; !ok { + if _, ok := p.runners.Load(instanceName); !ok { return errors.ErrNotFound } diff --git a/workers/scaleset/controller.go b/workers/scaleset/controller.go index 2d44b7585..c8574fc19 100644 --- a/workers/scaleset/controller.go +++ b/workers/scaleset/controller.go @@ -41,7 +41,6 @@ func NewController(ctx context.Context, store dbCommon.Store, entity params.Forg return &Controller{ ctx: ctx, consumerID: consumerID, - ScaleSets: make(map[uint]*scaleSet), Entity: entity, providers: providers, store: store, @@ -55,6 +54,18 @@ type scaleSet struct { mux sync.Mutex } +func (s *scaleSet) SetWorker(worker *Worker) { + s.mux.Lock() + defer s.mux.Unlock() + s.worker = worker +} + +func (s *scaleSet) SetScaleSet(sSet params.ScaleSet) { + s.mux.Lock() + defer s.mux.Unlock() + s.scaleSet = sSet +} + func (s *scaleSet) Stop() error { s.mux.Lock() defer s.mux.Unlock() @@ -71,7 +82,8 @@ type Controller struct { ctx context.Context consumerID string - ScaleSets map[uint]*scaleSet + // sync.Map[uint]*scaleSet + ScaleSets sync.Map Entity params.ForgeEntity @@ -149,13 +161,16 @@ func (c *Controller) Stop() error { } slog.DebugContext(c.ctx, "stopping scaleset controller", "entity", c.Entity.String()) - for scaleSetID, scaleSet := range c.ScaleSets { - if err := scaleSet.Stop(); err != nil { + c.ScaleSets.Range(func(key, value any) bool { + scaleSetID := key.(uint) + set := value.(*scaleSet) + if err := set.Stop(); err != nil { slog.ErrorContext(c.ctx, "stopping worker for scale set", "scale_set_id", scaleSetID, "error", err) - continue + return true } - delete(c.ScaleSets, scaleSetID) - } + c.ScaleSets.Delete(scaleSetID) + return true + }) c.running = false close(c.quit) @@ -170,16 +185,18 @@ func (c *Controller) Stop() error { // runners in either github or the providers. func (c *Controller) ConsolidateRunnerState(byScaleSetID map[int][]params.RunnerReference) error { g, ctx := errgroup.WithContext(c.ctx) - for _, scaleSet := range c.ScaleSets { - runners := byScaleSetID[scaleSet.scaleSet.ScaleSetID] + c.ScaleSets.Range(func(_, value any) bool { + set := value.(*scaleSet) + runners := byScaleSetID[set.scaleSet.ScaleSetID] g.Go(func() error { - slog.DebugContext(ctx, "consolidating runners for scale set", "scale_set_id", scaleSet.scaleSet.ScaleSetID, "runners", runners) - if err := scaleSet.worker.consolidateRunnerState(runners); err != nil { - return fmt.Errorf("consolidating runners for scale set %d: %w", scaleSet.scaleSet.ScaleSetID, err) + slog.DebugContext(ctx, "consolidating runners for scale set", "scale_set_id", set.scaleSet.ScaleSetID, "runners", runners) + if err := set.worker.consolidateRunnerState(runners); err != nil { + return fmt.Errorf("consolidating runners for scale set %d: %w", set.scaleSet.ScaleSetID, err) } return nil }) - } + return true + }) if err := c.waitForErrorGroupOrContextCancelled(g); err != nil { return fmt.Errorf("waiting for error group: %w", err) } diff --git a/workers/scaleset/controller_watcher.go b/workers/scaleset/controller_watcher.go index e3c32ea6f..51c61d82e 100644 --- a/workers/scaleset/controller_watcher.go +++ b/workers/scaleset/controller_watcher.go @@ -82,10 +82,7 @@ func (c *Controller) createScaleSetWorker(scaleSet params.ScaleSet) (*Worker, er } func (c *Controller) handleScaleSetCreateOperation(sSet params.ScaleSet) error { - c.mux.Lock() - defer c.mux.Unlock() - - if _, ok := c.ScaleSets[sSet.ID]; ok { + if _, ok := c.ScaleSets.Load(sSet.ID); ok { slog.DebugContext(c.ctx, "scale set already exists in worker list", "scale_set_id", sSet.ID) return nil } @@ -102,48 +99,44 @@ func (c *Controller) handleScaleSetCreateOperation(sSet params.ScaleSet) error { // can continue to work. return fmt.Errorf("error starting scale set worker: %w", err) } - c.ScaleSets[sSet.ID] = &scaleSet{ + c.ScaleSets.Store(sSet.ID, &scaleSet{ scaleSet: sSet, worker: worker, - } + }) return nil } func (c *Controller) handleScaleSetDeleteOperation(sSet params.ScaleSet) error { - c.mux.Lock() - defer c.mux.Unlock() - - set, ok := c.ScaleSets[sSet.ID] + val, ok := c.ScaleSets.Load(sSet.ID) if !ok { slog.DebugContext(c.ctx, "scale set not found in worker list", "scale_set_id", sSet.ID) return nil } + set := val.(*scaleSet) slog.DebugContext(c.ctx, "stopping scale set worker", "scale_set_id", sSet.ID) if err := set.worker.Stop(); err != nil { return fmt.Errorf("stopping scale set worker: %w", err) } - delete(c.ScaleSets, sSet.ID) + c.ScaleSets.Delete(sSet.ID) return nil } func (c *Controller) handleScaleSetUpdateOperation(sSet params.ScaleSet) error { - c.mux.Lock() - defer c.mux.Unlock() - - set, ok := c.ScaleSets[sSet.ID] + val, ok := c.ScaleSets.Load(sSet.ID) if !ok { // Some error may have occurred when the scale set was first created, so we // attempt to create it after the user updated the scale set, hopefully // fixing the reason for the failure. return c.handleScaleSetCreateOperation(sSet) } + set := val.(*scaleSet) if set.worker != nil && !set.worker.IsRunning() { worker, err := c.createScaleSetWorker(sSet) if err != nil { return fmt.Errorf("creating scale set worker: %w", err) } - set.worker = worker + set.SetWorker(worker) defer func() { if err := worker.Start(); err != nil { slog.ErrorContext(c.ctx, "failed to start worker", "error", err, "scaleset", sSet.Name) @@ -151,8 +144,8 @@ func (c *Controller) handleScaleSetUpdateOperation(sSet params.ScaleSet) error { }() } - set.scaleSet = sSet - c.ScaleSets[sSet.ID] = set + set.SetScaleSet(sSet) + c.ScaleSets.Store(sSet.ID, set) // We let the watcher in the scale set worker handle the update operation. return nil } @@ -183,8 +176,8 @@ func (c *Controller) handleEntityEvent(event dbCommon.ChangePayload) { case dbCommon.UpdateOperation: slog.DebugContext(c.ctx, "got update operation") c.mux.Lock() - defer c.mux.Unlock() c.Entity = entity + c.mux.Unlock() default: slog.ErrorContext(c.ctx, "invalid operation type", "operation_type", event.Operation) return diff --git a/workers/scaleset/scaleset.go b/workers/scaleset/scaleset.go index 915f4ff9a..24779a48e 100644 --- a/workers/scaleset/scaleset.go +++ b/workers/scaleset/scaleset.go @@ -768,14 +768,18 @@ Loop: slog.DebugContext(w.ctx, "listener is stopped; attempting to restart") w.mux.Lock() if !w.scaleSet.Enabled { - w.listener.Stop() // cleanup + if err := w.listener.Stop(); err != nil { + slog.ErrorContext(w.ctx, "failed to stop listener", "error", err) + } w.mux.Unlock() continue Loop } w.mux.Unlock() for { w.mux.Lock() - w.listener.Stop() // cleanup + if err := w.listener.Stop(); err != nil { + slog.ErrorContext(w.ctx, "failed to stop listener", "error", err) + } if !w.scaleSet.Enabled { w.mux.Unlock() continue Loop diff --git a/workers/scaleset/scaleset_listener.go b/workers/scaleset/scaleset_listener.go index d89d3450c..0d634525d 100644 --- a/workers/scaleset/scaleset_listener.go +++ b/workers/scaleset/scaleset_listener.go @@ -19,6 +19,7 @@ import ( "fmt" "log/slog" "sync" + "sync/atomic" runnerErrors "github.com/cloudbase/garm-provider-common/errors" "github.com/cloudbase/garm/params" @@ -58,7 +59,7 @@ type scaleSetListener struct { messageSession *scalesets.MessageSession mux sync.Mutex - running bool + running atomic.Bool quit chan struct{} loopExited chan struct{} } @@ -68,7 +69,7 @@ func (l *scaleSetListener) Start() error { l.mux.Lock() defer l.mux.Unlock() - if l.running { + if l.running.Load() { return nil } @@ -96,7 +97,7 @@ func (l *scaleSetListener) Start() error { } l.messageSession = session l.quit = make(chan struct{}) - l.running = true + l.running.Store(true) l.loopExited = make(chan struct{}) go l.loop() @@ -107,24 +108,25 @@ func (l *scaleSetListener) Stop() error { l.mux.Lock() defer l.mux.Unlock() - if !l.running { + if !l.running.Load() { return nil } - scaleSetClient, err := l.scaleSetHelper.GetScaleSetClient() - if err != nil { - return fmt.Errorf("getting scale set client: %w", err) - } if l.messageSession != nil { slog.DebugContext(l.ctx, "closing message session", "scale_set", l.scaleSetHelper.GetScaleSet().ScaleSetID) if err := l.messageSession.Close(); err != nil { slog.ErrorContext(l.ctx, "closing message session", "error", err) } - if err := scaleSetClient.DeleteMessageSession(context.Background(), l.messageSession); err != nil { - slog.ErrorContext(l.ctx, "error deleting message session", "error", err) + scaleSetClient, err := l.scaleSetHelper.GetScaleSetClient() + if err != nil { + slog.ErrorContext(l.ctx, "error getting scale set client", "error", err) + } else { + if err := scaleSetClient.DeleteMessageSession(context.Background(), l.messageSession); err != nil { + slog.ErrorContext(l.ctx, "error deleting message session", "error", err) + } } } - l.running = false + l.running.Store(false) close(l.quit) l.cancelFunc() return nil @@ -133,7 +135,7 @@ func (l *scaleSetListener) Stop() error { func (l *scaleSetListener) IsRunning() bool { l.mux.Lock() defer l.mux.Unlock() - return l.running + return l.running.Load() } func (l *scaleSetListener) handleSessionMessage(msg params.RunnerScaleSetMessage) { @@ -297,7 +299,7 @@ func (l *scaleSetListener) Wait() <-chan struct{} { l.mux.Lock() defer l.mux.Unlock() - if !l.running { + if !l.running.Load() { slog.DebugContext(l.ctx, "scale set listener is not running") return closed } diff --git a/workers/websocket/agent/agent.go b/workers/websocket/agent/agent.go index 6d04210fd..1d52c2a68 100644 --- a/workers/websocket/agent/agent.go +++ b/workers/websocket/agent/agent.go @@ -8,6 +8,7 @@ import ( "log/slog" "net" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -70,7 +71,7 @@ type Agent struct { consumerID string consumer common.Consumer - running bool + running atomic.Bool done chan struct{} shellSessions map[string]*ClientSession @@ -130,14 +131,14 @@ func (a *Agent) Done() <-chan struct{} { } func (a *Agent) IsRunning() bool { - return a.running + return a.running.Load() } func (a *Agent) Start() error { a.mux.Lock() defer a.mux.Unlock() - if a.running { + if a.running.Load() { return nil } @@ -157,7 +158,7 @@ func (a *Agent) Start() error { a.consumer = consumer a.done = make(chan struct{}) - a.running = true + a.running.Store(true) go a.agentReader() go a.loop() return nil @@ -167,7 +168,7 @@ func (a *Agent) Stop() error { a.mux.Lock() defer a.mux.Unlock() - if !a.running { + if !a.running.Load() { return nil } slog.InfoContext(a.ctx, "removing sessions") @@ -176,7 +177,7 @@ func (a *Agent) Stop() error { a.RemoveClientSession(val.sessionID, true) } - a.running = false + a.running.Store(false) slog.InfoContext(a.ctx, "sending websocket close message") a.writeMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) slog.InfoContext(a.ctx, "closing connection")