# SnapLogic - Data Integration
#
# Copyright (C) 2018, SnapLogic, Inc.  All rights reserved.
#
# This program is licensed under the terms of
# the SnapLogic Commercial Subscription agreement.
#
# 'SnapLogic' is a trademark of SnapLogic, Inc.

# https://docs-snaplogic.atlassian.net/wiki/spaces/SD/pages/543129601/

import datetime
import decimal
import json
import logging
import numpy
import os
import queue
import shutil
import socket
import ssl
import sys
import time
import traceback
import uuid
from io import StringIO
from logging.handlers import RotatingFileHandler
from multiprocessing import Process
from threading import Thread

UTF_8 = "utf-8"
READ_BUF_SIZE = 1048576
QUEUE_SIZE = 10
SOCK_BACKLOG = 1024
LOG_SIZE_LIMIT = 20480
LOG_BACKUP = 1

REMOTE_PYTHON_EXECUTOR_HOST = ""
REMOTE_PYTHON_EXECUTOR_PORT = int(sys.argv[1])
REMOTE_PYTHON_EXECUTOR_TOKEN = os.environ.get("REMOTE_PYTHON_EXECUTOR_TOKEN", "")
CRT_PATH = os.environ.get("REMOTE_PYTHON_EXECUTOR_CRT", "")
KEY_PATH = os.environ.get("REMOTE_PYTHON_EXECUTOR_KEY", "")
LOG_PATH = os.environ.get("REMOTE_PYTHON_EXECUTOR_LOG", "")
os.environ["REMOTE_PYTHON_EXECUTOR_TOKEN"] = ""

LOG_FORMATTER = logging.Formatter("%(asctime)s %(levelname)s %(funcName)s(%(lineno)d) %(message)s")
LOG_HANDLER = RotatingFileHandler(LOG_PATH, mode="a", maxBytes=LOG_SIZE_LIMIT, backupCount=LOG_BACKUP, encoding=None,
                                  delay=0)
LOG_HANDLER.setFormatter(LOG_FORMATTER)
LOG_HANDLER.setLevel(logging.INFO)
LOG = logging.getLogger("root")
LOG.setLevel(logging.INFO)
LOG.addHandler(LOG_HANDLER)


# Everything starts from here.
def start():
    context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)  # pylint: disable=E1101
    context.load_cert_chain(CRT_PATH, KEY_PATH)

    # Start listening to the JCC.
    snaplogic_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    snaplogic_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
    snaplogic_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    snaplogic_socket = context.wrap_socket(snaplogic_socket)
    snaplogic_socket.bind((REMOTE_PYTHON_EXECUTOR_HOST, REMOTE_PYTHON_EXECUTOR_PORT))
    snaplogic_socket.listen(SOCK_BACKLOG)

    while True:
        try:
            # Initialize connection with JCC.
            snaplogic_conn, snaplogic_ip_addr = snaplogic_socket.accept()
            LOG.info("Connecting to {:s}.".format(str(snaplogic_ip_addr)))
            # Start a new process for each Snap execution.
            process = SnapLogicSession(snaplogic_conn)
            process.daemon = True
            process.start()
        except KeyboardInterrupt:
            break
        except Exception:
            pass


