/*
 * Decompiled with CFR 0.152.
 */
package org.infinispan.server.resp.scripting;

import io.netty.channel.ChannelHandlerContext;
import java.lang.reflect.InvocationTargetException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import org.infinispan.commons.CacheListenerException;
import org.infinispan.commons.util.Util;
import org.infinispan.commons.util.Version;
import org.infinispan.remoting.RemoteException;
import org.infinispan.server.resp.AclCategory;
import org.infinispan.server.resp.Resp3Handler;
import org.infinispan.server.resp.RespCommand;
import org.infinispan.server.resp.RespRequestHandler;
import org.infinispan.server.resp.RespVersion;
import org.infinispan.server.resp.logging.Log;
import org.infinispan.server.resp.scripting.LuaCode;
import org.infinispan.server.resp.scripting.LuaContextPool;
import org.infinispan.server.resp.scripting.LuaTaskEngine;
import org.infinispan.server.resp.scripting.ScriptFlags;
import org.jboss.logging.Logger;
import party.iroiro.luajava.JFunction;
import party.iroiro.luajava.Lua;
import party.iroiro.luajava.lua51.Lua51;

public class LuaContext
implements AutoCloseable {
    private static final String REDIS_API_NAME = "redis";
    private static final Set<String> LIBRARIES_ALLOW_LIST = Set.of("string", "math", "table", "os");
    private static final Set<String> REDIS_API_ALLOW_LIST = Set.of("redis", "__redis__err__handler");
    private static final Set<String> LUA_BUILTINS_ALLOW_LIST = Set.of("xpcall", "tostring", "setmetatable", "next", "assert", "tonumber", "rawequal", "collectgarbage", "getmetatable", "rawset", "pcall", "coroutine", "type", "_G", "select", "unpack", "gcinfo", "pairs", "rawget", "loadstring", "ipairs", "_VERSION", "load", "error");
    private static final Set<String> LUA_BUILTINS_REMOVED_AFTER_INITIALIZATION_ALLOW_LIST = Set.of("debug");
    private static final Set<String> ALLOW_LISTS;
    private static final Set<String> DENY_LIST;
    public static final int LOG_DEBUG = 0;
    public static final int LOG_VERBOSE = 1;
    public static final int LOG_NOTICE = 2;
    public static final int LOG_WARNING = 3;
    private static final Logger.Level[] LEVEL_MAP;
    public static final int PROPAGATE_AOF = 1;
    public static final int PROPAGATE_REPL = 2;
    public static final int PROPAGATE_NONE = 0;
    public static final int PROPAGATE_ALL = 3;
    final Lua lua;
    long flags;
    Resp3Handler handler;
    ChannelHandlerContext ctx;
    Mode mode = Mode.USER;
    LuaContextPool pool;

    LuaContext() {
        this.lua = new Lua51();
        for (String lib : LIBRARIES_ALLOW_LIST) {
            this.lua.openLibrary(lib);
        }
        this.lua.openLibrary("debug");
        this.installMathRandom();
        this.installErrorHandler();
        this.installRedisAPI();
        this.luaSetErrorMetatable();
    }

    private void installRedisAPI() {
        this.luaSetAllowListProtection();
        this.lua.newTable();
        LuaContext.tableAdd(this.lua, "sha1hex", l -> {
            int argc = l.getTop();
            if (argc != 1) {
                l.error("wrong number of arguments");
            }
            String hex = LuaContext.sha1hex(l.toString(1));
            l.push(hex);
            return 1;
        });
        LuaContext.tableAdd(this.lua, "call", l -> this.executeRespCommand(l, true));
        LuaContext.tableAdd(this.lua, "pcall", l -> this.executeRespCommand(l, false));
        LuaContext.tableAdd(this.lua, "setresp", l -> {
            int argc = l.getTop();
            if (argc != 1) {
                l.error("redis.setresp() requires one argument.");
            }
            try {
                this.handler.writer().version(RespVersion.of((int)l.toInteger(-argc)));
            }
            catch (IllegalArgumentException e) {
                l.error("RESP version must be 2 or 3.");
            }
            return 0;
        });
        LuaContext.tableAdd(this.lua, "error_reply", l -> {
            if (l.getTop() != 1 || l.type(-1) != Lua.LuaType.STRING) {
                l.error("wrong number or type of arguments");
                return 1;
            }
            Object err = l.toString(-1);
            if (!((String)err).startsWith("-")) {
                err = "-" + (String)err;
            }
            LuaContext.luaPushError(this.lua, (String)err);
            return 1;
        });
        LuaContext.tableAdd(this.lua, "status_reply", l -> {
            if (l.getTop() != 1 || l.type(-1) != Lua.LuaType.STRING) {
                l.error("wrong number or type of arguments");
            }
            l.newTable();
            l.push("ok");
            l.pushValue(-3);
            l.setTable(-3);
            return 1;
        });
        LuaContext.tableAdd(this.lua, "set_repl", l -> {
            long flags;
            int argc = l.getTop();
            if (argc != 1) {
                l.error("redis.set_repl() requires one argument.");
            }
            if (((flags = l.toInteger(-1)) & 0xFFFFFFFFFFFFFFFCL) != 0L) {
                l.error("Invalid replication flags. Use REPL_AOF, REPL_REPLICA, REPL_ALL or REPL_NONE.");
            }
            return 0;
        });
        LuaContext.tableAdd(this.lua, "REPL_NONE", 0);
        LuaContext.tableAdd(this.lua, "REPL_AOF", 1);
        LuaContext.tableAdd(this.lua, "REPL_SLAVE", 2);
        LuaContext.tableAdd(this.lua, "REPL_REPLICA", 2);
        LuaContext.tableAdd(this.lua, "REPL_ALL", 3);
        LuaContext.tableAdd(this.lua, "log", l -> {
            int argc = l.getTop();
            if (argc < 2) {
                LuaContext.luaPushError(this.lua, "redis.log() requires two arguments or more.");
                return -1;
            }
            if (!l.isNumber(-argc)) {
                LuaContext.luaPushError(this.lua, "First argument must be a number (log level).");
                return -1;
            }
            int level = (int)l.toInteger(-argc);
            if (level < 0 || level > 3) {
                LuaContext.luaPushError(this.lua, "Invalid log level.");
                return -1;
            }
            StringBuilder sb = new StringBuilder();
            for (int j = 1; j < argc; ++j) {
                sb.append(l.toString(j - argc));
            }
            Log.SERVER.log(LEVEL_MAP[level], sb);
            return 0;
        });
        LuaContext.tableAdd(this.lua, "LOG_DEBUG", 0);
        LuaContext.tableAdd(this.lua, "LOG_VERBOSE", 1);
        LuaContext.tableAdd(this.lua, "LOG_NOTICE", 2);
        LuaContext.tableAdd(this.lua, "LOG_WARNING", 3);
        LuaContext.tableAdd(this.lua, "REDIS_VERSION_NUM", Version.getVersionShort());
        LuaContext.tableAdd(this.lua, "REDIS_VERSION", Version.getVersion());
        this.lua.setGlobal(REDIS_API_NAME);
    }

    private void installMathRandom() {
        this.lua.getGlobal("math");
        this.lua.push("random");
        this.lua.push(l -> {
            switch (l.getTop()) {
                case 0: {
                    this.lua.push((Number)this.handler.respServer().random().nextDouble());
                    break;
                }
                case 1: {
                    long upper = this.lua.toInteger(1);
                    if (upper <= 1L) {
                        this.lua.error("interval is empty");
                    }
                    this.lua.push(this.handler.respServer().random().nextLong(1L, upper));
                    break;
                }
                case 2: {
                    long lower = this.lua.toInteger(1);
                    long upper = this.lua.toInteger(2);
                    this.lua.push(this.handler.respServer().random().nextLong(lower, upper));
                    break;
                }
                default: {
                    this.lua.error("wrong number of arguments");
                }
            }
            return 1;
        });
        this.lua.setTable(-3);
        this.lua.push("randomseed");
        this.lua.push(l -> {
            this.handler.respServer().random().setSeed(l.toInteger(1));
            return 0;
        });
        this.lua.setTable(-3);
        this.lua.setGlobal("math");
    }

    private void installErrorHandler() {
        String err_handler = "-- copy the `debug` global to a local, and nil it so it cannot be used by user scripts\nlocal dbg = debug\ndebug = nil\nfunction __redis__err__handler(err)\n  -- get debug information for the previous call (type, source and line)\n  local i = dbg.getinfo(2,'nSl')\n  -- if it was a native call, get the information for the previous element in the stack\n  if i and i.what == 'C' then\n    i = dbg.getinfo(3,'nSl')\n  end\n  if type(err) ~= 'table' then\n    err = {err='ERR ' .. tostring(err)}\n  end\n  if i then\n    err['source'] = i.source\n    err['line'] = i.currentline\n  end\n  return err\nend\n";
        byte[] bytes = err_handler.getBytes(StandardCharsets.US_ASCII);
        ByteBuffer buffer = ByteBuffer.allocateDirect(bytes.length);
        buffer.put(bytes);
        this.lua.load((Buffer)buffer, "@err_handler_def");
        this.lua.pCall(0, 0);
    }

    public int executeRespCommand(Lua l, boolean raiseError) {
        int argc = l.getTop();
        String command = l.toString(-argc);
        RespCommand respCommand = RespCommand.fromString(command);
        if (respCommand == null) {
            l.push("Unknown Redis command called from script");
            return -1;
        }
        long commandMask = respCommand.aclMask();
        if (AclCategory.CONNECTION.matches(commandMask)) {
            l.push("This Redis command is not allowed from script");
            return -1;
        }
        if (ScriptFlags.NO_WRITES.isSet(this.flags) && AclCategory.WRITE.matches(commandMask)) {
            l.push("Write commands are not allowed from read-only scripts.");
            return -1;
        }
        ArrayList<byte[]> args = new ArrayList<byte[]>(argc - 1);
        for (int i = -argc + 1; i < 0; ++i) {
            args.add(l.toString(i).getBytes(StandardCharsets.US_ASCII));
        }
        CompletableFuture<RespRequestHandler> future = this.handler.handleRequest(this.ctx, respCommand, args).toCompletableFuture();
        try {
            future.get();
        }
        catch (Throwable t) {
            this.handler.writer().error(t);
            Throwable cause = LuaContext.filterCause(t);
            Log.SERVER.debugf(cause, "Error while processing command '%s'", respCommand);
        }
        if (this.lua.type(-1) == Lua.LuaType.TABLE) {
            this.lua.push("err");
            this.lua.rawGet(-2);
            if (this.lua.type(-1) == Lua.LuaType.STRING && raiseError) {
                String error = this.lua.toString(-1);
                this.lua.pop(2);
                this.lua.error(error);
            }
            this.lua.pop(1);
        }
        return 1;
    }

    public static Throwable filterCause(Throwable re) {
        if (re == null) {
            return null;
        }
        Class<?> tClass = re.getClass();
        Throwable cause = re.getCause();
        if (cause != null && (tClass == ExecutionException.class || tClass == CompletionException.class || tClass == InvocationTargetException.class || tClass == RemoteException.class || tClass == RuntimeException.class || tClass == CacheListenerException.class)) {
            return LuaContext.filterCause(cause);
        }
        return re;
    }

    @Override
    public void close() {
        if (this.pool != null) {
            this.pool.returnToPool(this);
        }
    }

    void shutdown() {
        this.pool = null;
        this.lua.close();
    }

    private static void tableAdd(Lua lua, String name, JFunction function) {
        lua.push(name);
        lua.push(function);
        lua.setTable(-3);
    }

    private static void tableAdd(Lua lua, String name, int i) {
        lua.push(name);
        lua.push((long)i);
        lua.setTable(-3);
    }

    private static void tableAdd(Lua lua, String name, String value) {
        lua.push(name);
        lua.push(value);
        lua.setTable(-3);
    }

    public static String sha1hex(String s) {
        try {
            MessageDigest sha1 = MessageDigest.getInstance("SHA-1");
            return Util.toHexString((byte[])sha1.digest(s.getBytes(StandardCharsets.UTF_8)));
        }
        catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
    }

    public static void luaPushError(Lua lua, String error) {
        int pos;
        int endpos = error.length() - (error.endsWith("\r\n") ? 2 : 0);
        Object msg = error.startsWith("-") ? ((pos = error.indexOf(32)) < 0 ? "ERR " + error.substring(1, endpos) : error.substring(1, endpos)) : "ERR " + error.substring(0, endpos);
        lua.newTable();
        LuaContext.tableAdd(lua, "err", (String)msg);
    }

    private static int luaProtectedTableError(Lua lua) {
        int argc = lua.getTop();
        if (argc != 2) {
            lua.error("Wrong number of arguments to luaProtectedTableError");
        }
        if (!lua.isString(-1) && !lua.isNumber(-1)) {
            lua.error("Second argument to luaProtectedTableError must be a string or number");
        }
        String variableName = lua.toString(-1);
        lua.error("Script attempted to access nonexistent global variable '" + variableName + "'");
        return 0;
    }

    private void luaSetErrorMetatable() {
        this.lua.push(-10002L);
        this.lua.newTable();
        this.lua.push(LuaContext::luaProtectedTableError);
        this.lua.setField(-2, "__index");
        this.lua.setMetatable(-2);
        this.lua.pop(1);
    }

    private static int luaNewIndexAllowList(Lua lua) {
        String variableName;
        int argc = lua.getTop();
        if (argc != 3) {
            lua.error("Wrong number of arguments to luaNewIndexAllowList");
        }
        if (!lua.isTable(-3)) {
            lua.error("first argument to luaNewIndexAllowList must be a table");
        }
        if (!lua.isString(-2) && !lua.isNumber(-2)) {
            lua.error("Second argument to luaNewIndexAllowList must be a string or number");
        }
        if (ALLOW_LISTS.contains(variableName = lua.toString(-2))) {
            lua.rawSet(-3);
        } else if (!DENY_LIST.contains(variableName)) {
            Log.SERVER.warnf("A key '%s' was added to Lua globals which is not on the globals allow list nor listed on the deny list.", variableName);
        }
        return 0;
    }

    private void luaSetAllowListProtection() {
        this.lua.push(-10002L);
        this.lua.newTable();
        this.lua.push(LuaContext::luaNewIndexAllowList);
        this.lua.setField(-2, "__newindex");
        this.lua.setMetatable(-2);
        this.lua.pop(1);
    }

    void registerScript(LuaCode code) {
        String name = LuaTaskEngine.fName(code.sha());
        this.lua.getField(-10000, name);
        if (this.lua.get().type() == Lua.LuaType.NIL) {
            byte[] bytes = code.code().getBytes(StandardCharsets.US_ASCII);
            ByteBuffer buffer = ByteBuffer.allocateDirect(bytes.length);
            buffer.put(bytes);
            this.lua.load((Buffer)buffer, "@user_script");
            this.lua.setField(-10000, name);
        }
    }

    void unregisterScript(LuaCode code) {
        this.lua.pushNil();
        this.lua.setField(-10000, LuaTaskEngine.fName(code.sha()));
    }

    static {
        DENY_LIST = Set.of("dofile", "loadfile", "print");
        LEVEL_MAP = new Logger.Level[]{Logger.Level.TRACE, Logger.Level.DEBUG, Logger.Level.INFO, Logger.Level.WARN};
        ALLOW_LISTS = new HashSet<String>();
        ALLOW_LISTS.addAll(LIBRARIES_ALLOW_LIST);
        ALLOW_LISTS.addAll(REDIS_API_ALLOW_LIST);
        ALLOW_LISTS.addAll(LUA_BUILTINS_ALLOW_LIST);
        ALLOW_LISTS.addAll(LUA_BUILTINS_REMOVED_AFTER_INITIALIZATION_ALLOW_LIST);
    }

    public static enum Mode {
        USER,
        LOAD;

    }
}

