lib/cw_card_utils/synergy_probability.rb



# frozen_string_literal: true

module CwCardUtils
  # Calculates probability of drawing specific card combinations for synergy analysis
  class SynergyProbability
    def initialize(deck, deck_size: 60)
      @deck = deck
      @deck_size = deck_size
    end

    # Probability of drawing at least one of the target cards
    def prob_single(target_names, draws)
      targets = Array(target_names).uniq
      draws_clamped = clamp_draws(draws)
      total_copies = count_copies(targets)
      total = hypergeometric(@deck_size, draws_clamped).to_f
      prob = 1 - hypergeometric(@deck_size - total_copies, draws_clamped).to_f / total
      prob.clamp(0.0, 1.0)
    end

    # Probability of drawing ALL cards in the targets list (synergy pair/trio)
    def prob_combo(target_names, draws)
      targets = Array(target_names).uniq
      case targets.size
      when 1
        prob_single(targets, draws)
      when 2
        prob_two_card_combo(targets, draws)
      when 3
        prob_three_card_combo(targets, draws)
      else
        # For >3 cards, fallback to approximation
        approx_combo(targets, draws)
      end
    end

    private

    # Exact for 2-card combos
    def prob_two_card_combo(names, draws)
      draws_clamped = clamp_draws(draws)

      copies_a = copies_by_name[names[0]]
      copies_b = copies_by_name[names[1]]

      total = hypergeometric(@deck_size, draws_clamped).to_f

      # Probability missing A
      miss_a = hypergeometric(@deck_size - copies_a, draws_clamped) / total
      # Probability missing B
      miss_b = hypergeometric(@deck_size - copies_b, draws_clamped) / total
      # Probability missing both
      miss_both = hypergeometric(@deck_size - (copies_a + copies_b), draws_clamped) / total

      # Inclusion–exclusion
      prob = 1 - (miss_a + miss_b - miss_both)
      prob.clamp(0.0, 1.0)
    end

    # Exact for 3-card combos
    def prob_three_card_combo(names, draws)
      draws_clamped = clamp_draws(draws)

      copies_a = copies_by_name[names[0]]
      copies_b = copies_by_name[names[1]]
      copies_c = copies_by_name[names[2]]

      total = hypergeometric(@deck_size, draws_clamped).to_f

      miss_a = hypergeometric(@deck_size - copies_a, draws_clamped) / total
      miss_b = hypergeometric(@deck_size - copies_b, draws_clamped) / total
      miss_c = hypergeometric(@deck_size - copies_c, draws_clamped) / total

      miss_ab = hypergeometric(@deck_size - (copies_a + copies_b), draws_clamped) / total
      miss_ac = hypergeometric(@deck_size - (copies_a + copies_c), draws_clamped) / total
      miss_bc = hypergeometric(@deck_size - (copies_b + copies_c), draws_clamped) / total

      miss_abc = hypergeometric(@deck_size - (copies_a + copies_b + copies_c), draws_clamped) / total

      # Inclusion–exclusion for 3 sets
      prob = 1 - (miss_a + miss_b + miss_c) +
        (miss_ab + miss_ac + miss_bc) -
        miss_abc
      prob.clamp(0.0, 1.0)
    end

    # Approximation for >3 cards
    def approx_combo(target_names, draws)
      draws_clamped = clamp_draws(draws)
      total = hypergeometric(@deck_size, draws_clamped).to_f

      prob_missing = target_names.sum do |name|
        copies = copies_by_name[name]
        hypergeometric(@deck_size - copies, draws_clamped).to_f / total
      end

      prob = 1 - prob_missing
      prob.clamp(0.0, 1.0)
    end

    # Utility: count how many copies of given cards are in the deck
    def count_copies(names)
      unique = names.uniq
      @deck.main.sum { |card| unique.include?(card.name) ? card.count : 0 }
    end

    # Hypergeometric combination helper
    def hypergeometric(n, k)
      return 0 if k > n
      factorial(n) / (factorial(k) * factorial(n - k))
    end

    def clamp_draws(draws)
      return 0 if draws.to_i < 0
      [draws.to_i, @deck_size].min
    end

    def factorial(n)
      return 1 if n.zero?
      (1..n).reduce(1, :*)
    end

    # Only need to calculate this once for the deck passed in.
    def copies_by_name
      @copies_by_name ||= @deck.main.each_with_object(Hash.new(0)) { |card, h| h[card.name] += card.count }
    end
  end
end