00001
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #include "Socket.hpp"
00030 #include "SocketException.hpp"
00031 #include "Debug.hpp"
00032 #include "UtilTime.hpp"
00033 #include <boost/lexical_cast.hpp>
00034 #include <boost/thread/thread.hpp>
00035 #include <errno.h>
00036 #include <signal.h>
00037 #include <poll.h>
00038 #include <netdb.h>
00039 #include <sys/socket.h>
00040 #include <netinet/in.h>
00041 #include <arpa/inet.h>
00042
00043 using namespace H;
00044 using namespace boost;
00045 using namespace std;
00046
00048
00050
00055 #define DEFAULT_BACKLOG 64
00056
00061 #define PACKET_SIZE 4096
00062
00067 #define POLL_TIMEOUT 1000
00068
00073 #define STOP_CODON "\255"
00074
00079 #define STOP_CODON_CHAR '\255'
00080
00082
00084
00088 Socket::Socket() : mThreadProcRead(this) {
00089 mpEventWatcher = NULL;
00090 init();
00091 }
00092
00096 Socket::Socket(Socket const & InitFrom) : mThreadProcRead(this) {
00097 mpEventWatcher = NULL;
00098 init();
00099 setTo(InitFrom);
00100 }
00101
00102
00106 Socket::~Socket() {
00107 shutdown();
00108 }
00109
00111
00113
00120 boost::shared_ptr<Socket> Socket::accept() {
00121
00122 if (mSocket == SOCKET_ERROR)
00123 throw SocketException("Accept Attempted on Invalid Socket!" + lexical_cast<string>(mPort), __FILE__, __FUNCTION__, __LINE__);
00124
00125
00126 struct pollfd PollFD;
00127 PollFD.fd = mSocket;
00128 PollFD.events = POLLIN | POLLOUT;
00129 PollFD.revents = 0;
00130
00131
00132 shared_ptr<Socket> pSocket = shared_ptr<Socket>(new Socket(*this));
00133
00134
00135 int ret;
00136 do {
00137 if ((ret = poll(&PollFD, 1, POLL_TIMEOUT)) < 0) {
00138
00139 cdbg1 << "Poll error: " << strerror(errno) << endl;
00140 return pSocket;
00141 }
00142 } while (mProcessing && (ret <= 0));
00143
00144
00145 if (!mProcessing)
00146 return pSocket;
00147
00148
00149 pSocket->mSocket = ::accept(mSocket, (struct sockaddr *) &pSocket->mSockAddr, &pSocket->mSockAddrLen);
00150 pSocket->setAddress();
00151 pSocket->mOldSocket = pSocket->mSocket;
00152
00153 return pSocket;
00154 }
00155
00159 void Socket::addToMessageBuffer(char * Data, int BufLen) {
00160 if (!mMessageMode)
00161 return;
00162
00163
00164 int Index = -1;
00165 for (int lp = 0; lp < BufLen; lp ++) {
00166 if (Data[lp] == STOP_CODON_CHAR) {
00167 Index = lp;
00168 break;
00169 }
00170 }
00171 if (Index > -1) {
00172
00173 string Message;
00174 if (mMessageBuffer.length())
00175 Message += mMessageBuffer.getBuffer();
00176 Message += string(Data, Index);
00177
00178
00179 if (mpEventWatcher)
00180 mpEventWatcher->onSocketMessage(*this, Message);
00181
00182
00183 mMessageBuffer.clear();
00184 if (BufLen - Index > 1)
00185 addToMessageBuffer(Data + Index + 1, BufLen - Index - 1);
00186 } else {
00187
00188 mMessageBuffer.addToBuffer(Data, BufLen);
00189 }
00190 }
00191
00196 void Socket::bind(int Port) {
00197 mPort = Port;
00198 memset(&mSockAddr, 0, sizeof(mSockAddr));
00199 mSockAddr.sin_family = mType;
00200 mSockAddr.sin_port = htons(mPort);
00201 mSockAddr.sin_addr.s_addr = INADDR_ANY;
00202
00203
00204 if (::bind(mSocket, (struct sockaddr *) &mSockAddr, sizeof(mSockAddr)))
00205 throw SocketException("Failed to Bind to Port [" + lexical_cast<string>(mPort) + "]", __FILE__, __FUNCTION__, __LINE__);
00206 }
00207
00211 void Socket::closeSocket() {
00212 if (mSocket != SOCKET_ERROR) {
00213 #ifndef WIN32
00214 if (::close(mSocket) == SOCKET_ERROR)
00215 #else
00216 if (::closesocket(mSocket) == SOCKET_ERROR)
00217 #endif
00218 throw SocketException("Failed to Close Socket [" + lexical_cast<string>(mSocket) + "]", __FILE__, __FUNCTION__, __LINE__);
00219 }
00220
00221 #ifdef HAVE_OPENSSL
00222 if (mSSL) {
00223 SSL_free(mSSL);
00224 mSSL = NULL;
00225 }
00226 #endif
00227
00228 mOldSocket = mSocket;
00229 init();
00230 }
00231
00237 void Socket::connect(std::string Host, int Port) {
00238
00239 if (mSocket == SOCKET_ERROR)
00240 throw SocketException("Connect Attempted on Invalid Socket!", __FILE__, __FUNCTION__, __LINE__);
00241
00242
00243 struct hostent * hp = gethostbyname(Host.c_str());
00244 if (!hp)
00245 throw SocketException("Connect Failed to Resolve Host [" + Host + "]", __FILE__, __FUNCTION__, __LINE__);
00246
00247
00248 mPort = Port;
00249 #ifndef WIN32
00250 struct in_addr address;
00251 memcpy(&address, *(hp->h_addr_list), sizeof(struct in_addr));
00252 mSockAddr.sin_addr = address;
00253 #else
00254 memset(&mSockAddr, 0, sizeof(mSockAddr));
00255 mSockAddr.sin_addr.s_addr = ((struct in_addr *)(hp->h_addr))->s_addr;
00256 #endif
00257 mSockAddr.sin_family = AF_INET;
00258 mSockAddr.sin_port = htons(mPort);
00259
00260 cdbg4 << "Initiating connection to [" << Host << ":" << mPort << "]" << endl;
00261
00262 #ifdef WIN32
00263 return tryConnectingWindows();
00264 #endif
00265
00266 if (::connect(mSocket, (struct sockaddr *) &mSockAddr, sizeof(mSockAddr)) == -1) {
00267 closeSocket();
00268 throw SocketException("Connect Attempted to [" + Host + ":" + lexical_cast<string>(Port) + "] Failed -- " + strerror(errno), __FILE__, __FUNCTION__, __LINE__);
00269 }
00270 setAddress();
00271
00272
00273 if (mpEventWatcher)
00274 mpEventWatcher->onSocketConnect(*this);
00275 }
00276
00280 void Socket::createSocket(SocketDomain Domain, SocketType Type) {
00281 mDomain = Domain;
00282 mType = Type;
00283
00284 #ifdef HAVE_OPENSSL
00285 if ( (mSSLMode) && (!initializeCTX()) )
00286 throw SocketException("Failed to Initialize OpenSSL", __FILE__, __FUNCTION__, __LINE__);
00287 #endif
00288
00289 if ((mSocket = socket(mDomain, mType, mProtocol)) == -1)
00290 throw SocketException(string("Failed to Create Socket -- ") + strerror(errno), __FILE__, __FUNCTION__, __LINE__);
00291
00292 mOldSocket = mSocket;
00293 }
00294
00299 std::string Socket::getAddress() const {
00300 return mAddress;
00301 }
00302
00307 int Socket::getOldSocket() const {
00308 return mOldSocket;
00309 }
00310
00315 int Socket::getSocket() const {
00316 return mSocket;
00317 }
00318
00322 void Socket::handleSocketDisconnect() {
00323 closeSocket();
00324 if (mpEventWatcher)
00325 mpEventWatcher->onSocketDisconnect(*this);
00326 }
00327
00331 void Socket::handleSocketRead(DynamicBuffer<char> & ReadBuffer) {
00332 if (mpEventWatcher)
00333 mpEventWatcher->onSocketRead(*this, ReadBuffer);
00334 }
00335
00339 void Socket::init() {
00340 mBacklog = DEFAULT_BACKLOG;
00341 mDomain = SOCKET_INTERNET;
00342 mMessageMode = false;
00343 mPort = 0;
00344 mProcessing = false;
00345 mProtocol = SOCKET_PROTO_TCP;
00346 mSockAddrLen = sizeof(struct sockaddr);
00347 mSocket = SOCKET_ERROR;
00348 mType = SOCKET_STREAM;
00349 }
00350
00355 bool Socket::isSocketValid() const {
00356 return (mSocket != SOCKET_ERROR);
00357 }
00358
00362 void Socket::listen() {
00363 if (::listen(mSocket, mBacklog) == -1)
00364 throw SocketException(string("Failed to Listen on Socket -- ") + strerror(errno), __FILE__, __FUNCTION__, __LINE__);
00365 }
00366
00370 void Socket::processEvents() {
00371
00372 thread thrd(mThreadProcRead);
00373 }
00374
00380 int Socket::read(char * Buffer, int BufLen) {
00381
00382 int ret;
00383 #ifdef HAVE_OPENSSL
00384 if (m_SSLMode) {
00385 SSL_set_bio(mSSL, mSSLbio, mSSLbio);
00386 ret = SSL_read(mSSL, Buffer, BufLen);
00387 } else
00388 ret = ::recv(mSocket, Buffer, sizeof(char) * BufLen, 0);
00389 #else
00390 ret = ::recv(mSocket, Buffer, sizeof(char) * BufLen, 0);
00391 #endif
00392
00393
00394 if (ret == 0)
00395 handleSocketDisconnect();
00396 return ret;
00397 }
00398
00404 int Socket::readIntoBuffer(DynamicBuffer<char> & Buffer) {
00405 char Packet[PACKET_SIZE];
00406 int TotalBytesRead = 0;
00407 int BytesRead;
00408 do {
00409 if ((BytesRead = read(Packet, PACKET_SIZE)) == -1) {
00410 #ifndef WIN32
00411 switch (errno) {
00412 case EINPROGRESS:
00413 case EALREADY:
00414 case EAGAIN:
00415 #else
00416 switch (WSAGetLastError()) {
00417 case WSAEINPROGRESS:
00418 case WSAEALREADY:
00419 case WSAEWOULDBLOCK:
00420 #endif
00421
00422 return TotalBytesRead;
00423 default:
00424 cdbg << "Socket Read Error -- " << strerror(errno) << endl;
00425 handleSocketDisconnect();
00426 return TotalBytesRead;
00427 }
00428 }
00429
00430 if (BytesRead > 0) {
00431 Buffer.addToBuffer(Packet, BytesRead);
00432 addToMessageBuffer(Packet, BytesRead);
00433 TotalBytesRead += BytesRead;
00434 }
00435 } while ( (BytesRead == PACKET_SIZE) && (BytesRead > 0) );
00436
00437 return TotalBytesRead;
00438 }
00439
00443 void Socket::setAddress() {
00444 mAddress = inet_ntoa(mSockAddr.sin_addr);
00445 }
00446
00451 void Socket::setEventWatcher(SocketEventWatcher * pWatcher) {
00452 mpEventWatcher = pWatcher;
00453 }
00454
00459 void Socket::setMessageMode(bool Enabled) {
00460 mMessageMode = Enabled;
00461 }
00462
00467 void Socket::setTo(Socket const & SocketToBecome) {
00468 mProtocol = SocketToBecome.mProtocol;
00469 mDomain = SocketToBecome.mDomain;
00470 mType = SocketToBecome.mType;
00471 mPort = SocketToBecome.mPort;
00472 mBacklog = SocketToBecome.mBacklog;
00473
00474 #ifdef HAVE_OPENSSL
00475 mSSL = SocketToBecome.mSSL;
00476 mSSLMode = SocketToBecome.mSSLMode;
00477 #endif
00478 }
00479
00483 void Socket::shutdown() {
00484 mProcessing = false;
00485 }
00486
00490 void Socket::threadProcRead() {
00491
00492 struct pollfd PollFD;
00493 PollFD.fd = mSocket;
00494 PollFD.events = POLLIN | POLLOUT;
00495 PollFD.revents = 0;
00496
00497
00498 mProcessing = true;
00499 while ( mProcessing && isSocketValid() ) {
00500
00501 int ret;
00502 do {
00503 if ((ret = poll(&PollFD, 1, POLL_TIMEOUT)) < 0) {
00504
00505 handleSocketDisconnect();
00506 return;
00507 }
00508 } while (mProcessing && (ret <= 0));
00509
00510 DynamicBuffer<char> ReadBuffer;
00511 if (readIntoBuffer(ReadBuffer) > 0)
00512 handleSocketRead(ReadBuffer);
00513 }
00514 }
00515
00522 int Socket::write(const char * Buffer, int BufLen) {
00523 #ifdef HAVE_OPENSSL
00524 if (mSSLMode) {
00525 SSL_set_bio(mSSL, mSSLbio, mSSLbio);
00526 return SSL_write(mSSL, Buffer, BufLen);
00527 } else
00528 #ifndef WIN32
00529 return ::write(mSocket, Buffer, sizeof(char) * BufLen);
00530 #else
00531 return ::send(mSocket, Buffer, sizeof(char) * BufLen, 0);
00532 #endif
00533 #else
00534 #ifndef WIN32
00535 return ::write(mSocket, Buffer, sizeof(char) * BufLen);
00536 #else
00537 return ::send(mSocket, Buffer, sizeof(char) * BufLen, 0);
00538 #endif
00539 #endif
00540 }
00541
00552 void Socket::writeMessage(std::string const & Message, bool FormatMessage) {
00553
00554 string OutMessage = Message;
00555 if (FormatMessage)
00556 OutMessage += STOP_CODON;
00557
00558 size_t CurPos = 0;
00559 int BytesWritten;
00560 do {
00561 if ((BytesWritten = write(OutMessage.c_str() + CurPos, OutMessage.length() - CurPos)) == -1)
00562 throw SocketException(string("Failed to Write Message to Socket -- ") + strerror(errno), __FILE__, __FUNCTION__, __LINE__);
00563 CurPos += BytesWritten;
00564 } while (CurPos < OutMessage.length());
00565 }