Browse Source

merge manyuser branch

master
breakwa11 5 years ago
parent
commit
5b6dec5940
7 changed files with 438 additions and 75 deletions
  1. 18
    3
      shadowsocks/common.py
  2. 9
    0
      shadowsocks/eventloop.py
  3. 286
    0
      shadowsocks/manager.py
  4. 10
    2
      shadowsocks/server.py
  5. 10
    3
      shadowsocks/shell.py
  6. 15
    12
      shadowsocks/tcprelay.py
  7. 90
    55
      shadowsocks/udprelay.py

+ 18
- 3
shadowsocks/common.py View File

@@ -21,7 +21,7 @@ from __future__ import absolute_import, division, print_function, \
import socket
import struct
import logging
import binascii

def compat_ord(s):
if type(s) == int:
@@ -140,7 +140,7 @@ def pack_addr(address):

def pre_parse_header(data):
datatype = ord(data[0])
if datatype == 0x80 :
if datatype == 0x80:
if len(data) <= 2:
return None
rand_data_size = ord(data[1])
@@ -151,7 +151,7 @@ def pre_parse_header(data):
data = data[rand_data_size + 2:]
elif datatype == 0x81:
data = data[1:]
elif datatype == 0x82 :
elif datatype == 0x82:
if len(data) <= 3:
return None
rand_data_size = struct.unpack('>H', data[1:3])[0]
@@ -160,6 +160,21 @@ def pre_parse_header(data):
'encryption method')
return None
data = data[rand_data_size + 3:]
elif datatype == 0x88:
if len(data) <= 7 + 7:
return None
data_size = struct.unpack('>H', data[1:3])[0]
ogn_data = data
data = data[:data_size]
crc = binascii.crc32(data) & 0xffffffff
if crc != 0xffffffff:
logging.warn('uncorrect CRC32, maybe wrong password or '
'encryption method')
return None
start_pos = 3 + ord(data[3])
data = data[start_pos:-4]
if data_size < len(ogn_data):
data += ogn_data[data_size:]
return data

def parse_header(data):

+ 9
- 0
shadowsocks/eventloop.py View File

@@ -98,6 +98,9 @@ class KqueueLoop(object):
self.unregister(fd)
self.register(fd, mode)

def close(self):
self.kqueue.close()


class SelectLoop(object):

@@ -135,6 +138,9 @@ class SelectLoop(object):
self.unregister(fd)
self.register(fd, mode)

def close(self):
pass


class EventLoop(object):
def __init__(self):
@@ -216,6 +222,9 @@ class EventLoop(object):
callback()
self._last_time = now

def __del__(self):
self._impl.close()


# from tornado
def errno_from_exception(e):

+ 286
- 0
shadowsocks/manager.py View File

@@ -0,0 +1,286 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright 2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

from __future__ import absolute_import, division, print_function, \
with_statement

import errno
import traceback
import socket
import logging
import json
import collections

from shadowsocks import common, eventloop, tcprelay, udprelay, asyncdns, shell


BUF_SIZE = 1506
STAT_SEND_LIMIT = 100


class Manager(object):

def __init__(self, config):
self._config = config
self._relays = {} # (tcprelay, udprelay)
self._loop = eventloop.EventLoop()
self._dns_resolver = asyncdns.DNSResolver()
self._dns_resolver.add_to_loop(self._loop)

self._statistics = collections.defaultdict(int)
self._control_client_addr = None
try:
manager_address = config['manager_address']
if ':' in manager_address:
addr = manager_address.rsplit(':', 1)
addr = addr[0], int(addr[1])
addrs = socket.getaddrinfo(addr[0], addr[1])
if addrs:
family = addrs[0][0]
else:
logging.error('invalid address: %s', manager_address)
exit(1)
else:
addr = manager_address
family = socket.AF_UNIX
self._control_socket = socket.socket(family,
socket.SOCK_DGRAM)
self._control_socket.bind(addr)
self._control_socket.setblocking(False)
except (OSError, IOError) as e:
logging.error(e)
logging.error('can not bind to manager address')
exit(1)
self._loop.add(self._control_socket,
eventloop.POLL_IN, self)
self._loop.add_periodic(self.handle_periodic)

