lib/common/auth/auth_token_provider.rb



# frozen_string_literal: true

require "async"
require "async/actor"
require "async/http/internet"
require "async/semaphore"
require "json"
require "logger"
require "grpc"
require_relative "token_provider"
require_relative "token_fetcher"
require_relative "../../error"

LOGGER = Logger.new($stdout)
LOGGER.level = Logger::WARN

# A module for Stately Cloud auth code
module StatelyDB
  module Common
    # A module for Stately Cloud auth code
    module Auth
      # AuthTokenProvider is an implementation of the TokenProvider abstract base class
      # which vends tokens from the StatelyDB auth API.
      # It will default to using the value of `STATELY_ACCESS_KEY` if
      # no credentials are explicitly passed and will throw an error if no credentials are found.
      class AuthTokenProvider < TokenProvider
        # @param [String] endpoint The endpoint of the auth server
        # @param [String] access_key The StatelyDB access key credential
        # @param [Float] base_retry_backoff_secs The base retry backoff in seconds
        def initialize(
          endpoint: "https://api.stately.cloud",
          access_key: ENV.fetch("STATELY_ACCESS_KEY", nil),
          base_retry_backoff_secs: 1
        )
          super()
          @actor = Async::Actor.new(Actor.new(endpoint:, access_key:, base_retry_backoff_secs:))
          # this initialization cannot happen in the constructor because it is async and must run on the event loop
          # which is not available in the constructor
          @actor.init
        end

        # Close the token provider and kill any background operations
        # This just invokes the close method on the actor which should do the cleanup
        # @return [void]
        def close
          @actor.close
        end

        # Get the current access token
        # @return [String] The current access token
        def get_token(force: false)
          @actor.get_token(force: force)
        end

        # Actor for managing the token refresh
        # This is designed to be used with Async::Actor and run on a dedicated thread.
        class Actor
          # @param [String] endpoint The endpoint of the OAuth server
          # @param [String] access_key The StatelyDB access key credential
          # @param [Float] base_retry_backoff_secs The base retry backoff in seconds
          def initialize(endpoint:, access_key:, base_retry_backoff_secs:)
            super()

            if access_key.nil?
              raise StatelyDB::Error.new(
                "Unable to find an access key in the STATELY_ACCESS_KEY " \
                "environment variable. Either pass your credentials in " \
                "the options when creating a client or set this environment variable.",
                code: GRPC::Core::StatusCodes::UNAUTHENTICATED,
                stately_code: "Unauthenticated"
              )
            end

            @token_fetcher = StatelyDB::Common::Auth::StatelyAccessTokenFetcher.new(
              endpoint: endpoint,
              access_key: access_key,
              base_retry_backoff_secs: base_retry_backoff_secs
            )
            @token_state = nil
            @pending_refresh = nil
          end

          # Initialize the actor. This runs on the actor thread which means
          # we can dispatch async operations here.
          # @return [void]
          def init
            # disable the async lib logger. We do our own error handling and propagation
            Console.logger.disable(Async::Task)
            refresh_token
          end

          # Close the token provider and kill any background operations
          # @return [void]
          def close
            @scheduled&.stop
            @token_fetcher&.close
          end

          # Get the current access token
          # @param [Boolean] force Whether to force a refresh of the token
          # @return [String] The current access token
          def get_token(force: false)
            if force
              @token_state = nil
            else
              token, ok = valid_access_token
              return token if ok
            end

            refresh_token.wait
          end

          # Get the current access token and whether it is valid
          # @return [Array] The current access token and whether it is valid
          def valid_access_token
            return "", false if @token_state.nil?
            return "", false if @token_state.expires_at_unix_secs < Time.now.to_i

            [@token_state.token, true]
          end

          # Refresh the access token
          # @return [::Async::Task] A task that will resolve to the new access token
          def refresh_token
            Async do
              # we use an Async::Condition to dedupe multiple requests here
              # if the condition exists, we wait on it to complete
              # otherwise we create a condition, make the request, then signal the condition with the result
              # If there is an error then we signal that instead so we can raise it for the waiters.
              if @pending_refresh.nil?
                begin
                  @pending_refresh = Async::Condition.new
                  new_access_token = refresh_token_impl
                  # now broadcast the new token to any waiters
                  @pending_refresh.signal(new_access_token)
                  new_access_token
                rescue StandardError => e
                  @pending_refresh.signal(e)
                  raise e
                ensure
                  # delete the condition to restart the process
                  @pending_refresh = nil
                end
              else
                res = @pending_refresh.wait
                # if the refresh result is an error, re-raise it.
                # otherwise return the token
                raise res if res.is_a?(StandardError)

                res
              end
            end
          end

          # Refresh the access token implementation
          # @return [String] The new access token
          def refresh_token_impl
            Sync do
              token_result = @token_fetcher.fetch
              new_expires_in_secs = token_result.expires_in_secs
              new_expires_at_unix_secs = Time.now.to_i + new_expires_in_secs

              # only update the token state if the new expiry is later than the current one
              if @token_state.nil? || new_expires_at_unix_secs > @token_state.expires_at_unix_secs
                @token_state = TokenState.new(token: token_result.token, expires_at_unix_secs: new_expires_at_unix_secs)
              else
                # otherwise use the existing expiry time for scheduling the refresh
                new_expires_in_secs = @token_state.expires_at_unix_secs - Time.now.to_i
              end

              # Schedule a refresh of the token ahead of the expiry time
              # Calculate a random multiplier between 0.9 and 0.95 to to apply to the expiry
              # so that we refresh in the background ahead of expiration, but avoid
              # multiple processes hammering the service at the same time.
              jitter = (Random.rand * 0.05) + 0.9
              delay_secs = new_expires_in_secs * jitter

              # do this on the fiber scheduler (the root scheduler) to avoid infinite recursion
              @scheduled ||= Fiber.scheduler.async do
                # Kernel.sleep is non-blocking if Ruby 3.1+ and Async 2+
                # https://github.com/socketry/async/issues/305#issuecomment-1945188193
                sleep(delay_secs)
                refresh_token
                @scheduled = nil
              end

              @token_state.token
            end
          end
        end

        # Persistent state for the token provider
        #
        # @!attribute [r] token
        #   @return [String] The token string.
        # @!attribute [r] expires_at_unix_secs
        #   @return [Integer] The expiration time in unix seconds.
        class TokenState
          attr_reader :token, :expires_at_unix_secs

          # Create a new TokenState
          # @param [String] token The access token
          # @param [Integer] expires_at_unix_secs The unix timestamp when the token expires
          def initialize(token:, expires_at_unix_secs:)
            @token = token
            @expires_at_unix_secs = expires_at_unix_secs
          end
        end
      end
    end
  end
end