lib/net/ssh/authentication/pageant.rb



if RUBY_VERSION < "1.9"
  require 'dl/import'
  require 'dl/struct'
elsif RUBY_VERSION < "2.1"
  require 'dl/import'
  require 'dl/types'
  require 'dl'
else
  require 'fiddle'
  require 'fiddle/types'
  require 'fiddle/import'

  # For now map DL to Fiddler versus updating all the code below
  module DL
    CPtr ||= Fiddle::Pointer
    if RUBY_PLATFORM != "java"
      RUBY_FREE ||= Fiddle::RUBY_FREE
    end
  end
end

require 'net/ssh/errors'

module Net; module SSH; module Authentication

  # This module encapsulates the implementation of a socket factory that
  # uses the PuTTY "pageant" utility to obtain information about SSH
  # identities.
  #
  # This code is a slightly modified version of the original implementation
  # by Guillaume Marçais (guillaume.marcais@free.fr). It is used and
  # relicensed by permission.
  module Pageant

    # From Putty pageant.c
    AGENT_MAX_MSGLEN = 8192
    AGENT_COPYDATA_ID = 0x804e50ba

    # The definition of the Windows methods and data structures used in
    # communicating with the pageant process.
    module Win # rubocop:disable Metrics/ModuleLength
      # Compatibility on initialization
      if RUBY_VERSION < "1.9"
        extend DL::Importable

        dlload 'user32'
        dlload 'kernel32'
        dlload 'advapi32'

        SIZEOF_DWORD = DL.sizeof('L')
      elsif RUBY_VERSION < "2.1"
        extend DL::Importer
        dlload 'user32','kernel32', 'advapi32'
        include DL::Win32Types

        SIZEOF_DWORD = DL::SIZEOF_LONG
      else
        extend Fiddle::Importer
        dlload 'user32','kernel32', 'advapi32'
        include Fiddle::Win32Types
        SIZEOF_DWORD = Fiddle::SIZEOF_LONG
      end

      if RUBY_ENGINE=="jruby"
        typealias("HANDLE", "void *")         # From winnt.h
        typealias("PHANDLE", "void *")         # From winnt.h
        typealias("ULONG_PTR", "unsigned long*")
      end
      typealias("LPCTSTR", "char *")         # From winnt.h
      typealias("LPVOID", "void *")          # From winnt.h
      typealias("LPCVOID", "const void *")   # From windef.h
      typealias("LRESULT", "long")           # From windef.h
      typealias("WPARAM", "unsigned int *")  # From windef.h
      typealias("LPARAM", "long *")          # From windef.h
      typealias("PDWORD_PTR", "long *")      # From basetsd.h
      typealias("USHORT", "unsigned short")  # From windef.h

      # From winbase.h, winnt.h
      INVALID_HANDLE_VALUE = -1
      NULL = nil
      PAGE_READWRITE = 0x0004
      FILE_MAP_WRITE = 2
      WM_COPYDATA = 74

      SMTO_NORMAL = 0   # From winuser.h

      SUFFIX = if RUBY_ENGINE == "jruby"
                 "A"
               else
                 ""
               end

      # args: lpClassName, lpWindowName
      extern "HWND FindWindow#{SUFFIX}(LPCTSTR, LPCTSTR)"

      # args: none
      extern 'DWORD GetCurrentThreadId()'

      # args: hFile, (ignored), flProtect, dwMaximumSizeHigh,
      #           dwMaximumSizeLow, lpName
      extern "HANDLE CreateFileMapping#{SUFFIX}(HANDLE, void *, DWORD, ' +
        'DWORD, DWORD, LPCTSTR)"

      # args: hFileMappingObject, dwDesiredAccess, dwFileOffsetHigh,
      #           dwfileOffsetLow, dwNumberOfBytesToMap
      extern 'LPVOID MapViewOfFile(HANDLE, DWORD, DWORD, DWORD, DWORD)'

      # args: lpBaseAddress
      extern 'BOOL UnmapViewOfFile(LPCVOID)'

      # args: hObject
      extern 'BOOL CloseHandle(HANDLE)'

      # args: hWnd, Msg, wParam, lParam, fuFlags, uTimeout, lpdwResult
      extern "LRESULT SendMessageTimeout#{SUFFIX}(HWND, UINT, WPARAM, LPARAM, ' +
        'UINT, UINT, PDWORD_PTR)"

      # args: none
      extern 'DWORD GetLastError()'

      # args: none
      extern 'HANDLE GetCurrentProcess()'

      # args: hProcessHandle, dwDesiredAccess, (out) phNewTokenHandle
      extern 'BOOL OpenProcessToken(HANDLE, DWORD, PHANDLE)'

      # args: hTokenHandle, uTokenInformationClass,
      #           (out) lpTokenInformation, dwTokenInformationLength
      #           (out) pdwInfoReturnLength
      extern 'BOOL GetTokenInformation(HANDLE, UINT, LPVOID, DWORD, ' +
        'PDWORD)'

      # args: (out) lpSecurityDescriptor, dwRevisionLevel
      extern 'BOOL InitializeSecurityDescriptor(LPVOID, DWORD)'

      # args: (out) lpSecurityDescriptor, lpOwnerSid, bOwnerDefaulted
      extern 'BOOL SetSecurityDescriptorOwner(LPVOID, LPVOID, BOOL)'

      # args: pSecurityDescriptor
      extern 'BOOL IsValidSecurityDescriptor(LPVOID)'

      # Constants needed for security attribute retrieval.
      # Specifies the access mask corresponding to the desired access
      # rights.
      TOKEN_QUERY = 0x8

      # The value of TOKEN_USER from the TOKEN_INFORMATION_CLASS enum.
      TOKEN_USER_INFORMATION_CLASS = 1

      # The initial revision level assigned to the security descriptor.
      REVISION = 1

      # Structs for security attribute functions.
      # Holds the retrieved user access token.
      TOKEN_USER = struct ['void * SID', 'DWORD ATTRIBUTES']

      # Contains the security descriptor, this gets passed to the
      # function that constructs the shared memory map.
      SECURITY_ATTRIBUTES = struct ['DWORD nLength',
                                    'LPVOID lpSecurityDescriptor',
                                    'BOOL bInheritHandle']

      # The security descriptor holds security information.
      SECURITY_DESCRIPTOR = struct ['UCHAR Revision', 'UCHAR Sbz1',
                                    'USHORT Control', 'LPVOID Owner',
                                    'LPVOID Group', 'LPVOID Sacl',
                                    'LPVOID Dacl']

      # The COPYDATASTRUCT is used to send WM_COPYDATA messages
      COPYDATASTRUCT = if RUBY_ENGINE == "jruby"
                         struct ['ULONG_PTR dwData', 'DWORD cbData', 'LPVOID lpData']
                       else
                         struct ['uintptr_t dwData', 'DWORD cbData', 'LPVOID lpData']
                       end

      # Compatibility for security attribute retrieval.
      if RUBY_VERSION < "1.9"
        # Alias functions to > 1.9 capitalization
        %w(findWindow
           getCurrentProcess
           initializeSecurityDescriptor
           setSecurityDescriptorOwner
           isValidSecurityDescriptor
           openProcessToken
           getTokenInformation
           getLastError
           getCurrentThreadId
           createFileMapping
           mapViewOfFile
           sendMessageTimeout
           unmapViewOfFile
           closeHandle).each do |name|
          new_name = name[0].chr.upcase + name[1..name.length]
          alias_method new_name, name
          module_function new_name
        end

        def self.malloc_ptr(size)
          return DL.malloc(size)
        end

        def self.get_ptr(data)
          return data.to_ptr
        end

        def self.set_ptr_data(ptr, data)
          ptr[0] = data
        end
      elsif RUBY_ENGINE == "jruby"
        %w(FindWindow CreateFileMapping SendMessageTimeout).each do |name|
          alias_method name, name+"A"
          module_function name
        end
        # :nodoc:
        module LibC
          extend FFI::Library
          ffi_lib FFI::Library::LIBC
          attach_function :malloc, [:size_t], :pointer
          attach_function :free, [:pointer], :void
        end

        def self.malloc_ptr(size)
          Fiddle::Pointer.new(LibC.malloc(size), size, LibC.method(:free))
        end

        def self.get_ptr(ptr)
          return data.address
        end

        def self.set_ptr_data(ptr, data)
          ptr.write_string_length(data, data.size)
        end
      else
        def self.malloc_ptr(size)
          return DL::CPtr.malloc(size, DL::RUBY_FREE)
        end

        def self.get_ptr(data)
          return DL::CPtr.to_ptr data
        end

        def self.set_ptr_data(ptr, data)
          DL::CPtr.new(ptr)[0,data.size] = data
        end
      end

      def self.get_security_attributes_for_user
        user = get_current_user

        psd_information = malloc_ptr(Win::SECURITY_DESCRIPTOR.size)
        raise_error_if_zero(
          Win.InitializeSecurityDescriptor(psd_information,
                                           Win::REVISION))
        raise_error_if_zero(
          Win.SetSecurityDescriptorOwner(psd_information, get_sid_ptr(user),
                                         0))
        raise_error_if_zero(
          Win.IsValidSecurityDescriptor(psd_information))

        sa = Win::SECURITY_ATTRIBUTES.new(to_struct_ptr(malloc_ptr(Win::SECURITY_ATTRIBUTES.size)))
        sa.nLength = Win::SECURITY_ATTRIBUTES.size
        sa.lpSecurityDescriptor = psd_information.to_i
        sa.bInheritHandle = 1

        return sa
      end

      if RUBY_ENGINE == "jruby"
        def self.ptr_to_s(ptr, size)
          ret = ptr.to_s(size)
          ret << "\x00" while ret.size < size
          ret
        end

        def self.ptr_to_handle(phandle)
          phandle.ptr
        end

        def self.ptr_to_dword(ptr)
          first = ptr.ptr.to_i
          second = ptr_to_s(ptr,Win::SIZEOF_DWORD).unpack('L')[0]
          raise "Error" unless first == second
          first
        end

        def self.to_token_user(ptoken_information)
           TOKEN_USER.new(ptoken_information.to_ptr)
        end

        def self.to_struct_ptr(ptr)
          ptr.to_ptr
        end

        def self.get_sid(user)
          ptr_to_s(user.to_ptr.ptr,Win::SIZEOF_DWORD).unpack('L')[0]
        end

        def self.get_sid_ptr(user)
          user.to_ptr.ptr
        end
      else
        def self.get_sid(user)
          user.SID
        end

        def self.ptr_to_handle(phandle)
          phandle.ptr.to_i
        end

        def self.to_struct_ptr(ptr)
          ptr
        end

        def self.ptr_to_dword(ptr)
          ptr.to_s(Win::SIZEOF_DWORD).unpack('L')[0]
        end

        def self.to_token_user(ptoken_information)
          TOKEN_USER.new(ptoken_information)
        end

        def self.get_sid_ptr(user)
          user.SID
        end
      end

      def self.get_current_user
        token_handle = open_process_token(Win.GetCurrentProcess,
                                          Win::TOKEN_QUERY)
        token_user =  get_token_information(token_handle,
                        Win::TOKEN_USER_INFORMATION_CLASS)
        return token_user
      end

      def self.open_process_token(process_handle, desired_access)
        ptoken_handle = malloc_ptr(Win::SIZEOF_DWORD)

        raise_error_if_zero(
          Win.OpenProcessToken(process_handle, desired_access,
                               ptoken_handle))
        token_handle = ptr_to_handle(ptoken_handle)
        return token_handle
      end

      def self.get_token_information(token_handle,
                                     token_information_class)
        # Hold the size of the information to be returned
        preturn_length = malloc_ptr(Win::SIZEOF_DWORD)

        # Going to throw an INSUFFICIENT_BUFFER_ERROR, but that is ok
        # here. This is retrieving the size of the information to be
        # returned.
        Win.GetTokenInformation(token_handle,
                                token_information_class,
                                Win::NULL, 0, preturn_length)
        ptoken_information = malloc_ptr(ptr_to_dword(preturn_length))

        # This call is going to write the requested information to
        # the memory location referenced by token_information.
        raise_error_if_zero(
          Win.GetTokenInformation(token_handle,
                                  token_information_class,
                                  ptoken_information,
                                  ptoken_information.size,
                                  preturn_length))

        return to_token_user(ptoken_information)
      end

      def self.raise_error_if_zero(result)
        if result == 0
          raise "Windows error: #{Win.GetLastError}"
        end
      end

      # Get a null-terminated string given a string.
      def self.get_cstr(str)
        return str + "\000"
      end
    end

    # This is the pseudo-socket implementation that mimics the interface of
    # a socket, translating each request into a Windows messaging call to
    # the pageant daemon. This allows pageant support to be implemented
    # simply by replacing the socket factory used by the Agent class.
    class Socket

      private_class_method :new

      # The factory method for creating a new Socket instance.
      def self.open
        new
      end

      # Create a new instance that communicates with the running pageant
      # instance. If no such instance is running, this will cause an error.
      def initialize
        @win = Win.FindWindow("Pageant", "Pageant")

        if @win.to_i == 0
          raise Net::SSH::Exception,
            "pageant process not running"
        end

        @input_buffer = Net::SSH::Buffer.new
        @output_buffer = Net::SSH::Buffer.new
      end

      # Forwards the data to #send_query, ignoring any arguments after
      # the first.
      def send(data, *args)
        @input_buffer.append(data)

        ret = data.length

        while true
          return ret if @input_buffer.length < 4
          msg_length = @input_buffer.read_long + 4
          @input_buffer.reset!

          return ret if @input_buffer.length < msg_length
          msg = @input_buffer.read!(msg_length)
          @output_buffer.append(send_query(msg))
        end
      end

      # Reads +n+ bytes from the cached result of the last query. If +n+
      # is +nil+, returns all remaining data from the last query.
      def read(n = nil)
        @output_buffer.read(n)
      end

      def close
      end

      # Packages the given query string and sends it to the pageant
      # process via the Windows messaging subsystem. The result is
      # cached, to be returned piece-wise when #read is called.
      def send_query(query)
        res = nil
        filemap = 0
        ptr = nil
        id = Win.malloc_ptr(Win::SIZEOF_DWORD)

        mapname = "PageantRequest%08x" % Win.GetCurrentThreadId()
        security_attributes = Win.get_ptr Win.get_security_attributes_for_user

        filemap = Win.CreateFileMapping(Win::INVALID_HANDLE_VALUE,
                                        security_attributes,
                                        Win::PAGE_READWRITE, 0,
                                        AGENT_MAX_MSGLEN, mapname)

        if filemap == 0 || filemap == Win::INVALID_HANDLE_VALUE
          raise Net::SSH::Exception,
            "Creation of file mapping failed with error: #{Win.GetLastError}"
        end

        ptr = Win.MapViewOfFile(filemap, Win::FILE_MAP_WRITE, 0, 0,
                                0)

        if ptr.nil? || ptr.null?
          raise Net::SSH::Exception, "Mapping of file failed"
        end

        Win.set_ptr_data(ptr, query)

        # using struct to achieve proper alignment and field size on 64-bit platform
        cds = Win::COPYDATASTRUCT.new(Win.malloc_ptr(Win::COPYDATASTRUCT.size))
        cds.dwData = AGENT_COPYDATA_ID
        cds.cbData = mapname.size + 1
        cds.lpData = Win.get_cstr(mapname)
        succ = Win.SendMessageTimeout(@win, Win::WM_COPYDATA, Win::NULL,
                                      cds.to_ptr, Win::SMTO_NORMAL, 5000, id)

        if succ > 0
          retlen = 4 + ptr.to_s(4).unpack("N")[0]
          res = ptr.to_s(retlen)
        else
          raise Net::SSH::Exception, "Message failed with error: #{Win.GetLastError}"
        end

        return res
      ensure
        Win.UnmapViewOfFile(ptr) unless ptr.nil? || ptr.null?
        Win.CloseHandle(filemap) if filemap != 0
      end
    end
  end

end; end; end