lib/aws-sdk-s3/multipart_file_uploader.rb



# frozen_string_literal: true

require 'pathname'
require 'set'

module Aws
  module S3
    # @api private
    class MultipartFileUploader

      MIN_PART_SIZE = 5 * 1024 * 1024 # 5MB

      FILE_TOO_SMALL = "unable to multipart upload files smaller than 5MB"

      MAX_PARTS = 10_000

      THREAD_COUNT = 10

      CREATE_OPTIONS = Set.new(
        Client.api.operation(:create_multipart_upload).input.shape.member_names
      )

      COMPLETE_OPTIONS = Set.new(
        Client.api.operation(:complete_multipart_upload).input.shape.member_names
      )

      UPLOAD_PART_OPTIONS = Set.new(
        Client.api.operation(:upload_part).input.shape.member_names
      )

      CHECKSUM_KEYS = Set.new(
        Client.api.operation(:upload_part).input.shape.members.map do |n, s|
          n if s.location == 'header' && s.location_name.start_with?('x-amz-checksum-')
        end.compact
      )

      # @option options [Client] :client
      # @option options [Integer] :thread_count (THREAD_COUNT)
      def initialize(options = {})
        @client = options[:client] || Client.new
        @thread_count = options[:thread_count] || THREAD_COUNT
      end

      # @return [Client]
      attr_reader :client

      # @param [String, Pathname, File, Tempfile] source The file to upload.
      # @option options [required, String] :bucket The bucket to upload to.
      # @option options [required, String] :key The key for the object.
      # @option options [Proc] :progress_callback
      #   A Proc that will be called when each chunk of the upload is sent.
      #   It will be invoked with [bytes_read], [total_sizes]
      # @return [Seahorse::Client::Response] - the CompleteMultipartUploadResponse
      def upload(source, options = {})
        if File.size(source) < MIN_PART_SIZE
          raise ArgumentError, FILE_TOO_SMALL
        else
          upload_id = initiate_upload(options)
          parts = upload_parts(upload_id, source, options)
          complete_upload(upload_id, parts, options)
        end
      end

      private

      def initiate_upload(options)
        @client.create_multipart_upload(create_opts(options)).upload_id
      end

      def complete_upload(upload_id, parts, options)
        @client.complete_multipart_upload(
          **complete_opts(options).merge(
            upload_id: upload_id,
            multipart_upload: { parts: parts }
          )
        )
      end

      def upload_parts(upload_id, source, options)
        pending = PartList.new(compute_parts(upload_id, source, options))
        completed = PartList.new
        errors = upload_in_threads(pending, completed, options)
        if errors.empty?
          completed.to_a.sort_by { |part| part[:part_number] }
        else
          abort_upload(upload_id, options, errors)
        end
      end

      def abort_upload(upload_id, options, errors)
        @client.abort_multipart_upload(
          bucket: options[:bucket],
          key: options[:key],
          upload_id: upload_id
        )
        msg = "multipart upload failed: #{errors.map(&:message).join('; ')}"
        raise MultipartUploadError.new(msg, errors)
      rescue MultipartUploadError => error
        raise error
      rescue => error
        msg = "failed to abort multipart upload: #{error.message}. "\
          "Multipart upload failed: #{errors.map(&:message).join('; ')}"
        raise MultipartUploadError.new(msg, errors + [error])
      end

      def compute_parts(upload_id, source, options)
        size = File.size(source)
        default_part_size = compute_default_part_size(size)
        offset = 0
        part_number = 1
        parts = []
        while offset < size
          parts << upload_part_opts(options).merge(
            upload_id: upload_id,
            part_number: part_number,
            body: FilePart.new(
              source: source,
              offset: offset,
              size: part_size(size, default_part_size, offset)
            )
          )
          part_number += 1
          offset += default_part_size
        end
        parts
      end

      def checksum_key?(key)
        CHECKSUM_KEYS.include?(key)
      end

      def has_checksum_key?(keys)
        keys.any? { |key| checksum_key?(key) }
      end

      def create_opts(options)
        opts = { checksum_algorithm: Aws::Plugins::ChecksumAlgorithm::DEFAULT_CHECKSUM }
        opts[:checksum_type] = 'FULL_OBJECT' if has_checksum_key?(options.keys)
        CREATE_OPTIONS.inject(opts) do |hash, key|
          hash[key] = options[key] if options.key?(key)
          hash
        end
      end

      def complete_opts(options)
        opts = {}
        opts[:checksum_type] = 'FULL_OBJECT' if has_checksum_key?(options.keys)
        COMPLETE_OPTIONS.inject(opts) do |hash, key|
          hash[key] = options[key] if options.key?(key)
          hash
        end
      end

      def upload_part_opts(options)
        UPLOAD_PART_OPTIONS.inject({}) do |hash, key|
          if options.key?(key)
            # don't pass through checksum calculations
            hash[key] = options[key] unless checksum_key?(key)
          end
          hash
        end
      end

      def upload_in_threads(pending, completed, options)
        threads = []
        if (callback = options[:progress_callback])
          progress = MultipartProgress.new(pending, callback)
        end
        options.fetch(:thread_count, @thread_count).times do
          thread = Thread.new do
            begin
              while part = pending.shift
                if progress
                  part[:on_chunk_sent] =
                    proc do |_chunk, bytes, _total|
                      progress.call(part[:part_number], bytes)
                    end
                end
                resp = @client.upload_part(part)
                part[:body].close
                completed_part = {
                  etag: resp.etag,
                  part_number: part[:part_number]
                }
                algorithm = resp.context.params[:checksum_algorithm]
                k = "checksum_#{algorithm.downcase}".to_sym
                completed_part[k] = resp.send(k)
                completed.push(completed_part)
              end
              nil
            rescue => error
              # keep other threads from uploading other parts
              pending.clear!
              error
            end
          end
          threads << thread
        end
        threads.map(&:value).compact
      end

      def compute_default_part_size(source_size)
        [(source_size.to_f / MAX_PARTS).ceil, MIN_PART_SIZE].max.to_i
      end

      def part_size(total_size, part_size, offset)
        if offset + part_size > total_size
          total_size - offset
        else
          part_size
        end
      end

      # @api private
      class PartList

        def initialize(parts = [])
          @parts = parts
          @mutex = Mutex.new
        end

        def push(part)
          @mutex.synchronize { @parts.push(part) }
        end

        def shift
          @mutex.synchronize { @parts.shift }
        end

        def clear!
          @mutex.synchronize { @parts.clear }
        end

        def size
          @mutex.synchronize { @parts.size }
        end

        def part_sizes
          @mutex.synchronize { @parts.map { |p| p[:body].size } }
        end

        def to_a
          @mutex.synchronize { @parts.dup }
        end

      end

      # @api private
      class MultipartProgress
        def initialize(parts, progress_callback)
          @bytes_sent = Array.new(parts.size, 0)
          @total_sizes = parts.part_sizes
          @progress_callback = progress_callback
        end

        def call(part_number, bytes_read)
          # part numbers start at 1
          @bytes_sent[part_number - 1] = bytes_read
          @progress_callback.call(@bytes_sent, @total_sizes)
        end
      end
    end
  end
end