lib/multiwoven/integrations/source/aws_athena/client.rb
# frozen_string_literal: true module Multiwoven::Integrations::Source module AwsAthena include Multiwoven::Integrations::Core class Client < SourceConnector def check_connection(connection_config) connection_config = connection_config.with_indifferent_access athena_client = create_connection(connection_config) athena_client.list_work_groups ConnectionStatus.new(status: ConnectionStatusType["succeeded"]).to_multiwoven_message rescue StandardError => e ConnectionStatus.new(status: ConnectionStatusType["failed"], message: e.message).to_multiwoven_message end def discover(connection_config) connection_config = connection_config.with_indifferent_access query = "SELECT table_name, column_name, data_type, is_nullable FROM information_schema.columns WHERE table_schema = '#{connection_config[:schema]}' ORDER BY table_name, ordinal_position;" db = create_connection(connection_config) results = query_execution(db, query) catalog = Catalog.new(streams: create_streams(results)) catalog.to_multiwoven_message rescue StandardError => e handle_exception(e, { context: "AWS:ATHENA:DISCOVER:EXCEPTION", type: "error" }) end def read(sync_config) connection_config = sync_config.source.connection_specification connection_config = connection_config.with_indifferent_access query = sync_config.model.query query = batched_query(query, sync_config.limit, sync_config.offset) unless sync_config.limit.nil? && sync_config.offset.nil? db = create_connection(connection_config) query(db, query) rescue StandardError => e handle_exception(e, { context: "AWS:ATHENA:READ:EXCEPTION", type: "error", sync_id: sync_config.sync_id, sync_run_id: sync_config.sync_run_id }) end private def create_connection(connection_config) Aws.config.update({ credentials: Aws::Credentials.new(connection_config[:access_key], connection_config[:secret_access_key]), region: connection_config[:region] }) @database = connection_config[:schema] @output_location = connection_config[:output_location] Aws::Athena::Client.new end def query_execution(db, query) response = db.start_query_execution( query_string: query, query_execution_context: { database: @database }, result_configuration: { output_location: @output_location } ) query_execution_id = response[:query_execution_id] loop do response = db.get_query_execution(query_execution_id: query_execution_id) status = response.query_execution.status.state break if %w[SUCCEEDED FAILED CANCELLED].include?(status) end transform_results(db.get_query_results(query_execution_id: query_execution_id)) end def create_streams(records) group_by_table(records).map do |_, r| Multiwoven::Integrations::Protocol::Stream.new(name: r[:tablename], action: StreamAction["fetch"], json_schema: convert_to_json_schema(r[:columns])) end end def transform_results(results) columns = results.result_set.result_set_metadata.column_info.map(&:name) rows = results.result_set.rows.map do |row| row.data.map(&:var_char_value) end rows.map { |row| columns.zip(row).to_h } end def query(db, query) records = [] query_execution(db, query).map do |row| records << RecordMessage.new(data: row, emitted_at: Time.now.to_i).to_multiwoven_message end records end def group_by_table(records) result = {} records.each_with_index do |entry, index| table_name = entry["table_name"] column_data = { column_name: entry["column_name"], data_type: entry["data_type"], is_nullable: entry["is_nullable"] == "YES" } result[index] ||= {} result[index][:tablename] = table_name result[index][:columns] = [column_data] end result end end end end