diff --git a/conf/example.yaml b/conf/example.yaml new file mode 100644 index 0000000..6d23451 --- /dev/null +++ b/conf/example.yaml @@ -0,0 +1,17 @@ +- public_key: pkey1 + private_key: pkey2 + endpoint: + ip: 1.1.1.1 + port: 51820 + ip: 10.1.0.1/24 + interface_name: test + namespace_name: test-netns2 + +- public_key: pkey3 + private_key: pkey4 + endpoint: + ip: 8.8.8.8 + port: 51820 + ip: 10.1.0.2/24 + interface_name: test + namespace_name: test-netns3 \ No newline at end of file diff --git a/src/config/mod.rs b/src/config/mod.rs new file mode 100644 index 0000000..2bf1675 --- /dev/null +++ b/src/config/mod.rs @@ -0,0 +1,86 @@ +use serde::{Serialize, Deserialize}; +use std::fs::File; +use std::io::Read; +use base64::prelude::*; + +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] +pub struct VPNConfig { + public_key: String, + private_key: String, + endpoint: PeerEndpoint, + ip: String, + interface_name: String, + namespace_name: String, +} + +pub struct ConsumableVPNConfig { + pub public_key: [u8; 32], + pub private_key: [u8; 32], + pub endpoint: ConsumablePeerEndpoint, + pub ip: String, + pub prefix: u8, + pub interface_name: String, + pub namespace_name: String, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] +pub struct PeerEndpoint { + ip: String, + port: u16, +} + +pub struct ConsumablePeerEndpoint { + pub ip: String, + //pub prefix: u8, + pub port: u16, +} + +impl VPNConfig { + pub fn get_consumable(self) -> ConsumableVPNConfig { + let c_ip: Vec<&str> = self.ip.split("/").collect(); + if c_ip.len() != 2 { + panic!("malformed ip, len:{}", c_ip.len()); + } + let priv_key_dirty = BASE64_STANDARD.decode(self.private_key.into_bytes()).unwrap(); + let pub_key_dirty = BASE64_STANDARD.decode(self.public_key.into_bytes()).unwrap(); + let mut priv_key: [u8; 32] = Default::default(); + let mut pub_key: [u8; 32] = Default::default(); + priv_key.copy_from_slice(&priv_key_dirty[0..32]); + pub_key.copy_from_slice(&pub_key_dirty[0..32]); + ConsumableVPNConfig { + public_key: pub_key, + private_key: priv_key, + endpoint: ConsumablePeerEndpoint { + ip: self.endpoint.ip, + port: self.endpoint.port, + }, + ip: c_ip[0].to_string(), + prefix: c_ip[1].parse::().unwrap(), + interface_name: self.interface_name, + namespace_name: self.namespace_name, + } + } +} + +pub fn get_vpn_conf(file: String) -> Vec { + let mut file = check_file(file); + let mut s = String::new(); + file.read_to_string(&mut s).unwrap(); + match serde_yaml::from_str(&s) { + Ok(result) => result, + Err(e) => { + log::error!("malformed: {}",e); + Vec::new() + }, + } +} + +pub fn check_file(file: String) -> File { + match File::open(file) { + Ok(f) => f, + Err(e) => { + log::error!("Cannot file conf file: {}", e); + panic!("{}", e); + } + } +} diff --git a/src/main.rs b/src/main.rs index 1613267..8970cb2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,19 +1,61 @@ mod namespace; mod manage_interfaces; mod wireguard_manager; +mod config; use std::io::Result; use rtnetlink::NetworkNamespace; use futures::executor::block_on; use base64::prelude::*; -fn main() { +fn main(){ + env_logger::Builder::from_default_env() + .format_timestamp_secs() + .filter(None, log::LevelFilter::Debug) + .init(); + let args: Vec = std::env::args().collect(); + match args.len() { + 3 => { + let cmd = &args[1]; + let param = &args[2]; + match &cmd[..] { + "-conf" => { + for vpn in config::get_vpn_conf(param.to_string()) { + create_namespace(vpn.get_consumable()); + } + } + _ => {}, + } + }, + _ => {}, + } +} + +pub fn create_namespace(vpn: config::ConsumableVPNConfig) { + let ns_name = vpn.namespace_name.clone(); + block_on(NetworkNamespace::add(ns_name.clone())).unwrap(); + namespace::bind_interface::run_in_namespace(|| { + manage_interfaces::set_interface_lo_up().unwrap(); + }, + &ns_name).unwrap(); + namespace::bind_interface::run_in_namespace(|| { + manage_interfaces::create_wireguard_interface(vpn.interface_name.clone(), + vpn.ip.clone(), + vpn.endpoint.ip.clone(), + vpn.prefix as u8, + vpn.public_key, + vpn.private_key).unwrap(); + }, + &ns_name).unwrap(); +} + +/*fn main2() { env_logger::Builder::from_default_env() .format_timestamp_secs() .filter(None, log::LevelFilter::Debug) .init(); - let priv_key_dirty = BASE64_STANDARD.decode(b"key1").unwrap(); - let pub_key_dirty = BASE64_STANDARD.decode(b"key2").unwrap(); + let priv_key_dirty = BASE64_STANDARD.decode(b"k1").unwrap(); + let pub_key_dirty = BASE64_STANDARD.decode(b"k2").unwrap(); let mut priv_key: [u8; 32] = Default::default(); let mut pub_key: [u8; 32] = Default::default(); priv_key.copy_from_slice(&priv_key_dirty[0..32]); @@ -27,9 +69,9 @@ fn main() { }, &ns_name).unwrap(); namespace::bind_interface::run_in_namespace(|| { - manage_interfaces::create_wireguard_interface(String::from("wgiface"), - String::from("local_ip"), - String::from("remote_ip"), + manage_interfaces::create_wireguard_interface(String::from("wgzurich"), + String::from("ip1"), + String::from("ip2"), 24, pub_key, priv_key).unwrap(); @@ -37,3 +79,4 @@ fn main() { &ns_name).unwrap(); //println!("{}",wireguard_manager::add_properties::set_params(pub_key, priv_key)) } +*/