# Main process.
class SnapLogicSession(Process):

    def __init__(self, conn):
        Process.__init__(self)
        self.conn = conn

    def run(self):
        sys.stdout = stdout_buffer = StringIO()

        exception_queue = queue.Queue(0)
        status_queue = queue.Queue(0)

        # Queue buffering incoming documents in text.
        in_text_queue = queue.Queue(QUEUE_SIZE)
        # Queue buffering incoming documents in object.
        in_obj_queue = queue.Queue(QUEUE_SIZE)
        # Queue buffering output documents in object.
        out_obj_queue = queue.Queue(QUEUE_SIZE)
        # Queue buffering output documents in text.
        out_text_queue = queue.Queue(QUEUE_SIZE)

        # Start thread reading incoming documents from JCC and store in in_text_queue.
        receiver_thread = Receiver(self.conn, in_text_queue, exception_queue)
        receiver_thread.setDaemon(True)
        receiver_thread.start()

        # Start thread reading documents in in_text_queue, parse into object and store in in_obj_queue.
        parser_thread = Parser(in_text_queue, in_obj_queue, exception_queue)
        parser_thread.setDaemon(True)
        parser_thread.start()

        # Start thread reading documents in in_obj_queue, execute Python script and store result in out_obj_queue.
        executor_thread = Executor(out_obj_queue, in_obj_queue, exception_queue)
        executor_thread.setDaemon(True)
        executor_thread.start()

        # Start thread reading documents in out_obj_queue, format into text and store in out_text_queue.
        formatter_thread = Formatter(out_obj_queue, out_text_queue, exception_queue)
        formatter_thread.setDaemon(True)
        formatter_thread.start()

        # Start thread reading documents in out_text_queue and send back to JCC.
        sender_thread = Sender(self.conn, out_text_queue, exception_queue, status_queue, stdout_buffer)
        sender_thread.setDaemon(True)
        sender_thread.start()

        # Wait until everything is completed.
        while status_queue.qsize() < 1:
            time.sleep(1)

        # Close connection.
        self.conn.close()


# Read incoming documents from JCC and store in in_text_queue.
class Receiver(Thread):

    def __init__(self, snaplogic_conn, in_text_queue, exception_queue):
        Thread.__init__(self)
        self.snaplogic_conn = snaplogic_conn
        self.in_text_queue = in_text_queue
        self.exception_queue = exception_queue

    def run(self):
        starter = ""
        end_stream = False

        try:
            while not end_stream:
                buffer_msg = starter
                while True:
                    # Read up to 1MB at a time.
                    recv_msg = self.snaplogic_conn.recv(READ_BUF_SIZE).decode(UTF_8)

                    # End batch. We may have excess message from the next batch.
                    if "\n" in recv_msg:
                        msg_chunks = recv_msg.split("\n")
                        msg = buffer_msg + msg_chunks.pop(0)
                        if not msg:
                            end_stream = True
                            self.in_text_queue.put(None)
                            break
                        self.in_text_queue.put(msg)

                        starter = msg_chunks.pop()

                        for msg in msg_chunks:
                            if len(msg) == 0:
                                end_stream = True
                                self.in_text_queue.put(None)
                                break
                            self.in_text_queue.put(msg)
                        break

                    # Need to read more to get a complete batch.
                    buffer_msg += recv_msg
        except Exception:  # pylint: disable=W0703
            LOG.exception("ERROR: Receiver failed.")
            self.in_text_queue.put(None)
            self.exception_queue.put(1)


# Read documents in in_text_queue, parse into object and store in in_obj_queue.
class Parser(Thread):
    def __init__(self, in_text_queue, in_obj_queue, exception_queue):
        Thread.__init__(self)
        self.in_text_queue = in_text_queue
        self.in_obj_queue = in_obj_queue
        self.exception_queue = exception_queue

    def run(self):
        try:
            while True:
                snaplogic_doc = self.in_text_queue.get()
                if not snaplogic_doc:
                    self.in_obj_queue.put(None)
                    break
                snaplogic_doc = json.loads(snaplogic_doc)
                self.in_obj_queue.put(snaplogic_doc)
        except Exception:  # pylint: disable=W0703
            LOG.exception("ERROR: Parser failed.")
            self.in_obj_queue.put(None)
            self.exception_queue.put(1)


