diff --git a/docs/release-notes/version-3.0.md b/docs/release-notes/version-3.0.md index 476c185ae..7de296936 100644 --- a/docs/release-notes/version-3.0.md +++ b/docs/release-notes/version-3.0.md @@ -5,6 +5,7 @@ ### Bug Fixes * [#7612](https://github.com/netbox-community/netbox/issues/7612) - Strip HTML from custom field descriptions +* [#7628](https://github.com/netbox-community/netbox/issues/7628) - Fix `load_yaml` method for custom scripts --- diff --git a/netbox/extras/scripts.py b/netbox/extras/scripts.py index fd84747b9..9c46278ae 100644 --- a/netbox/extras/scripts.py +++ b/netbox/extras/scripts.py @@ -4,7 +4,6 @@ import logging import os import pkgutil import traceback -import warnings from collections import OrderedDict import yaml @@ -345,9 +344,14 @@ class BaseScript: """ Return data from a YAML file """ + try: + from yaml import CLoader as Loader + except ImportError: + from yaml import Loader + file_path = os.path.join(settings.SCRIPTS_ROOT, filename) with open(file_path, 'r') as datafile: - data = yaml.load(datafile) + data = yaml.load(datafile, Loader=Loader) return data diff --git a/netbox/extras/tests/test_scripts.py b/netbox/extras/tests/test_scripts.py index 4518548d3..64971f1dc 100644 --- a/netbox/extras/tests/test_scripts.py +++ b/netbox/extras/tests/test_scripts.py @@ -1,3 +1,5 @@ +import tempfile + from django.core.files.uploadedfile import SimpleUploadedFile from django.test import TestCase from netaddr import IPAddress, IPNetwork @@ -11,6 +13,50 @@ CHOICES = ( ('0000ff', 'Blue') ) +YAML_DATA = """ +Foo: 123 +Bar: 456 +Baz: + - A + - B + - C +""" + +JSON_DATA = """ +{ + "Foo": 123, + "Bar": 456, + "Baz": ["A", "B", "C"] +} +""" + + +class ScriptTest(TestCase): + + def test_load_yaml(self): + datafile = tempfile.NamedTemporaryFile() + datafile.write(bytes(YAML_DATA, 'UTF-8')) + datafile.seek(0) + + data = Script().load_yaml(datafile.name) + self.assertEqual(data, { + 'Foo': 123, + 'Bar': 456, + 'Baz': ['A', 'B', 'C'], + }) + + def test_load_json(self): + datafile = tempfile.NamedTemporaryFile() + datafile.write(bytes(JSON_DATA, 'UTF-8')) + datafile.seek(0) + + data = Script().load_json(datafile.name) + self.assertEqual(data, { + 'Foo': 123, + 'Bar': 456, + 'Baz': ['A', 'B', 'C'], + }) + class ScriptVariablesTest(TestCase):