/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.memorycontainer;

import java.time.Instant;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.ml.common.memorycontainer.MLMemoryContainer;
import org.opensearch.ml.common.memorycontainer.MemoryConfiguration;
import org.opensearch.ml.common.memorycontainer.MemoryStrategy;
import org.opensearch.ml.common.memorycontainer.MemoryStrategyType;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.memorycontainer.MLCreateMemoryContainerInput;
import org.opensearch.ml.common.transport.memorycontainer.MLCreateMemoryContainerRequest;
import org.opensearch.ml.common.transport.memorycontainer.MLCreateMemoryContainerResponse;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.MemoryContainerModelValidator;
import org.opensearch.ml.helper.MemoryContainerPipelineHelper;
import org.opensearch.ml.helper.MemoryContainerSharedIndexValidator;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.remote.metadata.client.PutDataObjectRequest;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class TransportCreateMemoryContainerAction
extends HandledTransportAction<MLCreateMemoryContainerRequest, MLCreateMemoryContainerResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportCreateMemoryContainerAction.class);
    private final MLIndicesHandler mlIndicesHandler;
    private final Client client;
    private final SdkClient sdkClient;
    private final ConnectorAccessControlHelper connectorAccessControlHelper;
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
    private final MLModelManager mlModelManager;

    @Inject
    public TransportCreateMemoryContainerAction(TransportService transportService, ActionFilters actionFilters, Client client, SdkClient sdkClient, MLIndicesHandler mlIndicesHandler, ConnectorAccessControlHelper connectorAccessControlHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting, MLModelManager mlModelManager) {
        super("cluster:admin/opensearch/ml/memory_containers/create", transportService, actionFilters, MLCreateMemoryContainerRequest::new);
        this.client = client;
        this.sdkClient = sdkClient;
        this.mlIndicesHandler = mlIndicesHandler;
        this.connectorAccessControlHelper = connectorAccessControlHelper;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.mlModelManager = mlModelManager;
    }

    protected void doExecute(Task task, MLCreateMemoryContainerRequest request, ActionListener<MLCreateMemoryContainerResponse> listener) {
        if (!this.mlFeatureEnabledSetting.isAgenticMemoryEnabled()) {
            log.warn("Agentic memory feature is disabled. Request denied.");
            listener.onFailure((Exception)new OpenSearchStatusException(MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE, RestStatus.FORBIDDEN, new Object[0]));
            return;
        }
        MLCreateMemoryContainerInput input = request.getMlCreateMemoryContainerInput();
        if (!TenantAwareHelper.validateTenantId(this.mlFeatureEnabledSetting, input.getTenantId(), listener)) {
            return;
        }
        User user = RestActionUtils.getUserContext(this.client);
        String tenantId = input.getTenantId();
        this.validateConfiguration(input.getConfiguration(), (ActionListener<Boolean>)ActionListener.wrap(isValid -> {
            ActionListener indexCheckListener = ActionListener.wrap(created -> {
                try {
                    MLMemoryContainer memoryContainer = this.buildMemoryContainer(input, user, tenantId);
                    this.indexMemoryContainer(memoryContainer, (ActionListener<String>)ActionListener.wrap(memoryContainerId -> this.createMemoryDataIndices(memoryContainer, user, (ActionListener<String>)ActionListener.wrap(actualIndexName -> listener.onResponse((Object)new MLCreateMemoryContainerResponse(memoryContainerId, "created")), arg_0 -> ((ActionListener)listener).onFailure(arg_0))), arg_0 -> ((ActionListener)listener).onFailure(arg_0)));
                }
                catch (Exception e) {
                    log.error("Failed to create memory container", (Throwable)e);
                    listener.onFailure(e);
                }
            }, arg_0 -> ((ActionListener)listener).onFailure(arg_0));
            this.mlIndicesHandler.initMemoryContainerIndex(indexCheckListener);
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    private MLMemoryContainer buildMemoryContainer(MLCreateMemoryContainerInput input, User user, String tenantId) {
        Instant now = Instant.now();
        MemoryConfiguration configuration = input.getConfiguration();
        if (configuration != null && configuration.getStrategies() != null) {
            for (MemoryStrategy strategy : configuration.getStrategies()) {
                if (strategy.getId() == null || strategy.getId().isBlank()) {
                    strategy.setId(MemoryStrategy.generateStrategyId((MemoryStrategyType)strategy.getType()));
                }
                if (strategy.getEnabled() != null) continue;
                strategy.setEnabled(Boolean.valueOf(true));
            }
        }
        return MLMemoryContainer.builder().name(input.getName()).description(input.getDescription()).owner(user).tenantId(tenantId).createdTime(now).lastUpdatedTime(now).configuration(configuration).backendRoles(input.getBackendRoles()).build();
    }

    private void createMemoryDataIndices(MLMemoryContainer container, User user, ActionListener<String> listener) {
        String userId = user != null ? user.getName() : "default";
        MemoryConfiguration configuration = container.getConfiguration();
        String indexPrefix = configuration != null ? configuration.getIndexPrefix() : null;
        String sessionIndexName = configuration.getSessionIndexName();
        String workingMemoryIndexName = configuration.getWorkingMemoryIndexName();
        String longTermMemoryIndexName = configuration.getLongMemoryIndexName();
        String longTermMemoryHistoryIndexName = configuration.getLongMemoryHistoryIndexName();
        if (configuration.getStrategies() == null || configuration.getStrategies().isEmpty()) {
            if (configuration.isDisableSession()) {
                this.mlIndicesHandler.createWorkingMemoryDataIndex(workingMemoryIndexName, configuration, ActionListener.wrap(success -> listener.onResponse((Object)workingMemoryIndexName), arg_0 -> listener.onFailure(arg_0)));
            } else {
                this.mlIndicesHandler.createSessionMemoryDataIndex(sessionIndexName, configuration, ActionListener.wrap(result -> this.mlIndicesHandler.createWorkingMemoryDataIndex(workingMemoryIndexName, configuration, ActionListener.wrap(success -> listener.onResponse((Object)workingMemoryIndexName), arg_0 -> ((ActionListener)listener).onFailure(arg_0))), arg_0 -> listener.onFailure(arg_0)));
            }
        } else if (configuration.isDisableSession()) {
            this.createMemoryIndexes(container, listener, configuration, workingMemoryIndexName, longTermMemoryIndexName, longTermMemoryHistoryIndexName);
        } else {
            this.mlIndicesHandler.createSessionMemoryDataIndex(sessionIndexName, configuration, ActionListener.wrap(result -> this.createMemoryIndexes(container, listener, configuration, workingMemoryIndexName, longTermMemoryIndexName, longTermMemoryHistoryIndexName), arg_0 -> listener.onFailure(arg_0)));
        }
    }

    private void createMemoryIndexes(MLMemoryContainer container, ActionListener<String> listener, MemoryConfiguration configuration, String workingMemoryIndexName, String longTermMemoryIndexName, String longTermMemoryHistoryIndexName) {
        this.mlIndicesHandler.createWorkingMemoryDataIndex(workingMemoryIndexName, configuration, ActionListener.wrap(success -> this.createLongTermMemoryIngestPipeline(longTermMemoryIndexName, container.getConfiguration(), (ActionListener<Boolean>)ActionListener.wrap(success1 -> {
            if (!configuration.isDisableHistory()) {
                this.mlIndicesHandler.createLongTermMemoryHistoryIndex(longTermMemoryHistoryIndexName, configuration, ActionListener.wrap(success2 -> listener.onResponse((Object)longTermMemoryIndexName), arg_0 -> ((ActionListener)listener).onFailure(arg_0)));
            } else {
                listener.onResponse((Object)longTermMemoryIndexName);
            }
        }, arg_0 -> ((ActionListener)listener).onFailure(arg_0))), arg_0 -> listener.onFailure(arg_0)));
    }

    private void createLongTermMemoryIngestPipeline(String indexName, MemoryConfiguration memoryConfig, ActionListener<Boolean> listener) {
        MemoryContainerPipelineHelper.createLongTermMemoryIngestPipeline(indexName, memoryConfig, this.mlIndicesHandler, this.client, listener);
    }

    private void indexMemoryContainer(MLMemoryContainer container, ActionListener<String> listener) {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            this.sdkClient.putDataObjectAsync(((PutDataObjectRequest.Builder)((PutDataObjectRequest.Builder)PutDataObjectRequest.builder().tenantId(container.getTenantId())).index(".plugins-ml-am-memory-container")).dataObject((ToXContentObject)container).build()).whenComplete((r, throwable) -> {
                context.restore();
                if (throwable != null) {
                    Exception cause = SdkClientUtils.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[0]);
                    log.error("Failed to index memory container", (Throwable)cause);
                    listener.onFailure(cause);
                } else {
                    try {
                        IndexResponse indexResponse = r.indexResponse();
                        assert (indexResponse != null);
                        if (indexResponse.getResult() == DocWriteResponse.Result.CREATED) {
                            String generatedId = indexResponse.getId();
                            log.info("Successfully created memory container with ID: {}", (Object)generatedId);
                            listener.onResponse((Object)generatedId);
                        } else {
                            log.error("Failed to create memory container - unexpected index response result: {}", (Object)indexResponse.getResult());
                            listener.onFailure((Exception)new RuntimeException("Failed to create memory container"));
                        }
                    }
                    catch (Exception e) {
                        listener.onFailure(e);
                    }
                }
            });
        }
        catch (Exception e) {
            log.error("Failed to save memory container", (Throwable)e);
            listener.onFailure(e);
        }
    }

    private void validateConfiguration(MemoryConfiguration config, ActionListener<Boolean> listener) {
        try {
            MemoryConfiguration.validateStrategiesRequireModels((MemoryConfiguration)config);
        }
        catch (IllegalArgumentException e) {
            log.error("Strategy validation failed: {}", (Object)e.getMessage());
            listener.onFailure((Exception)e);
            return;
        }
        if (config.getStrategies() != null) {
            for (MemoryStrategy strategy : config.getStrategies()) {
                try {
                    MemoryStrategy.validate((MemoryStrategy)strategy);
                }
                catch (IllegalArgumentException e) {
                    log.error("Strategy validation failed: {}", (Object)e.getMessage());
                    listener.onFailure((Exception)e);
                    return;
                }
            }
        }
        MemoryContainerModelValidator.validateLlmModel(config.getLlmId(), this.mlModelManager, this.client, (ActionListener<Boolean>)ActionListener.wrap(isValid -> MemoryContainerModelValidator.validateEmbeddingModel(config.getEmbeddingModelId(), config.getEmbeddingModelType(), this.mlModelManager, this.client, (ActionListener<Boolean>)ActionListener.wrap(embeddingValid -> MemoryContainerSharedIndexValidator.validateSharedIndexCompatibility(config, config.getLongMemoryIndexName(), this.client, (ActionListener<MemoryContainerSharedIndexValidator.ValidationResult>)ActionListener.wrap(result -> listener.onResponse((Object)true), arg_0 -> ((ActionListener)listener).onFailure(arg_0))), arg_0 -> ((ActionListener)listener).onFailure(arg_0))), arg_0 -> listener.onFailure(arg_0)));
    }
}

