login

local bit = require "bit"
local ffi = require "ffi"

local bnot = bit.bnot
local band, bor, bxor = bit.band, bit.bor, bit.bxor
local lshift, rshift = bit.lshift, bit.rshift
local rol, ror = bit.rol, bit.ror

local char, byte = string.char, string.byte

ffi.cdef[[
ssize_t getrandom(void *buf, size_t buflen, unsigned int flags);
]]
local C = ffi.C

local SHA256_H = {
    0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a,
    0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19
}
local SHA256_K = {
    0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5,
    0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
    0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
    0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
    0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc,
    0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
    0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
    0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
    0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
    0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
    0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3,
    0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
    0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5,
    0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
    0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
    0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
}

local function hex(h)
    local s = ""
    for i = 1, #h do
        s = s .. ("%02x"):format(byte(h, i))
    end
    return s
end

local function urandom(len)
    local buf = ffi.new("char ["..len.."]")
    assert(C.getrandom(buf, len, 0) == len)
    return ffi.string(buf, len)
end

local function uuid4()
    local id = urandom(16)
    local version = char(bor(0x40, band(byte(id, 7), 0x0F)))
    local variant = char(bor(0x40, band(byte(id, 9), 0x1F)))
    id = id:sub(1, 6) .. version .. id:sub(8, 8) .. variant .. id:sub(10, 16)
    return id
end

-- returns b-bytes big endian integer encoded as string
local function be_to_str(n, b)
    local s = ""
    for i = 1, b do
        s = char(bit.band(n, 0xFF)) .. s
        n = rshift(n, 8)
    end
    return s
end

