From 2832746a473e43d7c5b39fced52e1fe9ae8efcbd Mon Sep 17 00:00:00 2001 From: Georg Date: Wed, 21 Jul 2021 17:44:10 +0200 Subject: [PATCH] deSEC implement pagination (#1208) * deSEC: Implement pagination for domain list #1177 * deSEC: add debug logging for pagination * deSEC: simplify get/post methods by allowing url / api endpoints as target * deSEC: implement pagination for getRecords function * deSEC: fix linter warnings * deSEC: replace domainIndexInitalized variable with checking if the domainIndex == nil * deSEC: add mutex for domainIndex Co-authored-by: Tom Limoncelli --- providers/desec/desecProvider.go | 12 +- providers/desec/protocol.go | 185 +++++++++++++++++++++++++------ 2 files changed, 160 insertions(+), 37 deletions(-) diff --git a/providers/desec/desecProvider.go b/providers/desec/desecProvider.go index e59bc7dd6..fa62afe6a 100644 --- a/providers/desec/desecProvider.go +++ b/providers/desec/desecProvider.go @@ -24,13 +24,16 @@ Info required in `creds.json`: func NewDeSec(m map[string]string, metadata json.RawMessage) (providers.DNSServiceProvider, error) { c := &desecProvider{} c.creds.token = m["auth-token"] - c.domainIndex = map[string]uint32{} if c.creds.token == "" { return nil, fmt.Errorf("missing deSEC auth-token") } if err := c.authenticate(); err != nil { return nil, fmt.Errorf("authentication failed") } + //DomainIndex is used for corrections (minttl) and domain creation + if err := c.initializeDomainIndex(); err != nil { + return nil, err + } return c, nil } @@ -81,11 +84,13 @@ func (c *desecProvider) GetDomainCorrections(dc *models.DomainConfig) ([]*models models.PostProcessRecords(existing) clean := PrepFoundRecords(existing) var minTTL uint32 + c.mutex.Lock() if ttl, ok := c.domainIndex[dc.Name]; !ok { minTTL = 3600 } else { minTTL = ttl } + c.mutex.Unlock() PrepDesiredRecords(dc, minTTL) return c.GenerateDomainCorrections(dc, clean) } @@ -108,10 +113,9 @@ func (c *desecProvider) GetZoneRecords(domain string) (models.Records, error) { // EnsureDomainExists returns an error if domain doesn't exist. func (c *desecProvider) EnsureDomainExists(domain string) error { - if err := c.fetchDomain(domain); err != nil { - return err - } // domain already exists + c.mutex.Lock() + defer c.mutex.Unlock() if _, ok := c.domainIndex[domain]; ok { return nil } diff --git a/providers/desec/protocol.go b/providers/desec/protocol.go index 1a580b35d..df1f7413e 100644 --- a/providers/desec/protocol.go +++ b/providers/desec/protocol.go @@ -6,7 +6,10 @@ import ( "fmt" "io/ioutil" "net/http" + "regexp" "strconv" + "strings" + "sync" "time" "github.com/StackExchange/dnscontrol/v3/pkg/printer" @@ -16,14 +19,14 @@ const apiBase = "https://desec.io/api/v1" // Api layer for desec type desecProvider struct { - domainIndex map[string]uint32 //stores the minimum ttl of each domain. (key = domain and value = ttl) - nameserversNames []string - creds struct { + domainIndex map[string]uint32 //stores the minimum ttl of each domain. (key = domain and value = ttl) + creds struct { tokenid string token string user string password string } + mutex sync.Mutex } type domainObject struct { @@ -71,37 +74,141 @@ func (c *desecProvider) authenticate() error { } return nil } +func (c *desecProvider) initializeDomainIndex() error { + c.mutex.Lock() + defer c.mutex.Unlock() + if c.domainIndex != nil { + return nil + } + endpoint := "/domains/" + var bodyString, resp, err = c.get(endpoint, "GET") + if resp.StatusCode == 400 && resp.Header.Get("Link") != "" { + //pagination is required + links := c.convertLinks(resp.Header.Get("Link")) + endpoint = links["first"] + printer.Debugf("initial endpoint %s\n", endpoint) + for endpoint != "" { + bodyString, resp, err = c.get(endpoint, "GET") + if err != nil { + if resp.StatusCode == 404 { + return nil + } + return fmt.Errorf("failed fetching domains: %s", err) + } + err = c.buildIndexFromResponse(bodyString) + if err != nil { + return fmt.Errorf("failed fetching domains: %s", err) + } + links = c.convertLinks(resp.Header.Get("Link")) + endpoint = links["next"] + printer.Debugf("next endpoint %s\n", endpoint) + } + printer.Debugf("Domain Index initilized with pagination (%d domains)\n", len(c.domainIndex)) + return nil //domainIndex was build using pagination without errors + } -func (c *desecProvider) fetchDomain(domain string) error { - endpoint := fmt.Sprintf("/domains/%s", domain) - var dr domainObject - var bodyString, statuscode, err = c.get(endpoint, "GET") - if err != nil { - if statuscode == 404 { + //no pagination required + if err != nil && resp.StatusCode != 400 { + if resp.StatusCode == 404 { return nil } - return fmt.Errorf("Failed fetching domain: %s", err) + return fmt.Errorf("failed fetching domains: %s", err) } - err = json.Unmarshal(bodyString, &dr) + err = c.buildIndexFromResponse(bodyString) + if err == nil { + printer.Debugf("Domain Index initilized without pagination (%d domains)\n", len(c.domainIndex)) + } + return err +} + +//buildIndexFromResponse takes the bodyString from initializeDomainIndex and builds the domainIndex +func (c *desecProvider) buildIndexFromResponse(bodyString []byte) error { + if c.domainIndex == nil { + c.domainIndex = map[string]uint32{} + } + var dr []domainObject + err := json.Unmarshal(bodyString, &dr) if err != nil { return err } - - //deSEC allows different minimum ttls per domain - //we store the actual minimum ttl to use it in desecProvider.go GetDomainCorrections() to enforce the minimum ttl and avoid api errors. - c.domainIndex[dr.Name] = dr.MinimumTTL + for _, domain := range dr { + //deSEC allows different minimum ttls per domain + //we store the actual minimum ttl to use it in desecProvider.go GetDomainCorrections() to enforce the minimum ttl and avoid api errors. + c.domainIndex[domain.Name] = domain.MinimumTTL + } return nil } +//Parses the Link Header into a map (https://github.com/desec-io/desec-tools/blob/master/fetch_zone.py#L13) +func (c *desecProvider) convertLinks(links string) map[string]string { + mapping := make(map[string]string) + printer.Debugf("Header: %s\n", links) + for _, link := range strings.Split(links, ", ") { + tmpurl := strings.Split(link, "; ") + if len(tmpurl) != 2 { + fmt.Printf("unexpected link header %s", link) + continue + } + r := regexp.MustCompile(`rel="(.*)"`) + matches := r.FindStringSubmatch(tmpurl[1]) + if len(matches) != 2 { + fmt.Printf("unexpected label %s", tmpurl[1]) + continue + } + // mapping["$label"] = "$URL" + //URL = https://desec.io/api/v1/domains/{domain}/rrsets/?cursor=:next_cursor + mapping[matches[1]] = strings.TrimSuffix(strings.TrimPrefix(tmpurl[0], "<"), ">") + } + return mapping +} + func (c *desecProvider) getRecords(domain string) ([]resourceRecord, error) { endpoint := "/domains/%s/rrsets/" + var rrsNew []resourceRecord + var bodyString, resp, err = c.get(fmt.Sprintf(endpoint, domain), "GET") + if resp.StatusCode == 400 && resp.Header.Get("Link") != "" { + //pagination required + links := c.convertLinks(resp.Header.Get("Link")) + endpoint = links["first"] + printer.Debugf("getRecords: initial endpoint %s\n", fmt.Sprintf(endpoint, domain)) + for endpoint != "" { + bodyString, resp, err = c.get(endpoint, "GET") + if err != nil { + if resp.StatusCode == 404 { + return rrsNew, nil + } + return rrsNew, fmt.Errorf("getRecords: failed fetching rrsets: %s", err) + } + tmp, err := generateRRSETfromResponse(bodyString) + if err != nil { + return rrsNew, fmt.Errorf("failed fetching records for domain %s (deSEC): %s", domain, err) + } + rrsNew = append(rrsNew, tmp...) + links = c.convertLinks(resp.Header.Get("Link")) + endpoint = links["next"] + printer.Debugf("getRecords: next endpoint %s\n", endpoint) + } + printer.Debugf("Build rrset using pagination (%d rrs)\n", len(rrsNew)) + return rrsNew, nil //domainIndex was build using pagination without errors + } + //no pagination + if err != nil { + return rrsNew, fmt.Errorf("failed fetching records for domain %s (deSEC): %s", domain, err) + } + tmp, err := generateRRSETfromResponse(bodyString) + if err != nil { + return rrsNew, err + } + rrsNew = append(rrsNew, tmp...) + printer.Debugf("Build rrset without pagination (%d rrs)\n", len(rrsNew)) + return rrsNew, nil +} + +//generateRRSETfromResponse takes the response rrset api calls and returns []resourceRecord +func generateRRSETfromResponse(bodyString []byte) ([]resourceRecord, error) { var rrs []rrResponse var rrsNew []resourceRecord - var bodyString, _, err = c.get(fmt.Sprintf(endpoint, domain), "GET") - if err != nil { - return rrsNew, fmt.Errorf("Failed fetching records for domain %s (deSEC): %s", domain, err) - } - err = json.Unmarshal(bodyString, &rrs) + err := json.Unmarshal(bodyString, &rrs) if err != nil { return rrsNew, err } @@ -126,7 +233,7 @@ func (c *desecProvider) createDomain(domain string) error { var resp []byte var err error if resp, err = c.post(endpoint, "POST", byt); err != nil { - return fmt.Errorf("Failed domain create (deSEC): %v", err) + return fmt.Errorf("failed domain create (deSEC): %v", err) } dm := domainObject{} err = json.Unmarshal(resp, &dm) @@ -143,7 +250,7 @@ func (c *desecProvider) upsertRR(rr []resourceRecord, domain string) error { endpoint := fmt.Sprintf("/domains/%s/rrsets/", domain) byt, _ := json.Marshal(rr) if _, err := c.post(endpoint, "PUT", byt); err != nil { - return fmt.Errorf("Failed create RRset (deSEC): %v", err) + return fmt.Errorf("failed create RRset (deSEC): %v", err) } return nil } @@ -151,16 +258,22 @@ func (c *desecProvider) upsertRR(rr []resourceRecord, domain string) error { func (c *desecProvider) deleteRR(domain, shortname, t string) error { endpoint := fmt.Sprintf("/domains/%s/rrsets/%s/%s/", domain, shortname, t) if _, _, err := c.get(endpoint, "DELETE"); err != nil { - return fmt.Errorf("Failed delete RRset (deSEC): %v", err) + return fmt.Errorf("failed delete RRset (deSEC): %v", err) } return nil } -func (c *desecProvider) get(endpoint, method string) ([]byte, int, error) { +func (c *desecProvider) get(target, method string) ([]byte, *http.Response, error) { retrycnt := 0 + var endpoint string + if strings.Contains(target, "http") { + endpoint = target + } else { + endpoint = apiBase + target + } retry: client := &http.Client{} - req, _ := http.NewRequest(method, apiBase+endpoint, nil) + req, _ := http.NewRequest(method, endpoint, nil) q := req.URL.Query() req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.creds.token)) @@ -168,7 +281,7 @@ retry: resp, err := client.Do(req) if err != nil { - return []byte{}, 0, err + return []byte{}, resp, err } bodyString, _ := ioutil.ReadAll(resp.Body) @@ -182,7 +295,7 @@ retry: wait, err := strconv.ParseInt(waitfor, 10, 64) if err == nil { if wait > 180 { - return []byte{}, 0, fmt.Errorf("rate limiting exceeded") + return []byte{}, resp, fmt.Errorf("rate limiting exceeded") } printer.Warnf("Rate limiting.. waiting for %s seconds", waitfor) time.Sleep(time.Duration(wait+1) * time.Second) @@ -197,24 +310,30 @@ retry: var nfieldErrors []nonFieldError err = json.Unmarshal(bodyString, &errResp) if err == nil { - return bodyString, resp.StatusCode, fmt.Errorf("%s", errResp.Detail) + return bodyString, resp, fmt.Errorf("%s", errResp.Detail) } err = json.Unmarshal(bodyString, &nfieldErrors) if err == nil && len(nfieldErrors) > 0 { if len(nfieldErrors[0].Errors) > 0 { - return bodyString, resp.StatusCode, fmt.Errorf("%s", nfieldErrors[0].Errors[0]) + return bodyString, resp, fmt.Errorf("%s", nfieldErrors[0].Errors[0]) } } - return bodyString, resp.StatusCode, fmt.Errorf("HTTP status %s Body: %s, the API does not provide more information", resp.Status, bodyString) + return bodyString, resp, fmt.Errorf("HTTP status %s Body: %s, the API does not provide more information", resp.Status, bodyString) } - return bodyString, resp.StatusCode, nil + return bodyString, resp, nil } -func (c *desecProvider) post(endpoint, method string, payload []byte) ([]byte, error) { +func (c *desecProvider) post(target, method string, payload []byte) ([]byte, error) { retrycnt := 0 + var endpoint string + if strings.Contains(target, "http") { + endpoint = target + } else { + endpoint = apiBase + target + } retry: client := &http.Client{} - req, err := http.NewRequest(method, apiBase+endpoint, bytes.NewReader(payload)) + req, err := http.NewRequest(method, endpoint, bytes.NewReader(payload)) if err != nil { return []byte{}, err }