Frederic G. MARAND 4 years ago
parent
commit
27bd4a91bc
3 changed files with 136 additions and 15 deletions
  1. 12 0
      .idea/runConfigurations/Test_part_8.xml
  2. 41 15
      part8/main.go
  3. 83 0
      part8/peers_test.go

+ 12 - 0
.idea/runConfigurations/Test_part_8.xml

@@ -0,0 +1,12 @@
+<component name="ProjectRunConfigurationManager">
+  <configuration default="false" name="Test part 8" type="GoTestRunConfiguration" factoryName="Go Test">
+    <module name="whispering_gophers" />
+    <working_directory value="$PROJECT_DIR$/part8" />
+    <framework value="gotest" />
+    <kind value="PACKAGE" />
+    <package value="code.osinet.fr/fgm/whispering_gophers/part8" />
+    <directory value="$PROJECT_DIR$/part7" />
+    <filePath value="$PROJECT_DIR$/" />
+    <method v="2" />
+  </configuration>
+</component>

+ 41 - 15
part8/main.go

@@ -19,12 +19,13 @@ import (
 	"os"
 	"sync"
 
-	"github.com/campoy/whispering-gophers/util"
+	"code.osinet.fr/fgm/whispering_gophers/util"
 )
 
 var (
-	peerAddr = flag.String("peer", "", "peer host:port")
-	self     string
+	listenAddr = flag.String("listen", "", "peer host:port")
+	peerAddr   = flag.String("peer", "", "peer host:port")
+	self       string
 )
 
 type Message struct {
@@ -35,7 +36,14 @@ type Message struct {
 func main() {
 	flag.Parse()
 
-	l, err := util.Listen()
+	var l net.Listener
+	var err error
+	// Create a new listener using util.Listen and put it in a variable named l.
+	if *listenAddr == "" {
+		l, err = util.ListenOnFirstUsableInterface()
+	} else {
+		l, err = net.Listen("tcp4", *listenAddr)
+	}
 	if err != nil {
 		log.Fatal(err)
 	}
@@ -54,13 +62,16 @@ func main() {
 	}
 }
 
-// TODO: create a global shared Peers instance
-
+// Create a global shared Peers instance
 type Peers struct {
 	m  map[string]chan<- Message
 	mu sync.RWMutex
 }
 
+var peers = Peers{
+	m: make(map[string]chan<- Message),
+}
+
 // Add creates and returns a new channel for the given peer address.
 // If an address already exists in the registry, it returns nil.
 func (p *Peers) Add(addr string) <-chan Message {
@@ -93,9 +104,15 @@ func (p *Peers) List() []chan<- Message {
 }
 
 func broadcast(m Message) {
-	for /* TODO: Range over the list of peers */ {
-		// TODO: Send a message to the channel, but don't block.
+	/* Range over the list of peers */
+	for i, peer := range peers.List() {
+		// Send a message to the channel, but don't block.
 		// Hint: Select is your friend.
+		select {
+		case peer <- m:
+		default:
+			log.Printf("Sending to peer %d would have blocked.\n", i)
+		}
 	}
 }
 
@@ -110,8 +127,8 @@ func serve(c net.Conn) {
 			return
 		}
 
-		// TODO: Launch dial in a new goroutine, to connect to the address in the message's Addr field.
-
+		// Launch dial in a new goroutine, to connect to the address in the message's Addr field.
+		go dial(m.Addr)
 		fmt.Printf("%#v\n", m)
 	}
 }
@@ -128,15 +145,24 @@ func readInput() {
 	if err := s.Err(); err != nil {
 		log.Fatal(err)
 	}
+	os.Exit(0)
 }
 
 func dial(addr string) {
-	// TODO: If dialing self, return.
+	// If dialing self, return.
+	if addr == self {
+		return
+	}
+
+	// Add the address to the peers map.
+	ch := peers.Add(addr)
+	// If you get a nil channel the peer is already connected, return.
+	if ch == nil {
+		return
+	}
 
-	// TODO: Add the address to the peers map.
-	// TODO: If you get a nil channel the peer is already connected, return.
-	// TODO: Remove the address from peers map when this function returns
-	//       (use defer).
+	// Remove the address from peers map when this function returns
+	defer peers.Remove(addr)
 
 	c, err := net.Dial("tcp", addr)
 	if err != nil {

+ 83 - 0
part8/peers_test.go

@@ -0,0 +1,83 @@
+package main
+
+import (
+	"testing"
+	"time"
+)
+
+func TestPeers(t *testing.T) {
+	peers := &Peers{m: make(map[string]chan<- Message)}
+	done := make(chan bool, 1)
+
+	var chA, chB <-chan Message
+	go func() {
+		defer func() { done <- true }()
+		if chA = peers.Add("a"); chA == nil {
+			t.Fatal(`peers.Add("a") returned nil, want channel`)
+		}
+	}()
+	go func() {
+		defer func() { done <- true }()
+		if chB = peers.Add("b"); chB == nil {
+			t.Fatal(`peers.Add("b") returned nil, want channel`)
+		}
+	}()
+	<-done
+	<-done
+	if chA == chB {
+		t.Fatal(`peers.Add("b") returned same channel as "a"!`)
+	}
+	if ch := peers.Add("a"); ch != nil {
+		t.Fatal(`second peers.Add("a") returned non-nil channel, want nil`)
+	}
+	if ch := peers.Add("b"); ch != nil {
+		t.Fatal(`second peers.Add("b") returned non-nil channel, want nil`)
+	}
+
+	list := peers.List()
+	if len(list) != 2 {
+		t.Fatalf("peers.List() returned a list of length %d, want 2", len(list))
+	}
+
+	go func() {
+		for _, ch := range list {
+			select {
+			case ch <- Message{Body: "foo"}:
+			case <-time.After(10 * time.Millisecond):
+			}
+		}
+		done <- true
+	}()
+	select {
+	case m := <-chA:
+		if m.Body != "foo" {
+			t.Fatalf("received message %q, want %q", m.Body, "foo")
+		}
+	case <-done:
+		t.Fatal(`didn't receive message on "a" channel`)
+	}
+	<-done
+
+	peers.Remove("a")
+
+	list = peers.List()
+	if len(list) != 1 {
+		t.Fatalf("peers.List() returned a list of length %d, want 1", len(list))
+	}
+
+	go func() {
+		select {
+		case list[0] <- Message{Body: "bar"}:
+		case <-time.After(10 * time.Millisecond):
+		}
+		done <- true
+	}()
+	select {
+	case m := <-chB:
+		if m.Body != "bar" {
+			t.Fatalf("received message %q, want %q", m.Body, "bar")
+		}
+	case <-done:
+		t.Fatal(`didn't receive message on "b" channel`)
+	}
+}