local ffi = require "ffi"
local bit = require "bit"
ffi.cdef[[
int fcntl(int fd, int cmd, ... /* arg */ );
typedef int socklen_t;
struct sockaddr {
unsigned short sa_family;
char sa_data[14];
};
struct addrinfo {
int ai_flags;
int ai_family;
int ai_socktype;
int ai_protocol;
socklen_t ai_addrlen;
struct sockaddr *ai_addr;
char *ai_canonname;
struct addrinfo *ai_next;
};
typedef union epoll_data {
void *ptr;
int fd;
uint32_t u32;
uint64_t u64;
} epoll_data_t;
struct epoll_event {
uint32_t events; /* Epoll events */
epoll_data_t data; /* User data variable */
};
int getaddrinfo(
const char *restrict node,
const char *restrict service,
const struct addrinfo *restrict hints,
struct addrinfo **restrict res
);
void freeaddrinfo(struct addrinfo *res);
const char *gai_strerror(int errcode);
int close(int fildes);
int setsockopt(int socket, int level, int option_name, const void *option_value, socklen_t option_len);
int socket(int domain, int type, int protocol);
int bind(int socket, const struct sockaddr *address, socklen_t address_len);
int listen(int socket, int backlog);
int accept(int socket, struct sockaddr *restrict address, socklen_t *restrict address_len);
ssize_t recv(int socket, void *buffer, size_t length, int flags);
ssize_t send(int socket, const void *buffer, size_t length, int flags);
int shutdown(int socket, int how);
int epoll_create(int size);
int epoll_ctl(int epfd, int op, int fd, struct epoll_event *event);
int epoll_wait(int epfd, struct epoll_event *events, int maxevents, int timeout);
enum {F_GETFL = 3, F_SETFL = 4};
enum {O_NONBLOCK = 2048};
enum {AF_UNSPEC = 0};
enum {SOCK_STREAM = 1};
enum {AI_PASSIVE = 1};
enum {EPOLLIN = 1, EPOLLET = -2147483648};
enum {EPOLL_CTL_ADD = 1, EPOLL_CTL_DEL = 2};
enum {SHUT_RDWR = 2};
enum {SOL_SOCKET = 1};
enum {SO_REUSEADDR = 2};
]]
local C = ffi.C
local function set_non_blocking(sockfd)
local flags = C.fcntl(sockfd, C.F_GETFL, 0)
flags = bit.bor(flags, C.O_NONBLOCK)
C.fcntl(sockfd, C.F_SETFL, flags)
end
local TCP = {}
TCP.__index = TCP
function TCP:init(port)
local hints = ffi.new("struct addrinfo[1]")
hints[0].ai_family = C.AF_UNSPEC
hints[0].ai_socktype = C.SOCK_STREAM
hints[0].ai_flags = C.AI_PASSIVE
local servinfo = ffi.new("struct addrinfo[1][1]")
servinfo = ffi.cast("struct addrinfo **", servinfo)
local err = C.getaddrinfo(nil, tostring(port), hints, servinfo)
assert(err == 0, ("getaddrinfo error: %s"):format(ffi.string(C.gai_strerror(err))))
addr = servinfo[0][0]
local val = ffi.new("int [1]", 1) -- for setsockopt()
local sockfd = -1
while addr ~= nil do
sockfd = C.socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol)
if sockfd >= 0 then
set_non_blocking(sockfd)
C.setsockopt(sockfd, C.SOL_SOCKET, C.SO_REUSEADDR, val, ffi.sizeof(val))
err = C.bind(sockfd, addr.ai_addr, addr.ai_addrlen)
if err >= 0 then
break
end
C.close(sockfd)
end
addr = addr.ai_next
end
assert(addr ~= nil, "bind error")
C.freeaddrinfo(servinfo[0])
assert(C.listen(sockfd, self.backlog) >= 0)
self.sockfd = sockfd
end
function TCP:run()
local efd = C.epoll_create(self.max_poll_size)
local ev = ffi.new("struct epoll_event[1]")
ev[0].events = bit.bor(C.EPOLLIN, C.EPOLLET)
ev[0].data.fd = self.sockfd
assert(C.epoll_ctl(efd, C.EPOLL_CTL_ADD, self.sockfd, ev) >= 0, "epoll_ctl error")
local evs = ffi.new("struct epoll_event["..self.max_poll_size.."]")
local curfds = 1
local client_addr = ffi.new("struct sockaddr[1]")
local addr_size = ffi.new("socklen_t[1]", ffi.sizeof(client_addr[0]))
local buflen = 4096
local buffer = ffi.new("char ["..buflen.."]")
local running = true
while running do
nfds = C.epoll_wait(efd, evs, curfds, -1)
assert(nfds >= 0, "epoll_wait error")
for n = 0, nfds-1 do
if evs[n].data.fd == self.sockfd then
local newfd = C.accept(self.sockfd, client_addr, addr_size)
if newfd < 0 then
break
end
set_non_blocking(newfd)
ev[0].events = bit.bor(C.EPOLLIN, C.EPOLLET)
ev[0].data.fd = newfd
assert(C.epoll_ctl(efd, C.EPOLL_CTL_ADD, newfd, ev) >= 0, "epoll_ctl error")
curfds = curfds + 1
else
C.recv(evs[n].data.fd, buffer, buflen, 0)
local data = self:process(ffi.string(buffer))
if data == nil then
running = false
else
C.send(evs[n].data.fd, data, #data, 0)
end
C.epoll_ctl(efd, C.EPOLL_CTL_DEL, evs[n].data.fd, ev)
C.shutdown(evs[n].data.fd, C.SHUT_RDWR)
curfds = curfds - 1
C.close(evs[n].data.fd)
end
end
end
C.shutdown(self.sockfd, C.SHUT_RDWR)
C.close(self.sockfd)
end
local function new_tcp(max_poll_size, backlog)
local self = setmetatable({}, TCP)
self.max_poll_size = max_poll_size
self.backlog = backlog
return self
end
return {new_tcp=new_tcp}