-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathcors.py
More file actions
150 lines (122 loc) · 5.49 KB
/
cors.py
File metadata and controls
150 lines (122 loc) · 5.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import re
from collections.abc import Awaitable, Callable
from typing import Any
# CORS configuration
ALLOWED_ORIGINS: list[str | re.Pattern] = [
"https://www.braintrust.dev",
"https://www.braintrustdata.com",
re.compile(r"https://.*\.preview\.braintrust\.dev"),
]
ALLOWED_HEADERS = [
"Content-Type",
"X-Amz-Date",
"Authorization",
"X-Api-Key",
"X-Amz-Security-Token",
"x-bt-auth-token",
"x-bt-parent",
"x-bt-org-name",
"x-bt-project-id",
"x-bt-stream-fmt",
"x-bt-use-cache",
"x-bt-use-gateway",
"x-stainless-os",
"x-stainless-lang",
"x-stainless-package-version",
"x-stainless-runtime",
"x-stainless-runtime-version",
"x-stainless-arch",
]
EXPOSED_HEADERS = [
"x-bt-cursor",
"x-bt-found-existing-experiment",
"x-bt-span-id",
"x-bt-span-export",
]
def check_origin(origin: str) -> bool:
"""Check if the origin is allowed."""
if not origin:
return False
# Check environment variables
whitelisted_origin = os.environ.get("WHITELISTED_ORIGIN")
if whitelisted_origin and origin == whitelisted_origin:
return True
braintrust_app_url = os.environ.get("BRAINTRUST_APP_URL")
if braintrust_app_url and origin == braintrust_app_url:
return True
# Check static and regex patterns
for allowed in ALLOWED_ORIGINS:
if isinstance(allowed, str) and origin == allowed:
return True
elif isinstance(allowed, re.Pattern) and allowed.match(origin):
return True
return False
def create_cors_middleware() -> type:
"""Create a Starlette CORS middleware class."""
class CORSMiddleware:
def __init__(self, app: Any) -> None:
self.app = app
async def __call__(
self,
scope: dict[str, Any],
receive: Callable[[], Awaitable[dict[str, Any]]],
send: Callable[[dict[str, Any]], Awaitable[None]],
) -> None:
if scope["type"] == "http":
headers = dict(scope["headers"])
origin = headers.get(b"origin", b"").decode("utf-8")
# Handle OPTIONS requests
if scope["method"] == "OPTIONS":
async def send_options_wrapper(message: dict[str, Any]) -> None:
if message["type"] == "http.response.start":
headers_dict = dict(message.get("headers", []))
if origin and check_origin(origin):
headers_dict[b"access-control-allow-origin"] = origin.encode()
headers_dict[b"access-control-allow-methods"] = (
b"GET, POST, PUT, DELETE, OPTIONS, PATCH"
)
headers_dict[b"access-control-allow-headers"] = ", ".join(ALLOWED_HEADERS).encode()
headers_dict[b"access-control-expose-headers"] = ", ".join(EXPOSED_HEADERS).encode()
headers_dict[b"access-control-allow-credentials"] = b"true"
headers_dict[b"access-control-max-age"] = b"86400"
# Handle private network access
if headers.get(b"access-control-request-private-network"):
headers_dict[b"access-control-allow-private-network"] = b"true"
message["headers"] = list(headers_dict.items())
await send(message)
# Send empty response for OPTIONS
await send_options_wrapper(
{
"type": "http.response.start",
"status": 200,
"headers": [],
}
)
await send(
{
"type": "http.response.body",
"body": b"",
}
)
return
# For other requests, add CORS headers if origin is valid
async def send_wrapper(message: dict[str, Any]) -> None:
if message["type"] == "http.response.start" and origin and check_origin(origin):
headers_dict = dict(message.get("headers", []))
# Add CORS headers
headers_dict[b"access-control-allow-origin"] = origin.encode()
headers_dict[b"access-control-allow-methods"] = b"GET, POST, PUT, DELETE, OPTIONS, PATCH"
headers_dict[b"access-control-allow-headers"] = ", ".join(ALLOWED_HEADERS).encode()
headers_dict[b"access-control-expose-headers"] = ", ".join(EXPOSED_HEADERS).encode()
headers_dict[b"access-control-allow-credentials"] = b"true"
headers_dict[b"access-control-max-age"] = b"86400"
# Handle private network access
if headers.get(b"access-control-request-private-network"):
headers_dict[b"access-control-allow-private-network"] = b"true"
message["headers"] = list(headers_dict.items())
await send(message)
await self.app(scope, receive, send_wrapper)
else:
await self.app(scope, receive, send)
return CORSMiddleware