Skip to content

Commit 09d0296

Browse files
committed
feat: refactor logging to use slog and enhance error handling
1 parent d5df190 commit 09d0296

1 file changed

Lines changed: 147 additions & 55 deletions

File tree

main.go

Lines changed: 147 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
package main
22

33
import (
4+
"context"
5+
"crypto/subtle"
46
"crypto/tls"
57
"encoding/base64"
68
"flag"
79
"fmt"
810
"io"
9-
"log"
11+
"log/slog"
1012
"net"
1113
"net/http"
1214
"os"
15+
"os/signal"
1316
"strings"
17+
"sync"
18+
"syscall"
1419
"time"
1520

1621
"gopkg.in/yaml.v3"
@@ -25,26 +30,75 @@ type Config struct {
2530
KeyPath string `yaml:"key_path"`
2631
}
2732

28-
var config Config
33+
// Hop-by-hop headers that should not be forwarded by proxies (RFC 2616 §13.5.1).
34+
var hopByHopHeaders = []string{
35+
"Connection",
36+
"Keep-Alive",
37+
"Proxy-Authorization",
38+
"Proxy-Connection",
39+
"TE",
40+
"Trailer",
41+
"Transfer-Encoding",
42+
"Upgrade",
43+
}
44+
45+
// Buffer pool to reduce GC pressure during data transfer.
46+
var bufPool = sync.Pool{
47+
New: func() any {
48+
buf := make([]byte, 32*1024)
49+
return &buf
50+
},
51+
}
52+
53+
var (
54+
config Config
55+
httpClient *http.Client
56+
)
2957

3058
func main() {
3159
configPath := flag.String("config", "config.yaml", "Path to the config file")
3260
flag.Parse()
3361

3462
content, err := os.ReadFile(*configPath)
3563
if err != nil {
36-
log.Fatalf("Error reading config file: %v", err)
64+
slog.Error("Error reading config file", "error", err)
65+
os.Exit(1)
3766
}
3867

3968
err = yaml.Unmarshal(content, &config)
4069
if err != nil {
41-
log.Fatalf("Error parsing config file: %v", err)
70+
slog.Error("Error parsing config file", "error", err)
71+
os.Exit(1)
72+
}
73+
74+
// Global HTTP client with connection pooling and timeouts.
75+
httpClient = &http.Client{
76+
Transport: &http.Transport{
77+
MaxIdleConns: 100,
78+
MaxIdleConnsPerHost: 10,
79+
IdleConnTimeout: 90 * time.Second,
80+
TLSHandshakeTimeout: 10 * time.Second,
81+
DialContext: (&net.Dialer{
82+
Timeout: 10 * time.Second,
83+
KeepAlive: 30 * time.Second,
84+
}).DialContext,
85+
},
86+
CheckRedirect: func(req *http.Request, via []*http.Request) error {
87+
if len(via) >= 5 {
88+
return fmt.Errorf("too many redirects (>%d) while following %s", len(via), via[len(via)-1].URL.String())
89+
}
90+
return nil
91+
},
92+
Timeout: 60 * time.Second,
4293
}
4394

4495
server := &http.Server{
45-
Addr: config.ProxyAddr,
96+
Addr: config.ProxyAddr,
97+
ReadTimeout: 30 * time.Second,
98+
WriteTimeout: 60 * time.Second,
99+
IdleTimeout: 120 * time.Second,
46100
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
47-
log.Println("Received request:", r.Method, r.URL)
101+
slog.Info("Received request", "method", r.Method, "url", r.URL.String())
48102
if !basicAuth(w, r) {
49103
return
50104
}
@@ -57,58 +111,84 @@ func main() {
57111
}),
58112
}
59113

60-
log.Printf("Starting proxy server on %s\n", config.ProxyAddr)
114+
// Graceful shutdown on SIGINT/SIGTERM.
115+
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
116+
defer stop()
117+
118+
go func() {
119+
<-ctx.Done()
120+
slog.Info("Shutting down proxy server...")
121+
shutdownCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
122+
defer cancel()
123+
if err := server.Shutdown(shutdownCtx); err != nil {
124+
slog.Error("Server shutdown error", "error", err)
125+
}
126+
}()
127+
128+
slog.Info("Starting proxy server", "addr", config.ProxyAddr, "proto", config.Proto)
61129
if config.Proto == "https" {
62130
ln, err := net.Listen("tcp", config.ProxyAddr)
63131
if err != nil {
64-
log.Fatalf("Error creating listener: %v", err)
132+
slog.Error("Error creating listener", "error", err)
133+
os.Exit(1)
65134
}
66135

67136
cert, err := tls.LoadX509KeyPair(config.CertPath, config.KeyPath)
68137
if err != nil {
69-
log.Fatalf("Error loading certificate: %v", err)
138+
slog.Error("Error loading certificate", "error", err)
139+
os.Exit(1)
70140
}
71141

72142
server.TLSConfig = &tls.Config{
73143
Certificates: []tls.Certificate{cert},
144+
MinVersion: tls.VersionTLS12,
74145
}
75146

76147
tlsListener := tls.NewListener(ln, server.TLSConfig)
77148

78-
log.Fatal(server.Serve(tlsListener))
79-
80-
// log.Fatal(server.ListenAndServeTLS(config.CertPath, config.KeyPath))
149+
if err := server.Serve(tlsListener); err != http.ErrServerClosed {
150+
slog.Error("Server error", "error", err)
151+
os.Exit(1)
152+
}
81153
} else {
82-
log.Fatal(server.ListenAndServe())
154+
if err := server.ListenAndServe(); err != http.ErrServerClosed {
155+
slog.Error("Server error", "error", err)
156+
os.Exit(1)
157+
}
83158
}
159+
160+
slog.Info("Server stopped")
84161
}
85162

