#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdarg.h>
#include <errno.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#ifndef INADDR_NONE
#define INADDR_NONE 0xffffffff
#endif
#include "common.h"
/*
s - socketfd
sec - timeout seconds
usec - timeout microseconds
x - select status(read write error)
*/
int net_select(int s, int sec, int usec, short x)
{
int st = errno;
struct timeval to;
fd_set fs;
to.tv_sec = sec;
to.tv_usec = usec;
FD_ZERO(&fs);
FD_SET(s, &fs);
switch(x){
case READ_STATUS:
st = select(s+1, &fs, 0, 0, &to);
break;
case WRITE_STATUS:
st = select(s+1, 0, &fs, 0, &to);
break;
case EXCPT_STATUS:
st = select(s+1, 0, 0, &fs, &to);
break;
default:
break;
}
return(st);
}
int tcp_connect(const char *host, const unsigned short port)
{
unsigned long non_blocking = 1;
unsigned long blocking = 0;
int ret = 0;
char * transport = "tcp";
struct hostent *phe; /* pointer to host information entry */
struct protoent *ppe; /* pointer to protocol information entry*/
struct sockaddr_in sin; /* an Internet endpoint address */
int s; /* socket descriptor and socket type */
int error;
memset(&sin, 0, sizeof(sin));
sin.sin_family = AF_INET;
if ((sin.sin_port = htons(port)) == 0)
{
DEBUG("invalid port \"%d\"\n", port);
exit(1);
}
/* Map host name to IP address, allowing for dotted decimal */
if ( phe = gethostbyname(host) )
memcpy(&sin.sin_addr, phe->h_addr, phe->h_length);
else if ( (sin.sin_addr.s_addr = inet_addr(host)) == INADDR_NONE )
{
DEBUG("can't get \"%s\" host entry\n", host);
exit(1);
}
/* Map transport protocol name to protocol number */
if ( (ppe = getprotobyname(transport)) == 0)
{
DEBUG("can't get \"%s\" protocol entry\n", transport);
exit(1);
}
/* Allocate a socket */
s = socket(PF_INET, SOCK_STREAM, ppe->p_proto);
if (s < 0)
{
DEBUG("can't create socket: %s\n", strerror(errno));
exit(1);
}
/* Connect the socket with timeout */
ioctl(s,FIONBIO,&non_blocking);
//fcntl(s,F_SETFL, O_NONBLOCK);
if (connect(s, (struct sockaddr *)&sin, sizeof(sin)) == -1)
{
struct timeval tv;
fd_set writefds;
fd_set readfds;
// 设置连接超时时间
if( EINPROGRESS)
goto error_ret;
tv.tv_sec = 10; // 秒数
tv.tv_usec = 0; // 毫秒
FD_ZERO(&writefds);
FD_SET(s, &writefds);
readfds = writefds;
if(select(s+1,&readfds,&writefds,NULL,&tv) != 0)
{
if(FD_ISSET(s,&writefds)||FD_ISSET(s,&readfds))
{
int len=sizeof(error);
//下面的一句一定要,主要针对防火墙
if(getsockopt(s, SOL_SOCKET, SO_ERROR, (char *)&error, &len) < 0)
goto error_ret;
if(error != 0)
goto error_ret;
}
else
goto error_ret; //timeout or error happen
}
else
goto error_ret;
goto ok_ret;
}
else
goto ok_ret;
error_ret:
close(s);
DEBUG("can't connect to %s:%d\n", host, port);
exit(1);
ok_ret:
DEBUG("%s", "nonblock connect over\n");
ioctl(s,FIONBIO,&blocking);
return s;
}
|