enhancement: reject oversized websocket messages

This commit is contained in:
pragmaxim
2026-03-11 08:41:07 +01:00
parent 33b99cc7d4
commit f8349fcebc
2 changed files with 55 additions and 1 deletions

View File

@@ -1300,6 +1300,50 @@ func assertNoWebsocketMessage(t *testing.T, s *websocket.Conn, timeout time.Dura
}
}
func Test_WebsocketRejectsOversizedMessage(t *testing.T) {
parser, chain := setupChain(t)
s, dbpath := setupPublicHTTPServer(parser, chain, t, false)
defer closeAndDestroyPublicServer(t, s, dbpath)
s.ConnectFullPublicInterface()
ts := httptest.NewServer(s.https.Handler)
defer ts.Close()
ws := connectWebsocket(t, ts)
defer ws.Close()
// Verify the connection is healthy before sending an oversized frame.
if err := ws.WriteJSON(websocketReq{ID: "0", Method: "getInfo"}); err != nil {
t.Fatal(err)
}
resp := readWebsocketResponse(t, ws, time.Second)
if resp.ID != "0" {
t.Fatalf("got response id %q, want %q", resp.ID, "0")
}
payload := strings.Repeat("a", int(maxWebsocketMessageBytes)+1)
if err := ws.WriteMessage(websocket.TextMessage, []byte(payload)); err != nil {
t.Fatal(err)
}
if err := ws.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatal(err)
}
_, _, err := ws.ReadMessage()
ws.SetReadDeadline(time.Time{})
if err == nil {
t.Fatal("expected websocket read error after oversized message")
}
if websocket.IsCloseError(err, websocket.CloseMessageTooBig, websocket.CloseAbnormalClosure) {
return
}
if errors.Is(err, io.EOF) {
return
}
t.Fatalf("unexpected websocket error after oversized message: %v", err)
}
var websocketTestsBitcoinType = []websocketTest{
{
name: "websocket getInfo",

View File

@@ -27,6 +27,8 @@ const upgradeFailed = "Upgrade failed: "
const outChannelSize = 500
const defaultTimeout = 60 * time.Second
const unknownMethodLabel = "unknown"
const maxWebsocketMessageBytes int64 = 4 * 1024 * 1024
const websocketLogPreviewBytes = 256
// allRates is a special "currency" parameter that means all available currencies
const allFiatRates = "!ALL!"
@@ -199,6 +201,13 @@ func getIP(r *http.Request) string {
return r.RemoteAddr
}
func getWebsocketPayloadPreview(d []byte) string {
if len(d) <= websocketLogPreviewBytes {
return string(d)
}
return string(d[:websocketLogPreviewBytes]) + "...(truncated)"
}
// ServeHTTP sets up handler of websocket channel
func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
@@ -210,6 +219,7 @@ func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, upgradeFailed+err.Error(), http.StatusServiceUnavailable)
return
}
conn.SetReadLimit(maxWebsocketMessageBytes)
c := &websocketChannel{
id: atomic.AddUint64(&connectionCounter, 1),
conn: conn,
@@ -299,7 +309,7 @@ func (s *WebsocketServer) inputLoop(c *websocketChannel) {
var req WsReq
err := json.Unmarshal(d, &req)
if err != nil {
glog.Error("Error parsing message from ", c.id, ", ", string(d), ", ", err)
glog.Error("Error parsing message from ", c.id, ", len ", len(d), ", preview ", getWebsocketPayloadPreview(d), ", ", err)
s.closeChannel(c, "protocol_error")
return
}