mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Limit origin check to when host is loopback.
This should still prevent the exploit without breaking things for people who use reverse proxies.
This commit is contained in:
parent
81778a7feb
commit
36c83cdbba
34
server.py
34
server.py
@ -12,6 +12,8 @@ import json
|
||||
import glob
|
||||
import struct
|
||||
import ssl
|
||||
import socket
|
||||
import ipaddress
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from io import BytesIO
|
||||
@ -80,6 +82,32 @@ def create_cors_middleware(allowed_origin: str):
|
||||
|
||||
return cors_middleware
|
||||
|
||||
def is_loopback(host):
|
||||
if host is None:
|
||||
return False
|
||||
try:
|
||||
if ipaddress.ip_address(host).is_loopback:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except:
|
||||
pass
|
||||
|
||||
loopback = False
|
||||
for family in (socket.AF_INET, socket.AF_INET6):
|
||||
try:
|
||||
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
|
||||
for family, _, _, _, sockaddr in r:
|
||||
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
|
||||
return loopback
|
||||
else:
|
||||
loopback = True
|
||||
except socket.gaierror:
|
||||
pass
|
||||
|
||||
return loopback
|
||||
|
||||
|
||||
def create_origin_only_middleware():
|
||||
@web.middleware
|
||||
async def origin_only_middleware(request: web.Request, handler):
|
||||
@ -93,12 +121,16 @@ def create_origin_only_middleware():
|
||||
parsed = urllib.parse.urlparse(origin)
|
||||
origin_domain = parsed.netloc.lower()
|
||||
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
|
||||
|
||||
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
|
||||
loopback = is_loopback(host_domain_parsed.hostname)
|
||||
|
||||
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
|
||||
host_domain = host_domain_parsed.hostname
|
||||
if host_domain_parsed.port is None:
|
||||
origin_domain = parsed.hostname
|
||||
|
||||
if host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
|
||||
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
|
||||
if host_domain != origin_domain:
|
||||
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
||||
return web.Response(status=403)
|
||||
|
Loading…
Reference in New Issue
Block a user