From f8349fcebcdf2a9066a06245c1e68c67f3e608bd Mon Sep 17 00:00:00 2001 From: pragmaxim Date: Wed, 11 Mar 2026 08:41:07 +0100 Subject: [PATCH] enhancement: reject oversized websocket messages --- server/public_test.go | 44 +++++++++++++++++++++++++++++++++++++++++++ server/websocket.go | 12 +++++++++++- 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/server/public_test.go b/server/public_test.go index e9ff62b1..bdd103e2 100644 --- a/server/public_test.go +++ b/server/public_test.go @@ -1300,6 +1300,50 @@ func assertNoWebsocketMessage(t *testing.T, s *websocket.Conn, timeout time.Dura } } +func Test_WebsocketRejectsOversizedMessage(t *testing.T) { + parser, chain := setupChain(t) + + s, dbpath := setupPublicHTTPServer(parser, chain, t, false) + defer closeAndDestroyPublicServer(t, s, dbpath) + s.ConnectFullPublicInterface() + + ts := httptest.NewServer(s.https.Handler) + defer ts.Close() + + ws := connectWebsocket(t, ts) + defer ws.Close() + + // Verify the connection is healthy before sending an oversized frame. + if err := ws.WriteJSON(websocketReq{ID: "0", Method: "getInfo"}); err != nil { + t.Fatal(err) + } + resp := readWebsocketResponse(t, ws, time.Second) + if resp.ID != "0" { + t.Fatalf("got response id %q, want %q", resp.ID, "0") + } + + payload := strings.Repeat("a", int(maxWebsocketMessageBytes)+1) + if err := ws.WriteMessage(websocket.TextMessage, []byte(payload)); err != nil { + t.Fatal(err) + } + + if err := ws.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatal(err) + } + _, _, err := ws.ReadMessage() + ws.SetReadDeadline(time.Time{}) + if err == nil { + t.Fatal("expected websocket read error after oversized message") + } + if websocket.IsCloseError(err, websocket.CloseMessageTooBig, websocket.CloseAbnormalClosure) { + return + } + if errors.Is(err, io.EOF) { + return + } + t.Fatalf("unexpected websocket error after oversized message: %v", err) +} + var websocketTestsBitcoinType = []websocketTest{ { name: "websocket getInfo", diff --git a/server/websocket.go b/server/websocket.go index ff17068a..5b69d99c 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -27,6 +27,8 @@ const upgradeFailed = "Upgrade failed: " const outChannelSize = 500 const defaultTimeout = 60 * time.Second const unknownMethodLabel = "unknown" +const maxWebsocketMessageBytes int64 = 4 * 1024 * 1024 +const websocketLogPreviewBytes = 256 // allRates is a special "currency" parameter that means all available currencies const allFiatRates = "!ALL!" @@ -199,6 +201,13 @@ func getIP(r *http.Request) string { return r.RemoteAddr } +func getWebsocketPayloadPreview(d []byte) string { + if len(d) <= websocketLogPreviewBytes { + return string(d) + } + return string(d[:websocketLogPreviewBytes]) + "...(truncated)" +} + // ServeHTTP sets up handler of websocket channel func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { @@ -210,6 +219,7 @@ func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, upgradeFailed+err.Error(), http.StatusServiceUnavailable) return } + conn.SetReadLimit(maxWebsocketMessageBytes) c := &websocketChannel{ id: atomic.AddUint64(&connectionCounter, 1), conn: conn, @@ -299,7 +309,7 @@ func (s *WebsocketServer) inputLoop(c *websocketChannel) { var req WsReq err := json.Unmarshal(d, &req) if err != nil { - glog.Error("Error parsing message from ", c.id, ", ", string(d), ", ", err) + glog.Error("Error parsing message from ", c.id, ", len ", len(d), ", preview ", getWebsocketPayloadPreview(d), ", ", err) s.closeChannel(c, "protocol_error") return }