mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Limit origin check to when host is loopback.
This should still prevent the exploit without breaking things for people who use reverse proxies.
This commit is contained in:
parent
81778a7feb
commit
36c83cdbba
34
server.py
34
server.py
@ -12,6 +12,8 @@ import json
|
|||||||
import glob
|
import glob
|
||||||
import struct
|
import struct
|
||||||
import ssl
|
import ssl
|
||||||
|
import socket
|
||||||
|
import ipaddress
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -80,6 +82,32 @@ def create_cors_middleware(allowed_origin: str):
|
|||||||
|
|
||||||
return cors_middleware
|
return cors_middleware
|
||||||
|
|
||||||
|
def is_loopback(host):
|
||||||
|
if host is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
if ipaddress.ip_address(host).is_loopback:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
loopback = False
|
||||||
|
for family in (socket.AF_INET, socket.AF_INET6):
|
||||||
|
try:
|
||||||
|
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
|
||||||
|
for family, _, _, _, sockaddr in r:
|
||||||
|
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
|
||||||
|
return loopback
|
||||||
|
else:
|
||||||
|
loopback = True
|
||||||
|
except socket.gaierror:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return loopback
|
||||||
|
|
||||||
|
|
||||||
def create_origin_only_middleware():
|
def create_origin_only_middleware():
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def origin_only_middleware(request: web.Request, handler):
|
async def origin_only_middleware(request: web.Request, handler):
|
||||||
@ -93,12 +121,16 @@ def create_origin_only_middleware():
|
|||||||
parsed = urllib.parse.urlparse(origin)
|
parsed = urllib.parse.urlparse(origin)
|
||||||
origin_domain = parsed.netloc.lower()
|
origin_domain = parsed.netloc.lower()
|
||||||
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
|
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
|
||||||
|
|
||||||
|
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
|
||||||
|
loopback = is_loopback(host_domain_parsed.hostname)
|
||||||
|
|
||||||
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
|
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
|
||||||
host_domain = host_domain_parsed.hostname
|
host_domain = host_domain_parsed.hostname
|
||||||
if host_domain_parsed.port is None:
|
if host_domain_parsed.port is None:
|
||||||
origin_domain = parsed.hostname
|
origin_domain = parsed.hostname
|
||||||
|
|
||||||
if host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
|
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
|
||||||
if host_domain != origin_domain:
|
if host_domain != origin_domain:
|
||||||
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
||||||
return web.Response(status=403)
|
return web.Response(status=403)
|
||||||
|
Loading…
Reference in New Issue
Block a user