diff --git a/provider/cloudflare/cloudflare.go b/provider/cloudflare/cloudflare.go index af827daa1c..94aefa151e 100644 --- a/provider/cloudflare/cloudflare.go +++ b/provider/cloudflare/cloudflare.go @@ -81,15 +81,52 @@ type DNSRecordIndex struct { Content string } +func newDNSRecordIndex(r cloudflare.DNSRecord) DNSRecordIndex { + return DNSRecordIndex{Name: r.Name, Type: r.Type, Content: r.Content} +} + type DNSRecordsMap map[DNSRecordIndex]cloudflare.DNSRecord -// for faster getCustomHostname() lookup +func (m DNSRecordsMap) GetRecordID(record cloudflare.DNSRecord) string { + return m[newDNSRecordIndex(record)].ID +} + +func (m DNSRecordsMap) Set(record cloudflare.DNSRecord) { + m[newDNSRecordIndex(record)] = record +} + +func (m DNSRecordsMap) Delete(record cloudflare.DNSRecord) { + delete(m, newDNSRecordIndex(record)) +} + type CustomHostnameIndex struct { Hostname string } +func newCustomHostnameIndex(ch cloudflare.CustomHostname) CustomHostnameIndex { + return CustomHostnameIndex{Hostname: ch.Hostname} +} + type CustomHostnamesMap map[CustomHostnameIndex]cloudflare.CustomHostname +func (m CustomHostnamesMap) Get(hostname string) (cloudflare.CustomHostname, error) { + if hostname == "" { + return cloudflare.CustomHostname{}, fmt.Errorf("failed to get custom hostname: %q is empty", hostname) + } + if ch, ok := m[CustomHostnameIndex{Hostname: hostname}]; ok { + return ch, nil + } + return cloudflare.CustomHostname{}, fmt.Errorf("failed to get custom hostname: %q not found", hostname) +} + +func (m CustomHostnamesMap) Set(ch cloudflare.CustomHostname) { + m[newCustomHostnameIndex(ch)] = ch +} + +func (m CustomHostnamesMap) Delete(ch cloudflare.CustomHostname) { + delete(m, newCustomHostnameIndex(ch)) +} + var recordTypeProxyNotSupported = map[string]bool{ "LOC": true, "MX": true, @@ -120,7 +157,7 @@ type cloudFlareDNS interface { ListDNSRecords(ctx context.Context, rc *cloudflare.ResourceContainer, rp cloudflare.ListDNSRecordsParams) ([]cloudflare.DNSRecord, *cloudflare.ResultInfo, error) CreateDNSRecord(ctx context.Context, rc *cloudflare.ResourceContainer, rp cloudflare.CreateDNSRecordParams) (cloudflare.DNSRecord, error) DeleteDNSRecord(ctx context.Context, rc *cloudflare.ResourceContainer, recordID string) error - UpdateDNSRecord(ctx context.Context, rc *cloudflare.ResourceContainer, rp cloudflare.UpdateDNSRecordParams) error + UpdateDNSRecord(ctx context.Context, rc *cloudflare.ResourceContainer, rp cloudflare.UpdateDNSRecordParams) (cloudflare.DNSRecord, error) CreateDataLocalizationRegionalHostname(ctx context.Context, rc *cloudflare.ResourceContainer, rp cloudflare.CreateDataLocalizationRegionalHostnameParams) error UpdateDataLocalizationRegionalHostname(ctx context.Context, rc *cloudflare.ResourceContainer, rp cloudflare.UpdateDataLocalizationRegionalHostnameParams) error DeleteDataLocalizationRegionalHostname(ctx context.Context, rc *cloudflare.ResourceContainer, hostname string) error @@ -153,9 +190,9 @@ func (z zoneService) ListDNSRecords(ctx context.Context, rc *cloudflare.Resource return z.service.ListDNSRecords(ctx, rc, rp) } -func (z zoneService) UpdateDNSRecord(ctx context.Context, rc *cloudflare.ResourceContainer, rp cloudflare.UpdateDNSRecordParams) error { - _, err := z.service.UpdateDNSRecord(ctx, rc, rp) - return err +func (z zoneService) UpdateDNSRecord(ctx context.Context, rc *cloudflare.ResourceContainer, rp cloudflare.UpdateDNSRecordParams) (cloudflare.DNSRecord, error) { + return z.service.UpdateDNSRecord(ctx, rc, rp) + } func (z zoneService) DeleteDNSRecord(ctx context.Context, rc *cloudflare.ResourceContainer, recordID string) error { @@ -434,6 +471,58 @@ func (p *CloudFlareProvider) ApplyChanges(ctx context.Context, changes *plan.Cha return p.submitChanges(ctx, cloudflareChanges) } +// createDNSRecord creates a DNS record in the specified zone and updates the DNSRecordsMap. +func (p *CloudFlareProvider) createDNSRecord(ctx context.Context, resourceContainer *cloudflare.ResourceContainer, change *cloudFlareChange, records DNSRecordsMap) error { + record, err := p.Client.CreateDNSRecord(ctx, resourceContainer, getCreateDNSRecordParam(*change)) + if err != nil { + return err + } + records.Set(record) + return nil +} + +// updateDNSRecord updates a DNS record in the specified zone and updates the DNSRecordsMap. +func (p *CloudFlareProvider) updateDNSRecord(ctx context.Context, resourceContainer *cloudflare.ResourceContainer, recordId string, change *cloudFlareChange, records DNSRecordsMap) error { + params := updateDNSRecordParam(*change) + params.ID = recordId + record, err := p.Client.UpdateDNSRecord(ctx, resourceContainer, params) + if err != nil { + return err + } + records.Set(record) + return nil +} + +// deleteDNSRecord deletes a DNS record in the specified zone and updates the DNSRecordsMap. +func (p *CloudFlareProvider) deleteDNSRecord(ctx context.Context, resourceContainer *cloudflare.ResourceContainer, recordID string, change *cloudFlareChange, records DNSRecordsMap) error { + err := p.Client.DeleteDNSRecord(ctx, resourceContainer, recordID) + if err != nil { + return err + } + records.Delete(change.ResourceRecord) + return nil +} + +// createCustomHostname creates a custom hostname in the specified zone and updates the CustomHostnamesMap. +func (p *CloudFlareProvider) createCustomHostname(ctx context.Context, zoneID string, customHostname cloudflare.CustomHostname, customHostnames CustomHostnamesMap) error { + resp, err := p.Client.CreateCustomHostname(ctx, zoneID, customHostname) + if err != nil { + return err + } + customHostnames.Set(resp.Result) + return nil +} + +// deleteCustomHostname deletes a custom hostname in the specified zone and updates the CustomHostnamesMap. +func (p *CloudFlareProvider) deleteCustomHostname(ctx context.Context, zoneID string, customHostname cloudflare.CustomHostname, customHostnames CustomHostnamesMap) error { + err := p.Client.DeleteCustomHostname(ctx, zoneID, customHostname.ID) + if err != nil { + return err + } + customHostnames.Delete(customHostname) + return nil +} + // submitCustomHostnameChanges implements Custom Hostname functionality for the Change, returns false if it fails func (p *CloudFlareProvider) submitCustomHostnameChanges(ctx context.Context, zoneID string, change *cloudFlareChange, chs CustomHostnamesMap, logFields log.Fields) bool { failedChange := false @@ -448,11 +537,11 @@ func (p *CloudFlareProvider) submitCustomHostnameChanges(ctx context.Context, zo add, remove, _ := provider.Difference(change.CustomHostnamesPrev, slices.Collect(maps.Keys(change.CustomHostnames))) for _, changeCH := range remove { - if prevCh, err := getCustomHostname(chs, changeCH); err == nil { + if prevCh, err := chs.Get(changeCH); err == nil { prevChID := prevCh.ID if prevChID != "" { log.WithFields(logFields).Infof("Removing previous custom hostname %q/%q", prevChID, changeCH) - chErr := p.Client.DeleteCustomHostname(ctx, zoneID, prevChID) + chErr := p.deleteCustomHostname(ctx, zoneID, prevCh, chs) if chErr != nil { failedChange = true log.WithFields(logFields).Errorf("failed to remove previous custom hostname %q/%q: %v", prevChID, changeCH, chErr) @@ -462,7 +551,7 @@ func (p *CloudFlareProvider) submitCustomHostnameChanges(ctx context.Context, zo } for _, changeCH := range add { log.WithFields(logFields).Infof("Adding custom hostname %q", changeCH) - _, chErr := p.Client.CreateCustomHostname(ctx, zoneID, change.CustomHostnames[changeCH]) + chErr := p.createCustomHostname(ctx, zoneID, change.CustomHostnames[changeCH], chs) if chErr != nil { failedChange = true log.WithFields(logFields).Errorf("failed to add custom hostname %q: %v", changeCH, chErr) @@ -473,9 +562,9 @@ func (p *CloudFlareProvider) submitCustomHostnameChanges(ctx context.Context, zo for _, changeCH := range change.CustomHostnames { if recordTypeCustomHostnameSupported[change.ResourceRecord.Type] && changeCH.Hostname != "" { log.WithFields(logFields).Infof("Deleting custom hostname %q", changeCH.Hostname) - if ch, err := getCustomHostname(chs, changeCH.Hostname); err == nil { + if ch, err := chs.Get(changeCH.Hostname); err == nil { chID := ch.ID - chErr := p.Client.DeleteCustomHostname(ctx, zoneID, chID) + chErr := p.createCustomHostname(ctx, zoneID, ch, chs) if chErr != nil { failedChange = true log.WithFields(logFields).Errorf("failed to delete custom hostname %q/%q: %v", chID, changeCH.Hostname, chErr) @@ -489,7 +578,7 @@ func (p *CloudFlareProvider) submitCustomHostnameChanges(ctx context.Context, zo for _, changeCH := range change.CustomHostnames { if recordTypeCustomHostnameSupported[change.ResourceRecord.Type] && changeCH.Hostname != "" { log.WithFields(logFields).Infof("Creating custom hostname %q", changeCH.Hostname) - if ch, err := getCustomHostname(chs, changeCH.Hostname); err == nil { + if ch, err := chs.Get(changeCH.Hostname); err == nil { if changeCH.CustomOriginServer == ch.CustomOriginServer { log.WithFields(logFields).Warnf("custom hostname %q already exists with the same origin %q, continue", changeCH.Hostname, ch.CustomOriginServer) } else { @@ -497,7 +586,7 @@ func (p *CloudFlareProvider) submitCustomHostnameChanges(ctx context.Context, zo log.WithFields(logFields).Errorf("failed to create custom hostname, %q already exists with origin %q", changeCH.Hostname, ch.CustomOriginServer) } } else { - _, chErr := p.Client.CreateCustomHostname(ctx, zoneID, changeCH) + chErr := p.createCustomHostname(ctx, zoneID, changeCH, chs) if chErr != nil { failedChange = true log.WithFields(logFields).Errorf("failed to create custom hostname %q: %v", changeCH.Hostname, chErr) @@ -527,8 +616,22 @@ func (p *CloudFlareProvider) submitChanges(ctx context.Context, changes []*cloud var failedZones []string for zoneID, zoneChanges := range changesByZone { var failedChange bool + var records DNSRecordsMap + var chs CustomHostnamesMap resourceContainer := cloudflare.ZoneIdentifier(zoneID) + if !p.DryRun { + records, err = p.listDNSRecordsWithAutoPagination(ctx, zoneID) + if err != nil { + return fmt.Errorf("could not fetch records from zone, %w", err) + } + + chs, err = p.listCustomHostnamesWithPagination(ctx, zoneID) + if err != nil { + return fmt.Errorf("could not fetch custom hostnames from zone, %v", err) + } + } + for _, change := range zoneChanges { logFields := log.Fields{ "record": change.ResourceRecord.Name, @@ -544,37 +647,28 @@ func (p *CloudFlareProvider) submitChanges(ctx context.Context, changes []*cloud continue } - records, err := p.listDNSRecordsWithAutoPagination(ctx, zoneID) - if err != nil { - return fmt.Errorf("could not fetch records from zone, %w", err) - } - chs, chErr := p.listCustomHostnamesWithPagination(ctx, zoneID) - if chErr != nil { - return fmt.Errorf("could not fetch custom hostnames from zone, %w", chErr) - } if change.Action == cloudFlareUpdate { if !p.submitCustomHostnameChanges(ctx, zoneID, change, chs, logFields) { failedChange = true } - recordID := p.getRecordID(records, change.ResourceRecord) + recordID := records.GetRecordID(change.ResourceRecord) if recordID == "" { log.WithFields(logFields).Errorf("failed to find previous record: %v", change.ResourceRecord) continue } - recordParam := updateDNSRecordParam(*change) - recordParam.ID = recordID - err := p.Client.UpdateDNSRecord(ctx, resourceContainer, recordParam) + err := p.updateDNSRecord(ctx, resourceContainer, recordID, change, records) if err != nil { failedChange = true log.WithFields(logFields).Errorf("failed to update record: %v", err) } } else if change.Action == cloudFlareDelete { - recordID := p.getRecordID(records, change.ResourceRecord) + recordID := records.GetRecordID(change.ResourceRecord) if recordID == "" { log.WithFields(logFields).Errorf("failed to find previous record: %v", change.ResourceRecord) continue } - err := p.Client.DeleteDNSRecord(ctx, resourceContainer, recordID) + + err := p.deleteDNSRecord(ctx, resourceContainer, recordID, change, records) if err != nil { failedChange = true log.WithFields(logFields).Errorf("failed to delete record: %v", err) @@ -583,8 +677,7 @@ func (p *CloudFlareProvider) submitChanges(ctx context.Context, changes []*cloud failedChange = true } } else if change.Action == cloudFlareCreate { - recordParam := getCreateDNSRecordParam(*change) - _, err := p.Client.CreateDNSRecord(ctx, resourceContainer, recordParam) + err := p.createDNSRecord(ctx, resourceContainer, change, records) if err != nil { failedChange = true log.WithFields(logFields).Errorf("failed to create record: %v", err) @@ -667,23 +760,6 @@ func (p *CloudFlareProvider) changesByZone(zones []cloudflare.Zone, changeSet [] return changes } -func (p *CloudFlareProvider) getRecordID(records DNSRecordsMap, record cloudflare.DNSRecord) string { - if zoneRecord, ok := records[DNSRecordIndex{Name: record.Name, Type: record.Type, Content: record.Content}]; ok { - return zoneRecord.ID - } - return "" -} - -func getCustomHostname(chs CustomHostnamesMap, chName string) (cloudflare.CustomHostname, error) { - if chName == "" { - return cloudflare.CustomHostname{}, fmt.Errorf("failed to get custom hostname: %q is empty", chName) - } - if ch, ok := chs[CustomHostnameIndex{Hostname: chName}]; ok { - return ch, nil - } - return cloudflare.CustomHostname{}, fmt.Errorf("failed to get custom hostname: %q not found", chName) -} - func (p *CloudFlareProvider) newCustomHostname(customHostname string, origin string) cloudflare.CustomHostname { return cloudflare.CustomHostname{ Hostname: customHostname, @@ -740,10 +816,6 @@ func (p *CloudFlareProvider) newCloudFlareChange(action changeAction, ep *endpoi } } -func newDNSRecordIndex(r cloudflare.DNSRecord) DNSRecordIndex { - return DNSRecordIndex{Name: r.Name, Type: r.Type, Content: r.Content} -} - // listDNSRecordsWithAutoPagination performs automatic pagination of results on requests to cloudflare.ListDNSRecords with custom per_page values func (p *CloudFlareProvider) listDNSRecordsWithAutoPagination(ctx context.Context, zoneID string) (DNSRecordsMap, error) { // for faster getRecordID lookup @@ -764,7 +836,7 @@ func (p *CloudFlareProvider) listDNSRecordsWithAutoPagination(ctx context.Contex } for _, r := range pageRecords { - records[newDNSRecordIndex(r)] = r + records.Set(r) } params.ResultInfo = resultInfo.Next() if params.Done() { @@ -774,10 +846,6 @@ func (p *CloudFlareProvider) listDNSRecordsWithAutoPagination(ctx context.Contex return records, nil } -func newCustomHostnameIndex(ch cloudflare.CustomHostname) CustomHostnameIndex { - return CustomHostnameIndex{Hostname: ch.Hostname} -} - // listCustomHostnamesWithPagination performs automatic pagination of results on requests to cloudflare.CustomHostnames func (p *CloudFlareProvider) listCustomHostnamesWithPagination(ctx context.Context, zoneID string) (CustomHostnamesMap, error) { if !p.CustomHostnamesConfig.Enabled { diff --git a/provider/cloudflare/cloudflare_test.go b/provider/cloudflare/cloudflare_test.go index 97c20bb4d8..b1455a06ec 100644 --- a/provider/cloudflare/cloudflare_test.go +++ b/provider/cloudflare/cloudflare_test.go @@ -210,7 +210,7 @@ func (m *mockCloudFlareClient) ListDNSRecords(ctx context.Context, rc *cloudflar }, nil } -func (m *mockCloudFlareClient) UpdateDNSRecord(ctx context.Context, rc *cloudflare.ResourceContainer, rp cloudflare.UpdateDNSRecordParams) error { +func (m *mockCloudFlareClient) UpdateDNSRecord(ctx context.Context, rc *cloudflare.ResourceContainer, rp cloudflare.UpdateDNSRecordParams) (cloudflare.DNSRecord, error) { recordData := getDNSRecordFromRecordParams(rp) m.Actions = append(m.Actions, MockAction{ Name: "Update", @@ -221,12 +221,12 @@ func (m *mockCloudFlareClient) UpdateDNSRecord(ctx context.Context, rc *cloudfla if zone, ok := m.Records[rc.Identifier]; ok { if _, ok := zone[rp.ID]; ok { if strings.HasPrefix(recordData.Name, "newerror-update-") { - return errors.New("failed to update erroring DNS record") + return cloudflare.DNSRecord{}, errors.New("failed to update erroring DNS record") } zone[rp.ID] = recordData } } - return nil + return recordData, nil } func (m *mockCloudFlareClient) CreateDataLocalizationRegionalHostname(ctx context.Context, rc *cloudflare.ResourceContainer, rp cloudflare.CreateDataLocalizationRegionalHostnameParams) error { @@ -345,7 +345,12 @@ func (m *mockCloudFlareClient) CreateCustomHostname(ctx context.Context, zoneID var newCustomHostname cloudflare.CustomHostname = ch newCustomHostname.ID = fmt.Sprintf("ID-%s", ch.Hostname) m.customHostnames[zoneID] = append(m.customHostnames[zoneID], newCustomHostname) - return &cloudflare.CustomHostnameResponse{}, nil + return &cloudflare.CustomHostnameResponse{ + Result: newCustomHostname, + Response: cloudflare.Response{ + Success: true, + }, + }, nil } func (m *mockCloudFlareClient) DeleteCustomHostname(ctx context.Context, zoneID string, customHostnameID string) error { @@ -1224,7 +1229,6 @@ func TestCloudflareApplyChangesError(t *testing.T) { } func TestCloudflareGetRecordID(t *testing.T) { - p := &CloudFlareProvider{} recordsMap := DNSRecordsMap{ {Name: "foo.com", Type: endpoint.RecordTypeCNAME, Content: "foobar"}: { Name: "foo.com", @@ -1245,29 +1249,29 @@ func TestCloudflareGetRecordID(t *testing.T) { }, } - assert.Empty(t, p.getRecordID(recordsMap, cloudflare.DNSRecord{ + assert.Empty(t, recordsMap.GetRecordID(cloudflare.DNSRecord{ Name: "foo.com", Type: endpoint.RecordTypeA, Content: "foobar", })) - assert.Empty(t, p.getRecordID(recordsMap, cloudflare.DNSRecord{ + assert.Empty(t, recordsMap.GetRecordID(cloudflare.DNSRecord{ Name: "foo.com", Type: endpoint.RecordTypeCNAME, Content: "fizfuz", })) - assert.Equal(t, "1", p.getRecordID(recordsMap, cloudflare.DNSRecord{ + assert.Equal(t, "1", recordsMap.GetRecordID(cloudflare.DNSRecord{ Name: "foo.com", Type: endpoint.RecordTypeCNAME, Content: "foobar", })) - assert.Empty(t, p.getRecordID(recordsMap, cloudflare.DNSRecord{ + assert.Empty(t, recordsMap.GetRecordID(cloudflare.DNSRecord{ Name: "bar.de", Type: endpoint.RecordTypeA, Content: "2.3.4.5", })) - assert.Equal(t, "2", p.getRecordID(recordsMap, cloudflare.DNSRecord{ + assert.Equal(t, "2", recordsMap.GetRecordID(cloudflare.DNSRecord{ Name: "bar.de", Type: endpoint.RecordTypeA, Content: "1.2.3.4", @@ -1507,7 +1511,7 @@ func TestCloudflareGroupByNameAndType(t *testing.T) { for _, tc := range testCases { records := make(DNSRecordsMap) for _, r := range tc.Records { - records[newDNSRecordIndex(r)] = r + records.Set(r) } endpoints := groupByNameAndTypeWithCustomHostnames(records, CustomHostnamesMap{}) // Targets order could be random with underlying map @@ -2898,7 +2902,7 @@ func TestCloudflareCustomHostnameNotFoundOnRecordDeletion(t *testing.T) { t.Error(e) } if tc.preApplyHook == "corrupt" { - if ch, err := getCustomHostname(chs, "newerror-getCustomHostnameOrigin.foo.fancybar.com"); errors.Is(err, nil) { + if ch, err := chs.Get("newerror-getCustomHostnameOrigin.foo.fancybar.com"); errors.Is(err, nil) { chID := ch.ID t.Logf("corrupting custom hostname %q", chID) oldIdx := getCustomHostnameIdxByID(client.customHostnames[zoneID], chID)