/*
 * Decompiled with CFR 0.152.
 */
package shadow.palantir.driver.com.palantir.dialogue.core;

import com.palantir.logsafe.Preconditions;
import com.palantir.logsafe.SafeArg;
import com.palantir.logsafe.logger.SafeLogger;
import com.palantir.logsafe.logger.SafeLoggerFactory;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import shadow.palantir.driver.com.codahale.metrics.Meter;
import shadow.palantir.driver.com.github.benmanes.caffeine.cache.Ticker;
import shadow.palantir.driver.com.google.common.annotations.VisibleForTesting;
import shadow.palantir.driver.com.google.common.collect.ImmutableList;
import shadow.palantir.driver.com.google.common.primitives.Ints;
import shadow.palantir.driver.com.google.common.util.concurrent.FutureCallback;
import shadow.palantir.driver.com.palantir.dialogue.Response;
import shadow.palantir.driver.com.palantir.dialogue.core.CoarseExponentialDecayReservoir;
import shadow.palantir.driver.com.palantir.dialogue.core.DialogueBalancedMetrics;
import shadow.palantir.driver.com.palantir.dialogue.core.DialogueInternalWeakReducingGauge;
import shadow.palantir.driver.com.palantir.dialogue.core.DialogueRoundrobinMetrics;
import shadow.palantir.driver.com.palantir.dialogue.core.Responses;
import shadow.palantir.driver.com.palantir.tritium.metrics.registry.MetricName;
import shadow.palantir.driver.com.palantir.tritium.metrics.registry.TaggedMetricRegistry;

final class BalancedScoreTracker {
    private static final SafeLogger log = SafeLoggerFactory.get(BalancedScoreTracker.class);
    private static final Comparator<ScoreSnapshot> BY_SCORE = Comparator.comparingInt(ScoreSnapshot::getScore);
    private static final Duration FAILURE_MEMORY = Duration.ofSeconds(30L);
    private static final double FAILURE_WEIGHT = 10.0;
    private final ImmutableList<ChannelScoreInfo> channelStats;
    private final Random random;
    private final Ticker clock;

    BalancedScoreTracker(int channelCount, Random random, Ticker ticker, TaggedMetricRegistry taggedMetrics, String channelName) {
        Preconditions.checkState(channelCount >= 1, "At least one channel required");
        this.random = random;
        this.clock = ticker;
        this.channelStats = IntStream.range(0, channelCount).mapToObj(index -> new ChannelScoreInfo(index, this.clock, PerHostObservability.create(channelCount, taggedMetrics, channelName, index))).collect(ImmutableList.toImmutableList());
        BalancedScoreTracker.registerGauges(taggedMetrics, channelName, this.channelStats);
    }

    ScoreSnapshot[] getSnapshotsInOrderOfIncreasingScore() {
        List<ChannelScoreInfo> shuffledMutableStats = BalancedScoreTracker.shuffleImmutableList(this.channelStats, this.random);
        ScoreSnapshot[] snapshotArray = new ScoreSnapshot[shuffledMutableStats.size()];
        for (int i = 0; i < snapshotArray.length; ++i) {
            snapshotArray[i] = shuffledMutableStats.get(i).computeScoreSnapshot();
        }
        Arrays.sort(snapshotArray, BY_SCORE);
        return snapshotArray;
    }

    public ChannelScoreInfo getSingleBestChannelByScore() {
        return this.getSnapshotsInOrderOfIncreasingScore()[0].getDelegate();
    }

    @VisibleForTesting
    IntStream getScoresForTesting() {
        return this.channelStats.stream().mapToInt(c -> c.computeScoreSnapshot().score);
    }

    ImmutableList<ChannelScoreInfo> channelStats() {
        return this.channelStats;
    }

    private static <T> List<T> shuffleImmutableList(ImmutableList<T> sourceList, Random random) {
        ArrayList<T> mutableList = new ArrayList<T>(sourceList);
        Collections.shuffle(mutableList, random);
        return mutableList;
    }

    public String toString() {
        return "BalancedScoreTracker{" + this.channelStats + "}";
    }

    private static void registerGauges(TaggedMetricRegistry taggedMetrics, String channelName, ImmutableList<ChannelScoreInfo> channels) {
        if (channels.size() > 10) {
            log.info("Not registering gauges as there are too many nodes {}", SafeArg.of("count", channels.size()));
            return;
        }
        for (int hostIndex = 0; hostIndex < channels.size(); ++hostIndex) {
            MetricName metricName = DialogueBalancedMetrics.of(taggedMetrics).score().channelName(channelName).hostIndex(Integer.toString(hostIndex)).buildMetricName();
            DialogueInternalWeakReducingGauge.getOrCreate(taggedMetrics, metricName, c -> c.computeScoreSnapshot().getScore(), longStream -> {
                long[] longs = longStream.toArray();
                if (log.isInfoEnabled() && longs.length > 1 && LongStream.of(longs).distinct().count() > 1L) {
                    log.info("Multiple ({}) objects contribute to the same gauge, taking the average (beware this may be misleading) {} {}", SafeArg.of("count", longs.length), SafeArg.of("metricName", metricName), SafeArg.of("values", Arrays.toString(longs)));
                }
                return Arrays.stream(longs).average().orElse(0.0);
            }, (ChannelScoreInfo)channels.get(hostIndex));
        }
    }

