# frozen_string_literal: true
require 'pathname'
require 'thread'
require 'set'
require 'tmpdir'
module Aws
module S3
# @api private
class FileDownloader
MIN_CHUNK_SIZE = 5 * 1024 * 1024
MAX_PARTS = 10_000
THREAD_COUNT = 10
def initialize(options = {})
@client = options[:client] || Client.new
end
# @return [Client]
attr_reader :client
def download(destination, options = {})
@path = destination
@mode = options[:mode] || 'auto'
@thread_count = options[:thread_count] || THREAD_COUNT
@chunk_size = options[:chunk_size]
@params = {
bucket: options[:bucket],
key: options[:key]
}
@params[:version_id] = options[:version_id] if options[:version_id]
@on_checksum_validated = options[:on_checksum_validated]
@progress_callback = options[:progress_callback]
validate!
Aws::Plugins::UserAgent.metric('S3_TRANSFER') do
case @mode
when 'auto' then multipart_download
when 'single_request' then single_request
when 'get_range'
if @chunk_size
resp = @client.head_object(@params)
multithreaded_get_by_ranges(resp.content_length, resp.etag)
else
msg = 'In :get_range mode, :chunk_size must be provided'
raise ArgumentError, msg
end
else
msg = "Invalid mode #{@mode} provided, "\
'mode should be :single_request, :get_range or :auto'
raise ArgumentError, msg
end
end
end
private
def validate!
if @on_checksum_validated && !@on_checksum_validated.respond_to?(:call)
raise ArgumentError, 'on_checksum_validated must be callable'
end
end
def multipart_download
resp = @client.head_object(@params.merge(part_number: 1))
count = resp.parts_count
if count.nil? || count <= 1
if resp.content_length <= MIN_CHUNK_SIZE
single_request
else
multithreaded_get_by_ranges(resp.content_length, resp.etag)
end
else
# partNumber is an option
resp = @client.head_object(@params)
if resp.content_length <= MIN_CHUNK_SIZE
single_request
else
compute_mode(resp.content_length, count, resp.etag)
end
end
end
def compute_mode(file_size, count, etag)
chunk_size = compute_chunk(file_size)
part_size = (file_size.to_f / count.to_f).ceil
if chunk_size < part_size
multithreaded_get_by_ranges(file_size, etag)
else
multithreaded_get_by_parts(count, file_size, etag)
end
end
def construct_chunks(file_size)
offset = 0
default_chunk_size = compute_chunk(file_size)
chunks = []
while offset < file_size
progress = offset + default_chunk_size
progress = file_size if progress > file_size
chunks << "bytes=#{offset}-#{progress - 1}"
offset = progress
end
chunks
end
def compute_chunk(file_size)
if @chunk_size && @chunk_size > file_size
raise ArgumentError, ":chunk_size shouldn't exceed total file size."
else
@chunk_size || [
(file_size.to_f / MAX_PARTS).ceil, MIN_CHUNK_SIZE
].max.to_i
end
end
def batches(chunks, mode)
chunks = (1..chunks) if mode.eql? 'part_number'
chunks.each_slice(@thread_count).to_a
end
def multithreaded_get_by_ranges(file_size, etag)
offset = 0
default_chunk_size = compute_chunk(file_size)
chunks = []
part_number = 1 # parts start at 1
while offset < file_size
progress = offset + default_chunk_size
progress = file_size if progress > file_size
range = "bytes=#{offset}-#{progress - 1}"
chunks << Part.new(
part_number: part_number,
size: (progress-offset),
params: @params.merge(range: range, if_match: etag)
)
part_number += 1
offset = progress
end
download_in_threads(PartList.new(chunks), file_size)
end
def multithreaded_get_by_parts(n_parts, total_size, etag)
parts = (1..n_parts).map do |part|
Part.new(part_number: part, params: @params.merge(part_number: part, if_match: etag))
end
download_in_threads(PartList.new(parts), total_size)
end
def download_in_threads(pending, total_size)
threads = []
progress = MultipartProgress.new(pending, total_size, @progress_callback) if @progress_callback
@thread_count.times do
thread = Thread.new do
begin
while part = pending.shift
if progress
part.params[:on_chunk_received] =
proc do |_chunk, bytes, total|
progress.call(part.part_number, bytes, total)
end
end
resp = @client.get_object(part.params)
write(resp)
if @on_checksum_validated && resp.checksum_validated
@on_checksum_validated.call(resp.checksum_validated, resp)
end
end
nil
rescue => error
# keep other threads from downloading other parts
pending.clear!
raise error
end
end
threads << thread
end
threads.map(&:value).compact
end
def write(resp)
range, _ = resp.content_range.split(' ').last.split('/')
head, _ = range.split('-').map {|s| s.to_i}
File.write(@path, resp.body.read, head)
end
def single_request
params = @params.merge(response_target: @path)
params[:on_chunk_received] = single_part_progress if @progress_callback
resp = @client.get_object(params)
return resp unless @on_checksum_validated
@on_checksum_validated.call(resp.checksum_validated, resp) if resp.checksum_validated
resp
end
def single_part_progress
proc do |_chunk, bytes_read, total_size|
@progress_callback.call([bytes_read], [total_size], total_size)
end
end
class Part < Struct.new(:part_number, :size, :params)
include Aws::Structure
end
# @api private
class PartList
include Enumerable
def initialize(parts = [])
@parts = parts
@mutex = Mutex.new
end
def shift
@mutex.synchronize { @parts.shift }
end
def size
@mutex.synchronize { @parts.size }
end
def clear!
@mutex.synchronize { @parts.clear }
end
def each(&block)
@mutex.synchronize { @parts.each(&block) }
end
end
# @api private
class MultipartProgress
def initialize(parts, total_size, progress_callback)
@bytes_received = Array.new(parts.size, 0)
@part_sizes = parts.map(&:size)
@total_size = total_size
@progress_callback = progress_callback
end
def call(part_number, bytes_received, total)
# part numbers start at 1
@bytes_received[part_number - 1] = bytes_received
# part size may not be known until we get the first response
@part_sizes[part_number - 1] ||= total
@progress_callback.call(@bytes_received, @part_sizes, @total_size)
end
end
end
end
end