port_password = config['port_password']
del config['port_password']
for port, password in port_password.items():
a_config = config.copy()
a_config['server_port'] = int(port)
a_config['password'] = password
self.add_port(a_config)

def add_port(self, config):
port = int(config['server_port'])
servers = self._relays.get(port, None)
if servers:
logging.error("server already exists at %s:%d" % (config['server'],
port))
return
logging.info("adding server at %s:%d" % (config['server'], port))
t = tcprelay.TCPRelay(config, self._dns_resolver, False,
self.stat_callback)
u = udprelay.UDPRelay(config, self._dns_resolver, False,
self.stat_callback)
t.add_to_loop(self._loop)
u.add_to_loop(self._loop)
self._relays[port] = (t, u)

def remove_port(self, config):
port = int(config['server_port'])
servers = self._relays.get(port, None)
if servers:
logging.info("removing server at %s:%d" % (config['server'], port))
t, u = servers
t.close(next_tick=False)
u.close(next_tick=False)
del self._relays[port]
else:
logging.error("server not exist at %s:%d" % (config['server'],
port))

def handle_event(self, sock, fd, event):
if sock == self._control_socket and event == eventloop.POLL_IN:
data, self._control_client_addr = sock.recvfrom(BUF_SIZE)
parsed = self._parse_command(data)
if parsed:
command, config = parsed
a_config = self._config.copy()
if config:
# let the command override the configuration file
a_config.update(config)
if 'server_port' not in a_config:
logging.error('can not find server_port in config')
else:
if command == 'add':
self.add_port(a_config)
self._send_control_data(b'ok')
elif command == 'remove':
self.remove_port(a_config)
self._send_control_data(b'ok')
elif command == 'ping':
self._send_control_data(b'pong')
else:
logging.error('unknown command %s', command)

def _parse_command(self, data):
# commands:
# add: {"server_port": 8000, "password": "foobar"}
# remove: {"server_port": 8000"}
data = common.to_str(data)
parts = data.split(':', 1)
if len(parts) < 2:
return data, None
command, config_json = parts
try:
config = shell.parse_json_in_str(config_json)
return command, config
except Exception as e:
logging.error(e)
return None

def stat_callback(self, port, data_len):
self._statistics[port] += data_len

def handle_periodic(self):
r = {}
i = 0

def send_data(data_dict):
if data_dict:
# use compact JSON format (without space)
data = common.to_bytes(json.dumps(data_dict,
separators=(',', ':')))
self._send_control_data(b'stat: ' + data)

for k, v in self._statistics.items():
r[k] = v
i += 1
# split the data into segments that fit in UDP packets
if i >= STAT_SEND_LIMIT:
send_data(r)
r.clear()
send_data(r)
self._statistics.clear()

def _send_control_data(self, data):
if self._control_client_addr:
try:
self._control_socket.sendto(data, self._control_client_addr)
except (socket.error, OSError, IOError) as e:
error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS,
errno.EWOULDBLOCK):
return
else:
shell.print_exception(e)
if self._config['verbose']:
traceback.print_exc()

def run(self):
self._loop.run()


def run(config):
Manager(config).run()


def test():
import time
import threading
import struct
from shadowsocks import encrypt

