/*
 * Decompiled with CFR 0.152.
 */
package ai.pqcrypto.sdk;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.SecureRandom;
import java.security.Security;
import java.util.Base64;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import org.bouncycastle.asn1.ASN1InputStream;
import org.bouncycastle.asn1.ASN1OctetString;
import org.bouncycastle.asn1.ASN1Primitive;
import org.bouncycastle.asn1.ASN1Sequence;
import org.bouncycastle.cert.X509CertificateHolder;
import org.bouncycastle.crypto.params.AsymmetricKeyParameter;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.pqc.crypto.mldsa.MLDSAParameters;
import org.bouncycastle.pqc.crypto.mldsa.MLDSAPrivateKeyParameters;
import org.bouncycastle.tls.Certificate;
import org.bouncycastle.tls.CertificateEntry;
import org.bouncycastle.tls.CertificateRequest;
import org.bouncycastle.tls.DefaultTlsClient;
import org.bouncycastle.tls.ProtocolVersion;
import org.bouncycastle.tls.SignatureAndHashAlgorithm;
import org.bouncycastle.tls.TlsAuthentication;
import org.bouncycastle.tls.TlsClientProtocol;
import org.bouncycastle.tls.TlsCredentials;
import org.bouncycastle.tls.TlsServerCertificate;
import org.bouncycastle.tls.crypto.TlsCryptoParameters;
import org.bouncycastle.tls.crypto.impl.bc.BcDefaultTlsCredentialedSigner;
import org.bouncycastle.tls.crypto.impl.bc.BcTlsCertificate;
import org.bouncycastle.tls.crypto.impl.bc.BcTlsCrypto;

public class MLDSATlsClient {
    private final byte[] clientCertBytes;
    private final byte[] caCertBytes;
    private final AsymmetricKeyParameter clientPrivateKey;
    private final MLDSAParameters mldsaParams;
    private final int connectTimeout;
    private final int readTimeout;

    private MLDSATlsClient(Builder builder) throws Exception {
        this.clientCertBytes = MLDSATlsClient.loadCertBytes(builder.clientCertPath);
        this.caCertBytes = MLDSATlsClient.loadCertBytes(builder.caCertPath);
        this.clientPrivateKey = MLDSATlsClient.loadPrivateKey(builder.clientKeyPath, builder.mldsaParams);
        this.mldsaParams = builder.mldsaParams;
        this.connectTimeout = builder.connectTimeout;
        this.readTimeout = builder.readTimeout;
    }

    public Response get(String host, int port, String path) throws Exception {
        return this.request("GET", host, port, path, null, null);
    }

    public Response get(String host, int port, String path, Map<String, String> headers) throws Exception {
        return this.request("GET", host, port, path, headers, null);
    }

    public Response post(String host, int port, String path, String body) throws Exception {
        return this.request("POST", host, port, path, null, body);
    }

