lib/json/jwk.rb



module JSON
  class JWK < ActiveSupport::HashWithIndifferentAccess
    class UnknownAlgorithm < JWT::Exception; end

    def initialize(params = {}, ex_params = {})
      case params
      when OpenSSL::PKey::RSA, OpenSSL::PKey::EC
        super params.to_jwk(ex_params)
      when OpenSSL::PKey::PKey
        raise UnknownAlgorithm.new('Unknown Key Type')
      when String
        super(
          k: params,
          kty: :oct
        )
        merge! ex_params
      else
        super params
        merge! ex_params
      end
      calculate_default_kid if self[:kid].blank?
    end

    def content_type
      'application/jwk+json'
    end

    def thumbprint(digest = OpenSSL::Digest::SHA256.new)
      digest = case digest
      when OpenSSL::Digest
        digest
      when String, Symbol
        OpenSSL::Digest.new digest.to_s
      else
        raise UnknownAlgorithm.new('Unknown Digest Algorithm')
      end
      Base64.urlsafe_encode64 digest.digest(normalize.to_json), padding: false
    end

    def to_key
      case
      when rsa?
        to_rsa_key
      when ec?
        to_ec_key
      when oct?
        self[:k]
      else
        raise UnknownAlgorithm.new('Unknown Key Type')
      end
    end

    def rsa?
      self[:kty]&.to_sym == :RSA
    end

    def ec?
      self[:kty]&.to_sym == :EC
    end

    def oct?
      self[:kty]&.to_sym == :oct
    end

    def normalize
      case
      when rsa?
        {
          e:   self[:e],
          kty: self[:kty],
          n:   self[:n]
        }
      when ec?
        {
          crv: self[:crv],
          kty: self[:kty],
          x:   self[:x],
          y:   self[:y]
        }
      when oct?
        {
          k:   self[:k],
          kty: self[:kty]
        }
      else
        raise UnknownAlgorithm.new('Unknown Key Type')
      end
    end

    private

    def calculate_default_kid
      self[:kid] = thumbprint
    rescue
      # ignore
    end

    def to_rsa_key
      e, n, d, p, q, dp, dq, qi = [:e, :n, :d, :p, :q, :dp, :dq, :qi].collect do |key|
        if self[key]
          OpenSSL::BN.new Base64.urlsafe_decode64(self[key]), 2
        end
      end

      # Public key
      data_sequence = OpenSSL::ASN1::Sequence([
        OpenSSL::ASN1::Integer(n),
        OpenSSL::ASN1::Integer(e),
      ])

      if d && p && q && dp && dq && qi
        data_sequence = OpenSSL::ASN1::Sequence([
          OpenSSL::ASN1::Integer(0),
          OpenSSL::ASN1::Integer(n),
          OpenSSL::ASN1::Integer(e),
          OpenSSL::ASN1::Integer(d),
          OpenSSL::ASN1::Integer(p),
          OpenSSL::ASN1::Integer(q),
          OpenSSL::ASN1::Integer(dp),
          OpenSSL::ASN1::Integer(dq),
          OpenSSL::ASN1::Integer(qi),
        ])
      end

      asn1 = OpenSSL::ASN1::Sequence(data_sequence)
      OpenSSL::PKey::RSA.new(asn1.to_der)
    end

    def to_ec_key
      curve_name = case self[:crv]&.to_sym
      when :'P-256'
        'prime256v1'
      when :'P-384'
        'secp384r1'
      when :'P-521'
        'secp521r1'
      when :secp256k1
        'secp256k1'
      else
        raise UnknownAlgorithm.new('Unknown EC Curve')
      end
      x, y, d = [:x, :y, :d].collect do |key|
        if self[key]
          Base64.urlsafe_decode64(self[key])
        end
      end

      point = OpenSSL::PKey::EC::Point.new(
        OpenSSL::PKey::EC::Group.new(curve_name),
        OpenSSL::BN.new(['04' + x.unpack('H*').first + y.unpack('H*').first].pack('H*'), 2)
      )

      # Public key
      data_sequence = OpenSSL::ASN1::Sequence([
        OpenSSL::ASN1::Sequence([
          OpenSSL::ASN1::ObjectId("id-ecPublicKey"),
          OpenSSL::ASN1::ObjectId(curve_name)
        ]),
        OpenSSL::ASN1::BitString(point.to_octet_string(:uncompressed))
      ])

      if d
        # Private key
        data_sequence = OpenSSL::ASN1::Sequence([
          OpenSSL::ASN1::Integer(1),
          OpenSSL::ASN1::OctetString(OpenSSL::BN.new(d, 2).to_s(2)),
          OpenSSL::ASN1::ObjectId(curve_name, 0, :EXPLICIT),
          OpenSSL::ASN1::BitString(point.to_octet_string(:uncompressed), 1, :EXPLICIT)
        ])
      end

      OpenSSL::PKey::EC.new(data_sequence.to_der)
    end
  end
end