/*
 * Copyright (c) 2017 Yrom Wang <http://www.yrom.net>
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.instabug.bug.internal.video.customencoding;

import static android.media.MediaCodec.BUFFER_FLAG_END_OF_STREAM;
import static android.media.MediaCodec.BUFFER_FLAG_KEY_FRAME;
import static android.media.MediaCodec.INFO_OUTPUT_FORMAT_CHANGED;
import static android.os.Build.VERSION_CODES.N;

import android.annotation.SuppressLint;
import android.annotation.TargetApi;
import android.media.AudioFormat;
import android.media.AudioRecord;
import android.media.MediaCodec;
import android.media.MediaFormat;
import android.media.MediaRecorder;
import android.os.Build;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.Looper;
import android.os.Message;
import android.os.SystemClock;

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;

import com.instabug.library.Constants;
import com.instabug.library.util.InstabugSDKLogger;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.Locale;
import java.util.concurrent.atomic.AtomicBoolean;

@TargetApi(Build.VERSION_CODES.JELLY_BEAN)
class MicRecorder implements Encoder {
    private static final String TAG = "MicRecorder";
    private static final int MSG_PREPARE = 0;
    private static final int MSG_FEED_INPUT = 1;
    private static final int MSG_DRAIN_OUTPUT = 2;
    private static final int MSG_RELEASE_OUTPUT = 3;
    private static final int MSG_STOP = 4;
    private static final int MSG_RELEASE = 5;
    private static final int LAST_FRAME_ID = -1;
    private final AudioEncoder mEncoder;
    private final HandlerThread mRecordThread;
    @Nullable
    private RecordHandler mRecordHandler;
    @Nullable
    private AudioRecord mMic; // access in mRecordThread only!
    private int mSampleRate;
    private int mChannelConfig;
    private int mFormat = AudioFormat.ENCODING_PCM_16BIT;
    private AtomicBoolean mForceStop = new AtomicBoolean(false);
    @Nullable
    private BaseEncoder.Callback mCallback;
    @Nullable
    private CallbackDelegate mCallbackDelegate;
    private int mChannelsSampleRate;
    private LinkedHashMap<Integer, Long> mFramesUsCache = new LinkedHashMap<>(2);

    MicRecorder(AudioEncodeConfig config) {
        mEncoder = new AudioEncoder(config);
        mSampleRate = config.getSamplingRate();
        mChannelsSampleRate = mSampleRate * config.getChannelCount();
        mChannelConfig = config.getChannelCount() == 2 ? AudioFormat.CHANNEL_IN_STEREO : AudioFormat.CHANNEL_IN_MONO;
        mRecordThread = new HandlerThread(TAG);
    }

    @Nullable
    private AudioRecord createAudioRecord(int sampleRateInHz, int channelConfig, int audioFormat) {
        int minBytes = AudioRecord.getMinBufferSize(sampleRateInHz, channelConfig, audioFormat);
        if (minBytes <= 0) {
            InstabugSDKLogger.e(Constants.LOG_TAG, String.format(Locale.US, "Bad arguments: getMinBufferSize(%d, %d, %d)",
                    sampleRateInHz, channelConfig, audioFormat));
            return null;
        }
        try {
            AudioRecord record = new AudioRecord(MediaRecorder.AudioSource.MIC,
                    sampleRateInHz,
                    channelConfig,
                    audioFormat,
                    minBytes * 2);

            if (record.getState() == AudioRecord.STATE_UNINITIALIZED) {
                InstabugSDKLogger.e(Constants.LOG_TAG, String.format(Locale.US, "Bad arguments to new AudioRecord %d, %d, %d",
                        sampleRateInHz, channelConfig, audioFormat));
                return null;
            }
            {
                if (Build.VERSION.SDK_INT >= N) {
                    InstabugSDKLogger.d(Constants.LOG_TAG, " size in frame " + record.getBufferSizeInFrames());
                }
            }
            return record;
        } catch (SecurityException e) {
            // RECORD_AUDIO permission not granted
            InstabugSDKLogger.e(Constants.LOG_TAG,"RECORD_AUDIO permission not granted " + e.getMessage());

            return null;
        }
    }

    @Override
    public void setCallback(Callback callback) {
        this.mCallback = (BaseEncoder.Callback) callback;
    }

    public void setCallback(BaseEncoder.Callback callback) {
        this.mCallback = callback;
    }

    @Override
    public void prepare() throws IOException {
        Looper myLooper = Looper.myLooper();
        if (myLooper == null) {
            throw new NullPointerException("Should prepare in HandlerThread");
        }
        // run callback in caller thread
        mCallbackDelegate = new CallbackDelegate(myLooper, mCallback);
        mRecordThread.start();
        mRecordHandler = new RecordHandler(mRecordThread.getLooper());
        mRecordHandler.sendEmptyMessage(MSG_PREPARE);
    }

    @Override
    @SuppressLint("ERADICATE_PARAMETER_NOT_NULLABLE")
    public void stop() {
        // clear callback queue
        if (mCallbackDelegate != null) mCallbackDelegate.removeCallbacksAndMessages(null);
        mForceStop.set(true);
        if (mRecordHandler != null) mRecordHandler.sendEmptyMessage(MSG_STOP);
    }

    @Override
    public void release() {
        if (mRecordHandler != null) mRecordHandler.sendEmptyMessage(MSG_RELEASE);
        mRecordThread.quit();
    }

    void releaseOutputBuffer(int index) {
        if (mRecordHandler != null)
            Message.obtain(mRecordHandler, MSG_RELEASE_OUTPUT, index, 0).sendToTarget();
    }

    @Nullable
    ByteBuffer getOutputBuffer(int index) {
        return mEncoder.getOutputBuffer(index);
    }

    /**
     * NOTE: Should waiting all output buffer disappear queue input buffer
     */
    private void feedAudioEncoder(int index) {
        if (index < 0 || mForceStop.get()) return;
        AudioRecord record = mMic;
        if (record != null) {

            final boolean eos = record.getRecordingState() == AudioRecord.RECORDSTATE_STOPPED;
            final ByteBuffer frame = mEncoder.getInputBuffer(index);
            int offset = 0, limit, read = 0;
            if (frame != null && !eos) {
                frame.position();
                limit = frame.limit();
                    read = record.read(frame, limit);
                    if (read < 0) {
                        read = 0;
                    }

            }

            long pstTs = calculateFrameTimestamp(read << 3);
            int flags = 0;
            if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.LOLLIPOP) {
                flags = BUFFER_FLAG_KEY_FRAME;
            }

            if (eos) {
                flags = BUFFER_FLAG_END_OF_STREAM;
            }
            // feed frame to encoder
            mEncoder.queueInputBuffer(index, offset, read, pstTs, flags);
            return;
        }
        throw new NullPointerException("maybe release");
    }

    /**
     * Gets presentation time (us) of polled frame.
     * 1 sample = 16 bit
     */
    private long calculateFrameTimestamp(int totalBits) {
        if (mFramesUsCache == null) {
            mFramesUsCache = new LinkedHashMap<>(2);
        }
        int samples = totalBits >> 4;
        long frameUs = -1;
        if (mFramesUsCache.get(samples) != null) {
            frameUs = mFramesUsCache.get(samples);
        }
        if (frameUs == -1) {
            frameUs = samples * 1000_000 / mChannelsSampleRate;
            mFramesUsCache.put(samples, frameUs);
        }
        long timeUs = SystemClock.elapsedRealtime() / 1000000;
        // accounts the delay of polling the audio sample data
        timeUs -= frameUs;
        long currentUs;
        long lastFrameUs = -1;
        if (mFramesUsCache.get(LAST_FRAME_ID) != null) {
            lastFrameUs = mFramesUsCache.get(LAST_FRAME_ID);
        }

        if (lastFrameUs == -1) { // it's the first frame
            currentUs = timeUs;
        } else {
            currentUs = lastFrameUs;
        }

        // maybe too late to acquire sample data
        if (timeUs - currentUs >= (frameUs << 1)) {
            // reset
            currentUs = timeUs;
        }
        mFramesUsCache.put(LAST_FRAME_ID, currentUs + frameUs);
        return currentUs;
    }

    private static class CallbackDelegate extends Handler {
        @Nullable
        private BaseEncoder.Callback mCallback;

        CallbackDelegate(@NonNull Looper l, @Nullable BaseEncoder.Callback callback) {
            super(l);
            this.mCallback = callback;
        }


        void onError(final Encoder encoder, final Exception exception) {
            Message.obtain(this, () -> {
                if (mCallback != null) {
                    mCallback.onError(encoder, exception);
                }
            }).sendToTarget();
        }

        void onOutputFormatChanged(final BaseEncoder encoder, final MediaFormat format) {
            Message.obtain(this, () -> {
                if (mCallback != null) {
                    mCallback.onOutputFormatChanged(encoder, format);
                }
            }).sendToTarget();
        }

        void onOutputBufferAvailable(final BaseEncoder encoder, final int index, final MediaCodec.BufferInfo info) {
            Message.obtain(this, () -> {
                if (mCallback != null) {
                    mCallback.onOutputBufferAvailable(encoder, index, info);
                }
            }).sendToTarget();
        }

    }

    private class RecordHandler extends Handler {

        private LinkedList<MediaCodec.BufferInfo> mCachedInfos = new LinkedList<>();
        private LinkedList<Integer> mMuxingOutputBufferIndices = new LinkedList<>();
        private int mPollRate = 2048_000 / mSampleRate; // poll per 2048 samples

        RecordHandler(Looper l) {
            super(l);
        }

        @Override
        public void handleMessage(Message msg) {
            switch (msg.what) {
                case MSG_PREPARE:
                    AudioRecord r = createAudioRecord(mSampleRate, mChannelConfig, mFormat);
                    if (r == null) {
                        InstabugSDKLogger.e(Constants.LOG_TAG, "create audio record failure");
                        if (mCallbackDelegate != null) {
                            mCallbackDelegate.onError(MicRecorder.this, new IllegalArgumentException());
                        }
                        break;
                    } else {
                        r.startRecording();
                        mMic = r;
                    }
                    try {
                        mEncoder.prepare();
                    } catch (Exception e) {
                        if (mCallbackDelegate != null) {
                            mCallbackDelegate.onError(MicRecorder.this, e);
                        }
                        break;
                    }
                case MSG_FEED_INPUT:
                    if (!mForceStop.get()) {
                        int index = pollInput();
                        if (index >= 0) {
                            feedAudioEncoder(index);
                            // tell encoder to eat the fresh meat!
                            if (!mForceStop.get()) sendEmptyMessage(MSG_DRAIN_OUTPUT);
                        } else {
                            // try later...
                            sendEmptyMessageDelayed(MSG_FEED_INPUT, mPollRate);
                        }
                    }
                    break;
                case MSG_DRAIN_OUTPUT:
                    offerOutput();
                    pollInputIfNeed();
                    break;
                case MSG_RELEASE_OUTPUT:
                    mEncoder.releaseOutputBuffer(msg.arg1);
                    mMuxingOutputBufferIndices.poll(); // Nobody care what it exactly is.
                    pollInputIfNeed();
                    break;
                case MSG_STOP:
                    if (mMic != null) {
                        mMic.stop();
                    }
                    mEncoder.stop();
                    break;
                case MSG_RELEASE:
                    if (mMic != null) {
                        mMic.release();
                        mMic = null;
                    }
                    mEncoder.release();
                    break;
            }
        }

        private void offerOutput() {
            try {
                while (!mForceStop.get()) {
                    MediaCodec.BufferInfo info = mCachedInfos.poll();
                    if (info == null) {
                        info = new MediaCodec.BufferInfo();
                    }
                    int index = mEncoder.getEncoder().dequeueOutputBuffer(info, 1);
                    if (index == INFO_OUTPUT_FORMAT_CHANGED) {
                        if (mCallbackDelegate != null) {
                            mCallbackDelegate.onOutputFormatChanged(mEncoder, mEncoder.getEncoder().getOutputFormat());
                        }
                    }
                    if (index < 0) {
                        info.set(0, 0, 0, 0);
                        mCachedInfos.offer(info);
                        break;
                    }
                    mMuxingOutputBufferIndices.offer(index);
                    if (mCallbackDelegate != null) {
                        mCallbackDelegate.onOutputBufferAvailable(mEncoder, index, info);
                    }

                }
            } catch (Exception exception) {
                InstabugSDKLogger.e(
                        Constants.LOG_TAG,
                        "Something went wrong while calling offerOutput. " + exception.getMessage(),
                        exception
                );
            }
        }

        private int pollInput() {
            try {
                return mEncoder.getEncoder().dequeueInputBuffer(0);
            } catch (Exception exception) {
                InstabugSDKLogger.e(
                        Constants.LOG_TAG,
                        "Something went wrong while calling dequeueInputBuffer. " + exception.getMessage(),
                        exception
                );
            }
            return -1;
        }

        private void pollInputIfNeed() {
            if (mMuxingOutputBufferIndices.size() <= 1 && !mForceStop.get()) {
                // need fresh data, right now!
                removeMessages(MSG_FEED_INPUT);
                sendEmptyMessageDelayed(MSG_FEED_INPUT, 0);
            }
        }
    }

}