    public static abstract class PerHostObservability {
        private final SafeArg<String> channelName;
        private final SafeArg<Integer> hostIndex;

        PerHostObservability(String channelName, int hostIndex) {
            this.channelName = SafeArg.of("channelName", channelName);
            this.hostIndex = SafeArg.of("hostIndex", hostIndex);
        }

        public abstract void markRequestMade();

        void debugLogThrowableFailure(CoarseExponentialDecayReservoir reservoir, Throwable throwable) {
            if (log.isDebugEnabled()) {
                log.debug("Recorded recent failure (throwable)", this.channelName, this.hostIndex, SafeArg.of("recentFailures", reservoir.get()), throwable);
            }
        }

        void debugLogStatusFailure(Response response) {
            if (log.isDebugEnabled()) {
                log.debug("Recorded recent failure (status)", this.channelName, this.hostIndex, SafeArg.of("status", response.code()));
            }
        }

        void traceLogComputedScore(int inflight, double failures, int score) {
            if (log.isTraceEnabled()) {
                log.trace("Computed score ({} {}) {}", this.channelName, this.hostIndex, SafeArg.of("score", score), SafeArg.of("inflight", inflight), SafeArg.of("failures", failures));
            }
        }

        static PerHostObservability create(int numChannels, TaggedMetricRegistry taggedMetrics, String channelName, int index) {
            if (numChannels > 10) {
                return new PerHostObservability(channelName, index){

                    @Override
                    public void markRequestMade() {
                    }
                };
            }
            final Meter meter = DialogueRoundrobinMetrics.of(taggedMetrics).success().channelName(channelName).hostIndex(Integer.toString(index)).build();
            return new PerHostObservability(channelName, index){

                @Override
                public void markRequestMade() {
                    meter.mark();
                }
            };
        }
    }

    static final class ScoreSnapshot {
        private final int score;
        private final int inflight;
        private final ChannelScoreInfo delegate;

        ScoreSnapshot(int score, int inflight, ChannelScoreInfo delegate) {
            this.score = score;
            this.inflight = inflight;
            this.delegate = delegate;
        }

        int getScore() {
            return this.score;
        }

        int getInflight() {
            return this.inflight;
        }

        ChannelScoreInfo getDelegate() {
            return this.delegate;
        }

        public String toString() {
            return "ScoreSnapshot{score=" + this.score + ", delegate=" + this.delegate + "}";
        }
    }

    public static final class ChannelScoreInfo
    implements FutureCallback<Response> {
        private final int hostIndex;
        private final PerHostObservability observability;
        private final AtomicInteger inflight = new AtomicInteger(0);
        private final CoarseExponentialDecayReservoir recentFailuresReservoir;

        ChannelScoreInfo(int hostIndex, Ticker clock, PerHostObservability observability) {
            this.hostIndex = hostIndex;
            this.recentFailuresReservoir = new CoarseExponentialDecayReservoir(clock::read, FAILURE_MEMORY);
            this.observability = observability;
        }

        public void startRequest() {
            this.inflight.incrementAndGet();
        }

        public PerHostObservability observability() {
            return this.observability;
        }

        public void undoStartRequest() {
            this.inflight.decrementAndGet();
        }

        @Override
        public void onSuccess(Response response) {
            this.inflight.decrementAndGet();
            if (ChannelScoreInfo.isGlobalQosStatus(response) || Responses.isServerErrorRange(response)) {
                this.recentFailuresReservoir.update(10.0);
                this.observability.debugLogStatusFailure(response);
            } else if (Responses.isClientError(response) || Responses.isQosStatus(response)) {
                this.recentFailuresReservoir.update(0.1);
                this.observability.debugLogStatusFailure(response);
            }
        }

        private static boolean isGlobalQosStatus(Response response) {
            return Responses.isQosStatus(response) && !Responses.isTooManyRequests(response);
        }

        public int channelIndex() {
            return this.hostIndex;
        }

        @Override
        public void onFailure(Throwable throwable) {
            this.inflight.decrementAndGet();
            this.recentFailuresReservoir.update(10.0);
            this.observability.debugLogThrowableFailure(this.recentFailuresReservoir, throwable);
        }

        private ScoreSnapshot computeScoreSnapshot() {
            int requestsInflight = this.inflight.get();
            double failureReservoir = this.recentFailuresReservoir.get();
            int score = requestsInflight + Ints.saturatedCast(Math.round(failureReservoir));
            this.observability.traceLogComputedScore(requestsInflight, failureReservoir, score);
            return new ScoreSnapshot(score, requestsInflight, this);
        }

        public String toString() {
            return "ChannelScoreInfo{hostIndex=" + this.hostIndex + ", inflight=" + this.inflight + ", recentFailures=" + this.recentFailuresReservoir + "}";
        }
    }
}

