Browse Source

connman is in charge of pushing messages

The changes here are dense and subtle, but hopefully all is more explicit
than before.

- CConnman is now in charge of sending data rather than the nodes themselves.
  This is necessary because many decisions need to be made with all nodes in
  mind, and a model that requires the nodes calling up to their manager quickly
  turns to spaghetti.

- The per-node-serializer (ssSend) has been replaced with a (quasi-)const
  send-version. Since the send version for serialization can only change once
  per connection, we now explicitly tag messages with INIT_PROTO_VERSION if
  they are sent before the handshake. With this done, there's no need to lock
  for access to nSendVersion.

  Also, a new stream is used for each message, so there's no need to lock
  during the serialization process.

- This takes care of accounting for optimistic sends, so the
  nOptimisticBytesWritten hack can be removed.

- -dropmessagestest and -fuzzmessagestest have not been preserved, as I suspect
  they haven't been used in years.
pull/1/head
Cory Fields 6 years ago committed by Pieter Wuille
parent
commit
3e32cd09f6
  1. 14
      src/main.cpp
  2. 91
      src/net.cpp
  3. 58
      src/net.h
  4. 4
      src/test/DoS_tests.cpp

14
src/main.cpp

@ -5047,7 +5047,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, @@ -5047,7 +5047,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv,
// Each connection can only send one version message
if (pfrom->nVersion != 0)
{
pfrom->PushMessage(NetMsgType::REJECT, strCommand, REJECT_DUPLICATE, string("Duplicate version message"));
connman.PushMessageWithVersion(pfrom, INIT_PROTO_VERSION, NetMsgType::REJECT, strCommand, REJECT_DUPLICATE, string("Duplicate version message"));
LOCK(cs_main);
Misbehaving(pfrom->GetId(), 1);
return false;
@ -5067,7 +5067,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, @@ -5067,7 +5067,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv,
if (pfrom->nServicesExpected & ~pfrom->nServices)
{
LogPrint("net", "peer=%d does not offer the expected services (%08x offered, %08x expected); disconnecting\n", pfrom->id, pfrom->nServices, pfrom->nServicesExpected);
pfrom->PushMessage(NetMsgType::REJECT, strCommand, REJECT_NONSTANDARD,
connman.PushMessageWithVersion(pfrom, INIT_PROTO_VERSION, NetMsgType::REJECT, strCommand, REJECT_NONSTANDARD,
strprintf("Expected to offer services %08x", pfrom->nServicesExpected));
pfrom->fDisconnect = true;
return false;
@ -5077,7 +5077,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, @@ -5077,7 +5077,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv,
{
// disconnect from peers older than this proto version
LogPrintf("peer=%d using obsolete version %i; disconnecting\n", pfrom->id, pfrom->nVersion);
pfrom->PushMessage(NetMsgType::REJECT, strCommand, REJECT_OBSOLETE,
connman.PushMessageWithVersion(pfrom, INIT_PROTO_VERSION, NetMsgType::REJECT, strCommand, REJECT_OBSOLETE,
strprintf("Version must be %d or greater", MIN_PEER_PROTO_VERSION));
pfrom->fDisconnect = true;
return false;
@ -5118,7 +5118,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, @@ -5118,7 +5118,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv,
// Be shy and don't send version until we hear
if (pfrom->fInbound)
pfrom->PushVersion();
connman.PushVersion(pfrom, GetAdjustedTime());
pfrom->fClient = !(pfrom->nServices & NODE_NETWORK);
@ -5135,8 +5135,8 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, @@ -5135,8 +5135,8 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv,
}
// Change version
pfrom->PushMessage(NetMsgType::VERACK);
pfrom->ssSend.SetVersion(min(pfrom->nVersion, PROTOCOL_VERSION));
connman.PushMessageWithVersion(pfrom, INIT_PROTO_VERSION, NetMsgType::VERACK);
pfrom->SetSendVersion(min(pfrom->nVersion, PROTOCOL_VERSION));
if (!pfrom->fInbound)
{
@ -6391,7 +6391,7 @@ bool ProcessMessages(CNode* pfrom, CConnman& connman) @@ -6391,7 +6391,7 @@ bool ProcessMessages(CNode* pfrom, CConnman& connman)
}
catch (const std::ios_base::failure& e)
{
pfrom->PushMessage(NetMsgType::REJECT, strCommand, REJECT_MALFORMED, string("error parsing message"));
connman.PushMessageWithVersion(pfrom, INIT_PROTO_VERSION, NetMsgType::REJECT, strCommand, REJECT_MALFORMED, string("error parsing message"));
if (strstr(e.what(), "end of data"))
{
// Allow exceptions from under-length message on vRecv

91
src/net.cpp

@ -394,6 +394,9 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo @@ -394,6 +394,9 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize();
CNode* pnode = new CNode(id, nLocalServices, GetBestHeight(), hSocket, addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, pszDest ? pszDest : "", false);
PushVersion(pnode, GetTime());
GetNodeSignals().InitializeNode(pnode->GetId(), pnode);
pnode->AddRef();
@ -415,6 +418,24 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo @@ -415,6 +418,24 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
return NULL;
}
void CConnman::PushVersion(CNode* pnode, int64_t nTime)
{
ServiceFlags nLocalNodeServices = pnode->GetLocalServices();
CAddress addrYou = (pnode->addr.IsRoutable() && !IsProxy(pnode->addr) ? pnode->addr : CAddress(CService(), pnode->addr.nServices));
CAddress addrMe = CAddress(CService(), nLocalNodeServices);
uint64_t nonce = pnode->GetLocalNonce();
int nNodeStartingHeight = pnode->nMyStartingHeight;
NodeId id = pnode->GetId();
PushMessageWithVersion(pnode, INIT_PROTO_VERSION, NetMsgType::VERSION, PROTOCOL_VERSION, (uint64_t)nLocalNodeServices, nTime, addrYou, addrMe,
nonce, strSubVersion, nNodeStartingHeight, ::fRelayTxes);
if (fLogIPs)
LogPrint("net", "send version message: version %d, blocks=%d, us=%s, them=%s, peer=%d\n", PROTOCOL_VERSION, nNodeStartingHeight, addrMe.ToString(), addrYou.ToString(), id);
else
LogPrint("net", "send version message: version %d, blocks=%d, us=%s, peer=%d\n", PROTOCOL_VERSION, nNodeStartingHeight, addrMe.ToString(), id);
}
void CConnman::DumpBanlist()
{
SweepBanned(); // clean unused entries (if bantime has expired)
@ -450,23 +471,6 @@ void CNode::CloseSocketDisconnect() @@ -450,23 +471,6 @@ void CNode::CloseSocketDisconnect()
vRecvMsg.clear();
}
void CNode::PushVersion()
{
int64_t nTime = (fInbound ? GetAdjustedTime() : GetTime());
CAddress addrYou = (addr.IsRoutable() && !IsProxy(addr) ? addr : CAddress(CService(), addr.nServices));
CAddress addrMe = CAddress(CService(), nLocalServices);
if (fLogIPs)
LogPrint("net", "send version message: version %d, blocks=%d, us=%s, them=%s, peer=%d\n", PROTOCOL_VERSION, nMyStartingHeight, addrMe.ToString(), addrYou.ToString(), id);
else
LogPrint("net", "send version message: version %d, blocks=%d, us=%s, peer=%d\n", PROTOCOL_VERSION, nMyStartingHeight, addrMe.ToString(), id);
PushMessage(NetMsgType::VERSION, PROTOCOL_VERSION, (uint64_t)nLocalServices, nTime, addrYou, addrMe,
nLocalHostNonce, strSubVersion, nMyStartingHeight, ::fRelayTxes);
}
void CConnman::ClearBanned()
{
{
@ -2530,7 +2534,8 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn @@ -2530,7 +2534,8 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn
filterInventoryKnown(50000, 0.000001),
nLocalHostNonce(nLocalHostNonceIn),
nLocalServices(nLocalServicesIn),
nMyStartingHeight(nMyStartingHeightIn)
nMyStartingHeight(nMyStartingHeightIn),
nSendVersion(0)
{
nServices = NODE_NONE;
nServicesExpected = NODE_NONE;
@ -2587,10 +2592,6 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn @@ -2587,10 +2592,6 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn
LogPrint("net", "Added connection to %s peer=%d\n", addrName, id);
else
LogPrint("net", "Added connection peer=%d\n", id);
// Be shy and don't send version until we hear
if (hSocket != INVALID_SOCKET && !fInbound)
PushVersion();
}
CNode::~CNode()
@ -2696,6 +2697,52 @@ void CNode::EndMessage(const char* pszCommand) UNLOCK_FUNCTION(cs_vSend) @@ -2696,6 +2697,52 @@ void CNode::EndMessage(const char* pszCommand) UNLOCK_FUNCTION(cs_vSend)
LEAVE_CRITICAL_SECTION(cs_vSend);
}
CDataStream CConnman::BeginMessage(CNode* pnode, int nVersion, int flags, const std::string& sCommand)
{
return {SER_NETWORK, (nVersion ? nVersion : pnode->GetSendVersion()) | flags, CMessageHeader(Params().MessageStart(), sCommand.c_str(), 0) };
}
void CConnman::EndMessage(CDataStream& strm)
{
// Set the size
assert(strm.size () >= CMessageHeader::HEADER_SIZE);
unsigned int nSize = strm.size() - CMessageHeader::HEADER_SIZE;
WriteLE32((uint8_t*)&strm[CMessageHeader::MESSAGE_SIZE_OFFSET], nSize);
// Set the checksum
uint256 hash = Hash(strm.begin() + CMessageHeader::HEADER_SIZE, strm.end());
memcpy((char*)&strm[CMessageHeader::CHECKSUM_OFFSET], hash.begin(), CMessageHeader::CHECKSUM_SIZE);
}
void CConnman::PushMessage(CNode* pnode, CDataStream& strm, const std::string& sCommand)
{
if(strm.empty())
return;
unsigned int nSize = strm.size() - CMessageHeader::HEADER_SIZE;
LogPrint("net", "sending %s (%d bytes) peer=%d\n", SanitizeString(sCommand.c_str()), nSize, pnode->id);
size_t nBytesSent = 0;
{
LOCK(pnode->cs_vSend);
if(pnode->hSocket == INVALID_SOCKET) {
return;
}
bool optimisticSend(pnode->vSendMsg.empty());
pnode->vSendMsg.emplace_back(strm.begin(), strm.end());
//log total amount of bytes per command
pnode->mapSendBytesPerMsgCmd[sCommand] += strm.size();
pnode->nSendSize += strm.size();
// If write queue empty, attempt "optimistic write"
if (optimisticSend == true)
nBytesSent = SocketSendData(pnode);
}
if (nBytesSent)
RecordBytesSent(nBytesSent);
}
bool CConnman::ForNode(NodeId id, std::function<bool(CNode* pnode)> func)
{
CNode* found = nullptr;

58
src/net.h

@ -136,6 +136,36 @@ public: @@ -136,6 +136,36 @@ public:
bool ForNode(NodeId id, std::function<bool(CNode* pnode)> func);
template <typename... Args>
void PushMessageWithVersionAndFlag(CNode* pnode, int nVersion, int flag, const std::string& sCommand, Args&&... args)
{
auto msg(BeginMessage(pnode, nVersion, flag, sCommand));
::SerializeMany(msg, msg.nType, msg.nVersion, std::forward<Args>(args)...);
EndMessage(msg);
PushMessage(pnode, msg, sCommand);
}
template <typename... Args>
void PushMessageWithFlag(CNode* pnode, int flag, const std::string& sCommand, Args&&... args)
{
PushMessageWithVersionAndFlag(pnode, 0, flag, sCommand, std::forward<Args>(args)...);
}
template <typename... Args>
void PushMessageWithVersion(CNode* pnode, int nVersion, const std::string& sCommand, Args&&... args)
{
PushMessageWithVersionAndFlag(pnode, nVersion, 0, sCommand, std::forward<Args>(args)...);
}
template <typename... Args>
void PushMessage(CNode* pnode, const std::string& sCommand, Args&&... args)
{
PushMessageWithVersionAndFlag(pnode, 0, 0, sCommand, std::forward<Args>(args)...);
}
void PushVersion(CNode* pnode, int64_t nTime);
template<typename Callable>
bool ForEachNodeContinueIf(Callable&& func)
{
@ -345,6 +375,10 @@ private: @@ -345,6 +375,10 @@ private:
unsigned int GetReceiveFloodSize() const;
CDataStream BeginMessage(CNode* node, int nVersion, int flags, const std::string& sCommand);
void PushMessage(CNode* pnode, CDataStream& strm, const std::string& sCommand);
void EndMessage(CDataStream& strm);
// Network stats
void RecordBytesRecv(uint64_t bytes);
void RecordBytesSent(uint64_t bytes);
@ -553,6 +587,7 @@ public: @@ -553,6 +587,7 @@ public:
/** Information about a peer */
class CNode
{
friend class CConnman;
public:
// socket
ServiceFlags nServices;
@ -681,6 +716,7 @@ private: @@ -681,6 +716,7 @@ private:
// Services offered to this peer
const ServiceFlags nLocalServices;
const int nMyStartingHeight;
int nSendVersion;
public:
NodeId GetId() const {
@ -716,6 +752,25 @@ public: @@ -716,6 +752,25 @@ public:
BOOST_FOREACH(CNetMessage &msg, vRecvMsg)
msg.SetVersion(nVersionIn);
}
void SetSendVersion(int nVersionIn)
{
// Send version may only be changed in the version message, and
// only one version message is allowed per session. We can therefore
// treat this value as const and even atomic as long as it's only used
// once the handshake is complete. Any attempt to set this twice is an
// error.
assert(nSendVersion == 0);
nSendVersion = nVersionIn;
}
int GetSendVersion() const
{
// The send version should always be explicitly set to
// INIT_PROTO_VERSION rather than using this value until the handshake
// is complete. See PushMessageWithVersion().
assert(nSendVersion != 0);
return nSendVersion;
}
CNode* AddRef()
{
@ -787,9 +842,6 @@ public: @@ -787,9 +842,6 @@ public:
// TODO: Document the precondition of this function. Is cs_vSend locked?
void EndMessage(const char* pszCommand) UNLOCK_FUNCTION(cs_vSend);
void PushVersion();
void PushMessage(const char* pszCommand)
{
try

4
src/test/DoS_tests.cpp

@ -49,6 +49,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning) @@ -49,6 +49,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning)
connman->ClearBanned();
CAddress addr1(ip(0xa0b0c001), NODE_NONE);
CNode dummyNode1(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr1, 0, 0, "", true);
dummyNode1.SetSendVersion(PROTOCOL_VERSION);
GetNodeSignals().InitializeNode(dummyNode1.GetId(), &dummyNode1);
dummyNode1.nVersion = 1;
Misbehaving(dummyNode1.GetId(), 100); // Should get banned
@ -58,6 +59,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning) @@ -58,6 +59,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning)
CAddress addr2(ip(0xa0b0c002), NODE_NONE);
CNode dummyNode2(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr2, 1, 1, "", true);
dummyNode2.SetSendVersion(PROTOCOL_VERSION);
GetNodeSignals().InitializeNode(dummyNode2.GetId(), &dummyNode2);
dummyNode2.nVersion = 1;
Misbehaving(dummyNode2.GetId(), 50);
@ -75,6 +77,7 @@ BOOST_AUTO_TEST_CASE(DoS_banscore) @@ -75,6 +77,7 @@ BOOST_AUTO_TEST_CASE(DoS_banscore)
mapArgs["-banscore"] = "111"; // because 11 is my favorite number
CAddress addr1(ip(0xa0b0c001), NODE_NONE);
CNode dummyNode1(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr1, 3, 1, "", true);
dummyNode1.SetSendVersion(PROTOCOL_VERSION);
GetNodeSignals().InitializeNode(dummyNode1.GetId(), &dummyNode1);
dummyNode1.nVersion = 1;
Misbehaving(dummyNode1.GetId(), 100);
@ -97,6 +100,7 @@ BOOST_AUTO_TEST_CASE(DoS_bantime) @@ -97,6 +100,7 @@ BOOST_AUTO_TEST_CASE(DoS_bantime)
CAddress addr(ip(0xa0b0c001), NODE_NONE);
CNode dummyNode(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr, 4, 4, "", true);
dummyNode.SetSendVersion(PROTOCOL_VERSION);
GetNodeSignals().InitializeNode(dummyNode.GetId(), &dummyNode);
dummyNode.nVersion = 1;

Loading…
Cancel
Save