/*
 * Decompiled with CFR 0.152.
 */
package com.github.microwww.redis.protocal.operation;

import com.github.microwww.redis.RequestParams;
import com.github.microwww.redis.database.HashKey;
import com.github.microwww.redis.database.Member;
import com.github.microwww.redis.database.RedisDatabase;
import com.github.microwww.redis.database.SortedSetData;
import com.github.microwww.redis.protocal.AbstractOperation;
import com.github.microwww.redis.protocal.RedisRequest;
import com.github.microwww.redis.protocal.ScanIterator;
import com.github.microwww.redis.util.Assert;
import com.github.microwww.redis.util.SafeEncoder;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.NavigableSet;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;

public class SortedSetOperation
extends AbstractOperation {
    public void zadd(RedisRequest request) throws IOException {
        request.expectArgumentsCountGE(3);
        RequestParams[] args = request.getParams();
        SortedSetData ss = this.getOrCreate(request);
        Member[] ms = new Member[args.length / 2];
        int i = 1;
        int j = 0;
        while (i < args.length) {
            BigDecimal dec = args[i].byteArray2decimal();
            byte[] ba = args[i + 1].getByteArray();
            ms[j] = new Member(ba, dec);
            i += 2;
            ++j;
        }
        int count = ss.addOrReplace(ms);
        request.getOutputProtocol().writer(count);
    }

    public void zcard(RedisRequest request) throws IOException {
        request.expectArgumentsCountGE(1);
        Optional<SortedSetData> ss = this.getData(request);
        int size = ss.map(e -> ((NavigableSet)e.getData()).size()).orElse(0);
        request.getOutputProtocol().writer(size);
    }

    public void zcount(RedisRequest request) throws IOException {
        request.expectArgumentsCount(3);
        RequestParams[] args = request.getParams();
        Optional<SortedSetData> ss = this.getData(request);
        int size = 0;
        if (ss.isPresent()) {
            Interval min = new Interval(args[1].getByteArray());
            Interval max = new Interval(args[2].getByteArray());
            long count = ((NavigableSet)ss.get().getData()).subSet(Member.MIN(min.val), Member.MAX(max.val)).stream().filter(e -> min.filterEqual((Member)e)).filter(e -> max.filterEqual((Member)e)).count();
            size = (int)count;
        }
        request.getOutputProtocol().writer(size);
    }

    public void zincrby(RedisRequest request) throws IOException {
        request.expectArgumentsCount(3);
        RequestParams[] args = request.getParams();
        SortedSetData ss = this.getOrCreate(request);
        BigDecimal inc = args[1].byteArray2decimal();
        byte[] val = args[2].getByteArray();
        Member mem = ss.inc(val, inc);
        request.getOutputProtocol().writer(mem.getScore().toPlainString());
    }

    public void zrange(RedisRequest request) throws IOException {
        request.expectArgumentsCountGE(3);
        RequestParams[] args = request.getParams();
        Optional<SortedSetData> ss = this.getData(request);
        ArrayList<byte[]> list = new ArrayList<byte[]>();
        if (ss.isPresent()) {
            int start = args[1].byteArray2int();
            int stop = args[2].byteArray2int();
            List<Member> mem = ss.get().range(start, stop);
            boolean withScores = args.length == 4;
            for (Member m : mem) {
                list.add(m.getMember());
                if (!withScores) continue;
                list.add(m.getScore().toPlainString().getBytes());
            }
        }
        request.getOutputProtocol().writerMulti((byte[][])list.toArray((T[])new byte[list.size()][]));
    }

    public void zrangebyscore(RedisRequest request) throws IOException {
        this.rangeByScore(request, false);
    }

    private void rangeByScore(RedisRequest request, boolean desc) throws IOException {
        request.expectArgumentsCountGE(3);
        RequestParams[] args = request.getParams();
        Optional<SortedSetData> ss = this.getData(request);
        ArrayList list = new ArrayList();
        if (ss.isPresent()) {
            Interval start = new Interval(args[1].getByteArray());
            Interval stop = new Interval(args[2].getByteArray());
            RangeByScore spm = new RangeByScore();
            for (int i = 3; i < args.length; ++i) {
                String op = args[i].getByteArray2string();
                RangeByScoreParams pm = RangeByScoreParams.valueOf(op.toUpperCase());
                i = pm.next(spm, args, i);
            }
            AtomicInteger count = new AtomicInteger(0);
            long end = spm.getCount() + spm.getOffset();
            ss.get().getSubSetData(desc, start, stop).stream().filter(e -> start.filterEqual((Member)e)).filter(e -> stop.filterEqual((Member)e)).forEach(e -> {
                int i = count.getAndIncrement();
                if ((long)i >= end) {
                    return;
                }
                if (i >= spm.offset) {
                    list.add(e.getMember());
                    if (spm.withScores) {
                        list.add(e.getScore().toPlainString().getBytes());
                    }
                }
            });
        }
        request.getOutputProtocol().writerMulti((byte[][])list.toArray((T[])new byte[list.size()][]));
    }

    public void zrank(RedisRequest request) throws IOException {
        this.rank(request, false);
    }

    private void rank(RedisRequest request, boolean desc) throws IOException {
        Optional<Member> member;
        request.expectArgumentsCount(2);
        RequestParams[] args = request.getParams();
        Optional<SortedSetData> ss = this.getData(request);
        if (ss.isPresent() && (member = ss.get().member(args[1].byteArray2hashKey())).isPresent()) {
            NavigableSet<Member> members = (NavigableSet<Member>)ss.get().getData();
            if (desc) {
                members = members.descendingSet();
            }
            int i = members.headSet(member.get()).size();
            request.getOutputProtocol().writer(i);
            return;
        }
        request.getOutputProtocol().writerNull();
    }

    public void zrem(RedisRequest request) throws IOException {
        request.expectArgumentsCountGE(2);
        RequestParams[] args = request.getParams();
        Optional<SortedSetData> ss = this.getData(request);
        int count = 0;
        if (ss.isPresent()) {
            ArrayList<HashKey> list = new ArrayList<HashKey>();
            for (int i = 1; i < args.length; ++i) {
                list.add(args[i].byteArray2hashKey());
            }
            count = ss.get().removeAll(list);
        }
        request.getOutputProtocol().writer(count);
    }

    public void zremrangebyrank(RedisRequest request) throws IOException {
        request.expectArgumentsCount(3);
        RequestParams[] args = request.getParams();
        Optional<SortedSetData> ss = this.getData(request);
        int count = 0;
        if (ss.isPresent()) {
            count = ss.get().remRangeByRank(args[1].byteArray2int(), args[2].byteArray2int());
        }
        request.getOutputProtocol().writer(count);
    }

    public void zremrangebyscore(RedisRequest request) throws IOException {
        request.expectArgumentsCount(3);
        RequestParams[] args = request.getParams();
        Optional<SortedSetData> ss = this.getData(request);
        int count = 0;
        if (ss.isPresent()) {
            Interval min = new Interval(args[1].getByteArray());
            Interval max = new Interval(args[2].getByteArray());
            count = ss.get().remRangeByScore(min, max);
        }
        request.getOutputProtocol().writer(count);
    }

    public void zrevrange(RedisRequest request) throws IOException {
        request.expectArgumentsCountGE(3);
        RequestParams[] args = request.getParams();
        Optional<SortedSetData> ss = this.getData(request);
        ArrayList<byte[]> list = new ArrayList<byte[]>();
        if (ss.isPresent()) {
            int start = args[1].byteArray2int();
            int stop = args[2].byteArray2int();
            List<Member> mem = ss.get().revRange(start, stop);
            boolean withScores = args.length == 4;
            for (Member m : mem) {
                list.add(m.getMember());
                if (!withScores) continue;
                list.add(m.getScore().toPlainString().getBytes());
            }
        }
        request.getOutputProtocol().writerMulti((byte[][])list.toArray((T[])new byte[list.size()][]));
    }

    public void zrevrangebyscore(RedisRequest request) throws IOException {
        this.rangeByScore(request, true);
    }

    public void zrevrank(RedisRequest request) throws IOException {
        this.rank(request, true);
    }

    public void zscore(RedisRequest request) throws IOException {
        Optional<Member> member;
        request.expectArgumentsCount(2);
        RequestParams[] args = request.getParams();
        Optional<SortedSetData> ss = this.getData(request);
        if (ss.isPresent() && (member = ss.get().member(args[1].byteArray2hashKey())).isPresent()) {
            request.getOutputProtocol().writer(member.get().getScore().toPlainString());
            return;
        }
        request.getOutputProtocol().writerNull();
    }

    public void zunionstore(RedisRequest request) throws IOException {
        SortedSetData target = this.getOrCreate(request);
        this.storeFromSortedSet(request, (db, param) -> target.unionStore(request.getDatabase(), (UnionStore)param));
    }

    public void storeFromSortedSet(RedisRequest request, BiFunction<RedisDatabase, UnionStore, Integer> fun) throws IOException {
        request.expectArgumentsCountGE(3);
        RequestParams[] args = request.getParams();
        int i = 1;
        int num = args[i++].byteArray2int();
        Assert.isTrue(num + i <= args.length, "num-keys count error");
        HashKey[] hks = new HashKey[num];
        int j = 0;
        while (j < num) {
            hks[j] = args[i].byteArray2hashKey();
            ++j;
            ++i;
        }
        UnionStore us = new UnionStore(hks);
        while (i < args.length) {
            String op = args[i].getByteArray2string();
            i = UnionStoreParam.valueOf(op.toUpperCase()).next(us, args, i);
            ++i;
        }
        int count = fun.apply(request.getDatabase(), us);
        request.getOutputProtocol().writer(count);
    }

    public void zinterstore(RedisRequest request) throws IOException {
        SortedSetData target = this.getOrCreate(request);
        this.storeFromSortedSet(request, (db, param) -> target.interStore(request.getDatabase(), (UnionStore)param));
    }

    public void zscan(RedisRequest request) throws IOException {
        Optional<SortedSetData> opt = this.getData(request);
        NavigableSet hk = opt.map(e -> (NavigableSet)e.getData()).orElse(Collections.emptyNavigableSet());
        Iterator iterator = hk.iterator();
        new ScanIterator(request, 1).skip(iterator).continueWrite(iterator, e -> e.getMember(), e -> e.getScore().toPlainString().getBytes());
    }

    private SortedSetData getOrCreate(RedisRequest request) {
        HashKey key = new HashKey(request.getParams()[0].getByteArray());
        return request.getDatabase().getOrCreate(key, SortedSetData::new);
    }

    private Optional<SortedSetData> getData(RedisRequest request) {
        HashKey hk = request.getParams()[0].byteArray2hashKey();
        return request.getDatabase().get(hk, SortedSetData.class);
    }

    public static enum Aggregate implements BinaryOperator<BigDecimal>
    {
        SUM{

            @Override
            public BigDecimal apply(BigDecimal d1, BigDecimal d2) {
                Assert.isNotNull(d2, "Not null");
                if (d1 == null) {
                    return d2;
                }
                return d1.add(d2);
            }
        }
        ,
        MIN{

            @Override
            public BigDecimal apply(BigDecimal d1, BigDecimal d2) {
                Assert.isNotNull(d2, "Not null");
                if (d1 == null) {
                    return d2;
                }
                return d1.compareTo(d2) > 0 ? d2 : d1;
            }
        }
        ,
        MAX{

            @Override
            public BigDecimal apply(BigDecimal d1, BigDecimal d2) {
                Assert.isNotNull(d2, "Not null");
                if (d1 == null) {
                    return d2;
                }
                return d1.compareTo(d2) > 0 ? d1 : d2;
            }
        };

    }

    public static enum UnionStoreParam {
        WEIGHTS{

            @Override
            public int next(UnionStore params, RequestParams[] args, int i) {
                Assert.isTrue(args.length > i + params.getHashKeys().length, "WEIGHTS  count error");
                int[] w = new int[params.getHashKeys().length];
                int j = 0;
                while (j < params.getHashKeys().length) {
                    w[j] = args[i + j + 1].byteArray2int();
                    ++j;
                    ++i;
                }
                params.setWeights(w);
                return i + w.length;
            }
        }
        ,
        AGGREGATE{

            @Override
            public int next(UnionStore params, RequestParams[] args, int i) {
                Aggregate agg = Aggregate.valueOf(args[i + 1].getByteArray2string().toUpperCase());
                params.setType(agg);
                return i + 1;
            }
        };


        public abstract int next(UnionStore var1, RequestParams[] var2, int var3);
    }

    public static class UnionStore {
        private final HashKey[] hashKeys;
        private int[] weights;
        private Aggregate type = Aggregate.SUM;

        public UnionStore(HashKey[] hashKeys) {
            this.hashKeys = hashKeys;
            this.weights = new int[this.getHashKeys().length];
            Arrays.fill(this.weights, 1);
        }

        public HashKey[] getHashKeys() {
            return this.hashKeys;
        }

        public int[] getWeights() {
            return this.weights;
        }

        public void setWeights(int[] weights) {
            this.weights = weights;
        }

        public Aggregate getType() {
            return this.type;
        }

        public void setType(Aggregate type) {
            this.type = type;
        }
    }

    public static enum RangeByScoreParams {
        WITHSCORES{

            @Override
            public int next(RangeByScore params, RequestParams[] args, int i) {
                params.withScores = true;
                return i + 1;
            }
        }
        ,
        LIMIT{

            @Override
            public int next(RangeByScore params, RequestParams[] args, int i) {
                params.offset = args[i + 1].byteArray2int();
                params.count = args[i + 2].byteArray2int();
                return i + 2;
            }
        };


        public abstract int next(RangeByScore var1, RequestParams[] var2, int var3);
    }

    public static class RangeByScore {
        private boolean withScores = false;
        private int offset = 0;
        private int count = Integer.MAX_VALUE;

        public boolean isWithScores() {
            return this.withScores;
        }

        public void setWithScores(boolean withScores) {
            this.withScores = withScores;
        }

        public int getOffset() {
            return this.offset;
        }

        public void setOffset(int offset) {
            Assert.isTrue(offset >= 0, "Offset >= 0");
            this.offset = offset;
        }

        public int getCount() {
            return this.count;
        }

        public void setCount(int count) {
            Assert.isTrue(count > 0, "count > 0");
            this.count = count;
        }
    }

    public static class Interval {
        public final boolean open;
        public final BigDecimal val;

        public Interval(byte[] bytes) {
            String code = SafeEncoder.encode(bytes).trim();
            this.open = code.startsWith("(");
            if (this.open) {
                code = code.substring(1);
            }
            this.val = new BigDecimal(code);
        }

        public boolean filterEqual(Member e) {
            if (this.open) {
                return this.val.compareTo(e.getScore()) != 0;
            }
            return true;
        }
    }
}

