lib/tapioca/dsl/compilers/active_record_associations.rb



# typed: strict
# frozen_string_literal: true

return unless defined?(ActiveRecord::Base)

require "tapioca/dsl/helpers/active_record_constants_helper"

module Tapioca
  module Dsl
    module Compilers
      # `Tapioca::Dsl::Compilers::ActiveRecordAssociations` refines RBI files for subclasses of
      # [`ActiveRecord::Base`](https://api.rubyonrails.org/classes/ActiveRecord/Base.html).
      # This compiler is only responsible for defining the methods that would be created for the associations that
      # are defined in the Active Record model.
      #
      # For example, with the following model class:
      #
      # ~~~rb
      # class Post < ActiveRecord::Base
      #   belongs_to :category
      #   has_many :comments
      #   has_one :author, class_name: "User"
      #
      #   accepts_nested_attributes_for :category, :comments, :author
      # end
      # ~~~
      #
      # this compiler will produce the following methods in the RBI file
      # `post.rbi`:
      #
      # ~~~rbi
      # # post.rbi
      # # typed: true
      #
      # class Post
      #   include Post::GeneratedAssociationMethods
      #
      #   module Post::GeneratedAssociationMethods
      #     sig { returns(T.nilable(::User)) }
      #     def author; end
      #
      #     sig { params(value: T.nilable(::User)).void }
      #     def author=(value); end
      #
      #     sig { params(attributes: T.untyped).returns(T.untyped) }
      #     def author_attributes=(attributes); end
      #
      #     sig { params(args: T.untyped, blk: T.untyped).returns(::User) }
      #     def build_author(*args, &blk); end
      #
      #     sig { params(args: T.untyped, blk: T.untyped).returns(::Category) }
      #     def build_category(*args, &blk); end
      #
      #     sig { returns(T.nilable(::Category)) }
      #     def category; end
      #
      #     sig { params(value: T.nilable(::Category)).void }
      #     def category=(value); end
      #
      #     sig { params(attributes: T.untyped).returns(T.untyped) }
      #     def category_attributes=(attributes); end
      #
      #     sig { returns(T::Array[T.untyped]) }
      #     def comment_ids; end
      #
      #     sig { params(ids: T::Array[T.untyped]).returns(T::Array[T.untyped]) }
      #     def comment_ids=(ids); end
      #
      #     sig { returns(::ActiveRecord::Associations::CollectionProxy[::Comment]) }
      #     def comments; end
      #
      #     sig { params(value: T::Enumerable[::Comment]).void }
      #     def comments=(value); end
      #
      #     sig { params(attributes: T.untyped).returns(T.untyped) }
      #     def comments_attributes=(attributes); end
      #
      #     sig { params(args: T.untyped, blk: T.untyped).returns(::User) }
      #     def create_author(*args, &blk); end
      #
      #     sig { params(args: T.untyped, blk: T.untyped).returns(::User) }
      #     def create_author!(*args, &blk); end
      #
      #     sig { params(args: T.untyped, blk: T.untyped).returns(::Category) }
      #     def create_category(*args, &blk); end
      #
      #     sig { params(args: T.untyped, blk: T.untyped).returns(::Category) }
      #     def create_category!(*args, &blk); end
      #
      #     sig { returns(T.nilable(::User)) }
      #     def reload_author; end
      #
      #     sig { returns(T.nilable(::Category)) }
      #     def reload_category; end
      #
      #     sig { void }
      #     def reset_author; end
      #
      #     sig { void }
      #     def reset_category; end
      #   end
      # end
      # ~~~
      #: [ConstantType = singleton(ActiveRecord::Base)]
      class ActiveRecordAssociations < Compiler
        extend T::Sig
        include Helpers::ActiveRecordConstantsHelper

        class SourceReflectionError < StandardError
        end

        class MissingConstantError < StandardError
          extend T::Sig

          #: String
          attr_reader :class_name

          #: (String class_name) -> void
          def initialize(class_name)
            @class_name = class_name
            super
          end
        end

        # @override
        #: -> void
        def decorate
          return if constant.reflections.empty?

          root.create_path(constant) do |model|
            model.create_module(AssociationMethodsModuleName) do |mod|
              populate_nested_attribute_writers(mod)
              populate_associations(mod)
            end

            model.create_include(AssociationMethodsModuleName)
          end
        end

        class << self
          extend T::Sig

          # @override
          #: -> T::Enumerable[Module]
          def gather_constants
            descendants_of(::ActiveRecord::Base).reject(&:abstract_class?)
          end
        end

        private

        #: (RBI::Scope mod) -> void
        def populate_nested_attribute_writers(mod)
          constant.nested_attributes_options.keys.each do |association_name|
            mod.create_method(
              "#{association_name}_attributes=",
              parameters: [create_param("attributes", type: "T.untyped")],
              return_type: "T.untyped",
            )
          end
        end

        #: (RBI::Scope mod) -> void
        def populate_associations(mod)
          constant.reflections.each do |association_name, reflection|
            if reflection.collection?
              populate_collection_assoc_getter_setter(mod, association_name, reflection)
            else
              populate_single_assoc_getter_setter(mod, association_name, reflection)
            end
          rescue SourceReflectionError
            add_error(<<~MSG.strip)
              Cannot generate association `#{reflection.name}` on `#{constant}` since the source of the through association is missing.
            MSG
          rescue MissingConstantError => error
            add_error(<<~MSG.strip)
              Cannot generate association `#{declaration(reflection)}` on `#{constant}` since the constant `#{error.class_name}` does not exist.
            MSG
          end
        end

        #: (RBI::Scope klass, (String | Symbol) association_name, ReflectionType reflection) -> void
        def populate_single_assoc_getter_setter(klass, association_name, reflection)
          association_class = type_for(reflection)
          association_type = as_nilable_type(association_class)
          association_methods_module = constant.generated_association_methods

          klass.create_method(
            association_name.to_s,
            return_type: association_type,
          )
          klass.create_method(
            "#{association_name}=",
            parameters: [create_param("value", type: association_type)],
            return_type: "void",
          )
          klass.create_method(
            "reload_#{association_name}",
            return_type: association_type,
          )
          klass.create_method(
            "reset_#{association_name}",
            return_type: "void",
          )
          if association_methods_module.method_defined?("#{association_name}_changed?")
            klass.create_method(
              "#{association_name}_changed?",
              return_type: "T::Boolean",
            )
          end
          if association_methods_module.method_defined?("#{association_name}_previously_changed?")
            klass.create_method(
              "#{association_name}_previously_changed?",
              return_type: "T::Boolean",
            )
          end
          unless reflection.polymorphic?
            klass.create_method(
              "build_#{association_name}",
              parameters: [
                create_rest_param("args", type: "T.untyped"),
                create_block_param("blk", type: "T.untyped"),
              ],
              return_type: association_class,
            )
            klass.create_method(
              "create_#{association_name}",
              parameters: [
                create_rest_param("args", type: "T.untyped"),
                create_block_param("blk", type: "T.untyped"),
              ],
              return_type: association_class,
            )
            klass.create_method(
              "create_#{association_name}!",
              parameters: [
                create_rest_param("args", type: "T.untyped"),
                create_block_param("blk", type: "T.untyped"),
              ],
              return_type: association_class,
            )
          end
        end

        #: (RBI::Scope klass, (String | Symbol) association_name, ReflectionType reflection) -> void
        def populate_collection_assoc_getter_setter(klass, association_name, reflection)
          association_class = type_for(reflection)
          relation_class = relation_type_for(reflection)

          klass.create_method(
            association_name.to_s,
            comments: association_comments(reflection),
            return_type: relation_class,
          )
          klass.create_method(
            "#{association_name}=",
            parameters: [create_param("value", type: "T::Enumerable[#{association_class}]")],
            return_type: "void",
          )
          klass.create_method(
            "#{association_name.to_s.singularize}_ids",
            return_type: "T::Array[T.untyped]",
          )
          klass.create_method(
            "#{association_name.to_s.singularize}_ids=",
            parameters: [create_param("ids", type: "T::Array[T.untyped]")],
            return_type: "T::Array[T.untyped]",
          )
        end

        #: (ReflectionType reflection) -> String
        def type_for(reflection)
          validate_reflection!(reflection)

          return "T.untyped" if !constant.table_exists? || polymorphic_association?(reflection)

          T.must(qualified_name_of(reflection.klass))
        end

        #: (ReflectionType reflection) -> void
        def validate_reflection!(reflection)
          # Check existence of source reflection, first, since, calling
          # `.klass` also tries to go through the source reflection
          # and fails with a cryptic error, otherwise.
          if reflection.through_reflection?
            raise SourceReflectionError unless reflection.source_reflection
          end

          # For non-polymorphic reflections, `.klass` should not be raising
          # a `NameError`.
          unless reflection.polymorphic?
            reflection.klass
          end
        rescue NameError
          class_name = if reflection.through_reflection?
            reflection.send(:delegate_reflection).class_name
          else
            reflection.class_name
          end

          raise MissingConstantError, class_name
        end

        #: (ReflectionType reflection) -> String?
        def declaration(reflection)
          case reflection
          when ActiveRecord::Reflection::HasOneReflection
            "has_one :#{reflection.name}"
          when ActiveRecord::Reflection::HasManyReflection
            "has_many :#{reflection.name}"
          when ActiveRecord::Reflection::HasAndBelongsToManyReflection
            "has_and_belongs_to_many :#{reflection.name}"
          when ActiveRecord::Reflection::BelongsToReflection
            "belongs_to :#{reflection.name}"
          when ActiveRecord::Reflection::ThroughReflection
            delegate_reflection = reflection.send(:delegate_reflection)
            declaration = declaration(delegate_reflection)
            through_name = delegate_reflection.options[:through]

            "#{declaration}, through: :#{through_name}"
          end
        end

        #: (ReflectionType reflection) -> Array[RBI::Comment]
        def association_comments(reflection)
          anchor_name = case reflection
          when ActiveRecord::Reflection::HasOneReflection
            "the-has-one-association"
          when ActiveRecord::Reflection::HasManyReflection
            "the-has-many-association"
          when ActiveRecord::Reflection::HasAndBelongsToManyReflection
            "the-has-and-belongs-to-many-association"
          when ActiveRecord::Reflection::BelongsToReflection
            "the-belongs-to-association"
          when ActiveRecord::Reflection::ThroughReflection
            delegate_reflection = reflection.send(:delegate_reflection)
            declaration = declaration(delegate_reflection)
            if T.must(declaration).match?("has_one")
              "the-has-one-through-association"
            else
              "the-has-many-through-association"
            end
          end

          if anchor_name
            url = "https://guides.rubyonrails.org/association_basics.html##{anchor_name}"
            association_name = anchor_name.sub(/^the-(.*)-association$/, '\1')
            comment = <<~MSG
              This method is created by ActiveRecord on the `#{reflection.active_record.name}` class because it declared `#{declaration(reflection)}`.
              🔗 [Rails guide for `#{association_name.gsub("-", "_")}` association](#{url})
            MSG
            [RBI::Comment.new(comment)]
          else
            []
          end
        end

        #: (ReflectionType reflection) -> String
        def relation_type_for(reflection)
          validate_reflection!(reflection)

          relations_enabled = compiler_enabled?("ActiveRecordRelations")
          polymorphic_association = !constant.table_exists? || polymorphic_association?(reflection)

          if relations_enabled
            if polymorphic_association
              "ActiveRecord::Associations::CollectionProxy"
            else
              "#{qualified_name_of(reflection.klass)}::#{AssociationsCollectionProxyClassName}"
            end
          elsif polymorphic_association
            "ActiveRecord::Associations::CollectionProxy[T.untyped]"
          else
            "::ActiveRecord::Associations::CollectionProxy[#{qualified_name_of(reflection.klass)}]"
          end
        end

        #: (ReflectionType reflection) -> bool
        def polymorphic_association?(reflection)
          if reflection.through_reflection?
            polymorphic_association?(reflection.source_reflection)
          else
            !!reflection.polymorphic?
          end
        end
      end
    end
  end
end