module JSON::JWK::Set::Fetcher

def self.fetch(jwks_uri, kid:, auto_detect: true, **options)

def self.fetch(jwks_uri, kid:, auto_detect: true, **options)
  cache_key = [
    'json:jwk:set',
    OpenSSL::Digest::MD5.hexdigest(jwks_uri),
    kid
  ].collect(&:to_s).join(':')
  parsed_jwks = JSON.parse(
    cache.fetch(cache_key, options) do
      http_client.get(jwks_uri).body
    end
  )
  unless parsed_jwks.is_a?(Hash) && parsed_jwks['keys'].is_a?(Array)
    cache.delete(cache_key, options)
    raise UnexpectedFormat
  end
  jwks = Set.new(parsed_jwks)
  cache.delete(cache_key, options) if jwks[kid].blank?
  if auto_detect
    jwks[kid] or raise KidNotFound
  else
    jwks
  end
end