# Read documents in in_obj_queue, execute Python script and store result in out_obj_queue.
class Executor(Thread):
    def __init__(self, out_obj_queue, in_obj_queue, exception_queue):
        Thread.__init__(self)
        self.snaplogic_is_init = False
        self.local_scope = {}
        self.tmp_dir = None
        self.in_flight = False
        self.pipe = None
        self.in_obj_queue = in_obj_queue
        self.out_obj_queue = out_obj_queue
        self.exception_queue = exception_queue

    def run(self):
        throw_exception = False

        try:
            while True:
                snaplogic_row = self.in_obj_queue.get()
                if snaplogic_row is None:
                    self.out_obj_queue.put(None)
                    break

                self.local_scope["snaplogic_row"] = snaplogic_row
                snaplogic_init_result = None
                if self.snaplogic_is_init is False:

                    # Check token.
                    if snaplogic_row["token"] != REMOTE_PYTHON_EXECUTOR_TOKEN:
                        throw_exception = True
                        raise Exception("The token does not match.")

                    # Extract pipeline information.
                    snaplogic_pipe = snaplogic_row["pipe_info"]
                    self.pipe = snaplogic_pipe
                    self.local_scope["snaplogic_pipe"] = snaplogic_pipe

                    # Generate tmp root directory and create if not exists.]
                    self.tmp_dir = "/tmp/snaplogic_{:s}_{:s}_{:s}/".format(
                        snaplogic_pipe.get("instanceId", uuid.uuid4().hex),
                        snaplogic_pipe.get("ruuid", uuid.uuid4().hex),
                        uuid.uuid4().hex
                    )

                    snaplogic_pipe["tmp_root"] = self.tmp_dir
                    while os.path.exists(self.tmp_dir):
                        self.tmp_dir = "/tmp/snaplogic_{:s}_{:s}_{:s}/".format(
                            snaplogic_pipe.get("instanceId", uuid.uuid4().hex),
                            snaplogic_pipe.get("ruuid", uuid.uuid4().hex),
                            uuid.uuid4().hex
                        )
                        snaplogic_pipe["tmp_root"] = self.tmp_dir
                    os.makedirs(self.tmp_dir)

                    self.local_scope["snaplogic_pipe"]["tmp_root"] = self.tmp_dir
                    self.local_scope["tmp_root"] = self.tmp_dir

                    # Get script.
                    exec(snaplogic_row["script"], self.local_scope)  # pylint: disable=W0122
                    # Execute init method.
                    exec("snaplogic_init_result = snaplogic_init()", self.local_scope)  # pylint: disable=W0122
                    self.snaplogic_is_init = True
                    snaplogic_init_result = self.local_scope["snaplogic_init_result"]

                # Execute process method on each document.
                if snaplogic_row["type"] == "data":
                    exec("snaplogic_row_result = snaplogic_process(snaplogic_row[\"data\"])",  # pylint: disable=W0122
                         self.local_scope)
                    # Execute finish method.
                elif snaplogic_row["type"] == "finish":
                    exec("snaplogic_row_result = snaplogic_final()", self.local_scope)  # pylint: disable=W0122
                snaplogic_row_result = self.local_scope["snaplogic_row_result"]

                # Only list of dicts or dict are allowed as a return value.
                if (snaplogic_init_result is not None and
                    not isinstance(snaplogic_init_result, list) and
                    not isinstance(snaplogic_init_result, dict)) or \
                        (snaplogic_row_result is not None and
                         not isinstance(snaplogic_row_result, list) and
                         not isinstance(snaplogic_row_result, dict)):
                    throw_exception = True
                    raise Exception("Return value can only be dict or list of dicts.")

                # Merge init result with process/final result.
                if snaplogic_init_result is not None:
                    snaplogic_row_result = self.combine_result(snaplogic_init_result, snaplogic_row_result)

                if snaplogic_row_result is None:
                    continue
                elif isinstance(snaplogic_row_result, dict):
                    self.out_obj_queue.put(snaplogic_row_result)
                else:
                    for result in snaplogic_row_result:
                        self.out_obj_queue.put(result)

        except Exception as exception:  # pylint: disable=W0703
            if throw_exception:
                snaplogic_row_result = {"snaplogic_remote_python_execution_error": str(exception)}
            else:
                snaplogic_row_result = {"snaplogic_remote_python_execution_error": "Python script error: " + repr(
                    exception) + " " + traceback.format_exc()}
            LOG.exception("ERROR: Executor failed.")
            self.out_obj_queue.put(snaplogic_row_result)
            self.out_obj_queue.put(None)
            self.exception_queue.put(1)

        # Remove tmp directory.
        if self.tmp_dir is not None and self.tmp_dir.startswith("/tmp/"):
            shutil.rmtree(self.tmp_dir, ignore_errors=True)

    # Merge returning results from init/process/final.
    def combine_result(self, snaplogic_init_result, snaplogic_row_result):
        if snaplogic_row_result is None:
            snaplogic_row_result = snaplogic_init_result
        elif isinstance(snaplogic_row_result, dict):
            if isinstance(snaplogic_init_result, dict):
                snaplogic_row_result = [snaplogic_init_result, snaplogic_row_result]
            elif isinstance(snaplogic_init_result, list):
                snaplogic_init_result.append(snaplogic_row_result)
                snaplogic_row_result = snaplogic_init_result
        elif isinstance(snaplogic_row_result, list):
            if isinstance(snaplogic_init_result, dict):
                snaplogic_row_result = snaplogic_row_result.insert(0, snaplogic_init_result)
            elif isinstance(snaplogic_init_result, list):
                snaplogic_row_result = snaplogic_init_result + snaplogic_row_result
        return snaplogic_row_result