logging.basicConfig(level=5,
format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
enc = []
eventloop.TIMEOUT_PRECISION = 1

def run_server():
config = {
'server': '127.0.0.1',
'local_port': 1081,
'port_password': {
'8381': 'foobar1',
'8382': 'foobar2'
},
'method': 'aes-256-cfb',
'manager_address': '127.0.0.1:6001',
'timeout': 60,
'fast_open': False,
'verbose': 2
}
manager = Manager(config)
enc.append(manager)
manager.run()

t = threading.Thread(target=run_server)
t.start()
time.sleep(1)
manager = enc[0]
cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
cli.connect(('127.0.0.1', 6001))

# test add and remove
time.sleep(1)
cli.send(b'add: {"server_port":7001, "password":"asdfadsfasdf"}')
time.sleep(1)
assert 7001 in manager._relays
data, addr = cli.recvfrom(1506)
assert b'ok' in data

cli.send(b'remove: {"server_port":8381}')
time.sleep(1)
assert 8381 not in manager._relays
data, addr = cli.recvfrom(1506)
assert b'ok' in data
logging.info('add and remove test passed')

# test statistics for TCP
header = common.pack_addr(b'google.com') + struct.pack('>H', 80)
data = encrypt.encrypt_all(b'asdfadsfasdf', 'aes-256-cfb', 1,
header + b'GET /\r\n\r\n')
tcp_cli = socket.socket()
tcp_cli.connect(('127.0.0.1', 7001))
tcp_cli.send(data)
tcp_cli.recv(4096)
tcp_cli.close()

data, addr = cli.recvfrom(1506)
data = common.to_str(data)
assert data.startswith('stat: ')
data = data.split('stat:')[1]
stats = shell.parse_json_in_str(data)
assert '7001' in stats
logging.info('TCP statistics test passed')

# test statistics for UDP
header = common.pack_addr(b'127.0.0.1') + struct.pack('>H', 80)
data = encrypt.encrypt_all(b'foobar2', 'aes-256-cfb', 1,
header + b'test')
udp_cli = socket.socket(type=socket.SOCK_DGRAM)
udp_cli.sendto(data, ('127.0.0.1', 8382))
tcp_cli.close()

data, addr = cli.recvfrom(1506)
data = common.to_str(data)
assert data.startswith('stat: ')
data = data.split('stat:')[1]
stats = json.loads(data)
assert '8382' in stats
logging.info('UDP statistics test passed')

manager._loop.stop()
t.join()


if __name__ == '__main__':
test()

+ 10
- 2
shadowsocks/server.py View File

@@ -24,7 +24,8 @@ import logging
import signal

sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../'))
from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, asyncdns
from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, \
asyncdns, manager


def main():
@@ -48,10 +49,17 @@ def main():
else:
config['port_password'][str(server_port)] = config['password']

if config.get('manager_address', 0):
logging.info('entering manager mode')
manager.run(config)
return

tcp_servers = []
udp_servers = []
dns_resolver = asyncdns.DNSResolver()
for port, password in config['port_password'].items():
port_password = config['port_password']
del config['port_password']
for port, password in port_password.items():
a_config = config.copy()
ipv6_ok = False
logging.info("server start with password [%s] method [%s]" % (password, a_config['method']))

+ 10
- 3
shadowsocks/shell.py View File

@@ -136,7 +136,7 @@ def get_config(is_local):
else:
shortopts = 'hd:s:p:k:m:c:t:vq'
longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'workers=',
'forbidden-ip=', 'user=', 'version']
'forbidden-ip=', 'user=', 'manager-address=', 'version']
try:
config_path = find_config()
optlist, args = getopt.getopt(sys.argv[1:], shortopts, longopts)
@@ -148,8 +148,7 @@ def get_config(is_local):
logging.info('loading config from %s' % config_path)
with open(config_path, 'rb') as f:
try:
config = json.loads(f.read().decode('utf8'),
object_hook=_decode_dict)
config = parse_json_in_str(f.read().decode('utf8'))
except ValueError as e:
logging.error('found an error in config.json: %s',
e.message)
@@ -181,6 +180,8 @@ def get_config(is_local):
config['fast_open'] = True
elif key == '--workers':
config['workers'] = int(value)
elif key == '--manager-address':
config['manager_address'] = value
elif key == '--user':
config['user'] = to_str(value)
elif key == '--forbidden-ip':
@@ -317,6 +318,7 @@ Proxy options:
--fast-open use TCP_FASTOPEN, requires Linux 3.7+
--workers WORKERS number of workers, available on Unix/Linux
--forbidden-ip IPLIST comma seperated IP list forbidden to connect
--manager-address ADDR optional server manager UDP address, see wiki

