diff --git a/netbox/dcim/views.py b/netbox/dcim/views.py index b3b99d804..12f7a5046 100644 --- a/netbox/dcim/views.py +++ b/netbox/dcim/views.py @@ -1904,23 +1904,15 @@ class CableTraceView(ObjectView): }) -class CableCreateView(PermissionRequiredMixin, GetReturnURLMixin, View): - permission_required = 'dcim.add_cable' +class CableCreateView(ObjectEditView): + queryset = Cable.objects.all() template_name = 'dcim/cable_connect.html' + default_return_url = 'dcim:cable_list' def dispatch(self, request, *args, **kwargs): - termination_a_type = kwargs.get('termination_a_type') - termination_a_id = kwargs.get('termination_a_id') - - termination_b_type_name = kwargs.get('termination_b_type') - self.termination_b_type = ContentType.objects.get(model=termination_b_type_name.replace('-', '')) - - self.obj = Cable( - termination_a=termination_a_type.objects.get(pk=termination_a_id), - termination_b_type=self.termination_b_type - ) - self.form_class = { + # Set the model_form class based on the type of component being connected + self.model_form = { 'console-port': forms.ConnectCableToConsolePortForm, 'console-server-port': forms.ConnectCableToConsoleServerPortForm, 'power-port': forms.ConnectCableToPowerPortForm, @@ -1930,59 +1922,42 @@ class CableCreateView(PermissionRequiredMixin, GetReturnURLMixin, View): 'rear-port': forms.ConnectCableToRearPortForm, 'power-feed': forms.ConnectCableToPowerFeedForm, 'circuit-termination': forms.ConnectCableToCircuitTerminationForm, - }[termination_b_type_name] + }[kwargs.get('termination_b_type')] return super().dispatch(request, *args, **kwargs) + def alter_obj(self, obj, request, url_args, url_kwargs): + termination_a_type = url_kwargs.get('termination_a_type') + termination_a_id = url_kwargs.get('termination_a_id') + termination_b_type_name = url_kwargs.get('termination_b_type') + self.termination_b_type = ContentType.objects.get(model=termination_b_type_name.replace('-', '')) + + # Initialize Cable termination attributes + obj.termination_a = termination_a_type.objects.get(pk=termination_a_id) + obj.termination_b_type = self.termination_b_type + + return obj + def get(self, request, *args, **kwargs): + obj = self.alter_obj(self.get_object(kwargs), request, args, kwargs) # Parse initial data manually to avoid setting field values as lists initial_data = {k: request.GET[k] for k in request.GET} # Set initial site and rack based on side A termination (if not already set) if 'termination_b_site' not in initial_data: - initial_data['termination_b_site'] = getattr(self.obj.termination_a.parent, 'site', None) + initial_data['termination_b_site'] = getattr(obj.termination_a.parent, 'site', None) if 'termination_b_rack' not in initial_data: - initial_data['termination_b_rack'] = getattr(self.obj.termination_a.parent, 'rack', None) + initial_data['termination_b_rack'] = getattr(obj.termination_a.parent, 'rack', None) - form = self.form_class(instance=self.obj, initial=initial_data) + form = self.model_form(instance=obj, initial=initial_data) return render(request, self.template_name, { - 'obj': self.obj, + 'obj': obj, 'obj_type': Cable._meta.verbose_name, 'termination_b_type': self.termination_b_type.name, 'form': form, - 'return_url': self.get_return_url(request, self.obj), - }) - - def post(self, request, *args, **kwargs): - - form = self.form_class(request.POST, request.FILES, instance=self.obj) - - if form.is_valid(): - obj = form.save() - - msg = 'Created cable {}'.format( - obj.get_absolute_url(), - escape(obj) - ) - messages.success(request, mark_safe(msg)) - - if '_addanother' in request.POST: - return redirect(request.get_full_path()) - - return_url = form.cleaned_data.get('return_url') - if return_url is not None and is_safe_url(url=return_url, allowed_hosts=request.get_host()): - return redirect(return_url) - else: - return redirect(self.get_return_url(request, obj)) - - return render(request, self.template_name, { - 'obj': self.obj, - 'obj_type': Cable._meta.verbose_name, - 'termination_b_type': self.termination_b_type.name, - 'form': form, - 'return_url': self.get_return_url(request, self.obj), + 'return_url': self.get_return_url(request, obj), }) diff --git a/netbox/utilities/views.py b/netbox/utilities/views.py index e448f2934..9271e1c64 100644 --- a/netbox/utilities/views.py +++ b/netbox/utilities/views.py @@ -346,7 +346,7 @@ class ObjectEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): form = self.model_form( data=request.POST, files=request.FILES, - instance=self.alter_obj(self.get_object(kwargs), request, args, kwargs) + instance=obj ) if form.is_valid():