# Read documents in out_obj_queue, format into text and store in out_text_queue.
class Formatter(Thread):
    def __init__(self, out_obj_queue, out_text_queue, exception_queue):
        Thread.__init__(self)
        self.out_obj_queue = out_obj_queue
        self.out_text_queue = out_text_queue
        self.exception_queue = exception_queue

    def run(self):
        try:
            while True:
                snaplogic_row_result = self.out_obj_queue.get()
                if snaplogic_row_result is None:
                    self.out_text_queue.put(None)
                    break
                else:
                    if not isinstance(snaplogic_row_result, list):
                        snaplogic_row_result = [snaplogic_row_result]
                    for result in snaplogic_row_result:
                        try:
                            result = json.dumps(result, cls=SnapLogicJsonEncoder)
                        except TypeError:
                            result = "{\"snaplogic_remote_python_execution_error\": " \
                                     "\"The output is not JSON serializable. " \
                                     "Please convert the output to primitive type.\"}"
                        self.out_text_queue.put(result)
        except KeyError:
            LOG.exception("ERROR: Formatter failed.")
            self.out_text_queue.put(None)
            self.exception_queue.put(1)


# Read documents in out_text_queue and send back to JCC.
class Sender(Thread):
    QUEUE_TIMEOUT = 10

    def __init__(self, snaplogic_conn, out_text_queue, exception_queue, status_queue, stdout):
        Thread.__init__(self)
        self.snaplogic_conn = snaplogic_conn
        self.out_text_queue = out_text_queue
        self.exception_queue = exception_queue
        self.status_queue = status_queue
        self.stdout = stdout

    def run(self):
        try:
            while True:
                try:
                    snaplogic_row_result = self.out_text_queue.get(timeout=Sender.QUEUE_TIMEOUT)
                except queue.Empty:
                    self.snaplogic_conn.sendall("{\"snaplogic_flag\":\"heartbeat\"}\n".encode(UTF_8))
                    continue
                if snaplogic_row_result is None:
                    self.stdout.seek(0)
                    self.snaplogic_conn.sendall((json.dumps({"snaplogic_stdout": str(self.stdout.read())}
                                                            , cls=SnapLogicJsonEncoder) + "\n").encode(UTF_8))
                    self.snaplogic_conn.sendall("done\n".encode(UTF_8))
                    self.status_queue.put("S")
                    break
                self.snaplogic_conn.sendall((snaplogic_row_result + "\n").encode(UTF_8))
        except Exception:  # pylint: disable=W0703
            LOG.exception("ERROR: Sender failed.")
            self.exception_queue.put(1)
            self.status_queue.put("S")


class SnapLogicJsonEncoder(json.JSONEncoder):
    def default(self, obj):  # pylint: disable=E0202
        if isinstance(obj, datetime.datetime):
            return str(obj)
        if isinstance(obj, decimal.Decimal):
            return float(obj)
        if isinstance(obj, (numpy.int_, numpy.intc, numpy.intp, numpy.int8, numpy.int16, numpy.int32, numpy.int64,
                            numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64)):
            return int(obj)
        if isinstance(obj, (numpy.float_, numpy.float16, numpy.float32, numpy.float64)):
            return float(obj)
        if isinstance(obj, (numpy.ndarray,)):
            return obj.tolist()
        if isinstance(obj, bytes):
            return obj.decode(UTF_8)
        return json.JSONEncoder.default(self, obj)


if __name__ == '__main__':
    start()
