11package main
22
33import (
4+ "context"
5+ "crypto/subtle"
46 "crypto/tls"
57 "encoding/base64"
68 "flag"
79 "fmt"
810 "io"
9- "log"
11+ "log/slog "
1012 "net"
1113 "net/http"
1214 "os"
15+ "os/signal"
1316 "strings"
17+ "sync"
18+ "syscall"
1419 "time"
1520
1621 "gopkg.in/yaml.v3"
@@ -25,26 +30,75 @@ type Config struct {
2530 KeyPath string `yaml:"key_path"`
2631}
2732
28- var config Config
33+ // Hop-by-hop headers that should not be forwarded by proxies (RFC 2616 §13.5.1).
34+ var hopByHopHeaders = []string {
35+ "Connection" ,
36+ "Keep-Alive" ,
37+ "Proxy-Authorization" ,
38+ "Proxy-Connection" ,
39+ "TE" ,
40+ "Trailer" ,
41+ "Transfer-Encoding" ,
42+ "Upgrade" ,
43+ }
44+
45+ // Buffer pool to reduce GC pressure during data transfer.
46+ var bufPool = sync.Pool {
47+ New : func () any {
48+ buf := make ([]byte , 32 * 1024 )
49+ return & buf
50+ },
51+ }
52+
53+ var (
54+ config Config
55+ httpClient * http.Client
56+ )
2957
3058func main () {
3159 configPath := flag .String ("config" , "config.yaml" , "Path to the config file" )
3260 flag .Parse ()
3361
3462 content , err := os .ReadFile (* configPath )
3563 if err != nil {
36- log .Fatalf ("Error reading config file: %v" , err )
64+ slog .Error ("Error reading config file" , "error" , err )
65+ os .Exit (1 )
3766 }
3867
3968 err = yaml .Unmarshal (content , & config )
4069 if err != nil {
41- log .Fatalf ("Error parsing config file: %v" , err )
70+ slog .Error ("Error parsing config file" , "error" , err )
71+ os .Exit (1 )
72+ }
73+
74+ // Global HTTP client with connection pooling and timeouts.
75+ httpClient = & http.Client {
76+ Transport : & http.Transport {
77+ MaxIdleConns : 100 ,
78+ MaxIdleConnsPerHost : 10 ,
79+ IdleConnTimeout : 90 * time .Second ,
80+ TLSHandshakeTimeout : 10 * time .Second ,
81+ DialContext : (& net.Dialer {
82+ Timeout : 10 * time .Second ,
83+ KeepAlive : 30 * time .Second ,
84+ }).DialContext ,
85+ },
86+ CheckRedirect : func (req * http.Request , via []* http.Request ) error {
87+ if len (via ) >= 5 {
88+ return fmt .Errorf ("too many redirects (>%d) while following %s" , len (via ), via [len (via )- 1 ].URL .String ())
89+ }
90+ return nil
91+ },
92+ Timeout : 60 * time .Second ,
4293 }
4394
4495 server := & http.Server {
45- Addr : config .ProxyAddr ,
96+ Addr : config .ProxyAddr ,
97+ ReadTimeout : 30 * time .Second ,
98+ WriteTimeout : 60 * time .Second ,
99+ IdleTimeout : 120 * time .Second ,
46100 Handler : http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
47- log . Println ("Received request: " , r .Method , r .URL )
101+ slog . Info ("Received request" , "method" , r .Method , "url" , r .URL . String () )
48102 if ! basicAuth (w , r ) {
49103 return
50104 }
@@ -57,58 +111,84 @@ func main() {
57111 }),
58112 }
59113
60- log .Printf ("Starting proxy server on %s\n " , config .ProxyAddr )
114+ // Graceful shutdown on SIGINT/SIGTERM.
115+ ctx , stop := signal .NotifyContext (context .Background (), syscall .SIGINT , syscall .SIGTERM )
116+ defer stop ()
117+
118+ go func () {
119+ <- ctx .Done ()
120+ slog .Info ("Shutting down proxy server..." )
121+ shutdownCtx , cancel := context .WithTimeout (context .Background (), 15 * time .Second )
122+ defer cancel ()
123+ if err := server .Shutdown (shutdownCtx ); err != nil {
124+ slog .Error ("Server shutdown error" , "error" , err )
125+ }
126+ }()
127+
128+ slog .Info ("Starting proxy server" , "addr" , config .ProxyAddr , "proto" , config .Proto )
61129 if config .Proto == "https" {
62130 ln , err := net .Listen ("tcp" , config .ProxyAddr )
63131 if err != nil {
64- log .Fatalf ("Error creating listener: %v" , err )
132+ slog .Error ("Error creating listener" , "error" , err )
133+ os .Exit (1 )
65134 }
66135
67136 cert , err := tls .LoadX509KeyPair (config .CertPath , config .KeyPath )
68137 if err != nil {
69- log .Fatalf ("Error loading certificate: %v" , err )
138+ slog .Error ("Error loading certificate" , "error" , err )
139+ os .Exit (1 )
70140 }
71141
72142 server .TLSConfig = & tls.Config {
73143 Certificates : []tls.Certificate {cert },
144+ MinVersion : tls .VersionTLS12 ,
74145 }
75146
76147 tlsListener := tls .NewListener (ln , server .TLSConfig )
77148
78- log .Fatal (server .Serve (tlsListener ))
79-
80- // log.Fatal(server.ListenAndServeTLS(config.CertPath, config.KeyPath))
149+ if err := server .Serve (tlsListener ); err != http .ErrServerClosed {
150+ slog .Error ("Server error" , "error" , err )
151+ os .Exit (1 )
152+ }
81153 } else {
82- log .Fatal (server .ListenAndServe ())
154+ if err := server .ListenAndServe (); err != http .ErrServerClosed {
155+ slog .Error ("Server error" , "error" , err )
156+ os .Exit (1 )
157+ }
83158 }
159+
160+ slog .Info ("Server stopped" )
84161}
85162
86163func basicAuth (w http.ResponseWriter , r * http.Request ) bool {
87164 auth := r .Header .Get ("Proxy-Authorization" )
88165 if auth == "" {
89- log . Println ("No Proxy-Authorization header" )
166+ slog . Debug ("No Proxy-Authorization header" , "remote" , r . RemoteAddr )
90167 w .Header ().Set ("Proxy-Authenticate" , `Basic realm="Proxy Authorization Required"` )
91168 w .WriteHeader (http .StatusProxyAuthRequired )
92169 return false
93170 }
94171
95172 payload , err := base64 .StdEncoding .DecodeString (strings .TrimPrefix (auth , "Basic " ))
96173 if err != nil {
97- log . Println ("Error decoding auth: " , err )
174+ slog . Warn ("Error decoding auth" , "error" , err , "remote" , r . RemoteAddr )
98175 w .WriteHeader (http .StatusBadRequest )
99176 return false
100177 }
101178
102179 pair := strings .SplitN (string (payload ), ":" , 2 )
103180 if len (pair ) != 2 {
104- log . Printf ("Invalid auth format: %v \n " , pair )
181+ slog . Warn ("Invalid auth format" , "remote" , r . RemoteAddr )
105182 w .Header ().Set ("Proxy-Authenticate" , `Basic realm="Proxy Authorization Required"` )
106183 w .WriteHeader (http .StatusProxyAuthRequired )
107184 return false
108185 }
109186
110- if pair [0 ] != config .Username || pair [1 ] != config .Password {
111- log .Printf ("Invalid credentials: %s:%s\n " , pair [0 ], pair [1 ])
187+ // Constant-time comparison to prevent timing attacks.
188+ usernameMatch := subtle .ConstantTimeCompare ([]byte (pair [0 ]), []byte (config .Username ))
189+ passwordMatch := subtle .ConstantTimeCompare ([]byte (pair [1 ]), []byte (config .Password ))
190+ if usernameMatch & passwordMatch != 1 {
191+ slog .Warn ("Invalid credentials" , "user" , pair [0 ], "remote" , r .RemoteAddr )
112192 w .Header ().Set ("Proxy-Authenticate" , `Basic realm="Proxy Authorization Required"` )
113193 w .WriteHeader (http .StatusProxyAuthRequired )
114194 return false
@@ -119,30 +199,22 @@ func basicAuth(w http.ResponseWriter, r *http.Request) bool {
119199
120200func handleHTTP (w http.ResponseWriter , r * http.Request ) {
121201 r .RequestURI = ""
122- r .Header .Del ("Proxy-Connection" )
123- r .Header .Del ("Proxy-Authorization" )
124-
125202 r .Host = r .URL .Host
126203
127- client := & http.Client {
128- CheckRedirect : func (req * http.Request , via []* http.Request ) error {
129- if len (via ) >= 5 {
130- lastURL := via [len (via )- 1 ].URL .String ()
131- err := fmt .Errorf ("too many redirects (>%d) while following %s" , len (via ), lastURL )
132- return err
133- }
134- return nil
135- },
204+ // Remove hop-by-hop headers.
205+ for _ , h := range hopByHopHeaders {
206+ r .Header .Del (h )
136207 }
137208
138- resp , err := client .Do (r )
209+ resp , err := httpClient .Do (r )
139210 if err != nil {
140- log . Printf ("Error forwarding request: %v \n " , err )
211+ slog . Error ("Error forwarding request" , "error" , err , "url" , r . URL . String () )
141212 http .Error (w , err .Error (), http .StatusServiceUnavailable )
142213 return
143214 }
144215 defer resp .Body .Close ()
145216
217+ // Copy response headers, skipping hop-by-hop headers.
146218 for key , values := range resp .Header {
147219 for _ , value := range values {
148220 w .Header ().Add (key , value )
@@ -151,54 +223,74 @@ func handleHTTP(w http.ResponseWriter, r *http.Request) {
151223
152224 w .WriteHeader (resp .StatusCode )
153225
154- written , err := io .Copy (w , resp .Body )
226+ bufPtr := bufPool .Get ().(* []byte )
227+ defer bufPool .Put (bufPtr )
228+ written , err := io .CopyBuffer (w , resp .Body , * bufPtr )
155229 if err != nil {
156- log . Printf ( "Error copying response body after %d bytes: %v \n " , written , err )
157- http .Error (w , err . Error (), http . StatusServiceUnavailable )
230+ // Headers already sent — cannot call http.Error, just log.
231+ slog .Error ("Error copying response body" , "written" , written , "error" , err )
158232 return
159233 }
160- log . Printf ( "Successfully copied %d bytes from response \n " , written )
234+ slog . Debug ( "Response copied" , " bytes" , written , "url" , r . URL . String () )
161235}
162236
163237func handleTunneling (w http.ResponseWriter , r * http.Request ) {
164- if r .Method != http .MethodConnect {
165- log .Printf ("Error: Method not allowed: %s\n " , r .Method )
166- http .Error (w , "Method not allowed" , http .StatusMethodNotAllowed )
167- return
168- }
169-
170238 destConn , err := net .DialTimeout ("tcp" , r .Host , 10 * time .Second )
171239 if err != nil {
172- log . Printf ( "Error: Can't connect to host: %s, %v \n " , r .Host , err )
240+ slog . Error ( " Can't connect to host" , "host" , r .Host , "error" , err )
173241 http .Error (w , err .Error (), http .StatusServiceUnavailable )
174242 return
175243 }
244+
176245 w .WriteHeader (http .StatusOK )
177246
178247 hijacker , ok := w .(http.Hijacker )
179248 if ! ok {
180- log . Printf ( "Error: Hijacking not supported\n " )
181- http . Error ( w , "Hijacking not supported" , http . StatusInternalServerError )
249+ slog . Error ( " Hijacking not supported" )
250+ destConn . Close ( )
182251 return
183252 }
184253 clientConn , _ , err := hijacker .Hijack ()
185254 if err != nil {
186- log . Printf ( "Error: Client connection error: %v \n " , err )
187- http . Error ( w , err . Error (), http . StatusServiceUnavailable )
255+ slog . Error ( " Client connection hijack error" , "error " , err )
256+ destConn . Close ( )
188257 return
189258 }
190259
191- go transfer (destConn , clientConn )
192- go transfer (clientConn , destConn )
260+ // Use a WaitGroup to wait for both directions to finish,
261+ // then close both connections cleanly.
262+ var wg sync.WaitGroup
263+ wg .Add (2 )
264+
265+ go func () {
266+ defer wg .Done ()
267+ transfer (destConn , clientConn )
268+ // Signal the other direction to stop by setting a read deadline.
269+ if tc , ok := destConn .(* net.TCPConn ); ok {
270+ tc .SetReadDeadline (time .Now ())
271+ }
272+ }()
273+
274+ go func () {
275+ defer wg .Done ()
276+ transfer (clientConn , destConn )
277+ if tc , ok := clientConn .(* net.TCPConn ); ok {
278+ tc .SetReadDeadline (time .Now ())
279+ }
280+ }()
281+
282+ wg .Wait ()
283+ destConn .Close ()
284+ clientConn .Close ()
193285}
194286
195- func transfer (destination io.WriteCloser , source io.ReadCloser ) {
196- defer destination . Close ( )
197- defer source . Close ( )
198- bytes , err := io .Copy (destination , source )
287+ func transfer (destination io.Writer , source io.Reader ) {
288+ bufPtr := bufPool . Get ().( * [] byte )
289+ defer bufPool . Put ( bufPtr )
290+ written , err := io .CopyBuffer (destination , source , * bufPtr )
199291 if err != nil {
200- log . Printf ("Transfer error: %v \n " , err )
292+ slog . Debug ("Transfer finished with error" , "bytes" , written , "error " , err )
201293 } else {
202- log . Printf ( "Transferred %d bytes\n " , bytes )
294+ slog . Debug ( "Transfer complete" , " bytes" , written )
203295 }
204296}
0 commit comments