import json
import base64
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding

# This method decodes the JWT and verifies the signature. If a key is provided,
# that will be used for signature verification. Otherwise, the key sent within
# the JWT payload will be used instead.
# This returns a tuple of (decoded_header, decoded_payload, verify_succeeded).
def decode_jwt(token, key=None):
    try:
        # Decode the header and payload.
        header, payload, signature = token.split('.')
        decoded_header = decode_base64_json(header)
        decoded_payload = decode_base64_json(payload)

        # If decoding failed, return nothing.
        if not decoded_header or not decoded_payload:
            return None, None, False

        # If there is a key passed in (for refresh), use that for checking the signature below.
        # Otherwise (for registration), use the key sent within the JWT to check the signature.
        if key == None:
            key = decoded_payload.get('key')
        public_key = serialization.load_pem_public_key(jwk_to_pem(key))
        # Verifying the signature will throw an exception if it fails.
        verify_rs256_signature(header, payload, signature, public_key)
        return decoded_header, decoded_payload, True
    except Exception:
        return None, None, False

def jwk_to_pem(jwk_data):
    jwk = json.loads(jwk_data) if isinstance(jwk_data, str) else jwk_data
    key_type = jwk.get("kty")

    if key_type != "RSA":
        raise ValueError(f"Unsupported key type: {key_type}")

    n = int.from_bytes(decode_base64url(jwk["n"]), 'big')
    e = int.from_bytes(decode_base64url(jwk["e"]), 'big')
    public_key = rsa.RSAPublicNumbers(e, n).public_key()
    pem_public_key = public_key.public_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PublicFormat.SubjectPublicKeyInfo
    )
    return pem_public_key

def verify_rs256_signature(encoded_header, encoded_payload, signature, public_key):
    message = (f'{encoded_header}.{encoded_payload}').encode('utf-8')
    signature_bytes = decode_base64(signature)
    # This will throw an exception if verification fails.
    public_key.verify(
        signature_bytes,
        message,
        padding.PKCS1v15(),
        hashes.SHA256()
    )

def add_base64_padding(encoded_data):
    remainder = len(encoded_data) % 4
    if remainder > 0:
        encoded_data += '=' * (4 - remainder)
    return encoded_data

def decode_base64url(encoded_data):
    encoded_data = add_base64_padding(encoded_data)
    encoded_data = encoded_data.replace("-", "+").replace("_", "/")
    return base64.b64decode(encoded_data)

def decode_base64(encoded_data):
    encoded_data = add_base64_padding(encoded_data)
    return base64.urlsafe_b64decode(encoded_data)

def decode_base64_json(encoded_data):
    return json.loads(decode_base64(encoded_data))
