module DSPy::Metrics
def self.composite_and(*metrics)
def self.composite_and(*metrics) proc do |example, prediction| results = {} all_passed = true metrics.each_with_index do |metric, index| result = metric.call(example, prediction) if result.is_a?(Hash) results[:"metric_#{index}"] = result all_passed &&= result[:passed] || result['passed'] || false else passed = !!result results[:"metric_#{index}"] = { passed: passed } all_passed &&= passed end end results[:passed] = all_passed results end end
def self.contains(field: :answer, case_sensitive: false)
def self.contains(field: :answer, case_sensitive: false) proc do |example, prediction| expected = extract_field(example, field) actual = extract_field(prediction, field) return false if expected.nil? || actual.nil? if case_sensitive actual.to_s.include?(expected.to_s) else actual.to_s.downcase.include?(expected.to_s.downcase) end end end
def self.exact_match(field: :answer, case_sensitive: true)
def self.exact_match(field: :answer, case_sensitive: true) proc do |example, prediction| expected = extract_field(example, field) actual = extract_field(prediction, field) return false if expected.nil? || actual.nil? if case_sensitive expected.to_s == actual.to_s else expected.to_s.downcase == actual.to_s.downcase end end end
def self.extract_field(obj, field)
def self.extract_field(obj, field) case obj when Hash obj[field] || obj[field.to_s] when ->(o) { o.respond_to?(field) } obj.send(field) when ->(o) { o.respond_to?(:to_h) } hash = obj.to_h hash[field] || hash[field.to_s] else nil end end
def self.numeric_difference(field: :answer, tolerance: 0.01)
def self.numeric_difference(field: :answer, tolerance: 0.01) proc do |example, prediction| expected = extract_field(example, field) actual = extract_field(prediction, field) return { passed: false, error: "Missing values" } if expected.nil? || actual.nil? begin expected_num = Float(expected) actual_num = Float(actual) difference = (expected_num - actual_num).abs passed = difference <= tolerance { passed: passed, difference: difference, expected: expected_num, actual: actual_num, tolerance: tolerance } rescue ArgumentError { passed: false, error: "Non-numeric values" } end end end