    public Response post(String host, int port, String path, Map<String, String> headers, String body) throws Exception {
        return this.request("POST", host, port, path, headers, body);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Response request(String method, String host, int port, String path, Map<String, String> headers, String body) throws Exception {
        Socket socket = new Socket();
        socket.connect(new InetSocketAddress(host, port), this.connectTimeout);
        socket.setSoTimeout(this.readTimeout);
        BcTlsCrypto crypto = new BcTlsCrypto(new SecureRandom());
        TlsClientProtocol protocol = new TlsClientProtocol(socket.getInputStream(), socket.getOutputStream());
        try {
            protocol.connect(this.createTlsClient(crypto, host));
            StringBuilder request = new StringBuilder();
            request.append(method).append(" ").append(path).append(" HTTP/1.1\r\n");
            request.append("Host: ").append(host).append("\r\n");
            request.append("Connection: close\r\n");
            if (headers != null) {
                for (Map.Entry<String, String> h : headers.entrySet()) {
                    request.append(h.getKey()).append(": ").append(h.getValue()).append("\r\n");
                }
            }
            if (body != null) {
                byte[] bodyBytes = body.getBytes(StandardCharsets.UTF_8);
                if (headers == null || !headers.containsKey("Content-Type")) {
                    request.append("Content-Type: application/json\r\n");
                }
                request.append("Content-Length: ").append(bodyBytes.length).append("\r\n");
                request.append("\r\n");
                request.append(body);
            } else {
                request.append("\r\n");
            }
            OutputStream output = protocol.getOutputStream();
            output.write(request.toString().getBytes(StandardCharsets.UTF_8));
            output.flush();
            InputStream input = protocol.getInputStream();
            Response response = this.parseResponse(input);
            return response;
        }
        finally {
            try {
                protocol.close();
            }
            catch (Exception exception) {}
            try {
                socket.close();
            }
            catch (Exception exception) {}
        }
    }

    private DefaultTlsClient createTlsClient(final BcTlsCrypto crypto, String host) {
        return new DefaultTlsClient(crypto){

            @Override
            public TlsAuthentication getAuthentication() {
                return new TlsAuthentication(){

                    @Override
                    public void notifyServerCertificate(TlsServerCertificate serverCertificate) {
                    }

                    @Override
                    public TlsCredentials getClientCredentials(CertificateRequest request) throws IOException {
                        CertificateEntry[] entries = new CertificateEntry[]{new CertificateEntry(new BcTlsCertificate(crypto, MLDSATlsClient.this.clientCertBytes), null), new CertificateEntry(new BcTlsCertificate(crypto, MLDSATlsClient.this.caCertBytes), null)};
                        Certificate cert = new Certificate(new byte[0], entries);
                        SignatureAndHashAlgorithm sigAlg = MLDSATlsClient.this.getSignatureAlgorithm();
                        return new BcDefaultTlsCredentialedSigner(new TlsCryptoParameters(context), crypto, MLDSATlsClient.this.clientPrivateKey, cert, sigAlg);
                    }
                };
            }

            @Override
            protected int[] getSupportedCipherSuites() {
                return new int[]{4866, 4865, 4867};
            }

            @Override
            protected ProtocolVersion[] getSupportedVersions() {
                return new ProtocolVersion[]{ProtocolVersion.TLSv13};
            }
        };
    }

    private SignatureAndHashAlgorithm getSignatureAlgorithm() {
        if (this.mldsaParams == MLDSAParameters.ml_dsa_44) {
            return SignatureAndHashAlgorithm.mldsa44;
        }
        if (this.mldsaParams == MLDSAParameters.ml_dsa_87) {
            return SignatureAndHashAlgorithm.mldsa87;
        }
        return SignatureAndHashAlgorithm.mldsa65;
    }

    private Response parseResponse(InputStream input) throws IOException {
        String line;
        BufferedReader reader = new BufferedReader(new InputStreamReader(input, StandardCharsets.UTF_8));
        String statusLine = reader.readLine();
        if (statusLine == null) {
            throw new IOException("Empty response");
        }
        String[] statusParts = statusLine.split(" ", 3);
        int statusCode = Integer.parseInt(statusParts[1]);
        String statusMessage = statusParts.length > 2 ? statusParts[2] : "";
        LinkedHashMap<String, String> headers = new LinkedHashMap<String, String>();
        int contentLength = -1;
        while ((line = reader.readLine()) != null && !line.isEmpty()) {
            int idx = line.indexOf(58);
            if (idx <= 0) continue;
            String key = line.substring(0, idx).trim();
            String value = line.substring(idx + 1).trim();
            headers.put(key, value);
            if (!key.equalsIgnoreCase("Content-Length")) continue;
            contentLength = Integer.parseInt(value);
        }
        StringBuilder body = new StringBuilder();
        if (contentLength > 0) {
            char[] buffer = new char[contentLength];
            int read = reader.read(buffer, 0, contentLength);
            if (read > 0) {
                body.append(buffer, 0, read);
            }
        } else {
            while ((line = reader.readLine()) != null) {
                body.append(line).append("\n");
            }
        }
        return new Response(statusCode, statusMessage, headers, body.toString().trim());
    }

    private static byte[] loadCertBytes(String path) throws Exception {
        try (PEMParser parser = new PEMParser(new FileReader(path));){
            Object obj = parser.readObject();
            if (obj instanceof X509CertificateHolder) {
                byte[] byArray = ((X509CertificateHolder)obj).getEncoded();
                return byArray;
            }
        }
        String pem = Files.readString(Path.of(path, new String[0]));
        String base64 = pem.replace("-----BEGIN CERTIFICATE-----", "").replace("-----END CERTIFICATE-----", "").replaceAll("\\s", "");
        return Base64.getDecoder().decode(base64);
    }

    private static AsymmetricKeyParameter loadPrivateKey(String path, MLDSAParameters params) throws Exception {
        byte[] seed;
        String pem = Files.readString(Path.of(path, new String[0]));
        String base64 = pem.replace("-----BEGIN PRIVATE KEY-----", "").replace("-----END PRIVATE KEY-----", "").replaceAll("\\s", "");
        byte[] decoded = Base64.getDecoder().decode(base64);
        ASN1InputStream asn1 = new ASN1InputStream(decoded);
        ASN1Sequence seq = (ASN1Sequence)asn1.readObject();
        asn1.close();
        ASN1OctetString privateKeyOctet = (ASN1OctetString)seq.getObjectAt(2);
        byte[] privateKeyBytes = privateKeyOctet.getOctets();
        ASN1InputStream pkAsn1 = new ASN1InputStream(privateKeyBytes);
        ASN1Primitive pkObj = pkAsn1.readObject();
        pkAsn1.close();
        if (pkObj instanceof ASN1Sequence) {
            ASN1Sequence pkSeq = (ASN1Sequence)pkObj;
            ASN1OctetString seedOctet = (ASN1OctetString)pkSeq.getObjectAt(0);
            seed = seedOctet.getOctets();
        } else if (pkObj instanceof ASN1OctetString) {
            seed = ((ASN1OctetString)pkObj).getOctets();
        } else {
            throw new IllegalArgumentException("Unsupported ML-DSA key format");
        }
        return new MLDSAPrivateKeyParameters(params, seed, null);
    }

    static {
        if (Security.getProvider("BC") == null) {
            Security.insertProviderAt(new BouncyCastleProvider(), 1);
        }
    }

    public static class Builder {
        private String clientCertPath;
        private String clientKeyPath;
        private String caCertPath;
        private MLDSAParameters mldsaParams = MLDSAParameters.ml_dsa_65;
        private int connectTimeout = 10000;
        private int readTimeout = 30000;

        public Builder clientCert(String path) {
            this.clientCertPath = path;
            return this;
        }

        public Builder clientKey(String path) {
            this.clientKeyPath = path;
            return this;
        }

        public Builder caCert(String path) {
            this.caCertPath = path;
            return this;
        }

        public Builder mldsaParams(MLDSAParameters params) {
            this.mldsaParams = params;
            return this;
        }

        public Builder useMlDsa44() {
            this.mldsaParams = MLDSAParameters.ml_dsa_44;
            return this;
        }

        public Builder useMlDsa65() {
            this.mldsaParams = MLDSAParameters.ml_dsa_65;
            return this;
        }

        public Builder useMlDsa87() {
            this.mldsaParams = MLDSAParameters.ml_dsa_87;
            return this;
        }

        public Builder connectTimeout(int ms) {
            this.connectTimeout = ms;
            return this;
        }

        public Builder readTimeout(int ms) {
            this.readTimeout = ms;
            return this;
        }

        public MLDSATlsClient build() throws Exception {
            Objects.requireNonNull(this.clientCertPath, "clientCert is required");
            Objects.requireNonNull(this.clientKeyPath, "clientKey is required");
            Objects.requireNonNull(this.caCertPath, "caCert is required");
            return new MLDSATlsClient(this);
        }
    }

    public static class Response {
        private final int statusCode;
        private final String statusMessage;
        private final Map<String, String> headers;
        private final String body;

        Response(int statusCode, String statusMessage, Map<String, String> headers, String body) {
            this.statusCode = statusCode;
            this.statusMessage = statusMessage;
            this.headers = headers;
            this.body = body;
        }

        public int getStatusCode() {
            return this.statusCode;
        }

        public String getStatusMessage() {
            return this.statusMessage;
        }

        public Map<String, String> getHeaders() {
            return this.headers;
        }

        public String getBody() {
            return this.body;
        }

        public boolean isSuccess() {
            return this.statusCode >= 200 && this.statusCode < 300;
        }

        public String toString() {
            return "Response{status=" + this.statusCode + " " + this.statusMessage + ", bodyLength=" + (this.body != null ? this.body.length() : 0) + "}";
        }
    }
}

