main.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. /*
  2. Package main contains command "duplicates", which detects documents containing
  3. a duplicated field in a MongoDB collection.
  4. It takes its configuration from environment variables: refer to file `example.env`
  5. for a sample.
  6. (c) 2024 Ouest Systèmes Informatiques
  7. Licensed under the Apache 2.0 license.
  8. */
  9. package main
  10. import (
  11. "context"
  12. "flag"
  13. "fmt"
  14. "io"
  15. "log"
  16. "os"
  17. "slices"
  18. "go.mongodb.org/mongo-driver/bson"
  19. "go.mongodb.org/mongo-driver/mongo"
  20. "go.mongodb.org/mongo-driver/mongo/options"
  21. "gopkg.in/yaml.v3"
  22. )
  23. const (
  24. defaultMongoDBURI = "mongodb://localhost:27017"
  25. defaultDatabase = "test"
  26. defaultCollection = "test"
  27. defaultField = "email"
  28. defaultCommand = "check"
  29. seedCommand = "seed"
  30. )
  31. type conf struct {
  32. dbURI string
  33. client *mongo.Client
  34. dbName string
  35. collName string
  36. command string
  37. field string
  38. }
  39. func configure(ctx context.Context, name string, args []string) (*conf, error) {
  40. var (
  41. conf conf
  42. err error
  43. ok bool
  44. )
  45. if conf.dbURI, ok = os.LookupEnv("MONGODB_URI"); !ok {
  46. conf.dbURI = defaultMongoDBURI
  47. }
  48. if conf.dbName, ok = os.LookupEnv("MONGODB_DB"); !ok {
  49. conf.dbName = defaultDatabase
  50. }
  51. if conf.collName, ok = os.LookupEnv("MONGODB_COLLECTION"); !ok {
  52. conf.collName = defaultCollection
  53. }
  54. if conf.field, ok = os.LookupEnv("MONGODB_FIELD"); !ok {
  55. conf.field = defaultField
  56. }
  57. conf.client, err = mongo.Connect(ctx, options.Client().ApplyURI(conf.dbURI))
  58. if err != nil {
  59. return nil, fmt.Errorf("failed to connect to MongoDB: %v", err)
  60. }
  61. fs := flag.NewFlagSet(name, flag.ContinueOnError)
  62. fs.StringVar(&conf.command, "command", defaultCommand, "sub-command to run")
  63. if err := fs.Parse(args); err != nil {
  64. return nil, fmt.Errorf("failed to parse arguments: %v", err)
  65. }
  66. if !slices.Contains([]string{defaultCommand, seedCommand}, conf.command) {
  67. return nil, fmt.Errorf("unknown command %q", conf.command)
  68. }
  69. return &conf, nil
  70. }
  71. func user(n int) string {
  72. return fmt.Sprintf("user%d@example.com", n)
  73. }
  74. func seed(ctx context.Context, coll *mongo.Collection, field string) error {
  75. // 1. Ensure empty collection on startup.
  76. if err := coll.Drop(ctx); err != nil {
  77. return fmt.Errorf("seed/dropping collection: %w", err)
  78. }
  79. // 2. Insert non duplicate elements
  80. for i := range 5 {
  81. if _, err := coll.InsertOne(ctx, bson.D{{Key: field, Value: user(i)}}); err != nil {
  82. return fmt.Errorf("seed/inserting initial doc %d: %w", i, err)
  83. }
  84. }
  85. // 3. Insert duplicate elements: 3*1, 2*2
  86. for _, i := range []int{1, 1, 2} {
  87. if _, err := coll.InsertOne(ctx, bson.D{{Key: field, Value: user(i)}}); err != nil {
  88. return fmt.Errorf("seed/inserting duplicate doc %d: %w", i, err)
  89. }
  90. }
  91. return nil
  92. }
  93. func check(ctx context.Context, coll *mongo.Collection, field string) (map[string]int, error) {
  94. docs, err := coll.Distinct(ctx, field, bson.D{}, nil)
  95. dups := make(map[string]int)
  96. if err != nil {
  97. return nil, fmt.Errorf("check/distinct: %w", err)
  98. }
  99. for _, doc := range docs {
  100. n, err := coll.CountDocuments(ctx, bson.D{{Key: field, Value: doc}}, nil)
  101. if err != nil {
  102. return nil, fmt.Errorf("check/counting: %w", err)
  103. }
  104. if n > 1 {
  105. dups[doc.(string)] = int(n)
  106. }
  107. }
  108. return dups, nil
  109. }
  110. // testableMain is extracted for testability
  111. func testableMain(ctx context.Context, w io.Writer, logger *log.Logger, name string, args []string) (exit int) {
  112. config, err := configure(ctx, name, args)
  113. if err != nil {
  114. exit = 1
  115. logger.Println(err)
  116. return
  117. }
  118. defer func() {
  119. if err := config.client.Disconnect(ctx); err != nil {
  120. exit = 2
  121. logger.Println(err)
  122. }
  123. }()
  124. coll := config.client.Database(config.dbName).Collection(config.collName)
  125. if config.command == seedCommand {
  126. if err := seed(ctx, coll, config.field); err != nil {
  127. exit = 3
  128. logger.Println(err)
  129. return
  130. }
  131. }
  132. dups, err := check(ctx, coll, config.field)
  133. if err != nil {
  134. exit = 4
  135. logger.Println(err)
  136. return
  137. }
  138. if err := yaml.NewEncoder(w).Encode(dups); err != nil {
  139. exit = 5
  140. logger.Println(err)
  141. return
  142. }
  143. // Allow a non-zero exit in the deferred disconnect.
  144. exit = 0
  145. return
  146. }
  147. func main() {
  148. ctx := context.Background()
  149. logger := log.Default()
  150. name, args := os.Args[0], os.Args[1:]
  151. out := os.Stdout
  152. os.Exit(testableMain(ctx, out, logger, name, args))
  153. }