local function sha256_str_to_be32(s)
    assert(#s == 4)
    local n = 0
    for i = 1, 4 do
        n = lshift(n, 8)
        n = bor(n, byte(s, i))
    end
    return n
end

-- returns padded message
local function sha256_pad(s)
    -- all lengths in bytes
    local L = #s                         -- input length
    local n = math.ceil((L+9) / 64) * 64 -- padded length
    local k = n - (L+9)                  -- zero-padding length
    s = s .. char(128)
    s = s .. char(0):rep(k)
    s = s .. be_to_str(L*8, 8)
    assert(#s % 64 == 0)
    return s
end

local function sha256_mod(n)
    return band(n, 0xFFFFFFFF)
end

-- compute hash of a 512-bit chunk
local function sha256_chunk(s, H)
    local s0, s1
    local K = SHA256_K
    local W = {}
    for i = 0, 15 do
        local w_str = s:sub(i*4+1, (i+1)*4)
        W[i] = sha256_str_to_be32(w_str)
    end
    for i = 16, 63 do
        s0 = bxor(ror(W[i-15],  7), ror(W[i-15], 18), rshift(W[i-15],  3))
        s1 = bxor(ror(W[i- 2], 17), ror(W[i- 2], 19), rshift(W[i- 2], 10))
        W[i] = sha256_mod(W[i-16] + s0 + W[i-7] + s1)
    end
    local a, b, c, d, e, f, g, h = unpack(H)
    for i = 0, 63 do
        local ch, maj, t1, t2
        s1 = bxor(ror(e, 6), ror(e, 11), ror(e, 25))
        ch = bxor(band(e, f), band(bnot(e), g))
        t1 = sha256_mod(h + s1 + ch + K[i+1] + W[i])
        s0 = bxor(ror(a, 2), ror(a, 13), ror(a, 22))
        maj = bxor(band(a, b), band(a, c), band(b, c))
        t2 = sha256_mod(s0 + maj)
        a, b, c, d, e, f, g, h = sha256_mod(t1+t2), a, b, c, sha256_mod(d+t1), e, f, g
    end
    local S = {a, b, c, d, e, f, g, h}
    for i = 1, 8 do
        H[i] = sha256_mod(H[i] + S[i])
    end
end

local function sha256(s)
    s = sha256_pad(s)
    local nchunks = #s / 64
    local H = {}
    for i = 1, 8 do
        H[i] = SHA256_H[i]
    end
    for i = 1, nchunks do
        local chunk = s:sub((i-1)*64+1, i*64)
        sha256_chunk(chunk, H)
    end
    local h = ""
    for i = 1, 8 do
        h = h .. be_to_str(H[i], 4)
    end
    return h
end

local function hmac(key, msg)
    if #key > 64 then
        key = sha256(key)
    end
    key = key .. char(0):rep(64 - #key)
    assert(#key == 64)
    local o_key_pad = ""
    local i_key_pad = ""
    for i = 1, 64 do
        local k = byte(key, i)
        o_key_pad = o_key_pad .. char(bxor(k, 0x5C))
        i_key_pad = i_key_pad .. char(bxor(k, 0x36))
    end
    return sha256(o_key_pad .. sha256(i_key_pad .. msg))
end

local function pbkdf2(password, salt, c, dklen)
    assert(dklen % 32 == 0)
    local n = dklen / 32
    local dk = ""
    for i = 1, n do
        local f = {}
        for j = 1, 32 do
            f[j] = 0
        end
        local u = salt .. be_to_str(i, 4)
        for j = 1, c do
            u = hmac(password, u)
            for k = 1, 32 do
                f[k] = bxor(f[k], byte(u, k))
            end
        end
        dk = dk .. char(unpack(f))
    end
    assert(#dk == dklen)
    return dk
end

local B64_C = {}
for i = byte("A"), byte("Z") do
    table.insert(B64_C, char(i))
end
for i = byte("a"), byte("z") do
    table.insert(B64_C, char(i))
end
for i = byte("0"), byte("9") do
    table.insert(B64_C, char(i))
end
table.insert(B64_C, "+")
table.insert(B64_C, "/")
local B64_D = {}
for i, c in ipairs(B64_C) do
    B64_D[c] = i - 1
end
local B64_P = "="

local function b64_tri_to_quad(bin)
    if #bin == 0 then
        return ""
    end
    local n = lshift(bin[1], 16)
    if #bin > 1 then
        n = bor(n, lshift(bin[2], 8))
        if #bin > 2 then
            n = bor(n, bin[3])
        end
    end
    local i1 = rshift(n, 18) + 1
    local i2 = band(rshift(n, 12), 0x3F) + 1
    local i3 = band(rshift(n, 6), 0x3F) + 1
    local i4 = band(n, 0x3F) + 1
    local quad = B64_C[i1] .. B64_C[i2]
    if #bin > 1 then
        quad = quad .. B64_C[i3]
        if #bin > 2 then
            quad = quad .. B64_C[i4]
        end
    end
    quad = quad .. B64_P:rep(4 - #quad)
    return quad
end

local function b64_enc(data)
    local b64 = ""
    local bin = {}
    for i = 1, #data do
        table.insert(bin, byte(data, i))
        if #bin == 3 then
            b64 = b64 .. b64_tri_to_quad(bin)
            bin = {}
        end
    end
    b64 = b64 .. b64_tri_to_quad(bin)
    return b64
end

local function base64_pad(b64)
    local padlen = -#b64 % 4
    return b64 .. B64_P:rep(padlen)
end

local function b64_dec(b64)
    b64 = base64_pad(b64)
    local data = ""
    for i = 1, #b64/4 do
        local n = 0
        for j = (i-1)*4+1, i*4 do
            local c = b64:sub(j, j)
            n = lshift(n, 6)
            if c ~= B64_P then
                n = bor(n, B64_D[c])
            end
        end
        local tri = be_to_str(n, 3)
        local j = i*4
        if b64:sub(j, j) == B64_P then
            if b64:sub(j-1, j-1) == B64_P then
                tri = tri:sub(1, 1)
            else
                tri = tri:sub(1, 2)
            end
        end
        data = data .. tri
    end
    return data
end

return {
    hex=hex, urandom=urandom, uuid4=uuid4, sha256=sha256, hmac=hmac,
    pbkdf2=pbkdf2, b64_enc=b64_enc, b64_dec=b64_dec
}