General options:
-h, --help show this help message and exit
@@ -356,3 +358,8 @@ def _decode_dict(data):
value = _decode_dict(value)
rv[key] = value
return rv


def parse_json_in_str(data):
# parse json and convert everything from unicode to str
return json.loads(data, object_hook=_decode_dict)

+ 15
- 12
shadowsocks/tcprelay.py View File

@@ -23,6 +23,7 @@ import socket
import errno
import struct
import logging
import binascii
import traceback
import random

@@ -32,9 +33,6 @@ from shadowsocks.common import pre_parse_header, parse_header
# we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time
TIMEOUTS_CLEAN_SIZE = 512

# we check timeouts every TIMEOUT_PRECISION seconds
TIMEOUT_PRECISION = 4

MSG_FASTOPEN = 0x20000000

# SOCKS command definition
@@ -153,10 +151,10 @@ class TCPRelayHandler(object):
logging.debug('chosen server: %s:%d', server, server_port)
return server, server_port

def _update_activity(self):
def _update_activity(self, data_len=0):
# tell the TCP Relay we have activities recently
# else it will think we are inactive and timed out
self._server.update_activity(self)
self._server.update_activity(self, data_len)

def _update_stream(self, stream, status):
# update a stream to a new waiting status
@@ -343,6 +341,8 @@ class TCPRelayHandler(object):
logging.error('unknown command %d', cmd)
self.destroy()
return
if False and ord(data[0]) != 0x88: # force new header
raise Exception('can not parse header')
data = pre_parse_header(data)
if data is None:
raise Exception('can not parse header')
@@ -379,7 +379,6 @@ class TCPRelayHandler(object):
self._log_error(e)
if self._config['verbose']:
traceback.print_exc()
# TODO use logging when debug completed
self.destroy()

def _create_remote_socket(self, ip, port):
@@ -397,7 +396,6 @@ class TCPRelayHandler(object):
common.to_str(sa[0]))
remote_sock = socket.socket(af, socktype, proto)
self._remote_sock = remote_sock

self._fd_to_handlers[remote_sock.fileno()] = self

if self._remote_udp:
@@ -410,7 +408,6 @@ class TCPRelayHandler(object):
remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32)
remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32)


remote_sock.setblocking(False)
if self._remote_udp:
pass
@@ -483,7 +480,6 @@ class TCPRelayHandler(object):
def _on_local_read(self):
# handle all local read events and dispatch them to methods for
# each stage
self._update_activity()
if not self._local_sock:
return
is_local = self._is_local
@@ -497,6 +493,7 @@ class TCPRelayHandler(object):
if not data:
self.destroy()
return
self._update_activity(len(data))
if not is_local:
data = self._encryptor.decrypt(data)
if not data:
@@ -520,7 +517,6 @@ class TCPRelayHandler(object):

def _on_remote_read(self, is_remote_sock):
# handle all remote read events
self._update_activity()
data = None
try:
if self._remote_udp:
@@ -547,6 +543,7 @@ class TCPRelayHandler(object):
self.destroy()
return
self._server.server_transfer_dl += len(data)
self._update_activity(len(data))
if self._is_local:
data = self._encryptor.decrypt(data)
else:
@@ -667,7 +664,7 @@ class TCPRelayHandler(object):


