forked from dotnet/machinelearning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathOneDalUtils.cs
92 lines (86 loc) · 3.49 KB
/
OneDalUtils.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
namespace Microsoft.ML.OneDal
{
[BestFriend]
internal static class OneDalUtils
{
[BestFriend]
internal static bool IsDispatchingEnabled()
{
if (Environment.GetEnvironmentVariable("MLNET_BACKEND") == "ONEDAL" &&
System.Runtime.InteropServices.RuntimeInformation.ProcessArchitecture == System.Runtime.InteropServices.Architecture.X64)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
#if NETFRAMEWORK
// AppContext not available in the framework, user needs to set PATH manually
// this will probably result in a runtime error where the user needs to set the PATH
#else
var currentDir = AppContext.BaseDirectory;
var nativeLibs = Path.Combine(currentDir, "runtimes", "win-x64", "native");
var originalPath = Environment.GetEnvironmentVariable("PATH");
Environment.SetEnvironmentVariable("PATH", nativeLibs + ";" + originalPath);
#endif
}
return true;
}
return false;
}
[BestFriend]
internal static long GetTrainData(IChannel channel, FloatLabelCursor.Factory cursorFactory, ref List<float> featuresList, ref List<float> labelsList, int numberOfFeatures)
{
long n = 0;
using (var cursor = cursorFactory.Create())
{
while (cursor.MoveNext())
{
// label
labelsList.Add(cursor.Label);
// features
var values = cursor.Features.GetValues();
if (cursor.Features.IsDense)
{
channel.Assert(values.Length == numberOfFeatures);
for (int j = 0; j < numberOfFeatures; ++j)
{
featuresList.Add(values[j]);
}
}
else
{
var indices = cursor.Features.GetIndices();
int i = 0;
for (int j = 0; j < indices.Length; ++j)
{
for (int k = i; k < indices[j]; ++k)
{
featuresList.Add(0);
}
featuresList.Add(values[j]);
i = indices[j] + 1;
}
for (int j = i; j < numberOfFeatures; ++j)
{
featuresList.Add(0);
}
}
n++;
}
channel.Check(n > 0, "No training examples in dataset.");
if (cursor.BadFeaturesRowCount > 0)
channel.Warning("Skipped {0} instances with missing features/labelColumn during training", cursor.SkippedRowCount);
}
return n;
}
}
}