lib/multiwoven/integrations/source/google_vertex_model/client.rb
# frozen_string_literal: true module Multiwoven::Integrations::Source module VertexModel include Multiwoven::Integrations::Core class Client < SourceConnector def check_connection(connection_config) connection_config = connection_config.with_indifferent_access create_connection(connection_config) @client.get_endpoint(name: build_url(GOOGLE_VERTEX_MODEL_NAME, connection_config)) ConnectionStatus.new(status: ConnectionStatusType["succeeded"]).to_multiwoven_message rescue StandardError => e ConnectionStatus.new(status: ConnectionStatusType["failed"], message: e.message).to_multiwoven_message end def discover(_connection_config = nil) catalog_json = read_json(CATALOG_SPEC_PATH) catalog = build_catalog(catalog_json) catalog.to_multiwoven_message rescue StandardError => e handle_exception(e, { context: "GOOGLE:VERTEX MODEL:DISCOVER:EXCEPTION", type: "error" }) end def read(sync_config) connection_config = sync_config.source.connection_specification connection_config = connection_config.with_indifferent_access # The server checks the ConnectorQueryType. # If it's "ai_ml," the server calculates the payload and passes it as a query in the sync config model protocol. # This query is then sent to the AI/ML model. payload = JSON.parse(sync_config.model.query) run_model(connection_config, payload) rescue StandardError => e handle_exception(e, { context: "GOOGLE:VERTEX MODEL:READ:EXCEPTION", type: "error" }) end private def create_connection(connection_config) Google::Cloud::AIPlatform::V1::EndpointService::Client.configure do |config| config.endpoint = build_url(GOOGLE_VERTEX_ENDPOINT_SERVICE_URL, connection_config) config.credentials = connection_config["credentials_json"] end Google::Cloud::AIPlatform::V1::PredictionService::Client.configure do |config| config.endpoint = build_url(GOOGLE_VERTEX_ENDPOINT_SERVICE_URL, connection_config) config.credentials = connection_config["credentials_json"] end @client = Google::Cloud::AIPlatform::V1::EndpointService::Client.new @endpoint = Google::Cloud::AIPlatform::V1::PredictionService::Client.new end def run_model(connection_config, payload) create_connection(connection_config) http_body = Google::Api::HttpBody.new(data: JSON.generate(payload)) response = @endpoint.raw_predict(endpoint: build_url(GOOGLE_VERTEX_MODEL_NAME, connection_config), http_body: http_body) process_response(response) rescue StandardError => e handle_exception(e, context: "GOOGLE:VERTEX MODEL:RUN_MODEL:EXCEPTION", type: "error") end def process_response(response) data = JSON.parse(response.data) [RecordMessage.new(data: data, emitted_at: Time.now.to_i).to_multiwoven_message] end def build_url(url, connection_config) case url when GOOGLE_VERTEX_MODEL_NAME format(url, project_id: connection_config[:project_id], region: connection_config[:region], endpoint_id: connection_config[:endpoint_id]) when GOOGLE_VERTEX_ENDPOINT_SERVICE_URL format(url, region: connection_config[:region]) end end end end end