class TCPRelay(object):
def __init__(self, config, dns_resolver, is_local):
def __init__(self, config, dns_resolver, is_local, stat_callback=None):
self._config = config
self._is_local = is_local
self._dns_resolver = dns_resolver
@@ -709,6 +706,7 @@ class TCPRelay(object):
self._config['fast_open'] = False
server_socket.listen(1024)
self._server_socket = server_socket
self._stat_callback = stat_callback

def add_to_loop(self, loop):
if self._eventloop:
@@ -727,7 +725,10 @@ class TCPRelay(object):
self._timeouts[index] = None
del self._handler_to_timeouts[hash(handler)]

def update_activity(self, handler):
def update_activity(self, handler, data_len):
if data_len and self._stat_callback:
self._stat_callback(self._listen_port, data_len)

# set handler to active
now = int(time.time())
if now - handler.last_activity < eventloop.TIMEOUT_PRECISION:
@@ -828,3 +829,5 @@ class TCPRelay(object):
self._eventloop.remove_periodic(self.handle_periodic)
self._eventloop.remove(self._server_socket)
self._server_socket.close()
for handler in list(self._fd_to_handlers.values()):
handler.destroy()

+ 90
- 55
shadowsocks/udprelay.py View File

@@ -77,9 +77,6 @@ from shadowsocks.common import pre_parse_header, parse_header, pack_addr
# we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time
TIMEOUTS_CLEAN_SIZE = 512

# we check timeouts every TIMEOUT_PRECISION seconds
TIMEOUT_PRECISION = 4

# for each handler, we have 2 stream directions:
# upstream: from client to server direction
# read local and write to remote
@@ -97,8 +94,9 @@ WAIT_STATUS_READWRITING = WAIT_STATUS_READING | WAIT_STATUS_WRITING

BUF_SIZE = 65536
DOUBLE_SEND_BEG_IDS = 16
POST_MTU_MIN = 1000
POST_MTU_MIN = 500
POST_MTU_MAX = 1400
SENDING_WINDOW_SIZE = 8192

STAGE_INIT = 0
STAGE_RSP_ID = 1
@@ -119,6 +117,14 @@ CMD_DISCONNECT = 8

CMD_VER_STR = "\x08"

RSP_STATE_EMPTY = ""
RSP_STATE_REJECT = "\x00"
RSP_STATE_CONNECTED = "\x01"
RSP_STATE_CONNECTEDREMOTE = "\x02"
RSP_STATE_ERROR = "\x03"
RSP_STATE_DISCONNECT = "\x04"
RSP_STATE_REDIRECT = "\x05"

class UDPLocalAddress(object):
def __init__(self, addr):
self.addr = addr
@@ -173,9 +179,6 @@ class SendingQueue(object):
while self.begin_id < begin_id:
self.begin_id += 1
del self.queue[self.begin_id]
#while len(self.queue) > 0 and self.queue[0][0] <= begin_id:
# del self.queue[0]
# self.begin_id += 1

class RecvQueue(object):
def __init__(self):
@@ -229,6 +232,38 @@ class RecvQueue(object):
missing.append(i - begin_id)
return (begin_id, missing)

class AddressMap(object):
def __init__(self):
self._queue = []
self._addr_map = {}

def add(self, addr):
if addr in self._addr_map:
self._addr_map[addr] = UDPLocalAddress(addr)
else:
self._addr_map[addr] = UDPLocalAddress(addr)
self._queue.append(addr)

def keys(self):
return self._queue

def get(self):
if self._queue:
while True:
if len(self._queue) == 1:
return self._queue[0]
index = random.randint(0, len(self._queue) - 1)
addr = self._queue[index]
if self._addr_map[addr].is_timeout():
self._queue[index] = self._queue[len(self._queue) - 1]
del self._queue[len(self._queue) - 1]
del self._addr_map[addr]
else:
break
return addr
else:
return None

