也许是一个最简单的poll实现Socket代理

这几天翻了翻项目的代码,看到了一个非常简单的代理程序,使用poll实现,可以在代理过程中输出数据流,基本上算是教科书级别的poll使用例子了,分享一下:

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iterator>
#include <iostream>
#include <unistd.h>
#include <errno.h>
#include <sys/time.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <poll.h>

#define PROXY_PORT  8888
#define BIN_TXT     0
#define RAW_TXT     1

static const char *backendHost = 0;
static int backendPort;
static int oflag = RAW_TXT;

bool proxy(int cfd);

// g++ -o proxy proxy.cpp -I. -Wall
int main(int argc, char *argv[])
{
    if (argc == 3 || argc == 4) {
        backendHost = argv[1];  
        backendPort = atoi(argv[2]);    
        if (argc == 4) oflag = atoi(argv[3]);
    } else {
        fprintf(stderr, "usage: %s BackendHost BackendPort oflag\n", argv[0]);  
        exit(-1);
    }

    int fd = socket(AF_INET, SOCK_STREAM, 0);

    struct sockaddr_in addr;
    memset(&addr, 0x00, sizeof(addr));
    addr.sin_family = AF_INET;
    addr.sin_port = htons(PROXY_PORT);
    addr.sin_addr.s_addr = htonl(INADDR_ANY);

    int flags = 1;
    setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (void *)&flags, sizeof(flags));

    if (bind(fd, (struct sockaddr*)&addr, sizeof(addr)) != 0) {
        fprintf(stderr, "[error] bind error, %s\n", strerror(errno));
        return EXIT_FAILURE;
    }

    listen(fd, 10);

    int cfd;
    while ((cfd = accept(fd, NULL, NULL)) > 0) {
#ifdef USE_FORK
        pid_t pid = fork();
        if (pid == 0) {
#endif
            proxy(cfd);
            close(cfd);
#ifdef USE_FORK
            exit(0);
        } else if (pid == -1) {
            fprintf(stderr, "[error] fork error, %s\n", strerror(errno));   
        } else {
            close(cfd); 
        }
#endif
    }
    close(fd);

    return EXIT_SUCCESS;
}

static int connectBackend()
{
    int sfd = socket(AF_INET, SOCK_STREAM, 0);
    if (sfd == -1) {
        fprintf(stderr, "socket error, %s\n", strerror(errno)); 
        return -1;
    }

    struct sockaddr_in addr;
    memset(&addr, 0x00, sizeof(addr));
    addr.sin_family = AF_INET;
    addr.sin_port = htons(backendPort);
    inet_pton(AF_INET, backendHost, &addr.sin_addr);

    if (connect(sfd, (struct sockaddr*)&addr, sizeof(addr)) != 0) {
        fprintf(stderr, "connect backed error, %s\n", strerror(errno));
        close(sfd);
        return -1;
    }

    fprintf(stdout, "[debug %d] connect backend ok\n", getpid());
    return sfd;
}

static void txtPrint(const char *buffer, size_t size)
{
    /* stderr unbuffered */
    if (oflag == BIN_TXT) {
        for (size_t i = 0; i < size; ++i) {
            fprintf(stderr, "%02x", (unsigned char) buffer[i]);
        }
    } else {
        for (size_t i = 0; i < size; ++i) {
            if (isprint((unsigned char) buffer[i])) {
                fprintf(stderr, "%c", buffer[i]);   
            } else {
                fprintf(stderr, "%%%02x", (unsigned char) buffer[i]);
            }
        }
    }
}

bool proxy(int cfd)
{
    const size_t len = 128;
    char buf[len];
    ssize_t n;

    fcntl(cfd, F_SETFL, fcntl(cfd, F_GETFL) | O_NONBLOCK);

    int sfd = connectBackend();
    if (sfd == -1) return false;

    fcntl(sfd, F_SETFL, fcntl(sfd, F_GETFL) | O_NONBLOCK);

    struct pollfd pfd[2] = {
        {cfd, POLLIN, 0},   
        {sfd, POLLIN, 0}
    };

    int nfds = 2;

    bool stop = false;
    while (!stop) {
        int nready = poll(pfd, nfds, -1);
        if (nready == -1) {
            fprintf(stderr, "[error] poll, %s\n", strerror(errno)); 
            close(sfd);
            return false;
        }

        for (int i = 0; i < nfds; ++i) {
            if (pfd[i].revents & POLLIN) {
                while ((n = recv(pfd[i].fd, buf, len, 0)) > 0) {
                    int fd = (pfd[i].fd == cfd) ? sfd : cfd;
                    ssize_t nn;
                    size_t offset = 0;
                    while ((nn = send(fd, buf + offset, n - offset, 0)) > 0) {
                        offset += nn;
                        if (offset == (size_t) n) break;
                    }
                    struct timeval tv;
                    gettimeofday(&tv, 0);
                    if (pfd[i].fd == cfd) {
                        fprintf(stdout, "[debug] %lu:%lu client read\n", (unsigned long) tv.tv_sec, (unsigned long) tv.tv_usec);
                    } else {
                        fprintf(stdout, "[debug] %lu:%lu server read\n", (unsigned long) tv.tv_sec, (unsigned long) tv.tv_usec);
                    }
                    txtPrint(buf, n);
                }
                if (n == 0) {
                    if (pfd[i].fd == cfd) {
                        fprintf(stdout, "[debug] client closed\n");
                    } else {
                        fprintf(stdout, "[debug] server closed\n");
                    }
                    stop = true;
                } else if (n == -1 && errno != EAGAIN) {
                    fprintf(stderr, "[error] read %s error, %s\n", pfd[i].fd == cfd ? "client" : "backend",
                                                                 strerror(errno));  
                    close(sfd);
                    return false;
                }
            }
        }
    }

    fprintf(stdout, "[debug %d] disconnect\n", getpid());

    close(sfd);
    return true;
}

使用g++ -o proxy proxy.cpp -I. -Wall 命令编译,运行时直接指定需要代理的后端IP和Port即可,也支持fork以支持多条链接。