mirror of
https://github.com/trezor/blockbook.git
synced 2026-03-24 16:37:19 +01:00
enhancement: reject oversized websocket messages
This commit is contained in:
@@ -1300,6 +1300,50 @@ func assertNoWebsocketMessage(t *testing.T, s *websocket.Conn, timeout time.Dura
|
||||
}
|
||||
}
|
||||
|
||||
func Test_WebsocketRejectsOversizedMessage(t *testing.T) {
|
||||
parser, chain := setupChain(t)
|
||||
|
||||
s, dbpath := setupPublicHTTPServer(parser, chain, t, false)
|
||||
defer closeAndDestroyPublicServer(t, s, dbpath)
|
||||
s.ConnectFullPublicInterface()
|
||||
|
||||
ts := httptest.NewServer(s.https.Handler)
|
||||
defer ts.Close()
|
||||
|
||||
ws := connectWebsocket(t, ts)
|
||||
defer ws.Close()
|
||||
|
||||
// Verify the connection is healthy before sending an oversized frame.
|
||||
if err := ws.WriteJSON(websocketReq{ID: "0", Method: "getInfo"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp := readWebsocketResponse(t, ws, time.Second)
|
||||
if resp.ID != "0" {
|
||||
t.Fatalf("got response id %q, want %q", resp.ID, "0")
|
||||
}
|
||||
|
||||
payload := strings.Repeat("a", int(maxWebsocketMessageBytes)+1)
|
||||
if err := ws.WriteMessage(websocket.TextMessage, []byte(payload)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := ws.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _, err := ws.ReadMessage()
|
||||
ws.SetReadDeadline(time.Time{})
|
||||
if err == nil {
|
||||
t.Fatal("expected websocket read error after oversized message")
|
||||
}
|
||||
if websocket.IsCloseError(err, websocket.CloseMessageTooBig, websocket.CloseAbnormalClosure) {
|
||||
return
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
t.Fatalf("unexpected websocket error after oversized message: %v", err)
|
||||
}
|
||||
|
||||
var websocketTestsBitcoinType = []websocketTest{
|
||||
{
|
||||
name: "websocket getInfo",
|
||||
|
||||
@@ -27,6 +27,8 @@ const upgradeFailed = "Upgrade failed: "
|
||||
const outChannelSize = 500
|
||||
const defaultTimeout = 60 * time.Second
|
||||
const unknownMethodLabel = "unknown"
|
||||
const maxWebsocketMessageBytes int64 = 4 * 1024 * 1024
|
||||
const websocketLogPreviewBytes = 256
|
||||
|
||||
// allRates is a special "currency" parameter that means all available currencies
|
||||
const allFiatRates = "!ALL!"
|
||||
@@ -199,6 +201,13 @@ func getIP(r *http.Request) string {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
func getWebsocketPayloadPreview(d []byte) string {
|
||||
if len(d) <= websocketLogPreviewBytes {
|
||||
return string(d)
|
||||
}
|
||||
return string(d[:websocketLogPreviewBytes]) + "...(truncated)"
|
||||
}
|
||||
|
||||
// ServeHTTP sets up handler of websocket channel
|
||||
func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
@@ -210,6 +219,7 @@ func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, upgradeFailed+err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
conn.SetReadLimit(maxWebsocketMessageBytes)
|
||||
c := &websocketChannel{
|
||||
id: atomic.AddUint64(&connectionCounter, 1),
|
||||
conn: conn,
|
||||
@@ -299,7 +309,7 @@ func (s *WebsocketServer) inputLoop(c *websocketChannel) {
|
||||
var req WsReq
|
||||
err := json.Unmarshal(d, &req)
|
||||
if err != nil {
|
||||
glog.Error("Error parsing message from ", c.id, ", ", string(d), ", ", err)
|
||||
glog.Error("Error parsing message from ", c.id, ", len ", len(d), ", preview ", getWebsocketPayloadPreview(d), ", ", err)
|
||||
s.closeChannel(c, "protocol_error")
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user