class TCPRelayHandler(object):
def __init__(self, server, reqid_to_handlers, fd_to_handlers, loop,
local_sock, local_id, client_param, config,
@@ -254,7 +289,7 @@ class TCPRelayHandler(object):
self._upstream_status = WAIT_STATUS_READING
self._downstream_status = WAIT_STATUS_INIT
self._request_id = 0
self._client_address = {}
self._client_address = AddressMap()
self._remote_address = None
self._sendingqueue = SendingQueue()
self._recvqueue = RecvQueue()
@@ -282,7 +317,10 @@ class TCPRelayHandler(object):
return self._remote_address

def add_local_address(self, addr):
self._client_address[addr] = UDPLocalAddress(addr)
self._client_address.add(addr)

def get_local_address(self):
return self._client_address.get()

def _update_activity(self):
# tell the TCP Relay we have activities recently
@@ -367,8 +405,6 @@ class TCPRelayHandler(object):
return False
if uncomplete:
if sock == self._local_sock:
#if data is not None and retry < 10:
# self._data_to_write_to_local.append([(data, addr), retry])
self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING)
elif sock == self._remote_sock:
self._data_to_write_to_remote.append(data)
@@ -377,15 +413,12 @@ class TCPRelayHandler(object):
logging.error('write_all_to_sock:unknown socket')
else:
if sock == self._local_sock:
if self._sendingqueue.size() > 8192:
if self._sendingqueue.size() > SENDING_WINDOW_SIZE:
self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING)
else:
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
elif sock == self._remote_sock:
if self._sendingqueue.size() > 8192:
self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING)
else:
self._update_stream(STREAM_UP, WAIT_STATUS_READING)
self._update_stream(STREAM_UP, WAIT_STATUS_READING)
else:
logging.error('write_all_to_sock:unknown socket')
return True
@@ -439,12 +472,10 @@ class TCPRelayHandler(object):
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
self._stage = STAGE_STREAM

for it_addr in self._client_address:
addr = it_addr
break
addr = self.get_local_address()

for i in xrange(2):
rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, "\x02")
rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, RSP_STATE_CONNECTEDREMOTE)
self._write_to_sock(rsp_data, self._local_sock, addr)

return
@@ -508,13 +539,11 @@ class TCPRelayHandler(object):

pack_id = self._sendingqueue.append(data)
post_data = self._pack_post_data(CMD_POST, pack_id, data)
for it_addr in self._client_address:
addr = it_addr
break
addr = self.get_local_address()
self._write_to_sock(post_data, self._local_sock, addr)
#if pack_id <= DOUBLE_SEND_BEG_IDS:
# post_data = self._pack_post_data(CMD_POST, pack_id, data)
# self._write_to_sock(post_data, self._local_sock, addr)
if pack_id <= DOUBLE_SEND_BEG_IDS:
post_data = self._pack_post_data(CMD_POST, pack_id, data)
self._write_to_sock(post_data, self._local_sock, addr)

except Exception as e:
shell.print_exception(e)
@@ -620,14 +649,14 @@ class TCPRelayHandler(object):
for post_pack_id, post_data in send_list:
rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data)
self._write_to_sock(rsp_data, self._local_sock, addr)
#if post_pack_id <= DOUBLE_SEND_BEG_IDS:
# rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data)
# self._write_to_sock(rsp_data, self._local_sock, addr)
if post_pack_id <= DOUBLE_SEND_BEG_IDS:
rsp_data = self._pack_post_data(CMD_POST, post_pack_id, post_data)
self._write_to_sock(rsp_data, self._local_sock, addr)

