Redis-to-HTTP proxy https://rpjios.com
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

487 lines
12 KiB

  1. package main
  2. import (
  3. "encoding/base64"
  4. "encoding/json"
  5. "flag"
  6. "fmt"
  7. "log"
  8. "net"
  9. "net/http"
  10. "os"
  11. "path/filepath"
  12. "plugin"
  13. "reflect"
  14. "strconv"
  15. "strings"
  16. "sync"
  17. "time"
  18. "github.com/go-redis/redis/v7"
  19. "github.com/google/uuid"
  20. "github.com/gorilla/websocket"
  21. )
  22. const defaultRedisHost = "localhost"
  23. const defaultRedisPort = 6379
  24. const defaultListenHost = "localhost"
  25. const defaultListenPort = 56545
  26. const defaultUsersFile = "./users.json"
  27. const defaultPluginsPath = "./build/plugins"
  28. var loadedPlugins rhpPluginsT
  29. var g_usersFile string
  30. var usersMap map[string]string = nil
  31. var usersMapLock sync.Mutex = sync.Mutex{}
  32. var redisDefaultClient *redis.Client = nil
  33. var redisOptions = redis.Options{
  34. DB: 0,
  35. }
  36. type subscriber struct {
  37. Channel string
  38. Addr string
  39. User string
  40. }
  41. type subscribeHandler struct {
  42. Lock sync.Mutex
  43. Pending map[uuid.UUID]subscriber
  44. }
  45. var gSubscribeHandler *subscribeHandler = nil
  46. var wsUpgrader = websocket.Upgrader{
  47. ReadBufferSize: 1024,
  48. WriteBufferSize: 1024,
  49. CheckOrigin: func(_ *http.Request) bool { return true },
  50. }
  51. type wsClient struct {
  52. Conn *websocket.Conn
  53. Lock *sync.Mutex
  54. Sub *redis.PubSub
  55. }
  56. var wsClients = map[net.Addr]wsClient{}
  57. var wsClientsLock = sync.Mutex{}
  58. func checkAuth(req *http.Request) (string, error) {
  59. if authHeader, ok := req.Header["Authorization"]; ok {
  60. if len(authHeader) > 1 {
  61. log.Panic("too many headers!")
  62. }
  63. authComps := strings.Split(authHeader[0], " ")
  64. if len(authComps) != 2 || authComps[0] != "Basic" {
  65. return "", fmt.Errorf("bad authComps '%v'", authComps)
  66. }
  67. decAuthBytes, err := base64.StdEncoding.DecodeString(authComps[1])
  68. if err != nil {
  69. log.Println(authComps)
  70. return "", err
  71. }
  72. decComps := strings.Split(string(decAuthBytes), ":")
  73. if len(decComps) != 2 {
  74. return "", fmt.Errorf("bad decComps")
  75. }
  76. usersMapLock.Lock()
  77. defer usersMapLock.Unlock()
  78. if storedPwd, okUser := usersMap[decComps[0]]; okUser {
  79. if storedPwd == decComps[1] {
  80. return decComps[0], nil
  81. } else {
  82. return "", fmt.Errorf("bad pwd")
  83. }
  84. }
  85. return "", fmt.Errorf("bad user")
  86. }
  87. log.Println(req.Header)
  88. log.Println(req.Method)
  89. return "", fmt.Errorf("bad auth")
  90. }
  91. func (sh *subscribeHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
  92. w.Header().Set("Access-Control-Allow-Origin", "*")
  93. if req.Method == "OPTIONS" {
  94. w.Header().Set("Access-Control-Allow-Headers", "authorization")
  95. w.WriteHeader(http.StatusOK)
  96. return
  97. }
  98. authedUser, err := checkAuth(req)
  99. if err != nil {
  100. log.Printf("auth err: %v\n", err)
  101. w.WriteHeader(http.StatusBadRequest)
  102. return
  103. }
  104. dir, file := filepath.Split(req.URL.Path)
  105. var respStr string = ""
  106. if dir == "/sub/" {
  107. newSubID := uuid.New()
  108. newSub := subscriber{file, req.RemoteAddr, authedUser}
  109. sh.Lock.Lock()
  110. if sh.Pending == nil {
  111. sh.Pending = map[uuid.UUID]subscriber{}
  112. }
  113. sh.Pending[newSubID] = newSub
  114. sh.Lock.Unlock()
  115. log.Printf("new sub %v\n", newSub)
  116. respStr = newSubID.String()
  117. } else if dir == "/list/" && file != "" {
  118. query := req.URL.Query()
  119. listLookup := func(start int64, end int64) ([]string, error) {
  120. return redisDefaultClient.LRange(file, start, end).Result()
  121. }
  122. // allow plugins the opprotunity to handle the request before using the default handler
  123. // only one can handle any given request, so the first to do so affirmatively ends the request
  124. loadedPlugins.Lock.Lock()
  125. for name, plugin := range loadedPlugins.List {
  126. strResp, err := plugin.HandleListReq(dir, file, query, listLookup)
  127. if err == nil && len(strResp) > 0 {
  128. respStr = strResp
  129. log.Printf("using response for '%v%v' produced by plugin '%v'\n", dir, file, name)
  130. break
  131. }
  132. }
  133. loadedPlugins.Lock.Unlock()
  134. // default handler
  135. if respStr == "" {
  136. start := int64(0)
  137. end := int64(10)
  138. query := req.URL.Query()
  139. var err error = nil
  140. log.Printf("LIST -- %v -- %v\n", file, query)
  141. if startSpec, ok := query["start"]; ok {
  142. start, err = strconv.ParseInt(startSpec[0], 10, 64)
  143. }
  144. if err == nil && start >= 0 {
  145. if endSpec, ok := query["end"]; ok {
  146. end, err = strconv.ParseInt(endSpec[0], 10, 64)
  147. }
  148. if err == nil && end > start {
  149. lRes := redisDefaultClient.LRange(file, start, end)
  150. listRes, err := lRes.Result()
  151. if err == nil {
  152. listStr, err := json.Marshal(listRes)
  153. if err == nil {
  154. respStr = string(listStr)
  155. }
  156. }
  157. }
  158. }
  159. }
  160. }
  161. if respStr != "" {
  162. w.WriteHeader(http.StatusOK)
  163. fmt.Fprintf(w, respStr)
  164. } else {
  165. log.Printf("BAD REQ: %v\n", req)
  166. w.WriteHeader(http.StatusBadRequest)
  167. }
  168. }
  169. func readUntilClose(c *websocket.Conn) {
  170. for {
  171. if _, _, err := c.NextReader(); err != nil {
  172. remoteAddr := c.RemoteAddr()
  173. log.Printf("ws client %v disconnected\n", remoteAddr)
  174. c.Close()
  175. wsClientsLock.Lock()
  176. defer wsClientsLock.Unlock()
  177. wsClients[remoteAddr].Sub.Close()
  178. delete(wsClients, remoteAddr)
  179. break
  180. }
  181. }
  182. }
  183. func forwardAllOnto(wsc wsClient) {
  184. for fwd := range wsc.Sub.Channel() {
  185. payload := interface{}(fwd.Payload)
  186. var err error
  187. loadedPlugins.Lock.Lock()
  188. for pluginName, plugin := range loadedPlugins.List {
  189. payload, err = plugin.HandleMsg(payload)
  190. if err != nil {
  191. log.Panic(pluginName)
  192. }
  193. }
  194. loadedPlugins.Lock.Unlock()
  195. go func() {
  196. wsc.Lock.Lock()
  197. defer wsc.Lock.Unlock()
  198. wsc.Conn.WriteJSON(payload)
  199. }()
  200. }
  201. }
  202. func registerNewClient(wsConn *websocket.Conn, channel string) {
  203. clientAddr := wsConn.RemoteAddr()
  204. wsClientsLock.Lock()
  205. defer wsClientsLock.Unlock()
  206. if curClient, ok := wsClients[clientAddr]; ok {
  207. log.Printf("already have conn for %v! closing it\n", clientAddr)
  208. curClient.Lock.Lock()
  209. curClient.Conn.Close()
  210. curClient.Lock.Unlock()
  211. }
  212. wsClients[clientAddr] = wsClient{wsConn, new(sync.Mutex), redisDefaultClient.Subscribe(channel)}
  213. log.Printf("ws client %v connected\n", clientAddr)
  214. go readUntilClose(wsConn)
  215. go forwardAllOnto(wsClients[clientAddr])
  216. }
  217. func refreshHandler(w http.ResponseWriter, req *http.Request) {
  218. if req.Method != "OPTIONS" {
  219. return
  220. }
  221. if rhpAuthHeader, ok := req.Header["X-Rhp-Auth"]; ok {
  222. expectAuth := fmt.Sprintf("%d", time.Now().Unix()/10)
  223. if expectAuth == rhpAuthHeader[0] {
  224. log.Printf("valid refresh request from %v, running\n", req.RemoteAddr)
  225. loadUsers()
  226. }
  227. }
  228. }
  229. func websocketHandler(w http.ResponseWriter, req *http.Request) {
  230. wsConn, err := wsUpgrader.Upgrade(w, req, nil)
  231. if err != nil {
  232. log.Printf("websocketHandler upgrade failed: %v\n", err)
  233. return
  234. }
  235. okReqUUID, err := uuid.Parse(req.URL.RawQuery)
  236. if err != nil {
  237. log.Printf("bad ws query '%s'\n", req.URL.RawQuery)
  238. log.Println(req)
  239. return
  240. }
  241. gSubscribeHandler.Lock.Lock()
  242. defer gSubscribeHandler.Lock.Unlock()
  243. if pendingConn, ok := gSubscribeHandler.Pending[okReqUUID]; ok {
  244. if strings.Split(wsConn.RemoteAddr().String(), ":")[0] == strings.Split(pendingConn.Addr, ":")[0] {
  245. go registerNewClient(wsConn, pendingConn.Channel)
  246. delete(gSubscribeHandler.Pending, okReqUUID)
  247. } else {
  248. log.Printf("bad addr match %s vs %s\n", wsConn.RemoteAddr(), pendingConn.Addr)
  249. }
  250. } else {
  251. log.Printf("bad pending connection '%v'\n", okReqUUID)
  252. }
  253. }
  254. func parseJSON(path string, intoObj interface{}) error {
  255. file, err := os.Open(path)
  256. if err != nil {
  257. fmt.Fprintf(os.Stderr, "parseJSON unable to open '%s': %v\n", path, err)
  258. return err
  259. }
  260. defer file.Close()
  261. dec := json.NewDecoder(file)
  262. err = dec.Decode(intoObj)
  263. if err != nil {
  264. fmt.Fprintf(os.Stderr, "parseJSON failed to decode: %v\n", err)
  265. return err
  266. }
  267. return nil
  268. }
  269. func loadPlugin(path string) (*rhpPluginImpl, error) {
  270. ifaceType := reflect.TypeOf((*RhpPlugin)(nil)).Elem()
  271. pluginLoad, err := plugin.Open(path)
  272. if err != nil {
  273. return nil, err
  274. }
  275. // without stubbing the fields, reflect.ValueOf(...).Elem().FieldByName(...) below will return nil
  276. newPlugin := newRhpPluginImpl()
  277. // for each method declared in the interface, look for the same-named concrete defintion
  278. // in the loaded plugin. if that exists, find the field in the concrete implementation
  279. // instance (rhpPluginImpl) and set each function pointer accordingly
  280. for i := 0; i < ifaceType.NumMethod(); i++ {
  281. methodName := ifaceType.Method(i).Name
  282. pluginMethod, err := pluginLoad.Lookup(methodName)
  283. if err != nil {
  284. return nil, err
  285. }
  286. implValue := reflect.ValueOf(&newPlugin).Elem()
  287. if implValue.IsZero() {
  288. return nil, fmt.Errorf("unable to get value of concrete impl")
  289. }
  290. implElem := implValue.FieldByName(methodName)
  291. if implElem.IsZero() {
  292. return nil, fmt.Errorf("unable to set value on concrete impl")
  293. }
  294. // must .Convert to the target type (implElem.Interface()), else will panic with a strangely-worded error:
  295. // "reflect.Set: value of type T is not assignable to type T"
  296. // (not a typo: the 'from' and 'to' types in the error message will be exactly the same, because
  297. // indeed if we've made it this far the types will match, hence why .Convert() succeeds!)
  298. implElem.Set(reflect.ValueOf(pluginMethod).Convert(reflect.TypeOf(implElem.Interface())))
  299. }
  300. return &newPlugin, nil
  301. }
  302. func loadPlugins(fromPath string) (rhpPluginMapT, error) {
  303. retVal := rhpPluginMapT{}
  304. // we never return a non-nil error from within the walk function so as to allow .Walk() to continue;
  305. // there is the special return filepath.SkipDir, but it will cause Walk to skip remaining files,
  306. // which isn't what we want either. only in the case that `err` is already non-nil do we return non-nil.
  307. err := filepath.Walk(filepath.ToSlash(fromPath), func(path string, info os.FileInfo, err error) error {
  308. if err != nil {
  309. log.Printf("filepath.Walk errored on entry: '%s' -> %v", path, err)
  310. return err
  311. }
  312. if filepath.Ext(path) != ".so" {
  313. return nil
  314. }
  315. log.Printf("found %s, checking for compatibility...", filepath.Base(path))
  316. newPlugin, err := loadPlugin(path)
  317. if err != nil {
  318. log.Printf("failed to load %s: %v", path, err)
  319. return nil
  320. }
  321. pBaseName := strings.Replace(filepath.Base(path), ".so", "", 1)
  322. log.Printf("loaded compatible plugin %s@%s", pBaseName, newPlugin.Version())
  323. retVal[pBaseName] = newPlugin
  324. return nil
  325. })
  326. return retVal, err
  327. }
  328. func loadUsers() {
  329. usersMapLock.Lock()
  330. defer usersMapLock.Unlock()
  331. err := parseJSON(g_usersFile, &usersMap)
  332. if err != nil {
  333. log.Panic(err.Error())
  334. }
  335. log.Printf("found %d valid users\n", len(usersMap))
  336. }
  337. func main() {
  338. listenPort := flag.Uint("port", defaultListenPort, "http listen port")
  339. listenHost := flag.String("listen", defaultListenHost, "http listen host")
  340. redisPort := flag.Uint("redis-port", defaultRedisPort, "redis server port")
  341. redisHost := flag.String("redis-host", defaultRedisHost, "redis server host")
  342. pluginsPath := flag.String("plugins", defaultPluginsPath, "plugins path")
  343. usersFile := flag.String("users", defaultUsersFile, "users JSON file")
  344. flag.Parse()
  345. if listenHost == nil || listenPort == nil || *listenPort < 1024 || *listenPort > 65535 {
  346. log.Panic("listen spec")
  347. }
  348. if redisPort == nil || redisHost == nil || *redisPort < 0 || *redisPort > 65535 {
  349. log.Panic("redis spec")
  350. }
  351. redisAuth := os.Getenv("REDIS_LOCAL_PWD")
  352. if len(redisAuth) == 0 {
  353. log.Panic("Need auth")
  354. }
  355. redisOptions.Addr = fmt.Sprintf("%s:%d", *redisHost, *redisPort)
  356. redisOptions.Password = redisAuth
  357. rc := redis.NewClient(&redisOptions)
  358. _, err := rc.Ping().Result()
  359. if err != nil {
  360. log.Panic("Ping")
  361. }
  362. log.Printf("connected to redis://%s\n", redisOptions.Addr)
  363. redisDefaultClient = rc
  364. g_usersFile = *usersFile
  365. loadUsers()
  366. loadedPlugins.Lock.Lock()
  367. loadedPlugins.List, err = loadPlugins(*pluginsPath)
  368. loadedPlugins.Lock.Unlock()
  369. if err != nil {
  370. log.Fatalf("plugin load failed: %v", err)
  371. }
  372. gSubscribeHandler = new(subscribeHandler)
  373. http.Handle("/sub/", gSubscribeHandler)
  374. http.Handle("/list/", gSubscribeHandler)
  375. http.HandleFunc("/ws/sub", websocketHandler)
  376. http.HandleFunc("/refresh", refreshHandler)
  377. listenSpec := fmt.Sprintf("%s:%d", *listenHost, *listenPort)
  378. log.Printf("listening on %s\n", listenSpec)
  379. http.ListenAndServe(listenSpec, nil)
  380. }