class JSON::JWK

def calculate_default_kid

def calculate_default_kid
  self[:kid] = thumbprint
rescue
  # ignore
end

def content_type

def content_type
  'application/jwk+json'
end

def ec?

def ec?
  self[:kty]&.to_sym == :EC
end

def initialize(params = {}, ex_params = {})

def initialize(params = {}, ex_params = {})
  case params
  when OpenSSL::PKey::RSA, OpenSSL::PKey::EC
    super params.to_jwk(ex_params)
  when OpenSSL::PKey::PKey
    raise UnknownAlgorithm.new('Unknown Key Type')
  when String
    super(
      k: params,
      kty: :oct
    )
    merge! ex_params
  else
    super params
    merge! ex_params
  end
  calculate_default_kid if self[:kid].blank?
end

def normalize

def normalize
  case
  when rsa?
    {
      e:   self[:e],
      kty: self[:kty],
      n:   self[:n]
    }
  when ec?
    {
      crv: self[:crv],
      kty: self[:kty],
      x:   self[:x],
      y:   self[:y]
    }
  when oct?
    {
      k:   self[:k],
      kty: self[:kty]
    }
  else
    raise UnknownAlgorithm.new('Unknown Key Type')
  end
end

def oct?

def oct?
  self[:kty]&.to_sym == :oct
end

def rsa?

def rsa?
  self[:kty]&.to_sym == :RSA
end

def thumbprint(digest = OpenSSL::Digest::SHA256.new)

def thumbprint(digest = OpenSSL::Digest::SHA256.new)
  digest = case digest
  when OpenSSL::Digest
    digest
  when String, Symbol
    OpenSSL::Digest.new digest.to_s
  else
    raise UnknownAlgorithm.new('Unknown Digest Algorithm')
  end
  Base64.urlsafe_encode64 digest.digest(normalize.to_json), padding: false
end

def to_ec_key

def to_ec_key
  curve_name = case self[:crv]&.to_sym
  when :'P-256'
    'prime256v1'
  when :'P-384'
    'secp384r1'
  when :'P-521'
    'secp521r1'
  when :secp256k1
    'secp256k1'
  else
    raise UnknownAlgorithm.new('Unknown EC Curve')
  end
  x, y, d = [:x, :y, :d].collect do |key|
    if self[key]
      Base64.urlsafe_decode64(self[key])
    end
  end
  point = OpenSSL::PKey::EC::Point.new(
    OpenSSL::PKey::EC::Group.new(curve_name),
    OpenSSL::BN.new(['04' + x.unpack('H*').first + y.unpack('H*').first].pack('H*'), 2)
  )
  # Public key
  data_sequence = OpenSSL::ASN1::Sequence([
    OpenSSL::ASN1::Sequence([
      OpenSSL::ASN1::ObjectId("id-ecPublicKey"),
      OpenSSL::ASN1::ObjectId(curve_name)
    ]),
    OpenSSL::ASN1::BitString(point.to_octet_string(:uncompressed))
  ])
  if d
    # Private key
    data_sequence = OpenSSL::ASN1::Sequence([
      OpenSSL::ASN1::Integer(1),
      OpenSSL::ASN1::OctetString(OpenSSL::BN.new(d, 2).to_s(2)),
      OpenSSL::ASN1::ObjectId(curve_name, 0, :EXPLICIT),
      OpenSSL::ASN1::BitString(point.to_octet_string(:uncompressed), 1, :EXPLICIT)
    ])
  end
  OpenSSL::PKey::EC.new(data_sequence.to_der)
end

def to_key

def to_key
  case
  when rsa?
    to_rsa_key
  when ec?
    to_ec_key
  when oct?
    self[:k]
  else
    raise UnknownAlgorithm.new('Unknown Key Type')
  end
end

def to_rsa_key

def to_rsa_key
  e, n, d, p, q, dp, dq, qi = [:e, :n, :d, :p, :q, :dp, :dq, :qi].collect do |key|
    if self[key]
      OpenSSL::BN.new Base64.urlsafe_decode64(self[key]), 2
    end
  end
  # Public key
  data_sequence = OpenSSL::ASN1::Sequence([
    OpenSSL::ASN1::Integer(n),
    OpenSSL::ASN1::Integer(e),
  ])
  if d && p && q && dp && dq && qi
    data_sequence = OpenSSL::ASN1::Sequence([
      OpenSSL::ASN1::Integer(0),
      OpenSSL::ASN1::Integer(n),
      OpenSSL::ASN1::Integer(e),
      OpenSSL::ASN1::Integer(d),
      OpenSSL::ASN1::Integer(p),
      OpenSSL::ASN1::Integer(q),
      OpenSSL::ASN1::Integer(dp),
      OpenSSL::ASN1::Integer(dq),
      OpenSSL::ASN1::Integer(qi),
    ])
  end
  asn1 = OpenSSL::ASN1::Sequence(data_sequence)
  OpenSSL::PKey::RSA.new(asn1.to_der)
end