1
0
mirror of https://github.com/StackExchange/dnscontrol.git synced 2024-05-11 05:55:12 +00:00

MSDNS: Improve PowerShell reliability (#2551)

This commit is contained in:
Tom Limoncelli
2023-09-07 10:30:18 -04:00
committed by GitHub
parent c9ce326ae1
commit a1c7a26351
2 changed files with 20 additions and 20 deletions

View File

@@ -19,12 +19,12 @@ type psHandle struct {
shell ps.Shell shell ps.Shell
} }
// func eLog(s string) { func eLog(s string) {
// f, _ := os.OpenFile("powershell.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) f, _ := os.OpenFile("powershell.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
// f.WriteString(s) f.WriteString(s)
// f.WriteString("\n") f.WriteString("\n")
// f.Close() f.Close()
// } }
func newPowerShell(config map[string]string) (*psHandle, error) { func newPowerShell(config map[string]string) (*psHandle, error) {
@@ -70,7 +70,7 @@ func (psh *psHandle) Exit() {
type dnsZone map[string]interface{} type dnsZone map[string]interface{}
func (psh *psHandle) GetDNSServerZoneAll(dnsserver string) ([]string, error) { func (psh *psHandle) GetDNSServerZoneAll(dnsserver string) ([]string, error) {
stdout, stderr, err := psh.shell.Execute(generatePSZoneAll(dnsserver)) stdout, stderr, err := psh.shell.Execute("\n\r" + generatePSZoneAll(dnsserver) + "\n\r")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -115,7 +115,8 @@ func (psh *psHandle) GetDNSZoneRecords(dnsserver, domain string) ([]nativeRecord
filename := tmpfile.Name() filename := tmpfile.Name()
tmpfile.Close() tmpfile.Close()
stdout, stderr, err := psh.shell.Execute(generatePSZoneDump(dnsserver, domain, filename)) stdout, stderr, err := psh.shell.Execute(
"\n\r" + generatePSZoneDump(dnsserver, domain, filename) + "\n\r")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -207,8 +208,8 @@ func (psh *psHandle) RecordDelete(dnsserver, domain string, rec *models.RecordCo
c = generatePSDelete(dnsserver, domain, rec) c = generatePSDelete(dnsserver, domain, rec)
} }
//eLog(c) eLog(c)
_, stderr, err := psh.shell.Execute(c) _, stderr, err := psh.shell.Execute("\n\r" + c + "\n\r")
if err != nil { if err != nil {
printer.Printf("PowerShell code was:\nSTART\n%s\nEND\n", c) printer.Printf("PowerShell code was:\nSTART\n%s\nEND\n", c)
return err return err
@@ -266,8 +267,8 @@ func (psh *psHandle) RecordCreate(dnsserver, domain string, rec *models.RecordCo
//printer.Printf("DEBUG: PScreate\n") //printer.Printf("DEBUG: PScreate\n")
} }
//eLog(c) eLog(c)
stdout, stderr, err := psh.shell.Execute(c) stdout, stderr, err := psh.shell.Execute("\n\r" + c + "\n\r")
if err != nil { if err != nil {
printer.Printf("PowerShell code was:\nSTART\n%s\nEND\n", c) printer.Printf("PowerShell code was:\nSTART\n%s\nEND\n", c)
return err return err
@@ -350,8 +351,8 @@ func generatePSCreate(dnsserver, domain string, rec *models.RecordConfig) string
func (psh *psHandle) RecordModify(dnsserver, domain string, old, rec *models.RecordConfig) error { func (psh *psHandle) RecordModify(dnsserver, domain string, old, rec *models.RecordConfig) error {
c := generatePSModify(dnsserver, domain, old, rec) c := generatePSModify(dnsserver, domain, old, rec)
//eLog(c) eLog(c)
_, stderr, err := psh.shell.Execute(c) _, stderr, err := psh.shell.Execute("\n\r" + c + "\n\r")
if err != nil { if err != nil {
printer.Printf("PowerShell code was:\nSTART\n%s\nEND\n", c) printer.Printf("PowerShell code was:\nSTART\n%s\nEND\n", c)
return err return err
@@ -363,10 +364,9 @@ func (psh *psHandle) RecordModify(dnsserver, domain string, old, rec *models.Rec
} }
return nil return nil
} }
func generatePSModify(dnsserver, domain string, old, rec *models.RecordConfig) string { func generatePSModify(dnsserver, domain string, old, rec *models.RecordConfig) string {
// The simple way is to just remove the old record and insert the new record. // The simple way is to just remove the old record and insert the new record.
return generatePSDelete(dnsserver, domain, old) + ` ; ` + generatePSCreate(dnsserver, domain, rec) return "\n\r" + generatePSDelete(dnsserver, domain, old) + " ; " + generatePSCreate(dnsserver, domain, rec) + "\n\r"
// NB: SOA records can't be deleted. When we implement them, we'll // NB: SOA records can't be deleted. When we implement them, we'll
// need to special case them and generate an in-place modification // need to special case them and generate an in-place modification
// command. // command.
@@ -374,8 +374,8 @@ func generatePSModify(dnsserver, domain string, old, rec *models.RecordConfig) s
func (psh *psHandle) RecordModifyTTL(dnsserver, domain string, old *models.RecordConfig, newTTL uint32) error { func (psh *psHandle) RecordModifyTTL(dnsserver, domain string, old *models.RecordConfig, newTTL uint32) error {
c := generatePSModifyTTL(dnsserver, domain, old, newTTL) c := generatePSModifyTTL(dnsserver, domain, old, newTTL)
//eLog(c) eLog(c)
_, stderr, err := psh.shell.Execute(c) _, stderr, err := psh.shell.Execute("\n\r" + c + "\n\r")
if err != nil { if err != nil {
printer.Printf("PowerShell code was:\nSTART\n%s\nEND\n", c) printer.Printf("PowerShell code was:\nSTART\n%s\nEND\n", c)
return err return err

View File

@@ -29,7 +29,7 @@ func Test_generatePSZoneAll(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := generatePSZoneAll(tt.args.dnsserver); got != tt.want { if got := generatePSZoneAll(tt.args.dnsserver); got != strings.TrimSpace(tt.want) {
t.Errorf("generatePSZoneAll() = got=(\n%s\n) want=(\n%s\n)", got, tt.want) t.Errorf("generatePSZoneAll() = got=(\n%s\n) want=(\n%s\n)", got, tt.want)
} }
}) })
@@ -59,7 +59,7 @@ func Test_generatePSZoneDump(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := generatePSZoneDump(tt.args.dnsserver, tt.args.domainname, "foo"); got != tt.want { if got := generatePSZoneDump(tt.args.dnsserver, tt.args.domainname, "foo"); got != strings.TrimSpace(tt.want) {
t.Errorf("generatePSZoneDump() = got=(\n%s\n) want=(\n%s\n)", got, tt.want) t.Errorf("generatePSZoneDump() = got=(\n%s\n) want=(\n%s\n)", got, tt.want)
} }
}) })