diff --git a/api.go b/api.go index 402f14d..2477887 100644 --- a/api.go +++ b/api.go @@ -1,75 +1,75 @@ package main import ( - "net/http" - "encoding/json" + "encoding/json" + "net/http" - //"context" - "fmt" + //"context" + "fmt" - "github.com/go-chi/chi/v5" - "github.com/ggicci/httpin" + "github.com/ggicci/httpin" + "github.com/go-chi/chi/v5" "github.com/rs/zerolog/log" ) -const DEFAULT_RESULT_COUNT = 50; +const DEFAULT_RESULT_COUNT = 50 type GetTransactionPaginationInput struct { - ResultCount int `in:"query=result_count"` - PageNum int `in:"query=page_num"` + ResultCount int `in:"query=result_count"` + PageNum int `in:"query=page_num"` } func apiRouter() http.Handler { - r := chi.NewRouter() - //r.Use(ApiLoginRequired) - r.With( - httpin.NewInput(GetTransactionPaginationInput{}), - ).Get("/get_transactions", getTransactions) - r.Post("/new_transaction", newTransaction) - return r + r := chi.NewRouter() + //r.Use(ApiLoginRequired) + r.With( + httpin.NewInput(GetTransactionPaginationInput{}), + ).Get("/get_transactions", getTransactions) + r.Post("/new_transaction", newTransaction) + return r } func getTransactions(w http.ResponseWriter, req *http.Request) { - input := req.Context().Value(httpin.Input).(*GetTransactionPaginationInput) + input := req.Context().Value(httpin.Input).(*GetTransactionPaginationInput) - if input.ResultCount == 0 { - input.ResultCount = DEFAULT_RESULT_COUNT - } + if input.ResultCount == 0 { + input.ResultCount = DEFAULT_RESULT_COUNT + } - transactions := []Transaction{} + transactions := []Transaction{} - err := db_get_transactions(&transactions, input) + err := db_get_transactions(&transactions, input) - if err != nil { - - log.Fatal(). - Err(err). - Msg("Fatal error in getTransactions from db_get_transactions") - } + if err != nil { - for _, trns := range transactions { - //bytes, err := json.Marshal(trns) - bytes, err := json.MarshalIndent(trns, "", "\t") - if err != nil { - log.Fatal(). - Err(err). - Msg("Could not marshal json") - } - fmt.Fprintf(w, string(bytes)) - } + log.Fatal(). + Err(err). + Msg("Fatal error in getTransactions from db_get_transactions") + } + + for _, trns := range transactions { + //bytes, err := json.Marshal(trns) + bytes, err := json.MarshalIndent(trns, "", "\t") + if err != nil { + log.Fatal(). + Err(err). + Msg("Could not marshal json") + } + fmt.Fprintf(w, string(bytes)) + } } func newTransaction(w http.ResponseWriter, req *http.Request) { - decoder := json.NewDecoder(req.Body) - var t Transaction - err := decoder.Decode(&t) - if err != nil { - log.Fatal(). - Err(err). - Msg("Could not decode incoming post data") - } - //fmt.Fprintf(w, "New transaction created for Account: %d, with an Amount of: %s", - // t.Account, t.Amount) - db_new_transaction(t) + decoder := json.NewDecoder(req.Body) + var t Transaction + err := decoder.Decode(&t) + if err != nil { + log.Fatal(). + Err(err). + Msg("Could not decode incoming post data") + } + //fmt.Fprintf(w, "New transaction created for Account: %d, with an Amount of: %s", + // t.Account, t.Amount) + db_new_transaction(t) } diff --git a/db.go b/db.go index 6d40505..86eddd5 100644 --- a/db.go +++ b/db.go @@ -1,58 +1,58 @@ package main import ( - "fmt" + "fmt" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" - "github.com/jmoiron/sqlx" + "github.com/jmoiron/sqlx" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" "github.com/rs/zerolog/log" ) -func db_get_transactions(transactions *[]Transaction,r *GetTransactionPaginationInput) (error) { - db, err := sqlx.Connect(DB_TYPE, DB_CONNECTION_STRING) - if err != nil { - log.Fatal(). - Err(err). - Msg("Fatal error in db_get_transactions\nCannot connect to server") - } +func db_get_transactions(transactions *[]Transaction, r *GetTransactionPaginationInput) error { + db, err := sqlx.Connect(DB_TYPE, DB_CONNECTION_STRING) + if err != nil { + log.Fatal(). + Err(err). + Msg("Fatal error in db_get_transactions\nCannot connect to server") + } - defer db.Close() + defer db.Close() - err = db.Select(transactions, - "SELECT trns_id, trns_amount, trns_description, " + - "trns_account, trns_bucket, trns_date " + - fmt.Sprintf("FROM %stransactions ORDER BY trns_id DESC ", DB_SCHEMA) + - fmt.Sprintf("LIMIT %d OFFSET %d", - r.ResultCount, r.PageNum * r.ResultCount )) - if err != nil { - return err - } - return nil + err = db.Select(transactions, + "SELECT trns_id, trns_amount, trns_description, "+ + "trns_account, trns_bucket, trns_date "+ + fmt.Sprintf("FROM %stransactions ORDER BY trns_id DESC ", DB_SCHEMA)+ + fmt.Sprintf("LIMIT %d OFFSET %d", + r.ResultCount, r.PageNum*r.ResultCount)) + if err != nil { + return err + } + return nil } -func db_new_transaction(transaction Transaction) (error) { - db, err := sqlx.Connect(DB_TYPE, DB_CONNECTION_STRING) - if err != nil { - log.Info() - log.Fatal(). - Err(err). - Msg("Fatal error in db_get_transactions\nCannot connect to server") - } +func db_new_transaction(transaction Transaction) error { + db, err := sqlx.Connect(DB_TYPE, DB_CONNECTION_STRING) + if err != nil { + log.Info() + log.Fatal(). + Err(err). + Msg("Fatal error in db_get_transactions\nCannot connect to server") + } - defer db.Close() + defer db.Close() - log.Debug().Msgf("%#v", transaction) + log.Debug().Msgf("%#v", transaction) - _, err = db.NamedExec( - fmt.Sprintf("INSERT INTO %stransactions", DB_SCHEMA) + - "(trns_amount, trns_description, trns_account, trns_bucket, trns_date)" + - "VALUES (:trns_amount, :trns_description, :trns_account, :trns_bucket, :trns_date)", - transaction) - if err != nil { - log.Fatal(). - Err(err). - Msg("Could not exec insert db query") - } - return nil + _, err = db.NamedExec( + fmt.Sprintf("INSERT INTO %stransactions", DB_SCHEMA)+ + "(trns_amount, trns_description, trns_account, trns_bucket, trns_date)"+ + "VALUES (:trns_amount, :trns_description, :trns_account, :trns_bucket, :trns_date)", + transaction) + if err != nil { + log.Fatal(). + Err(err). + Msg("Could not exec insert db query") + } + return nil } diff --git a/main.go b/main.go index 7fc1c77..da3bde8 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "nickiel.net/recount_server/tests" + "nickiel.net/recount_server/web" "database/sql" "net/http" @@ -25,80 +26,79 @@ var DB_CONNECTION_STRING string = "user=rcntuser password=Devel@pmentPa$$w0rd ho // "json:"json_code_name,omitempty"" (omit empty) // if you use `json:"-"` it doesn't encode it type Transaction struct { - Id int `db:"trns_id" json:"Id"` - Amount string `db:"trns_amount" json:"Amount"` - Description sql.NullString `db:"trns_description" json:"Description"` - Account int `db:"trns_account" json:"Account"` - Bucket sql.NullInt64 `db:"trns_bucket" json:"Bucket"` - Date time.Time `db:"trns_date" json:"TransactionDate"` + Id int `db:"trns_id" json:"Id"` + Amount string `db:"trns_amount" json:"Amount"` + Description sql.NullString `db:"trns_description" json:"Description"` + Account int `db:"trns_account" json:"Account"` + Bucket sql.NullInt64 `db:"trns_bucket" json:"Bucket"` + Date time.Time `db:"trns_date" json:"TransactionDate"` } func hello(w http.ResponseWriter, req *http.Request) { - fmt.Fprintf(w, "hello\n") + fmt.Fprintf(w, "hello\n") } func headers(w http.ResponseWriter, req *http.Request) { - for name, headers := range req.Header { - for _, h := range headers { - fmt.Fprintf(w, "%v: %v\n", name, h) - } - } + for name, headers := range req.Header { + for _, h := range headers { + fmt.Fprintf(w, "%v: %v\n", name, h) + } + } } func main() { - zerolog.TimeFieldFormat = zerolog.TimeFormatUnix - zerolog.SetGlobalLevel(zerolog.InfoLevel) - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + zerolog.TimeFieldFormat = zerolog.TimeFormatUnix + zerolog.SetGlobalLevel(zerolog.InfoLevel) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) - var debugFlag = flag.Bool("d", false, "whether to enable debug mode") - var traceFlag = flag.Bool("t", false, "whether to trace logging") + var debugFlag = flag.Bool("d", false, "whether to enable debug mode") + var traceFlag = flag.Bool("t", false, "whether to trace logging") - flag.Parse() + flag.Parse() - if *traceFlag { - zerolog.SetGlobalLevel(zerolog.TraceLevel) - log.Debug().Msg("Enabling trace level debugging") - } + if *traceFlag { + zerolog.SetGlobalLevel(zerolog.TraceLevel) + log.Debug().Msg("Enabling trace level debugging") + } - if *debugFlag { - if !*traceFlag { - zerolog.SetGlobalLevel(zerolog.DebugLevel) - } - log.Debug().Msg("Is debugging") - DB_TYPE = "sqlite3" - DB_SCHEMA = "" - DB_CONNECTION_STRING = "test.db" - debug_mode.Init_testdb(DB_TYPE, DB_CONNECTION_STRING) - } + if *debugFlag { + if !*traceFlag { + zerolog.SetGlobalLevel(zerolog.DebugLevel) + } + log.Debug().Msg("Is debugging") + DB_TYPE = "sqlite3" + DB_SCHEMA = "" + DB_CONNECTION_STRING = "test.db" + debug_mode.Init_testdb(DB_TYPE, DB_CONNECTION_STRING) + } + debug_mode.SetLogLevel(zerolog.GlobalLevel()) - debug_mode.SetLogLevel(zerolog.GlobalLevel()) + log.Info().Msg("starting server") - log.Info().Msg("starting server") + r := chi.NewRouter() - r := chi.NewRouter() + // A good base middleware stack + r.Use(middleware.RequestID) + r.Use(middleware.RealIP) + r.Use(middleware.Recoverer) + r.Use(middleware.Logger) - // A good base middleware stack - r.Use(middleware.RequestID) - r.Use(middleware.RealIP) - r.Use(middleware.Recoverer) - r.Use(middleware.Logger) + // Set a timeout value on the request context (ctx), that will signal + // through ctx.Done() that the request has timed out and further + // processing should be stopped. + //r.Use(middleware.Timeout(60 * time.Second)) - // Set a timeout value on the request context (ctx), that will signal - // through ctx.Done() that the request has timed out and further - // processing should be stopped. - //r.Use(middleware.Timeout(60 * time.Second)) + r.Get("/headers", headers) + r.Mount("/", web.WebRouter()) + r.Mount("/api", apiRouter()) - r.Get("/", hello) - r.Get("/headers", headers) - r.Mount("/api", apiRouter()) + err := http.ListenAndServe(":8090", r) + if err != nil { + log.Fatal(). + Err(err). + Msg("Could not open server connection") + } - err := http.ListenAndServe(":8090", r) - if err != nil { - log.Fatal(). - Err(err). - Msg("Could not open server connection") - } - - //fmt.Println("Hello World") + //fmt.Println("Hello World") } diff --git a/tests/testdb.go b/tests/testdb.go index 67f9d94..237c6d5 100644 --- a/tests/testdb.go +++ b/tests/testdb.go @@ -1,9 +1,9 @@ package debug_mode import ( - "os" - "encoding/json" "database/sql" + "encoding/json" + "os" "time" "github.com/jmoiron/sqlx" @@ -13,50 +13,50 @@ import ( ) type Transaction struct { - Id int `db:"trns_id" json:"Id"` - Amount string `db:"trns_amount" json:"Amount"` - Description sql.NullString `db:"trns_description" json:"Description"` - Account int `db:"trns_account" json:"Account"` - Bucket sql.NullInt64 `db:"trns_bucket" json:"Bucket"` - Date time.Time `db:"trns_date" json:"TransactionDate"` + Id int `db:"trns_id" json:"Id"` + Amount string `db:"trns_amount" json:"Amount"` + Description sql.NullString `db:"trns_description" json:"Description"` + Account int `db:"trns_account" json:"Account"` + Bucket sql.NullInt64 `db:"trns_bucket" json:"Bucket"` + Date time.Time `db:"trns_date" json:"TransactionDate"` } func SetLogLevel(level zerolog.Level) { - zerolog.SetGlobalLevel(level) + zerolog.SetGlobalLevel(level) } func Init_testdb(DB_TYPE string, DB_CONNECTION_STRING string) { - cwd, err := os.Getwd() - if err != nil { - log.Fatal().Err(err).Msg("Could not get current working directory") - } else { - log.Trace().Msgf("Currect working directory is: %s", cwd) - } + cwd, err := os.Getwd() + if err != nil { + log.Fatal().Err(err).Msg("Could not get current working directory") + } else { + log.Trace().Msgf("Currect working directory is: %s", cwd) + } - _, err = os.Stat(cwd + DB_CONNECTION_STRING) - if err != nil { - log.Debug().Msg("Found existing test.db file. Attempting to delete") - err = os.Remove(DB_CONNECTION_STRING) - if err != nil { - log.Fatal().Err(err).Msg("Failed to delete testing db") - } else { - log.Debug().Msg("Deleted test.db file successfully") - } - } else { - log.Debug().Msg("No existing test.db file found") - } - - db, err := sqlx.Connect(DB_TYPE, DB_CONNECTION_STRING) - if err != nil { - log.Fatal(). - Err(err). - Msg("Couldn't open test db") - } + _, err = os.Stat(cwd + DB_CONNECTION_STRING) + if err != nil { + log.Debug().Msg("Found existing test.db file. Attempting to delete") + err = os.Remove(DB_CONNECTION_STRING) + if err != nil { + log.Fatal().Err(err).Msg("Failed to delete testing db") + } else { + log.Debug().Msg("Deleted test.db file successfully") + } + } else { + log.Debug().Msg("No existing test.db file found") + } - defer db.Close() + db, err := sqlx.Connect(DB_TYPE, DB_CONNECTION_STRING) + if err != nil { + log.Fatal(). + Err(err). + Msg("Couldn't open test db") + } - init_sql := ` + defer db.Close() + + init_sql := ` CREATE TABLE accounts ( acnt_id Integer PRIMARY KEY, acnt_dsply_name varchar(50) NOT NULL, @@ -116,18 +116,18 @@ INSERT INTO transactions (trns_amount, trns_description, trns_account, trns_buck ("50.00", "Money", 1, 1, "2023-11-10"); ` - tx := db.MustBegin() - tx.MustExec(init_sql) + tx := db.MustBegin() + tx.MustExec(init_sql) - err = tx.Commit() + err = tx.Commit() - if err != nil { - log.Fatal(). - Err(err). - Msg("Could not commit transaction") - } + if err != nil { + log.Fatal(). + Err(err). + Msg("Could not commit transaction") + } - jsonExample := `{ + jsonExample := `{ "Id": 3, "Amount": "100", "Description": { @@ -142,20 +142,18 @@ INSERT INTO transactions (trns_amount, trns_description, trns_account, trns_buck "TransactionDate": "2023-11-11T00:00:00Z" }` - var trns Transaction = Transaction{} - err = json.Unmarshal([]byte(jsonExample), &trns) + var trns Transaction = Transaction{} + err = json.Unmarshal([]byte(jsonExample), &trns) - if err != nil { - log.Fatal().Err(err).Msg("could not unmarshal") - } + if err != nil { + log.Fatal().Err(err).Msg("could not unmarshal") + } - _, err = db.NamedExec("INSERT INTO transactions" + - "(trns_amount, trns_description, trns_account, trns_bucket, trns_date)" + - "VALUES (:trns_amount, :trns_description, :trns_account, :trns_bucket, :trns_date)", - trns) - + _, err = db.NamedExec("INSERT INTO transactions"+ + "(trns_amount, trns_description, trns_account, trns_bucket, trns_date)"+ + "VALUES (:trns_amount, :trns_description, :trns_account, :trns_bucket, :trns_date)", + trns) - - log.Debug().Msg("Test database initialized") + log.Debug().Msg("Test database initialized") } diff --git a/web/router.go b/web/router.go new file mode 100644 index 0000000..5c4cd30 --- /dev/null +++ b/web/router.go @@ -0,0 +1,17 @@ +package web + +import ( + "net/http" + + "github.com/go-chi/chi/v5" +) + +func WebRouter() http.Handler { + r := chi.NewRouter() + r.Get("/", getIndex) + return r +} + +func getIndex(w http.ResponseWriter, req *http.Request) { + +}