main.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. // Skeleton to part 9 of the Whispering Gophers code lab.
  2. //
  3. // This program extends part 8.
  4. //
  5. // It connects to the peer specified by -peer.
  6. // It accepts connections from peers and receives messages from them.
  7. // When it sees a peer with an address it hasn't seen before, it makes a
  8. // connection to that peer.
  9. // It adds an ID field containing a random string to each outgoing message.
  10. // When it recevies a message with an ID it hasn't seen before, it broadcasts
  11. // that message to all connected peers.
  12. //
  13. package main
  14. import (
  15. "bufio"
  16. "encoding/json"
  17. "flag"
  18. "fmt"
  19. "log"
  20. "net"
  21. "os"
  22. "sync"
  23. "github.com/campoy/whispering-gophers/util"
  24. )
  25. var (
  26. peerAddr = flag.String("peer", "", "peer host:port")
  27. self string
  28. )
  29. type Message struct {
  30. // TODO: add ID field
  31. Addr string
  32. Body string
  33. }
  34. func main() {
  35. flag.Parse()
  36. l, err := util.Listen()
  37. if err != nil {
  38. log.Fatal(err)
  39. }
  40. self = l.Addr().String()
  41. log.Println("Listening on", self)
  42. go dial(*peerAddr)
  43. go readInput()
  44. for {
  45. c, err := l.Accept()
  46. if err != nil {
  47. log.Fatal(err)
  48. }
  49. go serve(c)
  50. }
  51. }
  52. var peers = &Peers{m: make(map[string]chan<- Message)}
  53. type Peers struct {
  54. m map[string]chan<- Message
  55. mu sync.RWMutex
  56. }
  57. // Add creates and returns a new channel for the given peer address.
  58. // If an address already exists in the registry, it returns nil.
  59. func (p *Peers) Add(addr string) <-chan Message {
  60. p.mu.Lock()
  61. defer p.mu.Unlock()
  62. if _, ok := p.m[addr]; ok {
  63. return nil
  64. }
  65. ch := make(chan Message)
  66. p.m[addr] = ch
  67. return ch
  68. }
  69. // Remove deletes the specified peer from the registry.
  70. func (p *Peers) Remove(addr string) {
  71. p.mu.Lock()
  72. defer p.mu.Unlock()
  73. delete(p.m, addr)
  74. }
  75. // List returns a slice of all active peer channels.
  76. func (p *Peers) List() []chan<- Message {
  77. p.mu.RLock()
  78. defer p.mu.RUnlock()
  79. l := make([]chan<- Message, 0, len(p.m))
  80. for _, ch := range p.m {
  81. l = append(l, ch)
  82. }
  83. return l
  84. }
  85. func broadcast(m Message) {
  86. for _, ch := range peers.List() {
  87. select {
  88. case ch <- m:
  89. default:
  90. // Okay to drop messages sometimes.
  91. }
  92. }
  93. }
  94. func serve(c net.Conn) {
  95. defer c.Close()
  96. d := json.NewDecoder(c)
  97. for {
  98. var m Message
  99. err := d.Decode(&m)
  100. if err != nil {
  101. log.Println(err)
  102. return
  103. }
  104. // TODO: If this message has seen before, ignore it.
  105. fmt.Printf("%#v\n", m)
  106. broadcast(m)
  107. go dial(m.Addr)
  108. }
  109. }
  110. func readInput() {
  111. s := bufio.NewScanner(os.Stdin)
  112. for s.Scan() {
  113. m := Message{
  114. // TODO: use util.RandomID to populate the ID field.
  115. Addr: self,
  116. Body: s.Text(),
  117. }
  118. // TODO: Mark the message ID as seen.
  119. broadcast(m)
  120. }
  121. if err := s.Err(); err != nil {
  122. log.Fatal(err)
  123. }
  124. }
  125. func dial(addr string) {
  126. if addr == self {
  127. return // Don't try to dial self.
  128. }
  129. ch := peers.Add(addr)
  130. if ch == nil {
  131. return // Peer already connected.
  132. }
  133. defer peers.Remove(addr)
  134. c, err := net.Dial("tcp", addr)
  135. if err != nil {
  136. log.Println(addr, err)
  137. return
  138. }
  139. defer c.Close()
  140. e := json.NewEncoder(c)
  141. for m := range ch {
  142. err := e.Encode(m)
  143. if err != nil {
  144. log.Println(addr, err)
  145. return
  146. }
  147. }
  148. }
  149. // TODO: Create a new map of seen message IDs and a mutex to protect it.
  150. // Seen returns true if the specified id has been seen before.
  151. // If not, it returns false and marks the given id as "seen".
  152. func Seen(id string) bool {
  153. // TODO: Get a write lock on the seen message IDs map and unlock it at before returning.
  154. // TODO: Check if the id has been seen before and return that later.
  155. // TODO: Mark the ID as seen in the map.
  156. }