Redis-to-HTTP proxy https://rpjios.com
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

491 lines
12 KiB

  1. package main
  2. import (
  3. "encoding/base64"
  4. "encoding/json"
  5. "flag"
  6. "fmt"
  7. "log"
  8. "net"
  9. "net/http"
  10. "os"
  11. "path/filepath"
  12. "plugin"
  13. "reflect"
  14. "strconv"
  15. "strings"
  16. "sync"
  17. "time"
  18. "github.com/go-redis/redis/v7"
  19. "github.com/google/uuid"
  20. "github.com/gorilla/websocket"
  21. )
  22. const defaultRedisHost = "localhost"
  23. const defaultRedisPort = 6379
  24. const defaultListenHost = "localhost"
  25. const defaultListenPort = 56545
  26. const defaultUsersFile = "./users.json"
  27. const defaultPluginsPath = "./build/plugins"
  28. var loadedPlugins rhpPluginsT
  29. var g_usersFile string
  30. var usersMap map[string]string = nil
  31. var usersMapLock sync.Mutex = sync.Mutex{}
  32. var redisDefaultClient *redis.Client = nil
  33. var redisOptions = redis.Options{
  34. DB: 0,
  35. }
  36. type subscriber struct {
  37. Channel string
  38. Addr string
  39. User string
  40. }
  41. type subscribeHandler struct {
  42. Lock sync.Mutex
  43. Pending map[uuid.UUID]subscriber
  44. }
  45. var gSubscribeHandler *subscribeHandler = nil
  46. var wsUpgrader = websocket.Upgrader{
  47. ReadBufferSize: 1024,
  48. WriteBufferSize: 1024,
  49. CheckOrigin: func(_ *http.Request) bool { return true },
  50. }
  51. type wsClient struct {
  52. Conn *websocket.Conn
  53. Lock *sync.Mutex
  54. Sub *redis.PubSub
  55. }
  56. var wsClients = map[net.Addr]wsClient{}
  57. var wsClientsLock = sync.Mutex{}
  58. func checkAuth(req *http.Request) (string, error) {
  59. if authHeader, ok := req.Header["Authorization"]; ok {
  60. if len(authHeader) > 1 {
  61. log.Panic("too many headers!")
  62. }
  63. authComps := strings.Split(authHeader[0], " ")
  64. if len(authComps) != 2 || authComps[0] != "Basic" {
  65. return "", fmt.Errorf("bad authComps '%v'", authComps)
  66. }
  67. decAuthBytes, err := base64.StdEncoding.DecodeString(authComps[1])
  68. if err != nil {
  69. log.Println(authComps)
  70. return "", err
  71. }
  72. decComps := strings.Split(string(decAuthBytes), ":")
  73. if len(decComps) != 2 {
  74. return "", fmt.Errorf("bad decComps")
  75. }
  76. usersMapLock.Lock()
  77. defer usersMapLock.Unlock()
  78. if storedPwd, okUser := usersMap[decComps[0]]; okUser {
  79. if storedPwd == decComps[1] {
  80. return decComps[0], nil
  81. } else {
  82. return "", fmt.Errorf("bad pwd")
  83. }
  84. }
  85. return "", fmt.Errorf("bad user")
  86. }
  87. log.Println(req.Header)
  88. log.Println(req.Method)
  89. return "", fmt.Errorf("bad auth")
  90. }
  91. func (sh *subscribeHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
  92. w.Header().Set("Access-Control-Allow-Origin", "*")
  93. if req.Method == "OPTIONS" {
  94. w.Header().Set("Access-Control-Allow-Headers", "authorization")
  95. w.WriteHeader(http.StatusOK)
  96. return
  97. }
  98. authedUser, err := checkAuth(req)
  99. if err != nil {
  100. log.Printf("auth err: %v\n", err)
  101. w.WriteHeader(http.StatusBadRequest)
  102. return
  103. }
  104. dir, file := filepath.Split(req.URL.Path)
  105. var respStr string = ""
  106. if dir == "/sub/" {
  107. newSubID := uuid.New()
  108. newSub := subscriber{file, req.RemoteAddr, authedUser}
  109. sh.Lock.Lock()
  110. if sh.Pending == nil {
  111. sh.Pending = map[uuid.UUID]subscriber{}
  112. }
  113. sh.Pending[newSubID] = newSub
  114. sh.Lock.Unlock()
  115. log.Printf("new sub %v\n", newSub)
  116. respStr = newSubID.String()
  117. } else if dir == "/list/" && file != "" {
  118. query := req.URL.Query()
  119. listLookup := func(start int64, end int64) ([]string, error) {
  120. return redisDefaultClient.LRange(file, start, end).Result()
  121. }
  122. // allow plugins the opprotunity to handle the request before using the default handler
  123. // only one can handle any given request, so the first to do so affirmatively ends the request
  124. loadedPlugins.Lock.Lock()
  125. for name, plugin := range loadedPlugins.List {
  126. strResp, err := plugin.HandleListReq(dir, file, query, listLookup)
  127. if err == nil && len(strResp) > 0 {
  128. respStr = strResp
  129. log.Printf("using response for '%v%v' produced by plugin '%v'\n", dir, file, name)
  130. break
  131. }
  132. }
  133. loadedPlugins.Lock.Unlock()
  134. // default handler
  135. if respStr == "" {
  136. start := int64(0)
  137. end := int64(10)
  138. query := req.URL.Query()
  139. var err error = nil
  140. log.Printf("LIST -- %v -- %v\n", file, query)
  141. if startSpec, ok := query["start"]; ok {
  142. start, err = strconv.ParseInt(startSpec[0], 10, 64)
  143. }
  144. if err == nil && start >= 0 {
  145. if endSpec, ok := query["end"]; ok {
  146. end, err = strconv.ParseInt(endSpec[0], 10, 64)
  147. }
  148. if err == nil && end > start {
  149. lRes := redisDefaultClient.LRange(file, start, end)
  150. listRes, err := lRes.Result()
  151. if err == nil {
  152. listStr, err := json.Marshal(listRes)
  153. if err == nil {
  154. respStr = string(listStr)
  155. }
  156. }
  157. }
  158. }
  159. }
  160. }
  161. if respStr != "" {
  162. w.WriteHeader(http.StatusOK)
  163. fmt.Fprintf(w, respStr)
  164. } else {
  165. log.Printf("BAD REQ: %v\n", req)
  166. w.WriteHeader(http.StatusBadRequest)
  167. }
  168. }
  169. func readUntilClose(c *websocket.Conn) {
  170. log.Printf("readUntilClose for %v STARTING\n", c.RemoteAddr())
  171. for {
  172. if _, _, err := c.NextReader(); err != nil {
  173. remoteAddr := c.RemoteAddr()
  174. log.Printf("ws client %v disconnected\n", remoteAddr)
  175. c.Close()
  176. wsClientsLock.Lock()
  177. defer wsClientsLock.Unlock()
  178. wsClients[remoteAddr].Sub.Close()
  179. delete(wsClients, remoteAddr)
  180. break
  181. }
  182. }
  183. log.Printf("readUntilClose for %v EXITING\n", c.RemoteAddr())
  184. }
  185. func forwardAllOnto(wsc wsClient) {
  186. log.Printf("forAllOnto for %v STARTING\n", wsc.Conn.RemoteAddr())
  187. for fwd := range wsc.Sub.Channel() {
  188. payload := interface{}(fwd.Payload)
  189. var err error
  190. loadedPlugins.Lock.Lock()
  191. for pluginName, plugin := range loadedPlugins.List {
  192. payload, err = plugin.HandleMsg(payload)
  193. if err != nil {
  194. log.Panic(pluginName)
  195. }
  196. }
  197. loadedPlugins.Lock.Unlock()
  198. go func() {
  199. wsc.Lock.Lock()
  200. defer wsc.Lock.Unlock()
  201. wsc.Conn.WriteJSON(payload)
  202. }()
  203. }
  204. log.Printf("forAllOnto for %v EXITING\n", wsc.Conn.RemoteAddr())
  205. }
  206. func registerNewClient(wsConn *websocket.Conn, channel string) {
  207. clientAddr := wsConn.RemoteAddr()
  208. wsClientsLock.Lock()
  209. defer wsClientsLock.Unlock()
  210. if curClient, ok := wsClients[clientAddr]; ok {
  211. log.Printf("already have conn for %v! closing it\n", clientAddr)
  212. curClient.Lock.Lock()
  213. curClient.Conn.Close()
  214. curClient.Lock.Unlock()
  215. }
  216. wsClients[clientAddr] = wsClient{wsConn, new(sync.Mutex), redisDefaultClient.Subscribe(channel)}
  217. log.Printf("ws client %v connected\n", clientAddr)
  218. go readUntilClose(wsConn)
  219. go forwardAllOnto(wsClients[clientAddr])
  220. }
  221. func refreshHandler(w http.ResponseWriter, req *http.Request) {
  222. if req.Method != "OPTIONS" {
  223. return
  224. }
  225. if rhpAuthHeader, ok := req.Header["X-Rhp-Auth"]; ok {
  226. expectAuth := fmt.Sprintf("%d", time.Now().Unix()/10)
  227. if expectAuth == rhpAuthHeader[0] {
  228. log.Printf("valid refresh request from %v, running\n", req.RemoteAddr)
  229. loadUsers()
  230. }
  231. }
  232. }
  233. func websocketHandler(w http.ResponseWriter, req *http.Request) {
  234. wsConn, err := wsUpgrader.Upgrade(w, req, nil)
  235. if err != nil {
  236. log.Printf("websocketHandler upgrade failed: %v\n", err)
  237. return
  238. }
  239. okReqUUID, err := uuid.Parse(req.URL.RawQuery)
  240. if err != nil {
  241. log.Printf("bad ws query '%s'\n", req.URL.RawQuery)
  242. log.Println(req)
  243. return
  244. }
  245. gSubscribeHandler.Lock.Lock()
  246. defer gSubscribeHandler.Lock.Unlock()
  247. if pendingConn, ok := gSubscribeHandler.Pending[okReqUUID]; ok {
  248. if strings.Split(wsConn.RemoteAddr().String(), ":")[0] == strings.Split(pendingConn.Addr, ":")[0] {
  249. go registerNewClient(wsConn, pendingConn.Channel)
  250. delete(gSubscribeHandler.Pending, okReqUUID)
  251. } else {
  252. log.Printf("bad addr match %s vs %s\n", wsConn.RemoteAddr(), pendingConn.Addr)
  253. }
  254. } else {
  255. log.Printf("bad pending connection '%v'\n", okReqUUID)
  256. }
  257. }
  258. func parseJSON(path string, intoObj interface{}) error {
  259. file, err := os.Open(path)
  260. if err != nil {
  261. fmt.Fprintf(os.Stderr, "parseJSON unable to open '%s': %v\n", path, err)
  262. return err
  263. }
  264. defer file.Close()
  265. dec := json.NewDecoder(file)
  266. err = dec.Decode(intoObj)
  267. if err != nil {
  268. fmt.Fprintf(os.Stderr, "parseJSON failed to decode: %v\n", err)
  269. return err
  270. }
  271. return nil
  272. }
  273. func loadPlugin(path string) (*rhpPluginImpl, error) {
  274. ifaceType := reflect.TypeOf((*RhpPlugin)(nil)).Elem()
  275. pluginLoad, err := plugin.Open(path)
  276. if err != nil {
  277. return nil, err
  278. }
  279. // without stubbing the fields, reflect.ValueOf(...).Elem().FieldByName(...) below will return nil
  280. newPlugin := newRhpPluginImpl()
  281. // for each method declared in the interface, look for the same-named concrete defintion
  282. // in the loaded plugin. if that exists, find the field in the concrete implementation
  283. // instance (rhpPluginImpl) and set each function pointer accordingly
  284. for i := 0; i < ifaceType.NumMethod(); i++ {
  285. methodName := ifaceType.Method(i).Name
  286. pluginMethod, err := pluginLoad.Lookup(methodName)
  287. if err != nil {
  288. return nil, err
  289. }
  290. implValue := reflect.ValueOf(&newPlugin).Elem()
  291. if implValue.IsZero() {
  292. return nil, fmt.Errorf("unable to get value of concrete impl")
  293. }
  294. implElem := implValue.FieldByName(methodName)
  295. if implElem.IsZero() {
  296. return nil, fmt.Errorf("unable to set value on concrete impl")
  297. }
  298. // must .Convert to the target type (implElem.Interface()), else will panic with a strangely-worded error:
  299. // "reflect.Set: value of type T is not assignable to type T"
  300. // (not a typo: the 'from' and 'to' types in the error message will be exactly the same, because
  301. // indeed if we've made it this far the types will match, hence why .Convert() succeeds!)
  302. implElem.Set(reflect.ValueOf(pluginMethod).Convert(reflect.TypeOf(implElem.Interface())))
  303. }
  304. return &newPlugin, nil
  305. }
  306. func loadPlugins(fromPath string) (rhpPluginMapT, error) {
  307. retVal := rhpPluginMapT{}
  308. // we never return a non-nil error from within the walk function so as to allow .Walk() to continue;
  309. // there is the special return filepath.SkipDir, but it will cause Walk to skip remaining files,
  310. // which isn't what we want either. only in the case that `err` is already non-nil do we return non-nil.
  311. err := filepath.Walk(filepath.ToSlash(fromPath), func(path string, info os.FileInfo, err error) error {
  312. if err != nil {
  313. log.Printf("filepath.Walk errored on entry: '%s' -> %v", path, err)
  314. return err
  315. }
  316. if filepath.Ext(path) != ".so" {
  317. return nil
  318. }
  319. log.Printf("found %s, checking for compatibility...", filepath.Base(path))
  320. newPlugin, err := loadPlugin(path)
  321. if err != nil {
  322. log.Printf("failed to load %s: %v", path, err)
  323. return nil
  324. }
  325. pBaseName := strings.Replace(filepath.Base(path), ".so", "", 1)
  326. log.Printf("loaded compatible plugin %s@%s", pBaseName, newPlugin.Version())
  327. retVal[pBaseName] = newPlugin
  328. return nil
  329. })
  330. return retVal, err
  331. }
  332. func loadUsers() {
  333. usersMapLock.Lock()
  334. defer usersMapLock.Unlock()
  335. err := parseJSON(g_usersFile, &usersMap)
  336. if err != nil {
  337. log.Panic(err.Error())
  338. }
  339. log.Printf("found %d valid users\n", len(usersMap))
  340. }
  341. func main() {
  342. listenPort := flag.Uint("port", defaultListenPort, "http listen port")
  343. listenHost := flag.String("listen", defaultListenHost, "http listen host")
  344. redisPort := flag.Uint("redis-port", defaultRedisPort, "redis server port")
  345. redisHost := flag.String("redis-host", defaultRedisHost, "redis server host")
  346. pluginsPath := flag.String("plugins", defaultPluginsPath, "plugins path")
  347. usersFile := flag.String("users", defaultUsersFile, "users JSON file")
  348. flag.Parse()
  349. if listenHost == nil || listenPort == nil || *listenPort < 1024 || *listenPort > 65535 {
  350. log.Panic("listen spec")
  351. }
  352. if redisPort == nil || redisHost == nil || *redisPort < 0 || *redisPort > 65535 {
  353. log.Panic("redis spec")
  354. }
  355. redisAuth := os.Getenv("REDIS_LOCAL_PWD")
  356. if len(redisAuth) == 0 {
  357. log.Panic("Need auth")
  358. }
  359. redisOptions.Addr = fmt.Sprintf("%s:%d", *redisHost, *redisPort)
  360. redisOptions.Password = redisAuth
  361. rc := redis.NewClient(&redisOptions)
  362. _, err := rc.Ping().Result()
  363. if err != nil {
  364. log.Panic("Ping")
  365. }
  366. log.Printf("connected to redis://%s\n", redisOptions.Addr)
  367. redisDefaultClient = rc
  368. g_usersFile = *usersFile
  369. loadUsers()
  370. loadedPlugins.Lock.Lock()
  371. loadedPlugins.List, err = loadPlugins(*pluginsPath)
  372. loadedPlugins.Lock.Unlock()
  373. if err != nil {
  374. log.Fatalf("plugin load failed: %v", err)
  375. }
  376. gSubscribeHandler = new(subscribeHandler)
  377. http.Handle("/sub/", gSubscribeHandler)
  378. http.Handle("/list/", gSubscribeHandler)
  379. http.HandleFunc("/ws/sub", websocketHandler)
  380. http.HandleFunc("/refresh", refreshHandler)
  381. listenSpec := fmt.Sprintf("%s:%d", *listenHost, *listenPort)
  382. log.Printf("listening on %s\n", listenSpec)
  383. http.ListenAndServe(listenSpec, nil)
  384. }