Skip to content

Commit

Permalink
Merge pull request shirou#452 from leafnunes/master
Browse files Browse the repository at this point in the history
prevent hang on pkg import if wmi.Query hangs
  • Loading branch information
shirou authored Nov 21, 2017
2 parents 384a551 + 65598d9 commit bfe3c2e
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 11 deletions.
15 changes: 11 additions & 4 deletions cpu/cpu_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package cpu

import (
"context"
"fmt"
"unsafe"

Expand Down Expand Up @@ -81,8 +82,9 @@ func Info() ([]InfoStat, error) {
var ret []InfoStat
var dst []Win32_Processor
q := wmi.CreateQuery(&dst, "")
err := wmi.Query(q, &dst)
if err != nil {
ctx, cancel := context.WithTimeout(context.Background(), common.Timeout)
defer cancel()
if err := common.WMIQueryWithContext(ctx, q, &dst); err != nil {
return ret, err
}

Expand Down Expand Up @@ -113,8 +115,11 @@ func Info() ([]InfoStat, error) {
// Name property is the key by which overall, per cpu and per core metric is known.
func PerfInfo() ([]Win32_PerfFormattedData_Counters_ProcessorInformation, error) {
var ret []Win32_PerfFormattedData_Counters_ProcessorInformation

q := wmi.CreateQuery(&ret, "")
err := wmi.Query(q, &ret)
ctx, cancel := context.WithTimeout(context.Background(), common.Timeout)
defer cancel()
err := common.WMIQueryWithContext(ctx, q, &ret)
return ret, err
}

Expand All @@ -123,7 +128,9 @@ func PerfInfo() ([]Win32_PerfFormattedData_Counters_ProcessorInformation, error)
func ProcInfo() ([]Win32_PerfFormattedData_PerfOS_System, error) {
var ret []Win32_PerfFormattedData_PerfOS_System
q := wmi.CreateQuery(&ret, "")
err := wmi.Query(q, &ret)
ctx, cancel := context.WithTimeout(context.Background(), common.Timeout)
defer cancel()
err := common.WMIQueryWithContext(ctx, q, &ret)
return ret, err
}

Expand Down
6 changes: 4 additions & 2 deletions disk/disk_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ package disk

import (
"bytes"
"context"
"unsafe"

"github.com/StackExchange/wmi"
"github.com/shirou/gopsutil/internal/common"
"golang.org/x/sys/windows"
)
Expand Down Expand Up @@ -132,7 +132,9 @@ func IOCounters(names ...string) (map[string]IOCountersStat, error) {
ret := make(map[string]IOCountersStat, 0)
var dst []Win32_PerfFormattedData

err := wmi.Query("SELECT * FROM Win32_PerfFormattedData_PerfDisk_LogicalDisk ", &dst)
ctx, cancel := context.WithTimeout(context.Background(), common.Timeout)
defer cancel()
err := common.WMIQueryWithContext(ctx, "SELECT * FROM Win32_PerfFormattedData_PerfDisk_LogicalDisk", &dst)
if err != nil {
return ret, err
}
Expand Down
5 changes: 4 additions & 1 deletion host/host_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package host

import (
"context"
"fmt"
"os"
"runtime"
Expand Down Expand Up @@ -109,7 +110,9 @@ func getMachineGuid() (string, error) {
func GetOSInfo() (Win32_OperatingSystem, error) {
var dst []Win32_OperatingSystem
q := wmi.CreateQuery(&dst, "")
err := wmi.Query(q, &dst)
ctx, cancel := context.WithTimeout(context.Background(), common.Timeout)
defer cancel()
err := common.WMIQueryWithContext(ctx, q, &dst)
if err != nil {
return Win32_OperatingSystem{}, err
}
Expand Down
19 changes: 18 additions & 1 deletion internal/common/common_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
package common

import (
"context"
"unsafe"

"github.com/StackExchange/wmi"
"golang.org/x/sys/windows"
)

Expand Down Expand Up @@ -49,7 +51,7 @@ var (
ModNt = windows.NewLazyDLL("ntdll.dll")
ModPdh = windows.NewLazyDLL("pdh.dll")
ModPsapi = windows.NewLazyDLL("psapi.dll")

ProcGetSystemTimes = Modkernel32.NewProc("GetSystemTimes")
ProcNtQuerySystemInformation = ModNt.NewProc("NtQuerySystemInformation")
PdhOpenQuery = ModPdh.NewProc("PdhOpenQuery")
Expand Down Expand Up @@ -110,3 +112,18 @@ func CreateCounter(query windows.Handle, pname, cname string) (*CounterInfo, err
Counter: counter,
}, nil
}

// WMIQueryWithContext - wraps wmi.Query with a timed-out context to avoid hanging
func WMIQueryWithContext(ctx context.Context, query string, dst interface{}, connectServerArgs ...interface{}) error {
errChan := make(chan error, 1)
go func() {
errChan <- wmi.Query(query, dst, connectServerArgs...)
}()

select {
case <-ctx.Done():
return ctx.Err()
case err := <-errChan:
return err
}
}
11 changes: 8 additions & 3 deletions process/process_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package process

import (
"context"
"fmt"
"strings"
"syscall"
Expand Down Expand Up @@ -130,8 +131,10 @@ func GetWin32Proc(pid int32) ([]Win32_Process, error) {
var dst []Win32_Process
query := fmt.Sprintf("WHERE ProcessId = %d", pid)
q := wmi.CreateQuery(&dst, query)

if err := wmi.Query(q, &dst); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), common.Timeout)
defer cancel()
err := common.WMIQueryWithContext(ctx, q, &dst)
if err != nil {
return []Win32_Process{}, fmt.Errorf("could not get win32Proc: %s", err)
}

Expand Down Expand Up @@ -333,7 +336,9 @@ func (p *Process) MemoryInfoEx() (*MemoryInfoExStat, error) {
func (p *Process) Children() ([]*Process, error) {
var dst []Win32_Process
query := wmi.CreateQuery(&dst, fmt.Sprintf("Where ParentProcessId = %d", p.Pid))
err := wmi.Query(query, &dst)
ctx, cancel := context.WithTimeout(context.Background(), common.Timeout)
defer cancel()
err := common.WMIQueryWithContext(ctx, query, &dst)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit bfe3c2e

Please sign in to comment.