/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.redshift.plugin;

import com.amazon.redshift.CredentialsHolder;
import com.amazon.redshift.IPlugin;
import com.amazon.redshift.RedshiftProperty;
import com.amazon.redshift.httpclient.log.IamCustomLogFactory;
import com.amazon.redshift.logger.RedshiftLogger;
import com.amazon.redshift.plugin.IdpCredentialsProvider;
import com.amazon.redshift.plugin.utils.RequestUtils;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.SdkClientException;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.AnonymousAWSCredentials;
import com.amazonaws.auth.BasicSessionCredentials;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.AssumeRoleWithSAMLRequest;
import com.amazonaws.services.securitytoken.model.AssumeRoleWithSAMLResult;
import com.amazonaws.services.securitytoken.model.Credentials;
import com.amazonaws.util.StringUtils;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.xpath.XPath;
import javax.xml.xpath.XPathConstants;
import javax.xml.xpath.XPathExpressionException;
import javax.xml.xpath.XPathFactory;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.LogFactory;
import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;

public abstract class SamlCredentialsProvider
extends IdpCredentialsProvider
implements IPlugin {
    protected static final String KEY_IDP_HOST = "idp_host";
    private static final String KEY_IDP_PORT = "idp_port";
    private static final String KEY_DURATION = "duration";
    private static final String KEY_PREFERRED_ROLE = "preferred_role";
    protected String m_userName;
    protected String m_password;
    protected String m_idpHost;
    protected int m_idpPort = 443;
    protected int m_duration;
    protected String m_preferredRole;
    protected String m_dbUser;
    protected String m_dbGroups;
    protected String m_dbGroupsFilter;
    protected Boolean m_forceLowercase;
    protected Boolean m_autoCreate;
    protected String m_stsEndpoint;
    protected String m_region;
    protected Boolean m_disableCache = false;
    protected Boolean m_groupFederation = false;
    private static Map<String, CredentialsHolder> m_cache = new HashMap<String, CredentialsHolder>();
    private CredentialsHolder m_lastRefreshCredentials;
    private static final Class<?> CUSTOM_LOG_FACTORY_CLASS = IamCustomLogFactory.class;
    private static final String LOG_PROPERTIES_FILE_NAME = "log-factory.properties";
    private static final String LOG_PROPERTIES_FILE_PATH = "META-INF/services/org.apache.commons.logging.LogFactory";
    private static final ClassLoader CONTEXT_CLASS_LOADER = new ClassLoader(SamlCredentialsProvider.class.getClassLoader()){

        @Override
        public Class<?> loadClass(String name) throws ClassNotFoundException {
            Class<?> clazz = this.getParent().loadClass(name);
            if (LogFactory.class.isAssignableFrom(clazz)) {
                return CUSTOM_LOG_FACTORY_CLASS;
            }
            return clazz;
        }

        @Override
        public Enumeration<URL> getResources(String name) throws IOException {
            if ("commons-logging.properties".equals(name)) {
                return Collections.enumeration(Collections.emptyList());
            }
            return super.getResources(name);
        }

        @Override
        public URL getResource(String name) {
            if (SamlCredentialsProvider.LOG_PROPERTIES_FILE_PATH.equals(name)) {
                return SamlCredentialsProvider.class.getResource(SamlCredentialsProvider.LOG_PROPERTIES_FILE_NAME);
            }
            return super.getResource(name);
        }
    };

    protected abstract String getSamlAssertion() throws IOException;

    @Override
    public void addParameter(String key, String value) {
        if (RedshiftLogger.isEnable()) {
            this.m_log.logDebug("key: {0}", key);
        }
        if (RedshiftProperty.UID.getName().equalsIgnoreCase(key) || RedshiftProperty.USER.getName().equalsIgnoreCase(key)) {
            this.m_userName = value;
        } else if (RedshiftProperty.PWD.getName().equalsIgnoreCase(key) || RedshiftProperty.PASSWORD.getName().equalsIgnoreCase(key)) {
            this.m_password = value;
        } else if (KEY_IDP_HOST.equalsIgnoreCase(key)) {
            this.m_idpHost = value;
        } else if (KEY_IDP_PORT.equalsIgnoreCase(key)) {
            this.m_idpPort = Integer.parseInt(value);
        } else if (KEY_DURATION.equalsIgnoreCase(key)) {
            this.m_duration = Integer.parseInt(value);
        } else if (KEY_PREFERRED_ROLE.equalsIgnoreCase(key)) {
            this.m_preferredRole = value;
        } else if ("ssl_insecure".equalsIgnoreCase(key)) {
            this.m_sslInsecure = Boolean.parseBoolean(value);
        } else if (RedshiftProperty.DB_USER.getName().equalsIgnoreCase(key)) {
            this.m_dbUser = value;
        } else if (RedshiftProperty.DB_GROUPS.getName().equalsIgnoreCase(key)) {
            this.m_dbGroups = value;
        } else if (RedshiftProperty.DB_GROUPS_FILTER.getName().equalsIgnoreCase(key)) {
            this.m_dbGroupsFilter = value;
        } else if (RedshiftProperty.FORCE_LOWERCASE.getName().equalsIgnoreCase(key)) {
            this.m_forceLowercase = Boolean.valueOf(value);
        } else if (RedshiftProperty.USER_AUTOCREATE.getName().equalsIgnoreCase(key)) {
            this.m_autoCreate = Boolean.valueOf(value);
        } else if (RedshiftProperty.AWS_REGION.getName().equalsIgnoreCase(key)) {
            this.m_region = value;
        } else if (RedshiftProperty.STS_ENDPOINT_URL.getName().equalsIgnoreCase(key)) {
            this.m_stsEndpoint = value;
        } else if (RedshiftProperty.IAM_DISABLE_CACHE.getName().equalsIgnoreCase(key)) {
            this.m_disableCache = Boolean.valueOf(value);
        }
    }

    @Override
    public void setLogger(RedshiftLogger log) {
        this.m_log = log;
    }

    @Override
    public int getSubType() {
        return 1;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public CredentialsHolder getCredentials() {
        Object key;
        CredentialsHolder credentials = null;
        if (!this.m_disableCache.booleanValue()) {
            key = this.getCacheKey();
            credentials = m_cache.get(key);
        }
        if (credentials == null || credentials.isExpired()) {
            if (RedshiftLogger.isEnable()) {
                this.m_log.logInfo("SAML getCredentials NOT from cache", new Object[0]);
            }
            key = this;
            synchronized (key) {
                this.refresh();
                if (this.m_disableCache.booleanValue()) {
                    credentials = this.m_lastRefreshCredentials;
                    this.m_lastRefreshCredentials = null;
                }
            }
        } else {
            credentials.setRefresh(false);
            if (RedshiftLogger.isEnable()) {
                this.m_log.logInfo("SAML getCredentials from cache", new Object[0]);
            }
        }
        if (!this.m_disableCache.booleanValue()) {
            credentials = m_cache.get(this.getCacheKey());
        }
        if (!StringUtils.isNullOrEmpty((String)this.m_dbUser)) {
            credentials.getThisMetadata().setDbUser(this.m_dbUser);
        }
        if (credentials == null) {
            throw new SdkClientException("Unable to load AWS credentials from ADFS");
        }
        if (RedshiftLogger.isEnable()) {
            Date now = new Date();
            this.m_log.logInfo(now + ": Using entry for SamlCredentialsProvider.getCredentials cache with expiration " + credentials.getExpiration(), new Object[0]);
        }
        return credentials;
    }

    public void refresh() {
        Thread currentThread = Thread.currentThread();
        ClassLoader cl = currentThread.getContextClassLoader();
        Thread.currentThread().setContextClassLoader(CONTEXT_CLASS_LOADER);
        try {
            String principal;
            String roleArn;
            String samlAssertion = this.getSamlAssertion();
            if (RedshiftLogger.isEnable()) {
                this.m_log.logDebug("SamlCredentialsProvider: Received SAML assertion of length={0}", samlAssertion != null ? samlAssertion.length() : -1);
            }
            Pattern SAML_PROVIDER_PATTERN = Pattern.compile("arn:aws[-a-z]*:iam::\\d*:saml-provider/\\S+");
            Pattern ROLE_PATTERN = Pattern.compile("arn:aws[-a-z]*:iam::\\d*:role/\\S+");
            Document doc = SamlCredentialsProvider.parse(Base64.decodeBase64((String)samlAssertion));
            XPath xPath = XPathFactory.newInstance().newXPath();
            String expression = "//*[local-name()='Attribute'][@Name='https://aws.amazon.com/SAML/Attributes/Role']/*[local-name()='AttributeValue']/text()";
            NodeList nodeList = (NodeList)xPath.compile(expression).evaluate(doc, XPathConstants.NODESET);
            HashMap<String, String> roles = new HashMap<String, String>();
            if (nodeList != null) {
                for (int i = 0; i < nodeList.getLength(); ++i) {
                    Node node = nodeList.item(i);
                    String value = node.getNodeValue();
                    String[] arns = value.split(",");
                    if (arns.length < 2) continue;
                    String provider = null;
                    String role = null;
                    for (String arn : arns) {
                        Matcher providerMatcher = SAML_PROVIDER_PATTERN.matcher(arn);
                        if (providerMatcher.find()) {
                            provider = providerMatcher.group(0);
                            continue;
                        }
                        Matcher roleMatcher = ROLE_PATTERN.matcher(arn);
                        if (!roleMatcher.find()) continue;
                        role = roleMatcher.group(0);
                    }
                    if (StringUtils.isNullOrEmpty(role) || StringUtils.isNullOrEmpty(provider)) continue;
                    roles.put(role, provider);
                }
            }
            if (roles.isEmpty()) {
                throw new SdkClientException("No role found in SamlAssertion: " + samlAssertion);
            }
            if (this.m_preferredRole != null) {
                roleArn = this.m_preferredRole;
                principal = (String)roles.get(this.m_preferredRole);
                if (principal == null) {
                    throw new SdkClientException("Preferred role not found in SamlAssertion: " + samlAssertion);
                }
            } else {
                Map.Entry entry = roles.entrySet().iterator().next();
                roleArn = (String)entry.getKey();
                principal = (String)entry.getValue();
            }
            AssumeRoleWithSAMLRequest samlRequest = new AssumeRoleWithSAMLRequest();
            samlRequest.setSAMLAssertion(samlAssertion);
            samlRequest.setRoleArn(roleArn);
            samlRequest.setPrincipalArn(principal);
            if (this.m_duration > 0) {
                samlRequest.setDurationSeconds(Integer.valueOf(this.m_duration));
            }
            AWSStaticCredentialsProvider p = new AWSStaticCredentialsProvider((AWSCredentials)new AnonymousAWSCredentials());
            AWSSecurityTokenServiceClientBuilder builder = AWSSecurityTokenServiceClientBuilder.standard();
            ClientConfiguration config = null;
            builder.withClientConfiguration(config);
            AWSSecurityTokenService stsSvc = RequestUtils.buildSts(this.m_stsEndpoint, this.m_region, builder, (AWSCredentialsProvider)p, this.m_log);
            AssumeRoleWithSAMLResult result = stsSvc.assumeRoleWithSAML(samlRequest);
            Credentials cred = result.getCredentials();
            Date expiration = cred.getExpiration();
            BasicSessionCredentials c = new BasicSessionCredentials(cred.getAccessKeyId(), cred.getSecretAccessKey(), cred.getSessionToken());
            CredentialsHolder credentials = CredentialsHolder.newInstance((AWSCredentials)c, expiration);
            credentials.setMetadata(this.readMetadata(doc));
            credentials.setRefresh(true);
            if (!this.m_disableCache.booleanValue()) {
                m_cache.put(this.getCacheKey(), credentials);
            } else {
                this.m_lastRefreshCredentials = credentials;
            }
        }
        catch (IOException e) {
            if (RedshiftLogger.isEnable()) {
                this.m_log.logError(e);
            }
            throw new SdkClientException("SAML error: " + e.getMessage(), (Throwable)e);
        }
        catch (SAXException e) {
            if (RedshiftLogger.isEnable()) {
                this.m_log.logError(e);
            }
            throw new SdkClientException("SAML error: " + e.getMessage(), (Throwable)e);
        }
        catch (ParserConfigurationException e) {
            if (RedshiftLogger.isEnable()) {
                this.m_log.logError(e);
            }
            throw new SdkClientException("SAML error: " + e.getMessage(), (Throwable)e);
        }
        catch (XPathExpressionException e) {
            if (RedshiftLogger.isEnable()) {
                this.m_log.logError(e);
            }
            throw new SdkClientException("SAML error: " + e.getMessage(), (Throwable)e);
        }
        catch (Exception e) {
            if (RedshiftLogger.isEnable()) {
                this.m_log.logError(e);
            }
            throw new SdkClientException("SAML error: " + e.getMessage(), (Throwable)e);
        }
        finally {
            currentThread.setContextClassLoader(cl);
        }
    }

    @Override
    public String getPluginSpecificCacheKey() {
        return "";
    }

    @Override
    public String getIdpToken() {
        String samlAssertion = null;
        Thread currentThread = Thread.currentThread();
        ClassLoader cl = currentThread.getContextClassLoader();
        Thread.currentThread().setContextClassLoader(CONTEXT_CLASS_LOADER);
        try {
            samlAssertion = this.getSamlAssertion();
            if (RedshiftLogger.isEnable()) {
                this.m_log.logDebug("SamlCredentialsProvider: Got SAML assertion of length={0}", samlAssertion != null ? samlAssertion.length() : -1);
            }
        }
        catch (IOException e) {
            if (RedshiftLogger.isEnable()) {
                this.m_log.logError(e);
            }
            throw new SdkClientException("SAML error: " + e.getMessage(), (Throwable)e);
        }
        catch (Exception e) {
            if (RedshiftLogger.isEnable()) {
                this.m_log.logError(e);
            }
            throw new SdkClientException("SAML error: " + e.getMessage(), (Throwable)e);
        }
        finally {
            currentThread.setContextClassLoader(cl);
        }
        return samlAssertion;
    }

    @Override
    public void setGroupFederation(boolean groupFederation) {
        this.m_groupFederation = groupFederation;
    }

    @Override
    public String getCacheKey() {
        String pluginSpecificKey = this.getPluginSpecificCacheKey();
        return this.m_userName + this.m_password + this.m_idpHost + this.m_idpPort + this.m_duration + this.m_preferredRole + pluginSpecificKey;
    }

    private CredentialsHolder.IamMetadata readMetadata(Document doc) throws XPathExpressionException {
        CredentialsHolder.IamMetadata metadata = new CredentialsHolder.IamMetadata();
        XPath xPath = XPathFactory.newInstance().newXPath();
        List<String> attributeValues = SamlCredentialsProvider.GetSAMLAttributeValues(xPath, doc, "https://redshift.amazon.com/SAML/Attributes/AllowDbUserOverride");
        if (!attributeValues.isEmpty()) {
            metadata.setAllowDbUserOverride(Boolean.valueOf(attributeValues.get(0)));
        }
        if (!(attributeValues = SamlCredentialsProvider.GetSAMLAttributeValues(xPath, doc, "https://redshift.amazon.com/SAML/Attributes/DbUser")).isEmpty()) {
            metadata.setSamlDbUser(attributeValues.get(0));
        } else {
            attributeValues = SamlCredentialsProvider.GetSAMLAttributeValues(xPath, doc, "https://aws.amazon.com/SAML/Attributes/RoleSessionName");
            if (!attributeValues.isEmpty()) {
                metadata.setSamlDbUser(attributeValues.get(0));
            }
        }
        attributeValues = SamlCredentialsProvider.GetSAMLAttributeValues(xPath, doc, "https://redshift.amazon.com/SAML/Attributes/AutoCreate");
        if (!attributeValues.isEmpty()) {
            metadata.setAutoCreate(Boolean.valueOf(attributeValues.get(0)));
        }
        if (!(attributeValues = SamlCredentialsProvider.GetSAMLAttributeValues(xPath, doc, "https://redshift.amazon.com/SAML/Attributes/DbGroups")).isEmpty() && !(attributeValues = this.filterOutGroups(attributeValues)).isEmpty()) {
            StringBuilder sb = new StringBuilder();
            for (String value : attributeValues) {
                if (sb.length() > 0) {
                    sb.append(',');
                }
                sb.append(value);
            }
            metadata.setDbGroups(sb.toString());
        }
        if (!(attributeValues = SamlCredentialsProvider.GetSAMLAttributeValues(xPath, doc, "https://redshift.amazon.com/SAML/Attributes/ForceLowercase")).isEmpty()) {
            metadata.setForceLowercase(Boolean.valueOf(attributeValues.get(0)));
        }
        return metadata;
    }

    private List<String> filterOutGroups(List<String> attributeValues) {
        if (this.m_dbGroupsFilter != null) {
            Pattern groupsFilter = Pattern.compile(this.m_dbGroupsFilter);
            ArrayList<String> ret = new ArrayList<String>();
            for (String attributeValue : attributeValues) {
                this.m_log.logDebug("Check group {0} with regexp {1}", attributeValue, this.m_dbGroupsFilter);
                if (groupsFilter.matcher(attributeValue).matches()) continue;
                this.m_log.logDebug("Add {0} to dbgroups", attributeValue);
                ret.add(attributeValue);
            }
            return ret;
        }
        return attributeValues;
    }

    private static Document parse(byte[] samlAssertion) throws IOException, SAXException, ParserConfigurationException {
        DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
        factory.setFeature("http://apache.org/xml/features/disallow-doctype-decl", true);
        factory.setXIncludeAware(false);
        factory.setExpandEntityReferences(false);
        factory.setFeature("http://xml.org/sax/features/external-parameter-entities", false);
        factory.setFeature("http://xml.org/sax/features/external-general-entities", false);
        DocumentBuilder db = factory.newDocumentBuilder();
        return db.parse(new ByteArrayInputStream(samlAssertion));
    }

    private static List<String> GetSAMLAttributeValues(XPath xPath, Document doc, String attributeName) throws XPathExpressionException {
        String expression = String.format("//Attribute[@Name='%s']/AttributeValue/text()", attributeName);
        NodeList nodeList = (NodeList)xPath.compile(expression).evaluate(doc, XPathConstants.NODESET);
        if (null == nodeList || nodeList.getLength() == 0) {
            return Collections.emptyList();
        }
        ArrayList<String> attributeValues = new ArrayList<String>(nodeList.getLength());
        for (int i = 0; i < nodeList.getLength(); ++i) {
            Node node = nodeList.item(i);
            attributeValues.add(node.getNodeValue());
        }
        return attributeValues;
    }

    protected List<String> getInputTagsfromHTML(String body) {
        HashSet<String> distinctInputTags = new HashSet<String>();
        ArrayList<String> inputTags = new ArrayList<String>();
        Pattern inputTagPattern = Pattern.compile("<input(.+?)/>", 32);
        Matcher inputTagMatcher = inputTagPattern.matcher(body);
        while (inputTagMatcher.find()) {
            String tag = inputTagMatcher.group(0);
            String tagNameLower = this.getValueByKey(tag, "name").toLowerCase();
            if (tagNameLower.isEmpty() || !distinctInputTags.add(tagNameLower)) continue;
            inputTags.add(tag);
        }
        return inputTags;
    }

    protected String getFormAction(String body) {
        Pattern pattern = Pattern.compile("<form.*?action=\"([^\"]+)\"");
        Matcher m = pattern.matcher(body);
        if (m.find()) {
            return this.escapeHtmlEntity(m.group(1));
        }
        return null;
    }

    protected String getValueByKey(String input, String key) {
        Pattern keyValuePattern = Pattern.compile("(" + Pattern.quote(key) + ")\\s*=\\s*\"(.*?)\"");
        Matcher keyValueMatcher = keyValuePattern.matcher(input);
        if (keyValueMatcher.find()) {
            return this.escapeHtmlEntity(keyValueMatcher.group(2));
        }
        return "";
    }

    protected String getValueByKeyWithoutQuotesAndValueInSingleQuote(String input, String key) {
        Pattern keyValuePattern = Pattern.compile("(" + key + ")\\s*=\\s*'(.*?)'");
        Matcher keyValueMatcher = keyValuePattern.matcher(input);
        if (keyValueMatcher.find()) {
            return this.escapeHtmlEntity(keyValueMatcher.group(2));
        }
        return "";
    }

    protected String escapeHtmlEntity(String html) {
        StringBuilder sb = new StringBuilder(html.length());
        int i = 0;
        int length = html.length();
        while (i < length) {
            char c = html.charAt(i);
            if (c != '&') {
                sb.append(c);
                ++i;
                continue;
            }
            if (html.startsWith("&amp;", i)) {
                sb.append('&');
                i += 5;
                continue;
            }
            if (html.startsWith("&apos;", i)) {
                sb.append('\'');
                i += 6;
                continue;
            }
            if (html.startsWith("&quot;", i)) {
                sb.append('\"');
                i += 6;
                continue;
            }
            if (html.startsWith("&lt;", i)) {
                sb.append('<');
                i += 4;
                continue;
            }
            if (html.startsWith("&gt;", i)) {
                sb.append('>');
                i += 4;
                continue;
            }
            sb.append(c);
            ++i;
        }
        return sb.toString();
    }

    protected void checkRequiredParameters() throws IOException {
        if (StringUtils.isNullOrEmpty((String)this.m_userName)) {
            throw new IOException("Missing required property: " + RedshiftProperty.USER.getName());
        }
        if (StringUtils.isNullOrEmpty((String)this.m_password)) {
            throw new IOException("Missing required property: " + RedshiftProperty.PASSWORD.getName());
        }
        if (StringUtils.isNullOrEmpty((String)this.m_idpHost)) {
            throw new IOException("Missing required property: idp_host");
        }
    }

    protected boolean isText(String inputTag) {
        String typeVal = this.getValueByKey(inputTag, "type");
        if (typeVal == null || typeVal.length() == 0) {
            typeVal = this.getValueByKeyWithoutQuotesAndValueInSingleQuote(inputTag, "type");
        }
        return "text".equals(typeVal);
    }

    protected boolean isPassword(String inputTag) {
        String typeVal = this.getValueByKey(inputTag, "type");
        if (typeVal == null || typeVal.length() == 0) {
            typeVal = this.getValueByKeyWithoutQuotesAndValueInSingleQuote(inputTag, "type");
        }
        return "password".equals(typeVal);
    }
}

