libH/Socket.cpp

Go to the documentation of this file.
00001 
00012 /*
00013   
00014   Copyright (c) 2007, Tim Burrell
00015   Licensed under the Apache License, Version 2.0 (the "License");
00016   you may not use this file except in compliance with the License.
00017   You may obtain a copy of the License at 
00018 
00019         http://www.apache.org/licenses/LICENSE-2.0
00020 
00021   Unless required by applicable law or agreed to in writing, software
00022   distributed under the License is distributed on an "AS IS" BASIS,
00023   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00024   See the License for the specific language governing permissions and 
00025   limitations under the License. 
00026   
00027 */
00028 
00029 #include "Socket.hpp"
00030 #include "SocketException.hpp"
00031 #include "Debug.hpp"
00032 #include "UtilTime.hpp"
00033 #include <boost/lexical_cast.hpp>
00034 #include <boost/thread/thread.hpp>
00035 #include <errno.h>
00036 #include <signal.h>
00037 #include <poll.h>
00038 #include <netdb.h>
00039 #include <sys/socket.h>
00040 #include <netinet/in.h>
00041 #include <arpa/inet.h>
00042 
00043 using namespace H;
00044 using namespace boost;
00045 using namespace std;
00046 
00048 // Defines / Type Defs
00050 
00055 #define DEFAULT_BACKLOG 64
00056 
00061 #define PACKET_SIZE     4096
00062 
00067 #define POLL_TIMEOUT    1000
00068 
00073 #define STOP_CODON      "\255"
00074 
00079 #define STOP_CODON_CHAR '\255'
00080 
00082 // Construction / Deconstruction
00084 
00088 Socket::Socket() : mThreadProcRead(this) {
00089         mpEventWatcher = NULL;
00090         init();
00091 }
00092 
00096 Socket::Socket(Socket const & InitFrom) : mThreadProcRead(this) {
00097         mpEventWatcher = NULL;
00098         init();
00099         setTo(InitFrom);
00100 }
00101 
00102 
00106 Socket::~Socket() {
00107         shutdown();
00108 }
00109 
00111 // Class Body
00113 
00120 boost::shared_ptr<Socket> Socket::accept() {
00121         // error checking
00122         if (mSocket == SOCKET_ERROR)
00123                 throw SocketException("Accept Attempted on Invalid Socket!" + lexical_cast<string>(mPort), __FILE__, __FUNCTION__, __LINE__);
00124         
00125         // set up the poll structure
00126         struct pollfd PollFD;
00127         PollFD.fd = mSocket;
00128         PollFD.events = POLLIN | POLLOUT;
00129         PollFD.revents = 0;
00130         
00131         // create the new socket
00132         shared_ptr<Socket> pSocket = shared_ptr<Socket>(new Socket(*this));
00133         
00134         // wait until there's a connection on the socket
00135         int ret;
00136         do {
00137                 if ((ret = poll(&PollFD, 1, POLL_TIMEOUT)) < 0) {
00138                         // error
00139                         cdbg1 << "Poll error: " << strerror(errno) << endl;
00140                         return pSocket;
00141                 }
00142         } while (mProcessing && (ret <= 0));
00143         
00144         // don't try to accept if we're quitting
00145         if (!mProcessing)
00146                 return pSocket;
00147         
00148         // accept the connection
00149         pSocket->mSocket = ::accept(mSocket, (struct sockaddr *) &pSocket->mSockAddr, &pSocket->mSockAddrLen);
00150         pSocket->setAddress();
00151         pSocket->mOldSocket = pSocket->mSocket;
00152 
00153         return pSocket;
00154 }
00155 
00159 void Socket::addToMessageBuffer(char * Data, int BufLen) {
00160         if (!mMessageMode)
00161                 return;
00162         
00163         // search for the stop codon
00164         int Index = -1;
00165         for (int lp = 0; lp < BufLen; lp ++) {
00166                 if (Data[lp] == STOP_CODON_CHAR) {
00167                         Index = lp;
00168                         break;
00169                 }
00170         }
00171         if (Index > -1) {
00172                 // message found!
00173                 string Message;
00174                 if (mMessageBuffer.length())
00175                         Message += mMessageBuffer.getBuffer(); 
00176                 Message += string(Data, Index);
00177         
00178                 // fire the event
00179                 if (mpEventWatcher)
00180                         mpEventWatcher->onSocketMessage(*this, Message);
00181                 
00182                 // Check if there's another event 
00183                 mMessageBuffer.clear();
00184                 if (BufLen - Index > 1)
00185                         addToMessageBuffer(Data + Index + 1, BufLen - Index - 1);
00186         } else {
00187                 // message not found
00188                 mMessageBuffer.addToBuffer(Data, BufLen);
00189         }
00190 }
00191 
00196 void Socket::bind(int Port) {
00197         mPort = Port;
00198         memset(&mSockAddr, 0, sizeof(mSockAddr));
00199         mSockAddr.sin_family = mType;
00200         mSockAddr.sin_port = htons(mPort);
00201         mSockAddr.sin_addr.s_addr = INADDR_ANY;
00202         
00203         /* bind the socket to the newly formed address**/
00204         if (::bind(mSocket, (struct sockaddr *) &mSockAddr, sizeof(mSockAddr)))
00205                 throw SocketException("Failed to Bind to Port [" + lexical_cast<string>(mPort) + "]", __FILE__, __FUNCTION__, __LINE__);
00206 }
00207 
00211 void Socket::closeSocket() {
00212         if (mSocket != SOCKET_ERROR) {
00213                 #ifndef WIN32
00214                 if (::close(mSocket) == SOCKET_ERROR)
00215                 #else
00216                 if (::closesocket(mSocket) == SOCKET_ERROR)
00217                 #endif
00218                         throw SocketException("Failed to Close Socket [" + lexical_cast<string>(mSocket) + "]", __FILE__, __FUNCTION__, __LINE__);
00219         }
00220         
00221         #ifdef HAVE_OPENSSL
00222         if (mSSL) {
00223                 SSL_free(mSSL);
00224                 mSSL = NULL;
00225         }
00226         #endif
00227                 
00228         mOldSocket = mSocket;
00229         init();
00230 }
00231 
00237 void Socket::connect(std::string Host, int Port) {
00238         // error checking
00239         if (mSocket == SOCKET_ERROR)
00240                 throw SocketException("Connect Attempted on Invalid Socket!", __FILE__, __FUNCTION__, __LINE__);
00241         
00242         // get hostname
00243         struct hostent * hp = gethostbyname(Host.c_str());
00244         if (!hp)
00245                 throw SocketException("Connect Failed to Resolve Host [" + Host + "]", __FILE__, __FUNCTION__, __LINE__);
00246 
00247         // Set up the data structures
00248         mPort = Port;
00249 #ifndef WIN32
00250         struct in_addr address;
00251         memcpy(&address, *(hp->h_addr_list), sizeof(struct in_addr));
00252         mSockAddr.sin_addr = address;
00253 #else
00254         memset(&mSockAddr, 0, sizeof(mSockAddr));
00255         mSockAddr.sin_addr.s_addr = ((struct in_addr *)(hp->h_addr))->s_addr;
00256 #endif
00257         mSockAddr.sin_family = AF_INET;
00258         mSockAddr.sin_port = htons(mPort);
00259         
00260         cdbg4 << "Initiating connection to [" << Host << ":" << mPort << "]" << endl;
00261 
00262 #ifdef WIN32
00263         return tryConnectingWindows();
00264 #endif
00265                 
00266         if (::connect(mSocket, (struct sockaddr *) &mSockAddr, sizeof(mSockAddr)) == -1) {
00267                 closeSocket();
00268                 throw SocketException("Connect Attempted to [" + Host + ":" + lexical_cast<string>(Port) + "] Failed -- " + strerror(errno), __FILE__, __FUNCTION__, __LINE__);
00269         }
00270         setAddress();
00271                 
00272         // connection!
00273         if (mpEventWatcher)
00274                 mpEventWatcher->onSocketConnect(*this);
00275 }
00276 
00280 void Socket::createSocket(SocketDomain Domain, SocketType Type) {
00281         mDomain = Domain;
00282         mType = Type;
00283         
00284         #ifdef HAVE_OPENSSL
00285         if ( (mSSLMode) && (!initializeCTX()) )
00286                 throw SocketException("Failed to Initialize OpenSSL", __FILE__, __FUNCTION__, __LINE__);
00287         #endif
00288 
00289         if ((mSocket = socket(mDomain, mType, mProtocol)) == -1)
00290                 throw SocketException(string("Failed to Create Socket -- ") + strerror(errno), __FILE__, __FUNCTION__, __LINE__);
00291 
00292         mOldSocket = mSocket;
00293 }
00294 
00299 std::string Socket::getAddress() const {
00300         return mAddress;
00301 }
00302 
00307 int Socket::getOldSocket() const {
00308         return mOldSocket;
00309 }
00310 
00315 int Socket::getSocket() const {
00316         return mSocket;
00317 }
00318 
00322 void Socket::handleSocketDisconnect() {
00323         closeSocket();
00324         if (mpEventWatcher)
00325                 mpEventWatcher->onSocketDisconnect(*this);
00326 }
00327 
00331 void Socket::handleSocketRead(DynamicBuffer<char> & ReadBuffer) {
00332         if (mpEventWatcher)
00333                 mpEventWatcher->onSocketRead(*this, ReadBuffer);
00334 }
00335 
00339 void Socket::init() {
00340         mBacklog = DEFAULT_BACKLOG;
00341         mDomain = SOCKET_INTERNET;
00342         mMessageMode = false;
00343         mPort = 0;
00344         mProcessing = false;
00345         mProtocol = SOCKET_PROTO_TCP;
00346         mSockAddrLen = sizeof(struct sockaddr);
00347         mSocket = SOCKET_ERROR;
00348         mType = SOCKET_STREAM;
00349 }
00350 
00355 bool Socket::isSocketValid() const {
00356         return (mSocket != SOCKET_ERROR);
00357 }
00358 
00362 void Socket::listen() {
00363         if (::listen(mSocket, mBacklog) == -1)
00364                 throw SocketException(string("Failed to Listen on Socket -- ") + strerror(errno), __FILE__, __FUNCTION__, __LINE__);
00365 }
00366 
00370 void Socket::processEvents() {
00371         // initialize the read thread
00372         thread thrd(mThreadProcRead);
00373 }
00374 
00380 int Socket::read(char * Buffer, int BufLen) {
00381         // do the receiving
00382         int ret;
00383         #ifdef HAVE_OPENSSL
00384                 if (m_SSLMode) {
00385                         SSL_set_bio(mSSL, mSSLbio, mSSLbio);
00386                         ret = SSL_read(mSSL, Buffer, BufLen);
00387                 } else
00388                         ret = ::recv(mSocket, Buffer, sizeof(char) * BufLen, 0);
00389         #else
00390                 ret = ::recv(mSocket, Buffer, sizeof(char) * BufLen, 0);
00391         #endif
00392 
00393         // value is 0 if socket has disconnected
00394         if (ret == 0)
00395                 handleSocketDisconnect();
00396         return ret;
00397 }
00398 
00404 int Socket::readIntoBuffer(DynamicBuffer<char> & Buffer) {
00405         char Packet[PACKET_SIZE];
00406         int TotalBytesRead = 0;
00407         int BytesRead;
00408         do {
00409                 if ((BytesRead = read(Packet, PACKET_SIZE)) == -1) {
00410                         #ifndef WIN32
00411                         switch (errno) {
00412                                 case EINPROGRESS:
00413                                 case EALREADY:
00414                                 case EAGAIN:
00415                         #else
00416                         switch (WSAGetLastError()) {
00417                                 case WSAEINPROGRESS:
00418                                 case WSAEALREADY:
00419                                 case WSAEWOULDBLOCK:
00420                         #endif
00421                                 // this is just fine
00422                                 return TotalBytesRead;
00423                         default:
00424                                 cdbg << "Socket Read Error -- " << strerror(errno) << endl;
00425                                 handleSocketDisconnect();
00426                                 return TotalBytesRead;
00427                         }
00428                 }
00429                 
00430                 if (BytesRead > 0) {
00431                         Buffer.addToBuffer(Packet, BytesRead);
00432                         addToMessageBuffer(Packet, BytesRead);
00433                         TotalBytesRead += BytesRead;
00434                 }
00435         } while ( (BytesRead == PACKET_SIZE) && (BytesRead > 0) );
00436         
00437         return TotalBytesRead;
00438 }
00439 
00443 void Socket::setAddress() {
00444         mAddress = inet_ntoa(mSockAddr.sin_addr);
00445 }
00446 
00451 void Socket::setEventWatcher(SocketEventWatcher * pWatcher) {
00452         mpEventWatcher = pWatcher;
00453 }
00454 
00459 void Socket::setMessageMode(bool Enabled) {
00460         mMessageMode = Enabled;
00461 }
00462 
00467 void Socket::setTo(Socket const & SocketToBecome) {
00468         mProtocol = SocketToBecome.mProtocol;
00469         mDomain = SocketToBecome.mDomain;
00470         mType = SocketToBecome.mType;
00471         mPort = SocketToBecome.mPort;
00472         mBacklog = SocketToBecome.mBacklog;
00473         
00474         #ifdef HAVE_OPENSSL
00475         mSSL = SocketToBecome.mSSL;
00476         mSSLMode = SocketToBecome.mSSLMode;
00477         #endif
00478 }
00479 
00483 void Socket::shutdown() {
00484         mProcessing = false;
00485 }
00486 
00490 void Socket::threadProcRead() {
00491         // set up the poll structure
00492         struct pollfd PollFD;
00493         PollFD.fd = mSocket;
00494         PollFD.events = POLLIN | POLLOUT;
00495         PollFD.revents = 0;
00496         
00497         // loop
00498         mProcessing = true;
00499         while ( mProcessing && isSocketValid() ) {
00500                 // wait until there's data on the socket
00501                 int ret;
00502                 do {
00503                         if ((ret = poll(&PollFD, 1, POLL_TIMEOUT)) < 0) {
00504                                 // error
00505                                 handleSocketDisconnect();
00506                                 return;
00507                         }               
00508                 } while (mProcessing && (ret <= 0));            
00509                 
00510                 DynamicBuffer<char> ReadBuffer;
00511                 if (readIntoBuffer(ReadBuffer) > 0)
00512                         handleSocketRead(ReadBuffer);
00513         }
00514 }
00515 
00522 int Socket::write(const char * Buffer, int BufLen) {
00523         #ifdef HAVE_OPENSSL
00524         if (mSSLMode) {
00525                 SSL_set_bio(mSSL, mSSLbio, mSSLbio);
00526                 return SSL_write(mSSL, Buffer, BufLen);
00527         } else
00528                 #ifndef WIN32
00529                         return ::write(mSocket, Buffer, sizeof(char) * BufLen);
00530                 #else
00531                         return ::send(mSocket, Buffer, sizeof(char) * BufLen, 0);
00532                 #endif
00533         #else
00534                 #ifndef WIN32
00535                         return ::write(mSocket, Buffer, sizeof(char) * BufLen);
00536                 #else
00537                         return ::send(mSocket, Buffer, sizeof(char) * BufLen, 0);
00538                 #endif
00539         #endif
00540 }
00541 
00552 void Socket::writeMessage(std::string const & Message, bool FormatMessage) {
00553         // format the message
00554         string OutMessage = Message;
00555         if (FormatMessage)
00556                 OutMessage += STOP_CODON;
00557                 
00558         size_t CurPos = 0;
00559         int BytesWritten;
00560         do {
00561                 if ((BytesWritten = write(OutMessage.c_str() + CurPos, OutMessage.length() - CurPos)) == -1)
00562                         throw SocketException(string("Failed to Write Message to Socket -- ") + strerror(errno), __FILE__, __FUNCTION__, __LINE__);
00563                 CurPos += BytesWritten;
00564         } while (CurPos < OutMessage.length());
00565 }

Generated on Wed Nov 7 10:04:16 2007 for gizmod by  doxygen 1.5.3