lib/aws-sdk-s3/file_downloader.rb



# frozen_string_literal: true

require 'pathname'
require 'thread'
require 'set'
require 'tmpdir'

module Aws
  module S3
    # @api private
    class FileDownloader

      MIN_CHUNK_SIZE = 5 * 1024 * 1024
      MAX_PARTS = 10_000
      THREAD_COUNT = 10

      def initialize(options = {})
        @client = options[:client] || Client.new
      end

      # @return [Client]
      attr_reader :client

      def download(destination, options = {})
        @path = destination
        @mode = options[:mode] || 'auto'
        @thread_count = options[:thread_count] || THREAD_COUNT
        @chunk_size = options[:chunk_size]
        @params = {
          bucket: options[:bucket],
          key: options[:key]
        }
        @params[:version_id] = options[:version_id] if options[:version_id]
        @on_checksum_validated = options[:on_checksum_validated]
        @progress_callback = options[:progress_callback]

        validate!

        Aws::Plugins::UserAgent.metric('S3_TRANSFER') do
          case @mode
          when 'auto' then multipart_download
          when 'single_request' then single_request
          when 'get_range'
            if @chunk_size
              resp = @client.head_object(@params)
              multithreaded_get_by_ranges(resp.content_length, resp.etag)
            else
              msg = 'In :get_range mode, :chunk_size must be provided'
              raise ArgumentError, msg
            end
          else
            msg = "Invalid mode #{@mode} provided, "\
                  'mode should be :single_request, :get_range or :auto'
            raise ArgumentError, msg
          end
        end
      end

      private

      def validate!
        if @on_checksum_validated && !@on_checksum_validated.respond_to?(:call)
          raise ArgumentError, 'on_checksum_validated must be callable'
        end
      end

      def multipart_download
        resp = @client.head_object(@params.merge(part_number: 1))
        count = resp.parts_count
        if count.nil? || count <= 1
          if resp.content_length <= MIN_CHUNK_SIZE
            single_request
          else
            multithreaded_get_by_ranges(resp.content_length, resp.etag)
          end
        else
          # partNumber is an option
          resp = @client.head_object(@params)
          if resp.content_length <= MIN_CHUNK_SIZE
            single_request
          else
            compute_mode(resp.content_length, count, resp.etag)
          end
        end
      end

      def compute_mode(file_size, count, etag)
        chunk_size = compute_chunk(file_size)
        part_size = (file_size.to_f / count.to_f).ceil
        if chunk_size < part_size
          multithreaded_get_by_ranges(file_size, etag)
        else
          multithreaded_get_by_parts(count, file_size, etag)
        end
      end

      def construct_chunks(file_size)
        offset = 0
        default_chunk_size = compute_chunk(file_size)
        chunks = []
        while offset < file_size
          progress = offset + default_chunk_size
          progress = file_size if progress > file_size
          chunks << "bytes=#{offset}-#{progress - 1}"
          offset = progress
        end
        chunks
      end

      def compute_chunk(file_size)
        if @chunk_size && @chunk_size > file_size
          raise ArgumentError, ":chunk_size shouldn't exceed total file size."
        else
          @chunk_size || [
            (file_size.to_f / MAX_PARTS).ceil, MIN_CHUNK_SIZE
          ].max.to_i
        end
      end

      def batches(chunks, mode)
        chunks = (1..chunks) if mode.eql? 'part_number'
        chunks.each_slice(@thread_count).to_a
      end

      def multithreaded_get_by_ranges(file_size, etag)
        offset = 0
        default_chunk_size = compute_chunk(file_size)
        chunks = []
        part_number = 1 # parts start at 1
        while offset < file_size
          progress = offset + default_chunk_size
          progress = file_size if progress > file_size
          range = "bytes=#{offset}-#{progress - 1}"
          chunks << Part.new(
            part_number: part_number,
            size: (progress-offset),
            params: @params.merge(range: range, if_match: etag)
          )
          part_number += 1
          offset = progress
        end
        download_in_threads(PartList.new(chunks), file_size)
      end

      def multithreaded_get_by_parts(n_parts, total_size, etag)
        parts = (1..n_parts).map do |part|
          Part.new(part_number: part, params: @params.merge(part_number: part, if_match: etag))
        end
        download_in_threads(PartList.new(parts), total_size)
      end

      def download_in_threads(pending, total_size)
        threads = []
        progress = MultipartProgress.new(pending, total_size, @progress_callback) if @progress_callback
        @thread_count.times do
          thread = Thread.new do
            begin
              while part = pending.shift
                if progress
                  part.params[:on_chunk_received] =
                    proc do |_chunk, bytes, total|
                      progress.call(part.part_number, bytes, total)
                    end
                end
                resp = @client.get_object(part.params)
                write(resp)
                if @on_checksum_validated && resp.checksum_validated
                  @on_checksum_validated.call(resp.checksum_validated, resp)
                end
              end
              nil
            rescue => error
              # keep other threads from downloading other parts
              pending.clear!
              raise error
            end
          end
          threads << thread
        end
        threads.map(&:value).compact
      end

      def write(resp)
        range, _ = resp.content_range.split(' ').last.split('/')
        head, _ = range.split('-').map {|s| s.to_i}
        File.write(@path, resp.body.read, head)
      end

      def single_request
        params = @params.merge(response_target: @path)
        params[:on_chunk_received] = single_part_progress if @progress_callback
        resp = @client.get_object(params)

        return resp unless @on_checksum_validated

        @on_checksum_validated.call(resp.checksum_validated, resp) if resp.checksum_validated

        resp
      end

      def single_part_progress
        proc do |_chunk, bytes_read, total_size|
          @progress_callback.call([bytes_read], [total_size], total_size)
        end
      end

      class Part < Struct.new(:part_number, :size, :params)
        include Aws::Structure
      end

      # @api private
      class PartList
        include Enumerable
        def initialize(parts = [])
          @parts = parts
          @mutex = Mutex.new
        end

        def shift
          @mutex.synchronize { @parts.shift }
        end

        def size
          @mutex.synchronize { @parts.size }
        end

        def clear!
          @mutex.synchronize { @parts.clear }
        end

        def each(&block)
          @mutex.synchronize { @parts.each(&block) }
        end
      end

      # @api private
      class MultipartProgress
        def initialize(parts, total_size, progress_callback)
          @bytes_received = Array.new(parts.size, 0)
          @part_sizes = parts.map(&:size)
          @total_size = total_size
          @progress_callback = progress_callback
        end

        def call(part_number, bytes_received, total)
          # part numbers start at 1
          @bytes_received[part_number - 1] = bytes_received
          # part size may not be known until we get the first response
          @part_sizes[part_number - 1] ||= total
          @progress_callback.call(@bytes_received, @part_sizes, @total_size)
        end
      end
    end
  end
end