def handle_client(self, addr, cmd, request_id, data):
self.add_local_address(addr)
if cmd == CMD_DISCONNECT:
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "")
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
self._write_to_sock(rsp_data, self._local_sock, addr)
self.destroy()
self.destroy_local()
@@ -643,7 +672,7 @@ class TCPRelayHandler(object):
if self._stage == STAGE_RSP_ID:
if cmd == CMD_CONNECT:
for i in xrange(2):
rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, "\x01")
rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, RSP_STATE_CONNECTED)
self._write_to_sock(rsp_data, self._local_sock, addr)
elif cmd == CMD_CONNECT_REMOTE:
local_id = data[0:4]
@@ -660,35 +689,35 @@ class TCPRelayHandler(object):
logging.info('TCP connect %s:%d from %s:%d' % (remote_addr, remote_port, addr[0], addr[1]))
else:
# ileagal request
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "")
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
self._write_to_sock(rsp_data, self._local_sock, addr)
elif self._stage == STAGE_CONNECTING:
if cmd == CMD_CONNECT_REMOTE:
local_id = data[0:4]
if self._local_id == local_id:
for i in xrange(2):
rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, "\x02")
rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, RSP_STATE_CONNECTEDREMOTE)
self._write_to_sock(rsp_data, self._local_sock, addr)
else:
# ileagal request
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "")
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
self._write_to_sock(rsp_data, self._local_sock, addr)
elif self._stage == STAGE_STREAM:
if len(data) < 4:
# ileagal request
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "")
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
self._write_to_sock(rsp_data, self._local_sock, addr)
return
local_id = data[0:4]
if self._local_id != local_id:
# ileagal request
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "")
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
self._write_to_sock(rsp_data, self._local_sock, addr)
return
else:
data = data[4:]
if cmd == CMD_CONNECT_REMOTE:
rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, "\x02")
rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT_REMOTE, RSP_STATE_CONNECTEDREMOTE)
self._write_to_sock(rsp_data, self._local_sock, addr)
elif cmd == CMD_POST:
recv_id = struct.unpack(">I", data[0:4])[0]
@@ -701,7 +730,7 @@ class TCPRelayHandler(object):
self._recvqueue.insert(pack_id, data[16:])
self._sendingqueue.set_finish(recv_id, [])
elif cmd == CMD_DISCONNECT:
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "")
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
self._write_to_sock(rsp_data, self._local_sock, addr)
self.destroy()
self.destroy_local()
@@ -723,7 +752,7 @@ class TCPRelayHandler(object):
local_id = data[0:4]
if self._local_id != local_id:
# ileagal request
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "")
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
self._write_to_sock(rsp_data, self._local_sock, addr)
return
else:
@@ -732,13 +761,11 @@ class TCPRelayHandler(object):
pack_id = struct.unpack(">I", data[0:4])[0]
max_send_id = struct.unpack(">I", data[4:8])[0]
data = data[8:]
logging.info('handle_client STAGE_DESTROYED send %d %d' % (request_id, pack_id))
self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data)
elif cmd == CMD_SYN_STATUS_64:
pack_id = struct.unpack(">Q", data[0:8])[0]
max_send_id = struct.unpack(">Q", data[8:16])[0]
data = data[16:]
logging.info('handle_client STAGE_DESTROYED send %d %d' % (request_id, pack_id))
self.handle_stream_sync_status(addr, cmd, request_id, pack_id, max_send_id, data)

def handle_event(self, sock, event):
@@ -808,11 +835,9 @@ class TCPRelayHandler(object):
def destroy_local(self):
if self._local_sock:
logging.debug('disconnect local')
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, "")
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, RSP_STATE_EMPTY)
addr = None
for it_addr in self._client_address:
addr = it_addr
break
addr = self.get_local_address()
self._write_to_sock(rsp_data, self._local_sock, addr)
self._local_sock = None
del self._reqid_to_handlers[self._request_id]
@@ -822,8 +847,9 @@ def client_key(source_addr, server_af):
# notice this is server af, not dest af
return '%s:%s:%d' % (source_addr[0], source_addr[1], server_af)