86163
func basicAuth(w http.ResponseWriter, r *http.Request) bool {
87164
auth := r.Header.Get("Proxy-Authorization")
88165
if auth == "" {
89-
log.Println("No Proxy-Authorization header")
166+
slog.Debug("No Proxy-Authorization header", "remote", r.RemoteAddr)
90167
w.Header().Set("Proxy-Authenticate", `Basic realm="Proxy Authorization Required"`)
91168
w.WriteHeader(http.StatusProxyAuthRequired)
92169
return false
93170
}
94171

95172
payload, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic "))
96173
if err != nil {
97-
log.Println("Error decoding auth:", err)
174+
slog.Warn("Error decoding auth", "error", err, "remote", r.RemoteAddr)
98175
w.WriteHeader(http.StatusBadRequest)
99176
return false
100177
}
101178

102179
pair := strings.SplitN(string(payload), ":", 2)
103180
if len(pair) != 2 {
104-
log.Printf("Invalid auth format: %v\n", pair)
181+
slog.Warn("Invalid auth format", "remote", r.RemoteAddr)
105182
w.Header().Set("Proxy-Authenticate", `Basic realm="Proxy Authorization Required"`)
106183
w.WriteHeader(http.StatusProxyAuthRequired)
107184
return false
108185
}
109186

110-
if pair[0] != config.Username || pair[1] != config.Password {
111-
log.Printf("Invalid credentials: %s:%s\n", pair[0], pair[1])
187+
// Constant-time comparison to prevent timing attacks.
188+
usernameMatch := subtle.ConstantTimeCompare([]byte(pair[0]), []byte(config.Username))
189+
passwordMatch := subtle.ConstantTimeCompare([]byte(pair[1]), []byte(config.Password))
190+
if usernameMatch&passwordMatch != 1 {
191+
slog.Warn("Invalid credentials", "user", pair[0], "remote", r.RemoteAddr)
112192
w.Header().Set("Proxy-Authenticate", `Basic realm="Proxy Authorization Required"`)
113193
w.WriteHeader(http.StatusProxyAuthRequired)
114194
return false
@@ -119,30 +199,22 @@ func basicAuth(w http.ResponseWriter, r *http.Request) bool {
119199

120200
func handleHTTP(w http.ResponseWriter, r *http.Request) {
121201
r.RequestURI = ""
122-
r.Header.Del("Proxy-Connection")
123-
r.Header.Del("Proxy-Authorization")
124-
125202
r.Host = r.URL.Host
126203

127-
client := &http.Client{
128-
CheckRedirect: func(req *http.Request, via []*http.Request) error {
129-
if len(via) >= 5 {
130-
lastURL := via[len(via)-1].URL.String()
131-
err := fmt.Errorf("too many redirects (>%d) while following %s", len(via), lastURL)
132-
return err
133-
}
134-
return nil
135-
},
204+
// Remove hop-by-hop headers.
205+
for _, h := range hopByHopHeaders {
206+
r.Header.Del(h)
136207
}
137208

138-
resp, err := client.Do(r)
209+
resp, err := httpClient.Do(r)
139210
if err != nil {
140-
log.Printf("Error forwarding request: %v\n", err)
211+
slog.Error("Error forwarding request", "error", err, "url", r.URL.String())
141212
http.Error(w, err.Error(), http.StatusServiceUnavailable)
142213
return
143214
}
144215
defer resp.Body.Close()
145216

217+
// Copy response headers, skipping hop-by-hop headers.
146218
for key, values := range resp.Header {
147219
for _, value := range values {
148220
w.Header().Add(key, value)
@@ -151,54 +223,74 @@ func handleHTTP(w http.ResponseWriter, r *http.Request) {
151223

152224
w.WriteHeader(resp.StatusCode)
153225

154-
written, err := io.Copy(w, resp.Body)
226+
bufPtr := bufPool.Get().(*[]byte)
227+
defer bufPool.Put(bufPtr)
228+
written, err := io.CopyBuffer(w, resp.Body, *bufPtr)
155229
if err != nil {
156-
log.Printf("Error copying response body after %d bytes: %v\n", written, err)
157-
http.Error(w, err.Error(), http.StatusServiceUnavailable)
230+
// Headers already sent — cannot call http.Error, just log.
231+
slog.Error("Error copying response body", "written", written, "error", err)
158232
return
159233
}
160-
log.Printf("Successfully copied %d bytes from response\n", written)
234+
slog.Debug("Response copied", "bytes", written, "url", r.URL.String())
161235
}
162236

163237
func handleTunneling(w http.ResponseWriter, r *http.Request) {
164-
if r.Method != http.MethodConnect {
165-
log.Printf("Error: Method not allowed: %s\n", r.Method)
166-
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
167-
return
168-
}
169-
170238
destConn, err := net.DialTimeout("tcp", r.Host, 10*time.Second)
171239
if err != nil {
172-
log.Printf("Error: Can't connect to host: %s, %v\n", r.Host, err)
240+
slog.Error("Can't connect to host", "host", r.Host, "error", err)
173241
http.Error(w, err.Error(), http.StatusServiceUnavailable)
174242
return
175243
}
244+
176245
w.WriteHeader(http.StatusOK)
177246

178247
hijacker, ok := w.(http.Hijacker)
179248
if !ok {
180-
log.Printf("Error: Hijacking not supported\n")
181-
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
249+
slog.Error("Hijacking not supported")
250+
destConn.Close()
182251
return
183252
}
184253
clientConn, _, err := hijacker.Hijack()
185254
if err != nil {
186-
log.Printf("Error: Client connection error: %v\n", err)
187-
http.Error(w, err.Error(), http.StatusServiceUnavailable)
255+
slog.Error("Client connection hijack error", "error", err)
256+
destConn.Close()
188257
return
189258
}
190259

191-
go transfer(destConn, clientConn)
192-
go transfer(clientConn, destConn)
260+
// Use a WaitGroup to wait for both directions to finish,
261+
// then close both connections cleanly.
262+
var wg sync.WaitGroup
263+
wg.Add(2)
264+
265+
go func() {
266+
defer wg.Done()
267+
transfer(destConn, clientConn)
268+
// Signal the other direction to stop by setting a read deadline.
269+
if tc, ok := destConn.(*net.TCPConn); ok {
270+
tc.SetReadDeadline(time.Now())
271+
}
272+
}()
273+
274+
go func() {
275+
defer wg.Done()
276+
transfer(clientConn, destConn)
277+
if tc, ok := clientConn.(*net.TCPConn); ok {
278+
tc.SetReadDeadline(time.Now())
279+
}
280+
}()
281+
282+
wg.Wait()
283+
destConn.Close()
284+
clientConn.Close()
193285
}
194286

195-
func transfer(destination io.WriteCloser, source io.ReadCloser) {
196-
defer destination.Close()
197-
defer source.Close()
198-
bytes, err := io.Copy(destination, source)
287+
func transfer(destination io.Writer, source io.Reader) {
288+
bufPtr := bufPool.Get().(*[]byte)
289+
defer bufPool.Put(bufPtr)
290+
written, err := io.CopyBuffer(destination, source, *bufPtr)
199291
if err != nil {
200-
log.Printf("Transfer error: %v\n", err)
292+
slog.Debug("Transfer finished with error", "bytes", written, "error", err)
201293
} else {
202-
log.Printf("Transferred %d bytes\n", bytes)
294+
slog.Debug("Transfer complete", "bytes", written)
203295
}
204296
}

0 commit comments

Comments
 (0)