Skip to content

Commit

Permalink
支持自定义参数
Browse files Browse the repository at this point in the history
vnt-dev committed Apr 23, 2023
1 parent a901ebc commit c4e1635
Showing 3 changed files with 114 additions and 21 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@ edition = "2021"

[dependencies]
packet = {path = "./packet"}
clap = { version = "4.0.32", features = ["derive"] }
log = "0.4.17"
log4rs = "1.2.0"
dirs = "4.0.0"
87 changes: 82 additions & 5 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,43 @@
use std::net::UdpSocket;
use std::collections::HashSet;
use std::net::{Ipv4Addr, UdpSocket};
use std::thread;

use clap::Parser;

pub mod error;
pub mod proto;
pub mod protocol;
pub mod service;

/// 默认网关信息
const GATEWAY: Ipv4Addr = Ipv4Addr::new(10, 26, 0, 1);
const NETMASK: Ipv4Addr = Ipv4Addr::new(255, 255, 255, 0);

#[derive(Parser, Debug, Clone)]
pub struct StartArgs {
/// 指定端口
#[arg(long)]
port: Option<u16>,
/// token白名单,例如 --white-token 1234 --white-token 123
#[arg(long)]
white_token: Option<Vec<String>>,
/// 网关,例如 --gateway 10.10.0.1
#[arg(long)]
gateway: Option<String>,
/// 子网掩码,例如 --netmask 255.255.255.0
#[arg(long)]
netmask: Option<String>,
}

#[derive(Debug, Clone)]
pub struct ConfigInfo {
pub port: u16,
pub white_token: Option<HashSet<String>>,
pub gateway: Ipv4Addr,
pub broadcast: Ipv4Addr,
pub netmask: Ipv4Addr,
}

