1
0
mirror of https://github.com/StackExchange/dnscontrol.git synced 2024-05-11 05:55:12 +00:00

MSDNS: Improve PowerShell reliability (#2551)

This commit is contained in:
Tom Limoncelli
2023-09-07 10:30:18 -04:00
committed by GitHub
parent c9ce326ae1
commit a1c7a26351
2 changed files with 20 additions and 20 deletions

View File

@@ -19,12 +19,12 @@ type psHandle struct {
shell ps.Shell
}
// func eLog(s string) {
// f, _ := os.OpenFile("powershell.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
// f.WriteString(s)
// f.WriteString("\n")
// f.Close()
// }
func eLog(s string) {
f, _ := os.OpenFile("powershell.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
f.WriteString(s)
f.WriteString("\n")
f.Close()
}
func newPowerShell(config map[string]string) (*psHandle, error) {
@@ -70,7 +70,7 @@ func (psh *psHandle) Exit() {
type dnsZone map[string]interface{}
func (psh *psHandle) GetDNSServerZoneAll(dnsserver string) ([]string, error) {
stdout, stderr, err := psh.shell.Execute(generatePSZoneAll(dnsserver))
stdout, stderr, err := psh.shell.Execute("\n\r" + generatePSZoneAll(dnsserver) + "\n\r")
if err != nil {
return nil, err
}
@@ -115,7 +115,8 @@ func (psh *psHandle) GetDNSZoneRecords(dnsserver, domain string) ([]nativeRecord
filename := tmpfile.Name()
tmpfile.Close()
stdout, stderr, err := psh.shell.Execute(generatePSZoneDump(dnsserver, domain, filename))
stdout, stderr, err := psh.shell.Execute(
"\n\r" + generatePSZoneDump(dnsserver, domain, filename) + "\n\r")
if err != nil {
return nil, err
}
@@ -207,8 +208,8 @@ func (psh *psHandle) RecordDelete(dnsserver, domain string, rec *models.RecordCo
c = generatePSDelete(dnsserver, domain, rec)
}
//eLog(c)
_, stderr, err := psh.shell.Execute(c)
eLog(c)
_, stderr, err := psh.shell.Execute("\n\r" + c + "\n\r")
if err != nil {
printer.Printf("PowerShell code was:\nSTART\n%s\nEND\n", c)
return err
@@ -266,8 +267,8 @@ func (psh *psHandle) RecordCreate(dnsserver, domain string, rec *models.RecordCo
//printer.Printf("DEBUG: PScreate\n")
}
//eLog(c)
stdout, stderr, err := psh.shell.Execute(c)
eLog(c)
stdout, stderr, err := psh.shell.Execute("\n\r" + c + "\n\r")
if err != nil {
printer.Printf("PowerShell code was:\nSTART\n%s\nEND\n", c)
return err
@@ -350,8 +351,8 @@ func generatePSCreate(dnsserver, domain string, rec *models.RecordConfig) string
func (psh *psHandle) RecordModify(dnsserver, domain string, old, rec *models.RecordConfig) error {
c := generatePSModify(dnsserver, domain, old, rec)
//eLog(c)
_, stderr, err := psh.shell.Execute(c)
eLog(c)
_, stderr, err := psh.shell.Execute("\n\r" + c + "\n\r")
if err != nil {
printer.Printf("PowerShell code was:\nSTART\n%s\nEND\n", c)
return err
@@ -363,10 +364,9 @@ func (psh *psHandle) RecordModify(dnsserver, domain string, old, rec *models.Rec
}
return nil
}
func generatePSModify(dnsserver, domain string, old, rec *models.RecordConfig) string {
// The simple way is to just remove the old record and insert the new record.
return generatePSDelete(dnsserver, domain, old) + ` ; ` + generatePSCreate(dnsserver, domain, rec)
return "\n\r" + generatePSDelete(dnsserver, domain, old) + " ; " + generatePSCreate(dnsserver, domain, rec) + "\n\r"
// NB: SOA records can't be deleted. When we implement them, we'll
// need to special case them and generate an in-place modification
// command.
@@ -374,8 +374,8 @@ func generatePSModify(dnsserver, domain string, old, rec *models.RecordConfig) s
func (psh *psHandle) RecordModifyTTL(dnsserver, domain string, old *models.RecordConfig, newTTL uint32) error {
c := generatePSModifyTTL(dnsserver, domain, old, newTTL)
//eLog(c)
_, stderr, err := psh.shell.Execute(c)
eLog(c)
_, stderr, err := psh.shell.Execute("\n\r" + c + "\n\r")
if err != nil {
printer.Printf("PowerShell code was:\nSTART\n%s\nEND\n", c)
return err

View File

@@ -29,7 +29,7 @@ func Test_generatePSZoneAll(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := generatePSZoneAll(tt.args.dnsserver); got != tt.want {
if got := generatePSZoneAll(tt.args.dnsserver); got != strings.TrimSpace(tt.want) {
t.Errorf("generatePSZoneAll() = got=(\n%s\n) want=(\n%s\n)", got, tt.want)
}
})
@@ -59,7 +59,7 @@ func Test_generatePSZoneDump(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := generatePSZoneDump(tt.args.dnsserver, tt.args.domainname, "foo"); got != tt.want {
if got := generatePSZoneDump(tt.args.dnsserver, tt.args.domainname, "foo"); got != strings.TrimSpace(tt.want) {
t.Errorf("generatePSZoneDump() = got=(\n%s\n) want=(\n%s\n)", got, tt.want)
}
})