-
Notifications
You must be signed in to change notification settings - Fork 0
/
ball_and_speed.py
60 lines (46 loc) · 1.73 KB
/
ball_and_speed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# project to do prediction according to the speed of ball and wheel
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import joblib
# creatoin of own dataset of values
# Set random seed for reproducibility
np.random.seed(42)
# Generate synthetic data
num_samples = 100
ball_speed = np.random.uniform(0, 20, num_samples) # Ball speeds between 0 and 20 m/s
wheel_speed = np.random.uniform(0, 20, num_samples) # Wheel speeds between 0 and 20 m/s
target_variable = 0.5 * ball_speed + 0.3 * wheel_speed + np.random.normal(0, 1, num_samples) # Example target
# Create a DataFrame
data = pd.DataFrame({
'ball_speed': ball_speed,
'wheel_speed': wheel_speed,
'target_variable': target_variable
})
# Save to CSV
data.to_csv('ball_wheel_data.csv', index=False)
print("Dataset created and saved as 'ball_wheel_data.csv'.")
# Load data
data = pd.read_csv('ball_wheel_data.csv')
# Prepare features and target variable
X = data[['ball_speed', 'wheel_speed']]
y = data['target_variable'] # e.g., distance traveled
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Initialize and train model
model = LinearRegression()
model.fit(X_train, y_train)
# Make predictions
y_pred = model.predict(X_test)
# Evaluate model
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse}")
# Save the model
joblib.dump(model, 'ball_wheel_model.pkl')
# Load and use the model
model = joblib.load('ball_wheel_model.pkl')
new_data = pd.DataFrame({'ball_speed': [10], 'wheel_speed': [5]})
prediction = model.predict(new_data)
print(f"Prediction: {prediction}")