Skip to content

Commit

Permalink
cleanup oauth2
Browse files Browse the repository at this point in the history
  • Loading branch information
topi314 committed Jul 1, 2023
1 parent 97df606 commit 23abcc6
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 296 deletions.
143 changes: 79 additions & 64 deletions _examples/oauth2/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,105 +22,120 @@ var (
clientID = snowflake.GetEnv("client_id")
clientSecret = os.Getenv("client_secret")
baseURL = os.Getenv("base_url")
logger = log.Default()
httpClient = http.DefaultClient
client oauth2.Client
sessions map[string]oauth2.Session
)

func init() {
rand.Seed(time.Now().UnixNano())
}

func main() {
logger.SetLevel(log.LevelDebug)
logger.Info("starting example...")
logger.Infof("disgo %s", disgo.Version)
log.SetLevel(log.LevelDebug)
log.Info("starting example...")
log.Infof("disgo %s", disgo.Version)

client = oauth2.New(clientID, clientSecret, oauth2.WithLogger(logger), oauth2.WithRestClientConfigOpts(rest.WithHTTPClient(httpClient)))
s := &server{
client: oauth2.New(clientID, clientSecret,
oauth2.WithRestClientConfigOpts(
rest.WithHTTPClient(&http.Client{
Timeout: 5 * time.Second,
}),
),
),
sessions: map[string]oauth2.Session{},
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
}

mux := http.NewServeMux()
mux.HandleFunc("/", handleRoot)
mux.HandleFunc("/login", handleLogin)
mux.HandleFunc("/trylogin", handleTryLogin)
mux.HandleFunc("/", s.handleRoot)
mux.HandleFunc("/login", s.handleLogin)
mux.HandleFunc("/trylogin", s.handleTryLogin)
_ = http.ListenAndServe(":6969", mux)
}

func handleRoot(w http.ResponseWriter, r *http.Request) {
var body string
type server struct {
client *oauth2.Client
sessions map[string]oauth2.Session
rand *rand.Rand
}

func (s *server) handleRoot(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("token")
if err == nil {
session, ok := sessions[cookie.Value]
if ok {
var user *discord.OAuth2User
user, err = client.GetUser(session)
if err != nil {
writeError(w, "error while getting user data", err)
return
}
var userJSON []byte
userJSON, err = json.MarshalIndent(user, "<br />", "&ensp;")
if err != nil {
writeError(w, "error while formatting user data", err)
return
}

var connections []discord.Connection
connections, err = client.GetConnections(session)
if err != nil {
writeError(w, "error while getting connections data", err)
return
}
var connectionsJSON []byte
connectionsJSON, err = json.MarshalIndent(connections, "<br />", "&ensp;")
if err != nil {
writeError(w, "error while formatting connections data", err)
return
}
body = fmt.Sprintf("user:<br />%s<br />connections: <br />%s", userJSON, connectionsJSON)
}
if err != nil {
writeHTML(w, `<button><a href="/login">login</a></button>`)
}
if body == "" {
body = `<button><a href="/login">login</a></button>`
session, ok := s.sessions[cookie.Value]
if !ok {
writeHTML(w, `<button><a href="/login">login</a></button>`)
}
w.Header().Set("Content-Type", "text/html; charset=UTF-8")
w.WriteHeader(http.StatusOK)

_, _ = w.Write([]byte(body))
session, ok, err = s.client.VerifySession(session)
if err != nil {
writeError(w, "error while verifying or refresh session", err)
return
}
if ok {
s.sessions[cookie.Value] = session
}

user, err := s.client.GetUser(session)
if err != nil {
writeError(w, "error while getting user data", err)
return
}
connections, err := s.client.GetConnections(session)
if err != nil {
writeError(w, "error while getting connections data", err)
return
}

userJSON, err := json.MarshalIndent(user, "<br />", "&ensp;")
if err != nil {
writeError(w, "error while formatting user data", err)
return
}
connectionsJSON, err := json.MarshalIndent(connections, "<br />", "&ensp;")
if err != nil {
writeError(w, "error while formatting connections data", err)
return
}

writeHTML(w, fmt.Sprintf("user:<br />%s<br />connections: <br />%s", userJSON, connectionsJSON))
}

func handleLogin(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, client.GenerateAuthorizationURL(baseURL+"/trylogin", discord.PermissionsNone, 0, false, discord.OAuth2ScopeIdentify, discord.OAuth2ScopeGuilds, discord.OAuth2ScopeEmail, discord.OAuth2ScopeConnections, discord.OAuth2ScopeWebhookIncoming), http.StatusSeeOther)
func (s *server) handleLogin(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, s.client.GenerateAuthorizationURL(baseURL+"/trylogin", discord.PermissionsNone, 0, false, discord.OAuth2ScopeIdentify, discord.OAuth2ScopeGuilds, discord.OAuth2ScopeEmail, discord.OAuth2ScopeConnections, discord.OAuth2ScopeWebhookIncoming), http.StatusSeeOther)
}

func handleTryLogin(w http.ResponseWriter, r *http.Request) {
func (s *server) handleTryLogin(w http.ResponseWriter, r *http.Request) {
var (
query = r.URL.Query()
code = query.Get("code")
state = query.Get("state")
)
if code != "" && state != "" {
identifier := randStr(32)
session, _, err := client.StartSession(code, state)
identifier := s.randStr(32)
session, _, err := s.client.StartSession(code, state)
if err != nil {
writeError(w, "error while starting session", err)
return
}
sessions[identifier] = session
s.sessions[identifier] = session
http.SetCookie(w, &http.Cookie{Name: "token", Value: identifier})
}
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
}

func (s *server) randStr(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letters[s.rand.Intn(len(letters))]
}
return string(b)
}

func writeError(w http.ResponseWriter, text string, err error) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(text + ": " + err.Error()))
}

func randStr(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
func writeHTML(w http.ResponseWriter, text string) {
w.Header().Set("Content-Type", "text/html; charset=UTF-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(text))
}
14 changes: 7 additions & 7 deletions oauth2/ttl_map.go → internal/ttlmap/ttl_map.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package oauth2
package ttlmap

import (
"sync"
Expand All @@ -10,8 +10,8 @@ type value struct {
insertedAt int64
}

func newTTLMap(maxTTL time.Duration) *ttlMap {
m := &ttlMap{
func New(maxTTL time.Duration) *Map {
m := &Map{
maxTTL: maxTTL,
m: map[string]value{},
}
Expand All @@ -35,19 +35,19 @@ func newTTLMap(maxTTL time.Duration) *ttlMap {
return m
}

type ttlMap struct {
type Map struct {
maxTTL time.Duration
m map[string]value
mu sync.Mutex
}

func (m *ttlMap) put(k string, v string) {
func (m *Map) Put(k string, v string) {
m.mu.Lock()
m.m[k] = value{v, time.Now().Unix()}
m.mu.Unlock()
}

func (m *ttlMap) get(k string) string {
func (m *Map) Get(k string) string {
m.mu.Lock()
v, ok := m.m[k]
m.mu.Unlock()
Expand All @@ -57,7 +57,7 @@ func (m *ttlMap) get(k string) string {
return ""
}

func (m *ttlMap) delete(k string) {
func (m *Map) Delete(k string) {
m.mu.Lock()
delete(m.m, k)
m.mu.Unlock()
Expand Down
Loading

0 comments on commit 23abcc6

Please sign in to comment.