diff --git a/server/websocket.go b/server/websocket.go index d25500b9..d347baaa 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -1059,6 +1059,35 @@ func setEthereumReceiptIfAvailable(tx *bchain.Tx, getReceipt func(string) (*bcha tx.CoinSpecificData = csd } +func populateBitcoinVinAddrDescs(vins []bchain.MempoolVin, getAddrDesc func(string, uint32) (bchain.AddressDescriptor, error)) { + if getAddrDesc == nil { + return + } + for i := range vins { + if len(vins[i].AddrDesc) > 0 || vins[i].Txid == "" { + continue + } + addrDesc, err := getAddrDesc(vins[i].Txid, vins[i].Vout) + if err == nil && len(addrDesc) > 0 { + vins[i].AddrDesc = addrDesc + } + } +} + +func (s *WebsocketServer) getBitcoinVinAddrDesc(txid string, vout uint32) (bchain.AddressDescriptor, error) { + if s.txCache == nil { + return nil, bchain.ErrTxNotFound + } + prevTx, _, err := s.txCache.GetTransaction(txid) + if err != nil { + return nil, err + } + if int(vout) >= len(prevTx.Vout) { + return nil, bchain.ErrAddressMissing + } + return s.chainParser.GetAddrDescFromVout(&prevTx.Vout[vout]) +} + func (s *WebsocketServer) publishNewBlockTxsByAddr(block *bchain.Block) { for _, tx := range block.Txs { setConfirmedBlockTxMetadata(&tx, block.Time) @@ -1072,6 +1101,9 @@ func (s *WebsocketServer) publishNewBlockTxsByAddr(block *bchain.Block) { for i, vin := range tx.Vin { vins[i] = bchain.MempoolVin{Vin: vin} } + if s.chainParser.GetChainType() == bchain.ChainBitcoinType { + populateBitcoinVinAddrDescs(vins, s.getBitcoinVinAddrDesc) + } subscribed := s.getNewTxSubscriptions(vins, tx.Vout, tokenTransfers, internalTransfers) if len(subscribed) > 0 { go func(tx bchain.Tx, subscribed map[string]struct{}) { diff --git a/server/websocket_test.go b/server/websocket_test.go index 37913545..6442c95b 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -6,7 +6,9 @@ import ( "errors" "testing" + "github.com/trezor/blockbook/api" "github.com/trezor/blockbook/bchain" + "github.com/trezor/blockbook/tests/dbtestdata" ) func TestSetConfirmedBlockTxMetadataSetsConfirmedFields(t *testing.T) { @@ -123,3 +125,84 @@ func TestSetEthereumReceiptIfAvailableSetsReceipt(t *testing.T) { t.Fatalf("Receipt = %+v, want %+v", csd.Receipt, wantReceipt) } } + +func TestSendOnNewTxAddrFiltersNewBlockTxSubscriptions(t *testing.T) { + parser, _ := setupChain(t) + s := &WebsocketServer{ + chainParser: parser, + addressSubscriptions: make(map[string]map[*websocketChannel]*addressDetails), + } + addrDesc, err := parser.GetAddrDescFromAddress(dbtestdata.Addr1) + if err != nil { + t.Fatal(err) + } + stringAddrDesc := string(addrDesc) + onlyMempool := &websocketChannel{out: make(chan *WsRes, 1), alive: true} + withNewBlockTxs := &websocketChannel{out: make(chan *WsRes, 1), alive: true} + s.addressSubscriptions[stringAddrDesc] = map[*websocketChannel]*addressDetails{ + onlyMempool: { + requestID: "mempool-only", + publishNewBlockTxs: false, + }, + withNewBlockTxs: { + requestID: "with-new-block-txs", + publishNewBlockTxs: true, + }, + } + + s.sendOnNewTxAddr(stringAddrDesc, &api.Tx{Txid: "new-block-tx"}, true) + + if len(onlyMempool.out) != 0 { + t.Fatalf("mempool-only subscriber received %d messages, want 0", len(onlyMempool.out)) + } + if len(withNewBlockTxs.out) != 1 { + t.Fatalf("newBlockTxs subscriber received %d messages, want 1", len(withNewBlockTxs.out)) + } +} + +func TestPopulateBitcoinVinAddrDescsEnablesSenderOnlyMatching(t *testing.T) { + parser, _ := setupChain(t) + block := dbtestdata.GetTestBitcoinTypeBlock2(parser) + tx := block.Txs[0] // spends Addr3/Addr2 and pays Addr6/Addr7 + + vins := make([]bchain.MempoolVin, len(tx.Vin)) + for i := range tx.Vin { + vins[i] = bchain.MempoolVin{Vin: tx.Vin[i]} + } + addr3Desc, err := parser.GetAddrDescFromAddress(dbtestdata.Addr3) + if err != nil { + t.Fatal(err) + } + addr2Desc, err := parser.GetAddrDescFromAddress(dbtestdata.Addr2) + if err != nil { + t.Fatal(err) + } + dummy := &websocketChannel{} + s := &WebsocketServer{ + chainParser: parser, + addressSubscriptions: map[string]map[*websocketChannel]*addressDetails{ + string(addr3Desc): {dummy: {requestID: "sender", publishNewBlockTxs: true}}, + }, + } + + withoutResolvedVins := s.getNewTxSubscriptions(vins, tx.Vout, nil, nil) + if _, ok := withoutResolvedVins[string(addr3Desc)]; ok { + t.Fatal("sender subscription unexpectedly matched before vin descriptor resolution") + } + + populateBitcoinVinAddrDescs(vins, func(txid string, vout uint32) (bchain.AddressDescriptor, error) { + switch { + case txid == dbtestdata.TxidB1T2 && vout == 0: + return addr3Desc, nil + case txid == dbtestdata.TxidB1T1 && vout == 1: + return addr2Desc, nil + default: + return nil, errors.New("not found") + } + }) + + withResolvedVins := s.getNewTxSubscriptions(vins, tx.Vout, nil, nil) + if _, ok := withResolvedVins[string(addr3Desc)]; !ok { + t.Fatal("sender subscription did not match after vin descriptor resolution") + } +}