mirror of
https://github.com/peeringdb/peeringdb.git
synced 2024-05-11 05:55:09 +00:00
Support 202203 fixes (#1148)
* remove survey notifications * fixing old reference of IXF_IMPORTER_DAYS_UNTIL_TICKET through EnvironmentSettings, this setting is no longer controlled through that and should come straight from settings * fix session auth not setting x-auth-id header (#1120) fix basic auth not setting x-auth-id header on success (#1120) fix api key auth only setting prefix in x-auth-id header (#1120) fix x-auth-id header not being cleared between requests (#1120) * fix issue with rest throttling breaking api-cache generation (#1146) * add caching for get_permission_holder_from_request - fixes perfomance issues in #1147 * fix intermediate issue with api_cache rest throttle tests * sanitize cache key names for state normalization (#1079) each state normalization lookup moved into its own transaction so errors dont cause us to lose already obtained data (#1079) write cache regardess of --commit on or off (#1079) add a sanity check for running non-committal mode without --limit (#1079) * fix issue with ip block rate limiting if x-forwarded-for is set (#1126) * better handling of melissa timeouts through retrying (#1079) fix state normalization cache timeout to have no expiry (#1079) normalization command will display validation errors at the end and exit with a return code if there are any (#1079) * automatically apply address field normalization for `state` (#1079) * additional tests * only do a sanity check for --limit if no specific object is targeted * linting Co-authored-by: Stefan Pratter <stefan@20c.com>
This commit is contained in:
@@ -2,9 +2,12 @@
|
|||||||
Utilities for geocoding and geo normalization.
|
Utilities for geocoding and geo normalization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
import googlemaps
|
import googlemaps
|
||||||
import requests
|
import requests
|
||||||
import structlog
|
import structlog
|
||||||
|
import unidecode
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
@@ -268,15 +271,25 @@ class Melissa:
|
|||||||
This will use django-cache if it exists
|
This will use django-cache if it exists
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key = f"geo.normalize.state.{country_code}.{state}"
|
if not state:
|
||||||
|
return state
|
||||||
|
|
||||||
|
# clean up state value for cache key
|
||||||
|
|
||||||
|
state_clean = unidecode.unidecode(state.lower()).strip()
|
||||||
|
state_clean = re.sub(r"[^a-zA-Z]+", "", state_clean)
|
||||||
|
|
||||||
|
key = f"geo.normalize.state.{country_code}.{state_clean}"
|
||||||
|
|
||||||
value = cache.get(key)
|
value = cache.get(key)
|
||||||
if value is None:
|
if value is None:
|
||||||
|
|
||||||
result = self.global_address(country=country_code, address1=state)
|
result = self.global_address(country=country_code, address1=state)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
record = result["Records"][0]
|
record = result["Records"][0]
|
||||||
value = record.get("AdministrativeArea") or state
|
value = record.get("AdministrativeArea") or state
|
||||||
except (KeyError, IndexError):
|
except (KeyError, IndexError):
|
||||||
value = state
|
value = state
|
||||||
cache.set(key, value)
|
cache.set(key, value, timeout=None)
|
||||||
return value
|
return value
|
||||||
|
@@ -4,16 +4,18 @@ Normalize existing address fields based on Google Maps API response.
|
|||||||
import csv
|
import csv
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
import reversion
|
import reversion
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ValidationError
|
||||||
from django.core.management.base import BaseCommand
|
from django.core.management.base import BaseCommand, CommandError
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
|
|
||||||
from peeringdb_server import models
|
from peeringdb_server import models
|
||||||
from peeringdb_server.geo import Melissa
|
from peeringdb_server.geo import Melissa, Timeout
|
||||||
from peeringdb_server.serializers import AddressSerializer
|
from peeringdb_server.serializers import AddressSerializer
|
||||||
|
|
||||||
API_KEY = settings.MELISSA_KEY
|
API_KEY = settings.MELISSA_KEY
|
||||||
@@ -82,6 +84,7 @@ class Command(BaseCommand):
|
|||||||
self.state_only = options.get("state_only", False)
|
self.state_only = options.get("state_only", False)
|
||||||
self.pprint = options.get("pprint", False)
|
self.pprint = options.get("pprint", False)
|
||||||
self.csv_file = options.get("csv")
|
self.csv_file = options.get("csv")
|
||||||
|
self.validation_errors = {}
|
||||||
|
|
||||||
self.melissa = Melissa(API_KEY)
|
self.melissa = Melissa(API_KEY)
|
||||||
|
|
||||||
@@ -96,6 +99,9 @@ class Command(BaseCommand):
|
|||||||
else:
|
else:
|
||||||
_id = 0
|
_id = 0
|
||||||
|
|
||||||
|
if not limit and _id == 0 and not self.commit:
|
||||||
|
raise CommandError("Cannot run in pretend mode without a --limit supplied")
|
||||||
|
|
||||||
output_list = self.normalize(reftag, _id, limit=limit)
|
output_list = self.normalize(reftag, _id, limit=limit)
|
||||||
|
|
||||||
if self.csv_file:
|
if self.csv_file:
|
||||||
@@ -108,6 +114,12 @@ class Command(BaseCommand):
|
|||||||
pprint(entry)
|
pprint(entry)
|
||||||
self.log("\n")
|
self.log("\n")
|
||||||
|
|
||||||
|
if self.validation_errors:
|
||||||
|
self.log("Some objects had validation errors:")
|
||||||
|
for entity_id, err in self.validation_errors.items():
|
||||||
|
self.log(f"Object #{entity_id}: {err}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
def parse_suite(self, instance):
|
def parse_suite(self, instance):
|
||||||
|
|
||||||
# Case: "Suite 1" or "Suite B"
|
# Case: "Suite 1" or "Suite B"
|
||||||
@@ -166,10 +178,9 @@ class Command(BaseCommand):
|
|||||||
dict_writer.writeheader()
|
dict_writer.writeheader()
|
||||||
dict_writer.writerows(output_list)
|
dict_writer.writerows(output_list)
|
||||||
|
|
||||||
@reversion.create_revision()
|
|
||||||
@transaction.atomic()
|
|
||||||
def normalize(self, reftag, _id, limit=0):
|
def normalize(self, reftag, _id, limit=0):
|
||||||
model = models.REFTAG_MAP.get(reftag)
|
model = models.REFTAG_MAP.get(reftag)
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
raise ValueError(f"Unknown reftag: {reftag}")
|
raise ValueError(f"Unknown reftag: {reftag}")
|
||||||
if not hasattr(model, "geocode_status"):
|
if not hasattr(model, "geocode_status"):
|
||||||
@@ -200,18 +211,32 @@ class Command(BaseCommand):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
proceed_to_next = False
|
||||||
if self.state_only:
|
|
||||||
self._normalize_state(entity, output_dict, self.commit)
|
while not proceed_to_next:
|
||||||
else:
|
|
||||||
self._normalize(entity, output_dict, self.commit)
|
try:
|
||||||
|
if self.state_only:
|
||||||
|
self._normalize_state(entity, output_dict, self.commit)
|
||||||
|
else:
|
||||||
|
self._normalize(entity, output_dict, self.commit)
|
||||||
|
|
||||||
|
proceed_to_next = True
|
||||||
|
|
||||||
|
except ValidationError as exc:
|
||||||
|
self.log(f"Validation error: {exc}")
|
||||||
|
self.validation_errors[entity.id] = exc
|
||||||
|
proceed_to_next = True
|
||||||
|
except Timeout:
|
||||||
|
self.log("Request has timed out, retrying ...")
|
||||||
|
time.sleep(1.0)
|
||||||
|
|
||||||
except ValidationError as exc:
|
|
||||||
self.log(str(exc))
|
|
||||||
output_list.append(output_dict)
|
output_list.append(output_dict)
|
||||||
|
|
||||||
return output_list
|
return output_list
|
||||||
|
|
||||||
|
@reversion.create_revision()
|
||||||
|
@transaction.atomic()
|
||||||
def _normalize(self, instance, output_dict, save):
|
def _normalize(self, instance, output_dict, save):
|
||||||
|
|
||||||
suite = self.parse_suite(instance)
|
suite = self.parse_suite(instance)
|
||||||
|
@@ -14,6 +14,8 @@ from peeringdb_server.context import current_request
|
|||||||
from peeringdb_server.models import OrganizationAPIKey, UserAPIKey
|
from peeringdb_server.models import OrganizationAPIKey, UserAPIKey
|
||||||
from peeringdb_server.permissions import get_key_from_request
|
from peeringdb_server.permissions import get_key_from_request
|
||||||
|
|
||||||
|
ERR_MULTI_AUTH = "Cannot authenticate through Authorization header while logged in. Please log out and try again."
|
||||||
|
|
||||||
|
|
||||||
class CurrentRequestContext:
|
class CurrentRequestContext:
|
||||||
|
|
||||||
@@ -70,8 +72,6 @@ class PDBPermissionMiddleware(MiddlewareMixin):
|
|||||||
to access the requested resource.
|
to access the requested resource.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
auth_id = None
|
|
||||||
|
|
||||||
def get_username_and_password(self, http_auth):
|
def get_username_and_password(self, http_auth):
|
||||||
"""
|
"""
|
||||||
Get the username and password from the HTTP auth header.
|
Get the username and password from the HTTP auth header.
|
||||||
@@ -103,15 +103,43 @@ class PDBPermissionMiddleware(MiddlewareMixin):
|
|||||||
req_key = get_key_from_request(request)
|
req_key = get_key_from_request(request)
|
||||||
api_key = None
|
api_key = None
|
||||||
|
|
||||||
|
# session auth already exists, set x-auth-id value and return
|
||||||
|
|
||||||
|
if request.user.is_authenticated:
|
||||||
|
request.auth_id = request.user.username
|
||||||
|
|
||||||
|
# request attempting to provide separate authentication while
|
||||||
|
# already authenticated through session cookie, fail with
|
||||||
|
# bad request
|
||||||
|
|
||||||
|
if req_key or http_auth:
|
||||||
|
return self.response_unauthorized(
|
||||||
|
request,
|
||||||
|
message=ERR_MULTI_AUTH,
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
# Check if HTTP auth is valid and if the request is made with basic auth.
|
# Check if HTTP auth is valid and if the request is made with basic auth.
|
||||||
|
|
||||||
if http_auth and http_auth.startswith("Basic "):
|
if http_auth and http_auth.startswith("Basic "):
|
||||||
|
|
||||||
# Get the username and password from the HTTP auth header.
|
# Get the username and password from the HTTP auth header.
|
||||||
username, password = self.get_username_and_password(http_auth)
|
username, password = self.get_username_and_password(http_auth)
|
||||||
# Check if the username and password are valid.
|
# Check if the username and password are valid.
|
||||||
user = authenticate(username=username, password=password)
|
user = authenticate(username=username, password=password)
|
||||||
|
|
||||||
|
# return username input in x-auth-id header
|
||||||
|
request.auth_id = username
|
||||||
|
|
||||||
# if user is not authenticated return 401 Unauthorized
|
# if user is not authenticated return 401 Unauthorized
|
||||||
if not user:
|
if not user:
|
||||||
self.auth_id = username
|
|
||||||
|
# truncate the username if needed.
|
||||||
|
if len(username) > 255:
|
||||||
|
request.auth_id = username[:255]
|
||||||
|
|
||||||
return self.response_unauthorized(
|
return self.response_unauthorized(
|
||||||
request, message="Invalid username or password", status=401
|
request, message="Invalid username or password", status=401
|
||||||
)
|
)
|
||||||
@@ -132,16 +160,16 @@ class PDBPermissionMiddleware(MiddlewareMixin):
|
|||||||
|
|
||||||
# If api key is not valid return 401 Unauthorized
|
# If api key is not valid return 401 Unauthorized
|
||||||
if not api_key:
|
if not api_key:
|
||||||
self.auth_id = "apikey_%s" % (req_key)
|
|
||||||
if len(req_key) > 16:
|
if len(req_key) > 16:
|
||||||
self.auth_id = self.auth_id[:16]
|
req_key = req_key[:16]
|
||||||
|
request.auth_id = f"apikey_{req_key}"
|
||||||
return self.response_unauthorized(
|
return self.response_unauthorized(
|
||||||
request, message="Invalid API key", status=401
|
request, message="Invalid API key", status=401
|
||||||
)
|
)
|
||||||
|
|
||||||
# If API key is provided, check if the user has an active session
|
# If API key is provided, check if the user has an active session
|
||||||
if api_key:
|
if api_key:
|
||||||
self.auth_id = "apikey_%s" % req_key
|
request.auth_id = f"apikey_{api_key.prefix}"
|
||||||
if request.session.get("_auth_user_id") and request.user.id:
|
if request.session.get("_auth_user_id") and request.user.id:
|
||||||
if int(request.user.id) == int(
|
if int(request.user.id) == int(
|
||||||
request.session.get("_auth_user_id")
|
request.session.get("_auth_user_id")
|
||||||
@@ -149,19 +177,19 @@ class PDBPermissionMiddleware(MiddlewareMixin):
|
|||||||
|
|
||||||
return self.response_unauthorized(
|
return self.response_unauthorized(
|
||||||
request,
|
request,
|
||||||
message="Cannot authenticate through Authorization header while logged in. Please log out and try again.",
|
message=ERR_MULTI_AUTH,
|
||||||
status=400,
|
status=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_response(self, request, response):
|
def process_response(self, request, response):
|
||||||
|
|
||||||
if self.auth_id:
|
if hasattr(request, "auth_id"):
|
||||||
# Sanitizes the auth_id
|
# Sanitizes the auth_id
|
||||||
self.auth_id = self.auth_id.replace(" ", "_")
|
request.auth_id = request.auth_id.replace(" ", "_")
|
||||||
# If auth_id ends with a 401 make sure is it limited to 16 bytes
|
# If auth_id ends with a 401 make sure is it limited to 16 bytes
|
||||||
if response.status_code == 401 and len(self.auth_id) > 16:
|
if response.status_code == 401 and len(request.auth_id) > 16:
|
||||||
if not self.auth_id.startswith("apikey_"):
|
if not request.auth_id.startswith("apikey_"):
|
||||||
self.auth_id = self.auth_id[:16]
|
request.auth_id = request.auth_id[:16]
|
||||||
|
|
||||||
response["X-Auth-ID"] = self.auth_id
|
response["X-Auth-ID"] = request.auth_id
|
||||||
return response
|
return response
|
||||||
|
@@ -56,10 +56,15 @@ def get_permission_holder_from_request(request):
|
|||||||
"""Return either an API Key instance or User instance
|
"""Return either an API Key instance or User instance
|
||||||
depending on how the request is Authenticated.
|
depending on how the request is Authenticated.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if hasattr(request, "_permission_holder"):
|
||||||
|
return request._permission_holder
|
||||||
|
|
||||||
key = get_key_from_request(request)
|
key = get_key_from_request(request)
|
||||||
if key is not None:
|
if key is not None:
|
||||||
try:
|
try:
|
||||||
api_key = OrganizationAPIKey.objects.get_from_key(key)
|
api_key = OrganizationAPIKey.objects.get_from_key(key)
|
||||||
|
request._permission_holder = api_key
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
except OrganizationAPIKey.DoesNotExist:
|
except OrganizationAPIKey.DoesNotExist:
|
||||||
@@ -67,15 +72,19 @@ def get_permission_holder_from_request(request):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
api_key = UserAPIKey.objects.get_from_key(key)
|
api_key = UserAPIKey.objects.get_from_key(key)
|
||||||
|
request._permission_holder = api_key
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
except UserAPIKey.DoesNotExist:
|
except UserAPIKey.DoesNotExist:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if hasattr(request, "user"):
|
if hasattr(request, "user"):
|
||||||
|
request._permission_holder = request.user
|
||||||
return request.user
|
return request.user
|
||||||
|
|
||||||
return AnonymousUser()
|
anon = AnonymousUser()
|
||||||
|
request._permission_holder = anon
|
||||||
|
return anon
|
||||||
|
|
||||||
|
|
||||||
def get_user_from_request(request):
|
def get_user_from_request(request):
|
||||||
|
@@ -89,6 +89,10 @@ class TargetedRateThrottle(throttling.SimpleRateThrottle):
|
|||||||
|
|
||||||
ip_address = self.get_ident(request)
|
ip_address = self.get_ident(request)
|
||||||
|
|
||||||
|
# handle XFF
|
||||||
|
|
||||||
|
ip_address = ip_address.split(",")[0].strip()
|
||||||
|
|
||||||
if self.check_ip(request):
|
if self.check_ip(request):
|
||||||
self.ident = ip_address
|
self.ident = ip_address
|
||||||
self.ident = f"{ident_prefix}{self.ident}"
|
self.ident = f"{ident_prefix}{self.ident}"
|
||||||
@@ -135,6 +139,11 @@ class TargetedRateThrottle(throttling.SimpleRateThrottle):
|
|||||||
|
|
||||||
def allow_request(self, request, view):
|
def allow_request(self, request, view):
|
||||||
|
|
||||||
|
# skip rate throttling for the api-cache generate process
|
||||||
|
|
||||||
|
if getattr(settings, "GENERATING_API_CACHE", False):
|
||||||
|
return True
|
||||||
|
|
||||||
self.is_authenticated(request)
|
self.is_authenticated(request)
|
||||||
|
|
||||||
ident_prefix = self.ident_prefix(request)
|
ident_prefix = self.ident_prefix(request)
|
||||||
|
@@ -250,7 +250,7 @@ class GeocodeSerializerMixin:
|
|||||||
if not suggested_address:
|
if not suggested_address:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
for key in ["address1", "city", "state", "zipcode"]:
|
for key in ["address1", "city", "zipcode"]:
|
||||||
suggested_val = suggested_address.get(key, None)
|
suggested_val = suggested_address.get(key, None)
|
||||||
instance_val = getattr(instance, key, None)
|
instance_val = getattr(instance, key, None)
|
||||||
if instance_val != suggested_val:
|
if instance_val != suggested_val:
|
||||||
@@ -278,6 +278,12 @@ class GeocodeSerializerMixin:
|
|||||||
try:
|
try:
|
||||||
suggested_address = instance.process_geo_location()
|
suggested_address = instance.process_geo_location()
|
||||||
|
|
||||||
|
# normalize state if needed
|
||||||
|
if suggested_address.get("state") != instance.state:
|
||||||
|
instance.state = suggested_address.get("state")
|
||||||
|
instance.save()
|
||||||
|
|
||||||
|
# provide other normalization options as suggestion to the user
|
||||||
if self.needs_address_suggestion(suggested_address, instance):
|
if self.needs_address_suggestion(suggested_address, instance):
|
||||||
self._add_meta_information(
|
self._add_meta_information(
|
||||||
{
|
{
|
||||||
|
@@ -122,3 +122,49 @@ class APICacheTests(TestCase, api_test.TestJSON, api_test.Command):
|
|||||||
settings.API_CACHE_ALL_LIMITS = False
|
settings.API_CACHE_ALL_LIMITS = False
|
||||||
settings.API_CACHE_ENABLED = False
|
settings.API_CACHE_ENABLED = False
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_no_api_throttle():
|
||||||
|
guest_group = Group.objects.create(name="guest")
|
||||||
|
user_group = Group.objects.create(name="user")
|
||||||
|
reset_group_ids()
|
||||||
|
|
||||||
|
models.EnvironmentSetting.objects.create(
|
||||||
|
setting="API_THROTTLE_RESPONSE_SIZE_ENABLED_IP", value_bool=True
|
||||||
|
)
|
||||||
|
models.EnvironmentSetting.objects.create(
|
||||||
|
setting="API_THROTTLE_RESPONSE_SIZE_THRESHOLD_IP", value_int=1
|
||||||
|
)
|
||||||
|
models.EnvironmentSetting.objects.create(
|
||||||
|
setting="API_THROTTLE_RESPONSE_SIZE_RATE_IP", value_str="1/minute"
|
||||||
|
)
|
||||||
|
|
||||||
|
models.EnvironmentSetting.objects.create(
|
||||||
|
setting="API_THROTTLE_RESPONSE_SIZE_ENABLED_CIDR", value_bool=True
|
||||||
|
)
|
||||||
|
models.EnvironmentSetting.objects.create(
|
||||||
|
setting="API_THROTTLE_RESPONSE_SIZE_THRESHOLD_CIDR", value_int=1
|
||||||
|
)
|
||||||
|
models.EnvironmentSetting.objects.create(
|
||||||
|
setting="API_THROTTLE_RESPONSE_SIZE_RATE_CIDR", value_str="1/minute"
|
||||||
|
)
|
||||||
|
|
||||||
|
models.EnvironmentSetting.objects.create(
|
||||||
|
setting="API_THROTTLE_RATE_ANON", value_str="1/minute"
|
||||||
|
)
|
||||||
|
|
||||||
|
call_command("pdb_generate_test_data", limit=2, commit=True)
|
||||||
|
now = datetime.datetime.now() + datetime.timedelta(days=1)
|
||||||
|
call_command("pdb_api_cache", date=now.strftime("%Y%m%d"))
|
||||||
|
settings.GENERATING_API_CACHE = False
|
||||||
|
|
||||||
|
for (dirpath, dirnames, filenames) in os.walk(settings.API_CACHE_ROOT):
|
||||||
|
for f in filenames:
|
||||||
|
if f in ["log.log"]:
|
||||||
|
continue
|
||||||
|
path = os.path.join(settings.API_CACHE_ROOT, f)
|
||||||
|
with open(path, "r") as fh:
|
||||||
|
data_raw = fh.read()
|
||||||
|
data = json.loads(data_raw)
|
||||||
|
assert not data.get("message")
|
||||||
|
@@ -211,6 +211,64 @@ class APIThrottleTests(TestCase):
|
|||||||
response = ResponseSizeMockView.as_view({"get": "get"})(request)
|
response = ResponseSizeMockView.as_view({"get": "get"})(request)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_response_size_ip_block_x_forwarded(self):
|
||||||
|
"""
|
||||||
|
Ensure request rate is limited based on response size
|
||||||
|
for ip-block with HTTP_X_FORWARDED_FOR set
|
||||||
|
"""
|
||||||
|
|
||||||
|
request = self.factory.get("/")
|
||||||
|
request.META.update({"HTTP_X_FORWARDED_FOR": "10.10.10.10,77.77.77.77"})
|
||||||
|
|
||||||
|
# by default ip-block response size rate limiting is disabled
|
||||||
|
# ip 10.10.10.10 requesting 10 times (all should be ok)
|
||||||
|
|
||||||
|
for dummy in range(10):
|
||||||
|
response = ResponseSizeMockView.as_view({"get": "get"})(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# turn on response size throttling for responses bigger than 500 bytes
|
||||||
|
# for ip blocks
|
||||||
|
|
||||||
|
thold = models.EnvironmentSetting.objects.create(
|
||||||
|
setting="API_THROTTLE_RESPONSE_SIZE_THRESHOLD_CIDR", value_int=500
|
||||||
|
)
|
||||||
|
models.EnvironmentSetting.objects.create(
|
||||||
|
setting="API_THROTTLE_RESPONSE_SIZE_RATE_CIDR", value_str="3/minute"
|
||||||
|
)
|
||||||
|
models.EnvironmentSetting.objects.create(
|
||||||
|
setting="API_THROTTLE_RESPONSE_SIZE_ENABLED_CIDR", value_bool=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# ip 10.10.10.10 requesting 3 times (all should be ok)
|
||||||
|
for dummy in range(3):
|
||||||
|
response = ResponseSizeMockView.as_view({"get": "get"})(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# ip 10.10.10.10 requesting 4th time (rate limited)
|
||||||
|
response = ResponseSizeMockView.as_view({"get": "get"})(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
# ip 10.10.10.11 requesting 1st time (rate limited)
|
||||||
|
request.META.update(HTTP_X_FORWARDED_FOR="10.10.10.11,77.77.77.77")
|
||||||
|
response = ResponseSizeMockView.as_view({"get": "get"})(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
# ip 20.10.10.10 requesting 1st time (ok)
|
||||||
|
request.META.update(HTTP_X_FORWARDED_FOR="20.10.10.10,77.77.77.77")
|
||||||
|
response = ResponseSizeMockView.as_view({"get": "get"})(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# increase threshold, no longer rate limited
|
||||||
|
thold.value_int = 5000
|
||||||
|
thold.save()
|
||||||
|
|
||||||
|
# 10.10.10.10 requesting 3 times (all should be ok)
|
||||||
|
request.META.update(HTTP_X_FORWARDED_FOR="10.10.10.10,77.77.77.77")
|
||||||
|
for dummy in range(3):
|
||||||
|
response = ResponseSizeMockView.as_view({"get": "get"})(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
def test_response_size_ip(self):
|
def test_response_size_ip(self):
|
||||||
"""
|
"""
|
||||||
Ensure request rate is limited based on response size
|
Ensure request rate is limited based on response size
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
import base64
|
||||||
|
|
||||||
from django.http import HttpResponse
|
from django.http import HttpResponse
|
||||||
from django.test import (
|
from django.test import (
|
||||||
RequestFactory,
|
RequestFactory,
|
||||||
@@ -34,26 +36,25 @@ class PDBCommonMiddlewareTest(SimpleTestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
class PDBPermissionMiddlewareTest(APITestCase):
|
class PDBPermissionMiddlewareTest(APITestCase):
|
||||||
|
def setUp(self):
|
||||||
client = APIClient()
|
self.client = APIClient()
|
||||||
|
self.factory = RequestFactory()
|
||||||
|
|
||||||
def test_bogus_apikey_auth_id_response(self):
|
def test_bogus_apikey_auth_id_response(self):
|
||||||
|
|
||||||
self.client.credentials(HTTP_AUTHORIZATION="Api-Key bogus")
|
self.client.credentials(HTTP_AUTHORIZATION="Api-Key bogus")
|
||||||
response = self.client.get("/api/fac")
|
response = self.client.get("/api/fac")
|
||||||
self.assertEqual(response.status_code, 401)
|
self.assertEqual(response.status_code, 401)
|
||||||
self.assertEqual(response.headers.get("X-Auth-ID"), "apikey_bogus")
|
self.assertEqual(response.headers.get("X-Auth-ID"), "apikey_bogus")
|
||||||
|
|
||||||
def test_bogus_credentials_auth_id_response(self):
|
def test_bogus_credentials_auth_id_response(self):
|
||||||
|
|
||||||
self.client.credentials(HTTP_AUTHORIZATION="Basic Ym9ndXM6Ym9ndXM=")
|
self.client.credentials(HTTP_AUTHORIZATION="Basic Ym9ndXM6Ym9ndXM=")
|
||||||
response = self.client.get("/api/fac")
|
response = self.client.get("/api/fac")
|
||||||
self.assertEqual(response.status_code, 401)
|
self.assertEqual(response.status_code, 401)
|
||||||
self.assertEqual(response.headers.get("X-Auth-ID"), "bogus")
|
self.assertEqual(response.headers.get("X-Auth-ID"), "bogus")
|
||||||
|
|
||||||
def test_auth_id_response(self):
|
def test_auth_id_api_key(self):
|
||||||
user = User.objects.create(username="bogus")
|
user = User.objects.create(username="test_user")
|
||||||
user.set_password("bogus")
|
user.set_password("test_user")
|
||||||
user.save()
|
user.save()
|
||||||
|
|
||||||
# Create an API key for the user
|
# Create an API key for the user
|
||||||
@@ -63,7 +64,47 @@ class PDBPermissionMiddlewareTest(APITestCase):
|
|||||||
readonly=False,
|
readonly=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.client.credentials(HTTP_AUTHORIZATION="Api-Key %s" % key)
|
self.client.credentials(HTTP_AUTHORIZATION=f"Api-Key {key}")
|
||||||
response = self.client.get("/api/fac")
|
response = self.client.get("/api/fac")
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
assert response.headers.get("X-Auth-ID").startswith("apikey_")
|
assert response.headers.get("X-Auth-ID").startswith("apikey_")
|
||||||
|
|
||||||
|
# test that header gets cleared between requests
|
||||||
|
other_client = APIClient()
|
||||||
|
response = other_client.get("/api/fac")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
assert response.headers.get("X-Auth-ID") is None
|
||||||
|
|
||||||
|
def test_auth_id_session_auth(self):
|
||||||
|
user = User.objects.create(username="test_user")
|
||||||
|
user.set_password("test_user")
|
||||||
|
user.save()
|
||||||
|
|
||||||
|
self.client.force_login(user)
|
||||||
|
response = self.client.get("/api/fac")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
assert response.headers.get("X-Auth-ID") == user.username
|
||||||
|
|
||||||
|
# test that header gets cleared between requests
|
||||||
|
other_client = APIClient()
|
||||||
|
response = other_client.get("/api/fac")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
assert response.headers.get("X-Auth-ID") is None
|
||||||
|
|
||||||
|
def test_auth_id_basic_auth(self):
|
||||||
|
user = User.objects.create(username="test_user")
|
||||||
|
user.set_password("test_user")
|
||||||
|
user.save()
|
||||||
|
|
||||||
|
auth = base64.b64encode(b"test_user:test_user").decode("utf-8")
|
||||||
|
self.client.credentials(HTTP_AUTHORIZATION=f"Basic {auth}")
|
||||||
|
|
||||||
|
response = self.client.get("/api/fac")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
assert response.headers.get("X-Auth-ID") == user.username
|
||||||
|
|
||||||
|
# test that header gets cleared between requests
|
||||||
|
other_client = APIClient()
|
||||||
|
response = other_client.get("/api/fac")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
assert response.headers.get("X-Auth-ID") is None
|
||||||
|
Reference in New Issue
Block a user