login

local ffi = require "ffi"
local bit = require "bit"

ffi.cdef[[
int fcntl(int fd, int cmd, ... /* arg */ );
typedef int socklen_t;
struct sockaddr {
    unsigned short   sa_family;
    char             sa_data[14];
};
struct addrinfo {
    int              ai_flags;
    int              ai_family;
    int              ai_socktype;
    int              ai_protocol;
    socklen_t        ai_addrlen;
    struct sockaddr *ai_addr;
    char            *ai_canonname;
    struct addrinfo *ai_next;
};
typedef union epoll_data {
    void    *ptr;
    int      fd;
    uint32_t u32;
    uint64_t u64;
} epoll_data_t;
struct epoll_event {
    uint32_t     events;    /* Epoll events */
    epoll_data_t data;      /* User data variable */
};
int getaddrinfo(
    const char *restrict node,
    const char *restrict service,
    const struct addrinfo *restrict hints,
    struct addrinfo **restrict res
);
void freeaddrinfo(struct addrinfo *res);
const char *gai_strerror(int errcode);
int close(int fildes);
int setsockopt(int socket, int level, int option_name, const void *option_value, socklen_t option_len);
int socket(int domain, int type, int protocol);
int bind(int socket, const struct sockaddr *address, socklen_t address_len);
int listen(int socket, int backlog);
int accept(int socket, struct sockaddr *restrict address, socklen_t *restrict address_len);
ssize_t recv(int socket, void *buffer, size_t length, int flags);
ssize_t send(int socket, const void *buffer, size_t length, int flags);
int shutdown(int socket, int how);
int epoll_create(int size);
int epoll_ctl(int epfd, int op, int fd, struct epoll_event *event);
int epoll_wait(int epfd, struct epoll_event *events, int maxevents, int timeout);
enum {F_GETFL = 3, F_SETFL = 4};
enum {O_NONBLOCK = 2048};
enum {AF_UNSPEC = 0};
enum {SOCK_STREAM = 1};
enum {AI_PASSIVE = 1};
enum {EPOLLIN = 1, EPOLLET = -2147483648};
enum {EPOLL_CTL_ADD = 1, EPOLL_CTL_DEL = 2};
enum {SHUT_RDWR = 2};
enum {SOL_SOCKET = 1};
enum {SO_REUSEADDR = 2};
]]
local C = ffi.C

local function set_non_blocking(sockfd)
    local flags = C.fcntl(sockfd, C.F_GETFL, 0)
    flags = bit.bor(flags, C.O_NONBLOCK)
    C.fcntl(sockfd, C.F_SETFL, flags)
end

local TCP = {}
TCP.__index = TCP

function TCP:init(port)
    local hints = ffi.new("struct addrinfo[1]")
    hints[0].ai_family = C.AF_UNSPEC
    hints[0].ai_socktype = C.SOCK_STREAM
    hints[0].ai_flags = C.AI_PASSIVE

    local servinfo = ffi.new("struct addrinfo[1][1]")
    servinfo = ffi.cast("struct addrinfo **", servinfo)

    local err = C.getaddrinfo(nil, tostring(port), hints, servinfo)
    assert(err == 0, ("getaddrinfo error: %s"):format(ffi.string(C.gai_strerror(err))))
    addr = servinfo[0][0]
    local val = ffi.new("int [1]", 1) -- for setsockopt()
    local sockfd = -1
    while addr ~= nil do
        sockfd = C.socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol)
        if sockfd >= 0 then
            set_non_blocking(sockfd)
            C.setsockopt(sockfd, C.SOL_SOCKET, C.SO_REUSEADDR, val, ffi.sizeof(val))
            err = C.bind(sockfd, addr.ai_addr, addr.ai_addrlen)
            if err >= 0 then
                break
            end
            C.close(sockfd)
        end
        addr = addr.ai_next
    end
    assert(addr ~= nil, "bind error")
    C.freeaddrinfo(servinfo[0])
    assert(C.listen(sockfd, self.backlog) >= 0)
    self.sockfd = sockfd
end

function TCP:run()
    local efd = C.epoll_create(self.max_poll_size)
    local ev = ffi.new("struct epoll_event[1]")
    ev[0].events = bit.bor(C.EPOLLIN, C.EPOLLET)
    ev[0].data.fd = self.sockfd
    assert(C.epoll_ctl(efd, C.EPOLL_CTL_ADD, self.sockfd, ev) >= 0, "epoll_ctl error")

    local evs = ffi.new("struct epoll_event["..self.max_poll_size.."]")
    local curfds = 1

    local client_addr = ffi.new("struct sockaddr[1]")
    local addr_size = ffi.new("socklen_t[1]", ffi.sizeof(client_addr[0]))
    local buflen = 4096
    local buffer = ffi.new("char ["..buflen.."]")
    local running = true
    while running do
        nfds = C.epoll_wait(efd, evs, curfds, -1)
        assert(nfds >= 0, "epoll_wait error")
        for n = 0, nfds-1 do
            if evs[n].data.fd == self.sockfd then
                local newfd = C.accept(self.sockfd, client_addr, addr_size)
                if newfd < 0 then
                    break
                end
                set_non_blocking(newfd)
                ev[0].events = bit.bor(C.EPOLLIN, C.EPOLLET)
                ev[0].data.fd = newfd
                assert(C.epoll_ctl(efd, C.EPOLL_CTL_ADD, newfd, ev) >= 0, "epoll_ctl error")
                curfds = curfds + 1
            else
                C.recv(evs[n].data.fd, buffer, buflen, 0)
                local data = self:process(ffi.string(buffer))
                if data == nil then
                    running = false
                else
                    C.send(evs[n].data.fd, data, #data, 0)
                end
                C.epoll_ctl(efd, C.EPOLL_CTL_DEL, evs[n].data.fd, ev)
                C.shutdown(evs[n].data.fd, C.SHUT_RDWR)
                curfds = curfds - 1
                C.close(evs[n].data.fd)
            end
        end
    end
    C.shutdown(self.sockfd, C.SHUT_RDWR)
    C.close(self.sockfd)
end

local function new_tcp(max_poll_size, backlog)
    local self = setmetatable({}, TCP)
    self.max_poll_size = max_poll_size
    self.backlog = backlog
    return self
end

return {new_tcp=new_tcp}