diff --git a/swbemservices.go b/swbemservices.go new file mode 100644 index 0000000..9765a53 --- /dev/null +++ b/swbemservices.go @@ -0,0 +1,260 @@ +// +build windows + +package wmi + +import ( + "fmt" + "reflect" + "runtime" + "sync" + + "github.com/go-ole/go-ole" + "github.com/go-ole/go-ole/oleutil" +) + +// SWbemServices is used to access wmi. See https://msdn.microsoft.com/en-us/library/aa393719(v=vs.85).aspx +type SWbemServices struct { + //TODO: track namespace. Not sure if we can re connect to a different namespace using the same instance + cWMIClient *Client //This could also be an embedded struct, but then we would need to branch on Client vs SWbemServices in the Query method + sWbemLocatorIUnknown *ole.IUnknown + sWbemLocatorIDispatch *ole.IDispatch + queries chan *queryRequest + closeError chan error + lQueryorClose sync.Mutex +} + +type queryRequest struct { + query string + dst interface{} + args []interface{} + finished chan error +} + +// InitializeSWbemServices will return a new SWbemServices object that can be used to query WMI +func InitializeSWbemServices(c *Client, connectServerArgs ...interface{}) (*SWbemServices, error) { + //fmt.Println("InitializeSWbemServices: Starting") + //TODO: implement connectServerArgs as optional argument for init with connectServer call + s := new(SWbemServices) + s.cWMIClient = c + s.queries = make(chan *queryRequest) + initError := make(chan error) + go s.process(initError) + + err, ok := <-initError + if ok { + return nil, err //Send error to caller + } + //fmt.Println("InitializeSWbemServices: Finished") + return s, nil +} + +// Close will clear and release all of the SWbemServices resources +func (s *SWbemServices) Close() error { + s.lQueryorClose.Lock() + if s == nil || s.sWbemLocatorIDispatch == nil { + s.lQueryorClose.Unlock() + return fmt.Errorf("SWbemServices is not Initialized") + } + if s.queries == nil { + s.lQueryorClose.Unlock() + return fmt.Errorf("SWbemServices has been closed") + } + //fmt.Println("Close: sending close request") + var result error + ce := make(chan error) + s.closeError = ce //Race condition if multiple callers to close. May need to lock here + close(s.queries) //Tell background to shut things down + s.lQueryorClose.Unlock() + err, ok := <-ce + if ok { + result = err + } + //fmt.Println("Close: finished") + return result +} + +func (s *SWbemServices) process(initError chan error) { + //fmt.Println("process: starting background thread initialization") + //All OLE/WMI calls must happen on the same initialized thead, so lock this goroutine + runtime.LockOSThread() + defer runtime.LockOSThread() + + err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED) + if err != nil { + oleCode := err.(*ole.OleError).Code() + if oleCode != ole.S_OK && oleCode != S_FALSE { + initError <- fmt.Errorf("ole.CoInitializeEx error: %v", err) + return + } + } + defer ole.CoUninitialize() + + unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator") + if err != nil { + initError <- fmt.Errorf("CreateObject SWbemLocator error: %v", err) + return + } else if unknown == nil { + initError <- ErrNilCreateObject + return + } + defer unknown.Release() + s.sWbemLocatorIUnknown = unknown + + dispatch, err := s.sWbemLocatorIUnknown.QueryInterface(ole.IID_IDispatch) + if err != nil { + initError <- fmt.Errorf("SWbemLocator QueryInterface error: %v", err) + return + } + defer dispatch.Release() + s.sWbemLocatorIDispatch = dispatch + + // we can't do the ConnectServer call outside the loop unless we find a way to track and re-init the connectServerArgs + //fmt.Println("process: initialized. closing initError") + close(initError) + //fmt.Println("process: waiting for queries") + for q := range s.queries { + //fmt.Printf("process: new query: len(query)=%d\n", len(q.query)) + errQuery := s.queryBackground(q) + //fmt.Println("process: s.queryBackground finished") + if errQuery != nil { + q.finished <- errQuery + } + close(q.finished) + } + //fmt.Println("process: queries channel closed") + s.queries = nil //set channel to nil so we know it is closed + //TODO: I think the Release/Clear calls can panic if things are in a bad state. + //TODO: May need to recover from panics and send error to method caller instead. + close(s.closeError) +} + +// Query runs the WQL query using a SWbemServices instance and appends the values to dst. +// +// dst must have type *[]S or *[]*S, for some struct type S. Fields selected in +// the query must have the same name in dst. Supported types are all signed and +// unsigned integers, time.Time, string, bool, or a pointer to one of those. +// Array types are not supported. +// +// By default, the local machine and default namespace are used. These can be +// changed using connectServerArgs. See +// http://msdn.microsoft.com/en-us/library/aa393720.aspx for details. +func (s *SWbemServices) Query(query string, dst interface{}, connectServerArgs ...interface{}) error { + s.lQueryorClose.Lock() + if s == nil || s.sWbemLocatorIDispatch == nil { + s.lQueryorClose.Unlock() + return fmt.Errorf("SWbemServices is not Initialized") + } + if s.queries == nil { + s.lQueryorClose.Unlock() + return fmt.Errorf("SWbemServices has been closed") + } + + //fmt.Println("Query: Sending query request") + qr := queryRequest{ + query: query, + dst: dst, + args: connectServerArgs, + finished: make(chan error), + } + s.queries <- &qr + s.lQueryorClose.Unlock() + err, ok := <-qr.finished + if ok { + //fmt.Println("Query: Finished with error") + return err //Send error to caller + } + //fmt.Println("Query: Finished") + return nil +} + +func (s *SWbemServices) queryBackground(q *queryRequest) error { + if s == nil || s.sWbemLocatorIDispatch == nil { + return fmt.Errorf("SWbemServices is not Initialized") + } + wmi := s.sWbemLocatorIDispatch //Should just rename in the code, but this will help as we break things apart + //fmt.Println("queryBackground: Starting") + + dv := reflect.ValueOf(q.dst) + if dv.Kind() != reflect.Ptr || dv.IsNil() { + return ErrInvalidEntityType + } + dv = dv.Elem() + mat, elemType := checkMultiArg(dv) + if mat == multiArgTypeInvalid { + return ErrInvalidEntityType + } + + // service is a SWbemServices + serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", q.args...) + if err != nil { + return err + } + service := serviceRaw.ToIDispatch() + defer serviceRaw.Clear() + + // result is a SWBemObjectSet + resultRaw, err := oleutil.CallMethod(service, "ExecQuery", q.query) + if err != nil { + return err + } + result := resultRaw.ToIDispatch() + defer resultRaw.Clear() + + count, err := oleInt64(result, "Count") + if err != nil { + return err + } + + enumProperty, err := result.GetProperty("_NewEnum") + if err != nil { + return err + } + defer enumProperty.Clear() + + enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant) + if err != nil { + return err + } + if enum == nil { + return fmt.Errorf("can't get IEnumVARIANT, enum is nil") + } + defer enum.Release() + + // Initialize a slice with Count capacity + dv.Set(reflect.MakeSlice(dv.Type(), 0, int(count))) + + var errFieldMismatch error + for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) { + if err != nil { + return err + } + + err := func() error { + // item is a SWbemObject, but really a Win32_Process + item := itemRaw.ToIDispatch() + defer item.Release() + + ev := reflect.New(elemType) + if err = s.cWMIClient.loadEntity(ev.Interface(), item); err != nil { + if _, ok := err.(*ErrFieldMismatch); ok { + // We continue loading entities even in the face of field mismatch errors. + // If we encounter any other error, that other error is returned. Otherwise, + // an ErrFieldMismatch is returned. + errFieldMismatch = err + } else { + return err + } + } + if mat != multiArgTypeStructPtr { + ev = ev.Elem() + } + dv.Set(reflect.Append(dv, ev)) + return nil + }() + if err != nil { + return err + } + } + //fmt.Println("queryBackground: Finished") + return errFieldMismatch +} diff --git a/swbemservices_test.go b/swbemservices_test.go new file mode 100644 index 0000000..1d619af --- /dev/null +++ b/swbemservices_test.go @@ -0,0 +1,151 @@ +// +build windows + +package wmi + +import ( + "fmt" + "runtime" + "testing" + "time" +) + +func TestWbemQuery(t *testing.T) { + s, err := InitializeSWbemServices(DefaultClient) + if err != nil { + t.Fatalf("InitializeSWbemServices: %s", err) + } + + var dst []Win32_Process + q := CreateQuery(&dst, "WHERE name='lsass.exe'") + errQuery := s.Query(q, &dst) + if errQuery != nil { + t.Fatalf("Query1: %s", errQuery) + } + count := len(dst) + if count < 1 { + t.Fatal("Query1: no results found for lsass.exe") + } + //fmt.Printf("dst[0].ProcessID=%d\n", dst[0].ProcessId) + + q2 := CreateQuery(&dst, "WHERE name='svchost.exe'") + errQuery = s.Query(q2, &dst) + if errQuery != nil { + t.Fatalf("Query2: %s", errQuery) + } + count = len(dst) + if count < 1 { + t.Fatal("Query2: no results found for svchost.exe") + } + //for index, item := range dst { + // fmt.Printf("dst[%d].ProcessID=%d\n", index, item.ProcessId) + //} + errClose := s.Close() + if errClose != nil { + t.Fatalf("Close: %s", errClose) + } +} + +func TestWbemQueryNamespace(t *testing.T) { + s, err := InitializeSWbemServices(DefaultClient) + if err != nil { + t.Fatalf("InitializeSWbemServices: %s", err) + } + var dst []MSFT_NetAdapter + q := CreateQuery(&dst, "") + errQuery := s.Query(q, &dst, nil, "root\\StandardCimv2") + if errQuery != nil { + t.Fatalf("Query: %s", errQuery) + } + count := len(dst) + if count < 1 { + t.Fatal("Query: no results found for MSFT_NetAdapter in root\\StandardCimv2") + } + errClose := s.Close() + if errClose != nil { + t.Fatalf("Close: %s", errClose) + } +} + +// Run using: go test -run TestWbemMemory -timeout 60m +func TestWbemMemory(t *testing.T) { + s, err := InitializeSWbemServices(DefaultClient) + if err != nil { + t.Fatalf("InitializeSWbemServices: %s", err) + } + start := time.Now() + limit := 500000 + fmt.Printf("Benchmark Iterations: %d (Memory should stabilize around 7MB after ~3000)\n", limit) + var privateMB, allocMB, allocTotalMB float64 + for i := 0; i < limit; i++ { + privateMB, allocMB, allocTotalMB = WbemGetMemoryUsageMB(s) + if i%100 == 0 { + privateMB, allocMB, allocTotalMB = WbemGetMemoryUsageMB(s) + fmt.Printf("Time: %4ds Count: %5d Private Memory: %5.1fMB MemStats.Alloc: %4.1fMB MemStats.TotalAlloc: %5.1fMB\n", time.Now().Sub(start)/time.Second, i, privateMB, allocMB, allocTotalMB) + } + } + errClose := s.Close() + if errClose != nil { + t.Fatalf("Close: %s", err) + } + fmt.Printf("Final Time: %4ds Private Memory: %5.1fMB MemStats.Alloc: %4.1fMB MemStats.TotalAlloc: %5.1fMB\n", time.Now().Sub(start)/time.Second, privateMB, allocMB, allocTotalMB) +} + +func WbemGetMemoryUsageMB(s *SWbemServices) (float64, float64, float64) { + runtime.ReadMemStats(&mMemoryUsageMB) + errGetMemoryUsageMB = s.Query(qGetMemoryUsageMB, &dstGetMemoryUsageMB) + if errGetMemoryUsageMB != nil { + fmt.Println("ERROR GetMemoryUsage", errGetMemoryUsageMB) + return 0, 0, 0 + } + return float64(dstGetMemoryUsageMB[0].WorkingSetPrivate) / MB, float64(mMemoryUsageMB.Alloc) / MB, float64(mMemoryUsageMB.TotalAlloc) / MB +} + +//Run all benchmarks (should run for at least 60s to get a stable number): +//go test -run=NONE -bench=Version -benchtime=120s + +//Individual benchmarks: +//go test -run=NONE -bench=NewVersion -benchtime=120s +func BenchmarkNewVersion(b *testing.B) { + s, err := InitializeSWbemServices(DefaultClient) + if err != nil { + b.Fatalf("InitializeSWbemServices: %s", err) + } + var dst []Win32_OperatingSystem + q := CreateQuery(&dst, "") + for n := 0; n < b.N; n++ { + errQuery := s.Query(q, &dst) + if errQuery != nil { + b.Fatalf("Query%d: %s", n, errQuery) + } + count := len(dst) + if count < 1 { + b.Fatalf("Query%d: no results found for Win32_OperatingSystem", n) + } + } + errClose := s.Close() + if errClose != nil { + b.Fatalf("Close: %s", errClose) + } +} + +//go test -run=NONE -bench=OldVersion -benchtime=120s +func BenchmarkOldVersion(b *testing.B) { + var dst []Win32_OperatingSystem + q := CreateQuery(&dst, "") + for n := 0; n < b.N; n++ { + errQuery := Query(q, &dst) + if errQuery != nil { + b.Fatalf("Query%d: %s", n, errQuery) + } + count := len(dst) + if count < 1 { + b.Fatalf("Query%d: no results found for Win32_OperatingSystem", n) + } + } +} + +type MSFT_NetAdapter struct { + Name string + InterfaceIndex int + DriverDescription string +} diff --git a/wmi.go b/wmi.go index 5d0c7ab..9c688b0 100644 --- a/wmi.go +++ b/wmi.go @@ -72,7 +72,10 @@ func QueryNamespace(query string, dst interface{}, namespace string) error { // // Query is a wrapper around DefaultClient.Query. func Query(query string, dst interface{}, connectServerArgs ...interface{}) error { - return DefaultClient.Query(query, dst, connectServerArgs...) + if DefaultClient.SWbemServicesClient == nil { + return DefaultClient.Query(query, dst, connectServerArgs...) + } + return DefaultClient.SWbemServicesClient.Query(query, dst, connectServerArgs...) } // A Client is an WMI query client. @@ -99,6 +102,11 @@ type Client struct { // Setting this to true allows custom queries to be used with full // struct definitions instead of having to define multiple structs. AllowMissingFields bool + + // SWbemServiceClient is an optional SWbemServices object that can be + // initialized and then reused across multiple queries. If it is null + // then the method will initialize a new temporary client each time. + SWbemServicesClient *SWbemServices } // DefaultClient is the default Client and is used by Query, QueryNamespace