class UDPRelay(object):
def __init__(self, config, dns_resolver, is_local):
def __init__(self, config, dns_resolver, is_local, stat_callback=None):
self._config = config
if is_local:
self._listen_addr = config['local_address']
@@ -836,7 +862,7 @@ class UDPRelay(object):
self._remote_addr = None
self._remote_port = None
self._dns_resolver = dns_resolver
self._password = config['password']
self._password = common.to_bytes(config['password'])
self._method = config['method']
self._timeout = config['timeout']
self._is_local = is_local
@@ -877,6 +903,7 @@ class UDPRelay(object):
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32)
self._server_socket = server_socket
self._stat_callback = stat_callback

def _get_a_server(self):
server = self._config['server']
@@ -937,6 +964,8 @@ class UDPRelay(object):
data, r_addr = server.recvfrom(BUF_SIZE)
if not data:
logging.debug('UDP handle_server: data is empty')
if self._stat_callback:
self._stat_callback(self._listen_port, len(data))
if self._is_local:
frag = common.ord(data[2])
if frag != 0:
@@ -976,7 +1005,7 @@ class UDPRelay(object):
break
# return req id
self._reqid_to_hd[req_id] = (data[2][0:4], None)
rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, req_id, "\x01")
rsp_data = self._pack_rsp_data(CMD_RSP_CONNECT, req_id, RSP_STATE_CONNECTED)
data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data)
self.write_to_server_socket(data_to_send, r_addr)
elif data[0] == CMD_CONNECT_REMOTE:
@@ -994,7 +1023,7 @@ class UDPRelay(object):
self.update_activity(handle)
else:
# disconnect
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], "")
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], RSP_STATE_EMPTY)
data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data)
self.write_to_server_socket(data_to_send, r_addr)
else:
@@ -1002,16 +1031,19 @@ class UDPRelay(object):
self._reqid_to_hd[data[1]].handle_client(r_addr, *data)
else:
# disconnect
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], "")
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], RSP_STATE_EMPTY)
data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data)
self.write_to_server_socket(data_to_send, r_addr)
elif data[0] > CMD_CONNECT_REMOTE and data[0] <= CMD_DISCONNECT:
if data[1] in self._reqid_to_hd:
self.update_activity(self._reqid_to_hd[data[1]])
self._reqid_to_hd[data[1]].handle_client(r_addr, *data)
if type(self._reqid_to_hd[data[1]]) is tuple:
pass
else:
self.update_activity(self._reqid_to_hd[data[1]])
self._reqid_to_hd[data[1]].handle_client(r_addr, *data)
else:
# disconnect
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], "")
rsp_data = self._pack_rsp_data(CMD_DISCONNECT, data[1], RSP_STATE_EMPTY)
data_to_send = encrypt.encrypt_all(self._password, self._method, 1, rsp_data)
self.write_to_server_socket(data_to_send, r_addr)
return
@@ -1042,7 +1074,6 @@ class UDPRelay(object):

af, socktype, proto, canonname, sa = addrs[0]
key = client_key(r_addr, af)
logging.debug(key)
client = self._cache.get(key, None)
if not client:
# TODO async getaddrinfo
@@ -1083,6 +1114,8 @@ class UDPRelay(object):
if not data:
logging.debug('UDP handle_client: data is empty')
return
if self._stat_callback:
self._stat_callback(self._listen_port, len(data))
if not self._is_local:
addrlen = len(r_addr[0])
if addrlen > 255:
@@ -1101,7 +1134,7 @@ class UDPRelay(object):
header_result = parse_header(data)
if header_result is None:
return
connecttype, dest_addr, dest_port, header_length = header_result
#connecttype, dest_addr, dest_port, header_length = header_result
#logging.debug('UDP handle_client %s:%d to %s:%d' % (common.to_str(r_addr[0]), r_addr[1], dest_addr, dest_port))

response = b'\x00\x00\x00' + data
@@ -1250,3 +1283,5 @@ class UDPRelay(object):
self._eventloop.remove_periodic(self.handle_periodic)
self._eventloop.remove(self._server_socket)
self._server_socket.close()
for client in list(self._cache.values()):
client.close()

Loading…
Cancel
Save