123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- /*
- Package main contains command "duplicates", which detects documents containing
- a duplicated field in a MongoDB collection.
- It takes its configuration from environment variables: refer to file `example.env`
- for a sample.
- (c) 2024 Ouest Systèmes Informatiques
- Licensed under the Apache 2.0 license.
- */
- package main
- import (
- "context"
- "flag"
- "fmt"
- "io"
- "log"
- "os"
- "slices"
- "go.mongodb.org/mongo-driver/bson"
- "go.mongodb.org/mongo-driver/mongo"
- "go.mongodb.org/mongo-driver/mongo/options"
- "gopkg.in/yaml.v3"
- )
- const (
- defaultMongoDBURI = "mongodb://localhost:27017"
- defaultDatabase = "test"
- defaultCollection = "test"
- defaultField = "email"
- defaultCommand = "check"
- seedCommand = "seed"
- )
- type conf struct {
- dbURI string
- client *mongo.Client
- dbName string
- collName string
- command string
- field string
- }
- func configure(ctx context.Context, name string, args []string) (*conf, error) {
- var (
- conf conf
- err error
- ok bool
- )
- if conf.dbURI, ok = os.LookupEnv("MONGODB_URI"); !ok {
- conf.dbURI = defaultMongoDBURI
- }
- if conf.dbName, ok = os.LookupEnv("MONGODB_DB"); !ok {
- conf.dbName = defaultDatabase
- }
- if conf.collName, ok = os.LookupEnv("MONGODB_COLLECTION"); !ok {
- conf.collName = defaultCollection
- }
- if conf.field, ok = os.LookupEnv("MONGODB_FIELD"); !ok {
- conf.field = defaultField
- }
- conf.client, err = mongo.Connect(ctx, options.Client().ApplyURI(conf.dbURI))
- if err != nil {
- return nil, fmt.Errorf("failed to connect to MongoDB: %v", err)
- }
- fs := flag.NewFlagSet(name, flag.ContinueOnError)
- fs.StringVar(&conf.command, "command", defaultCommand, "sub-command to run")
- if err := fs.Parse(args); err != nil {
- return nil, fmt.Errorf("failed to parse arguments: %v", err)
- }
- if !slices.Contains([]string{defaultCommand, seedCommand}, conf.command) {
- return nil, fmt.Errorf("unknown command %q", conf.command)
- }
- return &conf, nil
- }
- func user(n int) string {
- return fmt.Sprintf("user%d@example.com", n)
- }
- func seed(ctx context.Context, coll *mongo.Collection, field string) error {
- // 1. Ensure empty collection on startup.
- if err := coll.Drop(ctx); err != nil {
- return fmt.Errorf("seed/dropping collection: %w", err)
- }
- // 2. Insert non duplicate elements
- for i := range 5 {
- if _, err := coll.InsertOne(ctx, bson.D{{Key: field, Value: user(i)}}); err != nil {
- return fmt.Errorf("seed/inserting initial doc %d: %w", i, err)
- }
- }
- // 3. Insert duplicate elements: 3*1, 2*2
- for _, i := range []int{1, 1, 2} {
- if _, err := coll.InsertOne(ctx, bson.D{{Key: field, Value: user(i)}}); err != nil {
- return fmt.Errorf("seed/inserting duplicate doc %d: %w", i, err)
- }
- }
- return nil
- }
- func check(ctx context.Context, coll *mongo.Collection, field string) (map[string]int, error) {
- docs, err := coll.Distinct(ctx, field, bson.D{}, nil)
- dups := make(map[string]int)
- if err != nil {
- return nil, fmt.Errorf("check/distinct: %w", err)
- }
- for _, doc := range docs {
- n, err := coll.CountDocuments(ctx, bson.D{{Key: field, Value: doc}}, nil)
- if err != nil {
- return nil, fmt.Errorf("check/counting: %w", err)
- }
- if n > 1 {
- dups[doc.(string)] = int(n)
- }
- }
- return dups, nil
- }
- // testableMain is extracted for testability
- func testableMain(ctx context.Context, w io.Writer, logger *log.Logger, name string, args []string) (exit int) {
- config, err := configure(ctx, name, args)
- if err != nil {
- exit = 1
- logger.Println(err)
- return
- }
- defer func() {
- if err := config.client.Disconnect(ctx); err != nil {
- exit = 2
- logger.Println(err)
- }
- }()
- coll := config.client.Database(config.dbName).Collection(config.collName)
- if config.command == seedCommand {
- if err := seed(ctx, coll, config.field); err != nil {
- exit = 3
- logger.Println(err)
- return
- }
- }
- dups, err := check(ctx, coll, config.field)
- if err != nil {
- exit = 4
- logger.Println(err)
- return
- }
- if err := yaml.NewEncoder(w).Encode(dups); err != nil {
- exit = 5
- logger.Println(err)
- return
- }
- // Allow a non-zero exit in the deferred disconnect.
- exit = 0
- return
- }
- func main() {
- ctx := context.Background()
- logger := log.Default()
- name, args := os.Args[0], os.Args[1:]
- out := os.Stdout
- os.Exit(testableMain(ctx, out, logger, name, args))
- }
|