Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions workers/entity/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ func NewController(ctx context.Context, store dbCommon.Store, providers map[stri
ctx: ctx,
store: store,
providers: providers,
Entities: make(map[string]*Worker),
}, nil
}

Expand All @@ -53,7 +52,8 @@ type Controller struct {
store dbCommon.Store

providers map[string]common.Provider
Entities map[string]*Worker
// sync.Map[string]*Worker
Entities sync.Map

running bool
quit chan struct{}
Expand All @@ -62,8 +62,6 @@ type Controller struct {
}

func (c *Controller) loadAllRepositories() error {
c.mux.Lock()
defer c.mux.Unlock()
repos, err := c.store.ListRepositories(c.ctx, params.RepositoryFilter{})
if err != nil {
return fmt.Errorf("fetching repositories: %w", err)
Expand All @@ -83,7 +81,7 @@ func (c *Controller) loadAllRepositories() error {
if err := worker.Start(); err != nil {
return fmt.Errorf("starting worker: %w", err)
}
c.Entities[entity.ID] = worker
c.Entities.Store(entity.ID, worker)
return nil
})
}
Expand All @@ -94,8 +92,6 @@ func (c *Controller) loadAllRepositories() error {
}

func (c *Controller) loadAllOrganizations() error {
c.mux.Lock()
defer c.mux.Unlock()
orgs, err := c.store.ListOrganizations(c.ctx, params.OrganizationFilter{})
if err != nil {
return fmt.Errorf("fetching organizations: %w", err)
Expand All @@ -115,7 +111,7 @@ func (c *Controller) loadAllOrganizations() error {
if err := worker.Start(); err != nil {
return fmt.Errorf("starting worker: %w", err)
}
c.Entities[entity.ID] = worker
c.Entities.Store(entity.ID, worker)
return nil
})
}
Expand All @@ -126,8 +122,6 @@ func (c *Controller) loadAllOrganizations() error {
}

func (c *Controller) loadAllEnterprises() error {
c.mux.Lock()
defer c.mux.Unlock()
enterprises, err := c.store.ListEnterprises(c.ctx, params.EnterpriseFilter{})
if err != nil {
return fmt.Errorf("fetching enterprises: %w", err)
Expand All @@ -148,7 +142,7 @@ func (c *Controller) loadAllEnterprises() error {
if err := worker.Start(); err != nil {
return fmt.Errorf("starting worker: %w", err)
}
c.Entities[entity.ID] = worker
c.Entities.Store(entity.ID, worker)
return nil
})
}
Expand Down Expand Up @@ -220,11 +214,14 @@ func (c *Controller) Stop() error {
}
slog.DebugContext(c.ctx, "stopping entity controller")

for entityID, worker := range c.Entities {
c.Entities.Range(func(key, value any) bool {
entityID := key.(string)
worker := value.(*Worker)
if err := worker.Stop(); err != nil {
slog.ErrorContext(c.ctx, "stopping worker for entity", "entity_id", entityID, "error", err)
}
}
return true
})

c.running = false
close(c.quit)
Expand Down
19 changes: 6 additions & 13 deletions workers/entity/controller_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,12 @@ func (c *Controller) handleWatcherEvent(event dbCommon.ChangePayload) {
}

func (c *Controller) handleWatcherUpdateOperation(entity params.ForgeEntity) {
c.mux.Lock()
defer c.mux.Unlock()

worker, ok := c.Entities[entity.ID]
val, ok := c.Entities.Load(entity.ID)
if !ok {
slog.InfoContext(c.ctx, "entity not found in worker list", "entity_id", entity.ID)
return
}
worker := val.(*Worker)

if worker.IsRunning() {
// The worker is running. It watches for updates to its own entity. We only care about updates
Expand All @@ -104,9 +102,6 @@ func (c *Controller) handleWatcherUpdateOperation(entity params.ForgeEntity) {
}

func (c *Controller) handleWatcherCreateOperation(entity params.ForgeEntity) {
c.mux.Lock()
defer c.mux.Unlock()

worker, err := NewWorker(c.ctx, c.store, entity, c.providers)
if err != nil {
slog.ErrorContext(c.ctx, "creating worker from repository", "entity_type", entity.EntityType, "error", err)
Expand All @@ -119,22 +114,20 @@ func (c *Controller) handleWatcherCreateOperation(entity params.ForgeEntity) {
return
}

c.Entities[entity.ID] = worker
c.Entities.Store(entity.ID, worker)
}

func (c *Controller) handleWatcherDeleteOperation(entity params.ForgeEntity) {
c.mux.Lock()
defer c.mux.Unlock()

worker, ok := c.Entities[entity.ID]
val, ok := c.Entities.Load(entity.ID)
if !ok {
slog.InfoContext(c.ctx, "entity not found in worker list", "entity_id", entity.ID)
return
}
worker := val.(*Worker)
slog.InfoContext(c.ctx, "stopping entity worker", "entity_id", entity.ID, "entity_type", entity.EntityType)
if err := worker.Stop(); err != nil {
slog.ErrorContext(c.ctx, "stopping worker", "entity_id", entity.ID, "error", err)
return
}
delete(c.Entities, entity.ID)
c.Entities.Delete(entity.ID)
}
19 changes: 11 additions & 8 deletions workers/provider/instance_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"

runnerErrors "github.com/cloudbase/garm-provider-common/errors"
Expand Down Expand Up @@ -71,7 +72,7 @@ type instanceManager struct {

updates chan dbCommon.ChangePayload
mux sync.Mutex
running bool
running atomic.Bool
quit chan struct{}
}

Expand All @@ -80,11 +81,11 @@ func (i *instanceManager) Start() error {
defer i.mux.Unlock()

slog.DebugContext(i.ctx, "starting instance manager", "instance", i.instance.Name)
if i.running {
if i.running.Load() {
return nil
}

i.running = true
i.running.Store(true)
i.quit = make(chan struct{})
i.updates = make(chan dbCommon.ChangePayload)

Expand All @@ -97,11 +98,11 @@ func (i *instanceManager) Stop() error {
i.mux.Lock()
defer i.mux.Unlock()

if !i.running {
if !i.running.Load() {
return nil
}

i.running = false
i.running.Store(false)
close(i.quit)
close(i.updates)
return nil
Expand Down Expand Up @@ -292,7 +293,7 @@ func (i *instanceManager) consolidateState() error {
i.mux.Lock()
defer i.mux.Unlock()

if !i.running {
if !i.running.Load() {
return nil
}

Expand Down Expand Up @@ -365,7 +366,7 @@ func (i *instanceManager) handleUpdate(update dbCommon.ChangePayload) error {
// We need a better way to handle instance state. Database updates may fail, and we
// end up with an inconsistent state between what we know about the instance and what
// is reflected in the database.
if !i.running {
if !i.running.Load() {
return nil
}

Expand All @@ -374,12 +375,14 @@ func (i *instanceManager) handleUpdate(update dbCommon.ChangePayload) error {
return runnerErrors.NewBadRequestError("invalid payload type")
}

i.mux.Lock()
i.instance = instance
i.mux.Unlock()
return nil
}

func (i *instanceManager) Update(instance dbCommon.ChangePayload) error {
if !i.running {
if !i.running.Load() {
return runnerErrors.NewBadRequestError("instance manager is not running")
}

Expand Down
54 changes: 29 additions & 25 deletions workers/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ func NewWorker(ctx context.Context, store dbCommon.Store, providers map[string]c
consumerID: consumerID,
providers: providers,
tokenGetter: tokenGetter,
scaleSets: make(map[uint]params.ScaleSet),
runners: make(map[string]*instanceManager),
}, nil
}

Expand All @@ -62,8 +60,10 @@ type Provider struct {
providers map[string]common.Provider
// A cache of all scale sets kept updated by the watcher.
// This helps us avoid a bunch of queries to the database.
scaleSets map[uint]params.ScaleSet
runners map[string]*instanceManager
// sync.Map[uint]params.ScaleSet
scaleSets sync.Map
// sync.Map[string]*instanceManager
runners sync.Map

mux sync.Mutex
running bool
Expand All @@ -77,7 +77,7 @@ func (p *Provider) loadAllScaleSets() error {
}

for _, scaleSet := range scaleSets {
p.scaleSets[scaleSet.ID] = scaleSet
p.scaleSets.Store(scaleSet.ID, scaleSet)
}

return nil
Expand Down Expand Up @@ -111,11 +111,12 @@ func (p *Provider) loadAllRunners() error {
continue
}

scaleSet, ok := p.scaleSets[runner.ScaleSetID]
val, ok := p.scaleSets.Load(runner.ScaleSetID)
if !ok {
slog.ErrorContext(p.ctx, "scale set not found", "scale_set_id", runner.ScaleSetID)
continue
}
scaleSet := val.(params.ScaleSet)
provider, ok := p.providers[scaleSet.ProviderName]
if !ok {
slog.ErrorContext(p.ctx, "provider not found", "provider_name", runner.ProviderName)
Expand All @@ -130,7 +131,7 @@ func (p *Provider) loadAllRunners() error {
return fmt.Errorf("starting instance manager: %w", err)
}

p.runners[runner.Name] = instanceManager
p.runners.Store(runner.Name, instanceManager)
}

return nil
Expand Down Expand Up @@ -211,9 +212,6 @@ func (p *Provider) handleWatcherEvent(payload dbCommon.ChangePayload) {
}

func (p *Provider) handleScaleSetEvent(event dbCommon.ChangePayload) {
p.mux.Lock()
defer p.mux.Unlock()

scaleSet, ok := event.Payload.(params.ScaleSet)
if !ok {
slog.ErrorContext(p.ctx, "invalid payload type", "payload_type", fmt.Sprintf("%T", event.Payload))
Expand All @@ -223,21 +221,22 @@ func (p *Provider) handleScaleSetEvent(event dbCommon.ChangePayload) {
switch event.Operation {
case dbCommon.CreateOperation, dbCommon.UpdateOperation:
slog.DebugContext(p.ctx, "got create/update operation")
p.scaleSets[scaleSet.ID] = scaleSet
p.scaleSets.Store(scaleSet.ID, scaleSet)
case dbCommon.DeleteOperation:
slog.DebugContext(p.ctx, "got delete operation")
delete(p.scaleSets, scaleSet.ID)
p.scaleSets.Delete(scaleSet.ID)
default:
slog.ErrorContext(p.ctx, "invalid operation type", "operation_type", event.Operation)
return
}
}

func (p *Provider) handleInstanceAdded(instance params.Instance) error {
scaleSet, ok := p.scaleSets[instance.ScaleSetID]
val, ok := p.scaleSets.Load(instance.ScaleSetID)
if !ok {
return fmt.Errorf("scale set not found for instance %s", instance.Name)
}
scaleSet := val.(params.ScaleSet)
instanceManager, err := newInstanceManager(
p.ctx, instance, scaleSet, p.providers[instance.ProviderName], p)
if err != nil {
Expand All @@ -246,28 +245,27 @@ func (p *Provider) handleInstanceAdded(instance params.Instance) error {
if err := instanceManager.Start(); err != nil {
return fmt.Errorf("starting instance manager: %w", err)
}
p.runners[instance.Name] = instanceManager
p.runners.Store(instance.Name, instanceManager)
return nil
}

func (p *Provider) stopAndDeleteInstance(instance params.Instance) error {
if instance.Status != commonParams.InstanceDeleted {
return nil
}
existingInstance, ok := p.runners[instance.Name]
if ok {
if err := existingInstance.Stop(); err != nil {
return fmt.Errorf("failed to stop instance manager: %w", err)
}
delete(p.runners, instance.Name)
val, ok := p.runners.Load(instance.Name)
if !ok {
return nil
}
existingInstance := val.(*instanceManager)
if err := existingInstance.Stop(); err != nil {
return fmt.Errorf("failed to stop instance manager: %w", err)
}
p.runners.Delete(instance.Name)
return nil
}

func (p *Provider) handleInstanceEvent(event dbCommon.ChangePayload) {
p.mux.Lock()
defer p.mux.Unlock()

instance, ok := event.Payload.(params.Instance)
if !ok {
slog.ErrorContext(p.ctx, "invalid payload type", "payload_type", fmt.Sprintf("%T", event.Payload))
Expand All @@ -289,19 +287,25 @@ func (p *Provider) handleInstanceEvent(event dbCommon.ChangePayload) {
}
case dbCommon.UpdateOperation:
slog.DebugContext(p.ctx, "got update operation")
existingInstance, ok := p.runners[instance.Name]

val, ok := p.runners.Load(instance.Name)

if !ok {
if instance.Status == commonParams.InstanceDeleted {
// No manager running for this instance and it's already deleted.
return
}
slog.DebugContext(p.ctx, "instance not found, creating new instance", "instance_name", instance.Name)
if err := p.handleInstanceAdded(instance); err != nil {
slog.ErrorContext(p.ctx, "failed to handle instance added", "error", err)
return
}
} else {
existingInstance := val.(*instanceManager)
slog.DebugContext(p.ctx, "updating instance", "instance_name", instance.Name)
if instance.Status == commonParams.InstanceDeleted {
if err := p.stopAndDeleteInstance(instance); err != nil {
slog.ErrorContext(p.ctx, "failed to clean up instance manager", "error", err)
return
}
return
}
Expand Down
8 changes: 1 addition & 7 deletions workers/provider/provider_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ func (p *Provider) updateArgsFromProviderInstance(instanceName string, providerI
}

func (p *Provider) GetControllerInfo() (params.ControllerInfo, error) {
p.mux.Lock()
defer p.mux.Unlock()

info, err := p.store.ControllerInfo()
if err != nil {
return params.ControllerInfo{}, fmt.Errorf("getting controller info: %w", err)
Expand All @@ -60,10 +57,7 @@ func (p *Provider) GetControllerInfo() (params.ControllerInfo, error) {
}

func (p *Provider) SetInstanceStatus(instanceName string, status commonParams.InstanceStatus, providerFault []byte, force bool) error {
p.mux.Lock()
defer p.mux.Unlock()

if _, ok := p.runners[instanceName]; !ok {
if _, ok := p.runners.Load(instanceName); !ok {
return errors.ErrNotFound
}

Expand Down
Loading
Loading