fn log_init() {
let home = dirs::home_dir().unwrap().join(".switch_server");
if !home.exists() {
@@ -30,21 +62,66 @@ fn log_init() {
}

fn main() {
let args = StartArgs::parse();
let port = args.port.unwrap_or(29871);
println!("端口:{}", port);
let white_token = if let Some(white_token) = args.white_token {
Some(HashSet::from_iter(white_token.into_iter()))
} else {
None
};
println!("白名单:{:?}", white_token);
let gateway = if let Some(gateway) = args.gateway {
gateway.parse::<Ipv4Addr>().expect("网关错误,必须为有效的ipv4地址")
} else {
GATEWAY
};
println!("网关:{:?}", gateway);
if gateway.is_broadcast() || gateway.is_unspecified() {
println!("网关错误");
return;
}
if !gateway.is_private() {
println!("Warning 网关不是一个私有地址:{:?}", gateway);
}
let netmask = if let Some(netmask) = args.netmask {
netmask.parse::<Ipv4Addr>().expect("子网掩码错误,必须为有效的ipv4地址")
} else {
NETMASK
};
println!("子网掩码:{:?}", netmask);
if netmask.is_broadcast() || netmask.is_unspecified() || !(!u32::from_be_bytes(netmask.octets()) + 1).is_power_of_two() {
println!("子网掩码错误");
return;
}

let broadcast = (!u32::from_be_bytes(netmask.octets()))
| u32::from_be_bytes(gateway.octets());
let broadcast = Ipv4Addr::from(broadcast);
let config = ConfigInfo {
port,
white_token,
gateway,
broadcast,
netmask,
};
log_init();
let udp = UdpSocket::bind("0.0.0.0:29871").unwrap();
log::info!("启动:{:?}",udp.local_addr().unwrap());
println!("启动成功:{:?}",udp.local_addr().unwrap());
log::info!("启动成功,udp:{:?}",udp.local_addr().unwrap());
println!("启动成功,udp:{:?}", udp.local_addr().unwrap());
log::info!("config:{:?}",config);
let num = if let Ok(num) = thread::available_parallelism() {
num.get() * 2
} else {
2
};
for _ in 0..num {
let udp = udp.try_clone().unwrap();
let config = config.clone();
thread::spawn(move || {
service::handle_loop(udp);
service::handle_loop(udp, config);
});
}

service::handle_loop(udp);
service::handle_loop(udp, config);
}
47 changes: 31 additions & 16 deletions src/service/udp_service.rs
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@ use packet::ip::ipv4::packet::IpV4Packet;
use parking_lot::Mutex;
use protobuf::Message;

use crate::ConfigInfo;
use crate::error::*;
use crate::proto::message;
use crate::proto::message::{DeviceList, RegistrationRequest, RegistrationResponse};
@@ -102,12 +103,12 @@ impl From<u8> for PeerDeviceStatus {
}
}

pub fn handle_loop(udp: UdpSocket) {
pub fn handle_loop(udp: UdpSocket, config: ConfigInfo) {
let mut buf = [0u8; 65536];
loop {
match udp.recv_from(&mut buf) {
Ok((len, addr)) => {
match handle(&udp, &mut buf[..len], addr) {
match handle(&udp, &mut buf[..len], addr, &config) {
Ok(_) => {}
Err(e) => {
log::error!("{:?}", e)
@@ -121,7 +122,7 @@ pub fn handle_loop(udp: UdpSocket) {
}
}

fn handle(udp: &UdpSocket, buf: &mut [u8], addr: SocketAddr) -> Result<()> {
fn handle(udp: &UdpSocket, buf: &mut [u8], addr: SocketAddr, config: &ConfigInfo) -> Result<()> {
let net_packet = NetPacket::new(buf)?;
if net_packet.protocol() == Protocol::Service
&& net_packet.transport_protocol()
@@ -131,6 +132,20 @@ fn handle(udp: &UdpSocket, buf: &mut [u8], addr: SocketAddr) -> Result<()> {
{
let request = RegistrationRequest::parse_from_bytes(net_packet.payload())?;
log::info!("register:{:?}",request);
if let Some(white_token) = &config.white_token {
if !white_token.contains(&request.token) {
log::info!("token不在白名单,white_token={:?},token={:?}",white_token,request.token);
let mut net_packet = NetPacket::new([0u8; 12])?;
net_packet.set_version(Version::V1);
net_packet.set_protocol(Protocol::Error);
net_packet
.set_transport_protocol(error_packet::Protocol::TokenError.into());
net_packet.first_set_ttl(MAX_TTL);
net_packet.set_source(config.gateway);
udp.send_to(net_packet.buffer(), addr)?;
return Ok(());
}
}
let mut response = RegistrationResponse::new();
match addr.ip() {
IpAddr::V4(ipv4) => {
@@ -142,9 +157,8 @@ fn handle(udp: &UdpSocket, buf: &mut [u8], addr: SocketAddr) -> Result<()> {
return Ok(());
}
}
//todo 暂时写死地址 考虑验证token,比如从数据库根据token读出网关
response.virtual_netmask = u32::from_be_bytes(NETMASK.octets());
response.virtual_gateway = u32::from_be_bytes(GATE_WAY.octets());
response.virtual_netmask = u32::from_be_bytes(config.netmask.octets());
response.virtual_gateway = u32::from_be_bytes(config.gateway.octets());
if let Some(v) = VIRTUAL_NETWORK.optionally_get_with(request.token.clone(), || {
Some(Arc::new(parking_lot::const_mutex(VirtualNetwork {
epoch: 0,
@@ -169,7 +183,10 @@ fn handle(udp: &UdpSocket, buf: &mut [u8], addr: SocketAddr) -> Result<()> {
.iter()
.map(|(_, device_info)| device_info.ip)
.collect();
for ip in response.virtual_gateway + 1..response.virtual_gateway + 128 {
for ip in (response.virtual_gateway & response.virtual_netmask) + 1..response.virtual_gateway | (!response.virtual_netmask) {
if ip == response.virtual_gateway {
continue;
}
if !set.contains(&ip) {
virtual_ip = ip;
break;
@@ -183,7 +200,7 @@ fn handle(udp: &UdpSocket, buf: &mut [u8], addr: SocketAddr) -> Result<()> {
net_packet
.set_transport_protocol(error_packet::Protocol::AddressExhausted.into());
net_packet.first_set_ttl(MAX_TTL);
net_packet.set_source(GATE_WAY);
net_packet.set_source(config.gateway);
udp.send_to(net_packet.buffer(), addr)?;
return Ok(());
}
@@ -226,7 +243,7 @@ fn handle(udp: &UdpSocket, buf: &mut [u8], addr: SocketAddr) -> Result<()> {
let mut net_packet = NetPacket::new(vec![0u8; 12 + bytes.len()])?;
net_packet.set_version(Version::V1);
net_packet.set_protocol(Protocol::Service);
net_packet.set_source(GATE_WAY);
net_packet.set_source(config.gateway);
net_packet.set_destination(Ipv4Addr::from(response.virtual_ip));
net_packet.set_transport_protocol(service_packet::Protocol::RegistrationResponse.into());
net_packet.first_set_ttl(MAX_TTL);
@@ -242,7 +259,7 @@ fn handle(udp: &UdpSocket, buf: &mut [u8], addr: SocketAddr) -> Result<()> {
.get(&(context.token.clone(), context.device_id.clone()))
.is_some()
{
handle_(udp, addr, net_packet, context)?;
handle_(udp, addr, net_packet, context, config)?;
return Ok(());
}
}
@@ -253,29 +270,27 @@ fn handle(udp: &UdpSocket, buf: &mut [u8], addr: SocketAddr) -> Result<()> {
net_packet.set_protocol(Protocol::Error);
net_packet.set_transport_protocol(error_packet::Protocol::Disconnect.into());
net_packet.first_set_ttl(MAX_TTL);
net_packet.set_source(GATE_WAY);
net_packet.set_source(config.gateway);
net_packet.set_destination(source);
udp.send_to(net_packet.buffer(), addr)?;
Ok(())
}

const GATE_WAY: Ipv4Addr = Ipv4Addr::new(10, 26, 0, 1);
const BROADCAST: Ipv4Addr = Ipv4Addr::new(10, 26, 0, 255);
const NETMASK: Ipv4Addr = Ipv4Addr::new(255, 255, 255, 0);

fn handle_(
udp: &UdpSocket,
addr: SocketAddr,
mut net_packet: NetPacket<&mut [u8]>,
context: Context,
config: &ConfigInfo,
) -> Result<()> {
let source = net_packet.source();
let destination = net_packet.destination();
if destination != GATE_WAY {
if destination != config.gateway {
// 转发
if net_packet.ttl() > 1 {
net_packet.set_ttl(net_packet.ttl() - 1);
if destination.is_broadcast() || (destination.octets()[3] == 255 && BROADCAST == destination) {
if destination.is_broadcast() || (destination.octets()[3] == 255 && config.broadcast == destination) {
//本地广播和直接广播
if let Some(v) = VIRTUAL_NETWORK.get(&context.token) {
if let Some(lock) = v.try_lock() {

0 comments on commit c4e1635

Please sign in to comment.